1
0
mirror of https://github.com/MontFerret/ferret.git synced 2024-12-14 11:23:02 +02:00

Added extra locks

This commit is contained in:
Tim Voronov 2020-08-25 01:25:07 -04:00
parent e283722d37
commit b28506af1f
5 changed files with 90 additions and 23 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"io" "io"
"sync"
) )
var ( var (
@ -36,6 +37,7 @@ type (
ChildNodeRemovedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID) ChildNodeRemovedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID)
Manager struct { Manager struct {
mu sync.RWMutex
logger *zerolog.Logger logger *zerolog.Logger
client *cdp.Client client *cdp.Client
events *events.Loop events *events.Loop
@ -180,6 +182,9 @@ func (m *Manager) Close() error {
} }
func (m *Manager) GetMainFrame() *HTMLDocument { func (m *Manager) GetMainFrame() *HTMLDocument {
m.mu.RLock()
defer m.mu.RUnlock()
mainFrameID := m.mainFrame.Get() mainFrameID := m.mainFrame.Get()
if mainFrameID == "" { if mainFrameID == "" {
@ -196,6 +201,9 @@ func (m *Manager) GetMainFrame() *HTMLDocument {
} }
func (m *Manager) SetMainFrame(doc *HTMLDocument) { func (m *Manager) SetMainFrame(doc *HTMLDocument) {
m.mu.Lock()
defer m.mu.Unlock()
mainFrameID := m.mainFrame.Get() mainFrameID := m.mainFrame.Get()
if mainFrameID != "" { if mainFrameID != "" {
@ -210,18 +218,30 @@ func (m *Manager) SetMainFrame(doc *HTMLDocument) {
} }
func (m *Manager) AddFrame(frame page.FrameTree) { func (m *Manager) AddFrame(frame page.FrameTree) {
m.mu.RLock()
defer m.mu.RUnlock()
m.addFrameInternal(frame) m.addFrameInternal(frame)
} }
func (m *Manager) RemoveFrame(frameID page.FrameID) error { func (m *Manager) RemoveFrame(frameID page.FrameID) error {
m.mu.RLock()
defer m.mu.RUnlock()
return m.removeFrameInternal(frameID) return m.removeFrameInternal(frameID)
} }
func (m *Manager) RemoveFrameRecursively(frameID page.FrameID) error { func (m *Manager) RemoveFrameRecursively(frameID page.FrameID) error {
m.mu.RLock()
defer m.mu.RUnlock()
return m.removeFrameRecursivelyInternal(frameID) return m.removeFrameRecursivelyInternal(frameID)
} }
func (m *Manager) RemoveFramesByParentID(parentFrameID page.FrameID) error { func (m *Manager) RemoveFramesByParentID(parentFrameID page.FrameID) error {
m.mu.RLock()
defer m.mu.RUnlock()
frame, found := m.frames.Get(parentFrameID) frame, found := m.frames.Get(parentFrameID)
if !found { if !found {
@ -242,6 +262,9 @@ func (m *Manager) GetFrameNode(ctx context.Context, frameID page.FrameID) (*HTML
} }
func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.FrameTree, error) { func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.FrameTree, error) {
m.mu.RLock()
defer m.mu.RUnlock()
frame, found := m.frames.Get(frameID) frame, found := m.frames.Get(frameID)
if !found { if !found {
@ -252,6 +275,9 @@ func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.Fr
} }
func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) { func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) {
m.mu.RLock()
defer m.mu.RUnlock()
arr := values.NewArray(m.frames.Length()) arr := values.NewArray(m.frames.Length())
for _, f := range m.frames.ToSlice() { for _, f := range m.frames.ToSlice() {
@ -268,6 +294,9 @@ func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) {
} }
func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) events.ListenerID { func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) events.ListenerID {
m.mu.RLock()
defer m.mu.RUnlock()
return m.events.AddListener(eventDocumentUpdated, func(ctx context.Context, _ interface{}) bool { return m.events.AddListener(eventDocumentUpdated, func(ctx context.Context, _ interface{}) bool {
listener(ctx) listener(ctx)
@ -276,10 +305,16 @@ func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) e
} }
func (m *Manager) RemoveReloadListener(listenerID events.ListenerID) { func (m *Manager) RemoveReloadListener(listenerID events.ListenerID) {
m.mu.RLock()
defer m.mu.RUnlock()
m.events.RemoveListener(eventDocumentUpdated, listenerID) m.events.RemoveListener(eventDocumentUpdated, listenerID)
} }
func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListener) events.ListenerID { func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListener) events.ListenerID {
m.mu.RLock()
defer m.mu.RUnlock()
return m.events.AddListener(eventChildNodeInserted, func(ctx context.Context, message interface{}) bool { return m.events.AddListener(eventChildNodeInserted, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.ChildNodeInsertedReply) reply := message.(*dom.ChildNodeInsertedReply)
@ -290,10 +325,16 @@ func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListene
} }
func (m *Manager) RemoveChildNodeInsertedListener(listenerID events.ListenerID) { func (m *Manager) RemoveChildNodeInsertedListener(listenerID events.ListenerID) {
m.mu.RLock()
defer m.mu.RUnlock()
m.events.RemoveListener(eventChildNodeInserted, listenerID) m.events.RemoveListener(eventChildNodeInserted, listenerID)
} }
func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener) events.ListenerID { func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener) events.ListenerID {
m.mu.RLock()
defer m.mu.RUnlock()
return m.events.AddListener(eventChildNodeRemoved, func(ctx context.Context, message interface{}) bool { return m.events.AddListener(eventChildNodeRemoved, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.ChildNodeRemovedReply) reply := message.(*dom.ChildNodeRemovedReply)
@ -304,6 +345,9 @@ func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener)
} }
func (m *Manager) RemoveChildNodeRemovedListener(listenerID events.ListenerID) { func (m *Manager) RemoveChildNodeRemovedListener(listenerID events.ListenerID) {
m.mu.RLock()
defer m.mu.RUnlock()
m.events.RemoveListener(eventChildNodeRemoved, listenerID) m.events.RemoveListener(eventChildNodeRemoved, listenerID)
} }

View File

@ -3,7 +3,7 @@ package events
import "sync" import "sync"
type ListenerCollection struct { type ListenerCollection struct {
mu sync.Mutex mu sync.RWMutex
values map[ID]map[ListenerID]Listener values map[ID]map[ListenerID]Listener
} }
@ -15,8 +15,8 @@ func NewListenerCollection() *ListenerCollection {
} }
func (lc *ListenerCollection) Size(eventID ID) int { func (lc *ListenerCollection) Size(eventID ID) int {
lc.mu.Lock() lc.mu.RLock()
defer lc.mu.Unlock() defer lc.mu.RUnlock()
bucket, exists := lc.values[eventID] bucket, exists := lc.values[eventID]
@ -55,8 +55,8 @@ func (lc *ListenerCollection) Remove(eventID ID, listenerID ListenerID) {
} }
func (lc *ListenerCollection) Values(eventID ID) []Listener { func (lc *ListenerCollection) Values(eventID ID) []Listener {
lc.mu.Lock() lc.mu.RLock()
defer lc.mu.Unlock() defer lc.mu.RUnlock()
bucket, exists := lc.values[eventID] bucket, exists := lc.values[eventID]

View File

@ -3,9 +3,11 @@ package events
import ( import (
"context" "context"
"math/rand" "math/rand"
"sync"
) )
type Loop struct { type Loop struct {
mu sync.RWMutex
sources *SourceCollection sources *SourceCollection
listeners *ListenerCollection listeners *ListenerCollection
} }
@ -23,14 +25,23 @@ func (loop *Loop) Run(ctx context.Context) {
} }
func (loop *Loop) AddSource(source Source) { func (loop *Loop) AddSource(source Source) {
loop.mu.RLock()
defer loop.mu.RUnlock()
loop.sources.Add(source) loop.sources.Add(source)
} }
func (loop *Loop) RemoveSource(source Source) { func (loop *Loop) RemoveSource(source Source) {
loop.mu.RLock()
defer loop.mu.RUnlock()
loop.sources.Remove(source) loop.sources.Remove(source)
} }
func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID { func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID {
loop.mu.RLock()
defer loop.mu.RUnlock()
listener := Listener{ listener := Listener{
ID: ListenerID(rand.Int()), ID: ListenerID(rand.Int()),
EventID: eventID, EventID: eventID,
@ -43,6 +54,9 @@ func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID {
} }
func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) { func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) {
loop.mu.RLock()
defer loop.mu.RUnlock()
loop.listeners.Remove(eventID, listenerID) loop.listeners.Remove(eventID, listenerID)
} }
@ -50,7 +64,8 @@ func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) {
// It constantly iterates over each event source. // It constantly iterates over each event source.
// Additionally to that, on each iteration it checks the command channel in order to perform add/remove listener/source operations. // Additionally to that, on each iteration it checks the command channel in order to perform add/remove listener/source operations.
func (loop *Loop) run(ctx context.Context) { func (loop *Loop) run(ctx context.Context) {
size := loop.sources.Size() sources := loop.sources
size := sources.Size()
counter := -1 counter := -1
// in case event array is empty // in case event array is empty
@ -62,14 +77,14 @@ func (loop *Loop) run(ctx context.Context) {
if counter >= size { if counter >= size {
// reset the counter // reset the counter
size = loop.sources.Size() size = sources.Size()
counter = 0 counter = 0
} }
var source Source var source Source
if size > 0 { if size > 0 {
found, err := loop.sources.Get(counter) found, err := sources.Get(counter)
if err == nil { if err == nil {
source = found source = found
@ -106,7 +121,9 @@ func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err
message = err message = err
} }
loop.mu.RLock()
snapshot := loop.listeners.Values(eventID) snapshot := loop.listeners.Values(eventID)
loop.mu.RUnlock()
for _, listener := range snapshot { for _, listener := range snapshot {
select { select {
@ -115,7 +132,9 @@ func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err
default: default:
// if returned false, it means the loops should call the handler anymore // if returned false, it means the loops should call the handler anymore
if !listener.Handler(ctx, message) { if !listener.Handler(ctx, message) {
loop.mu.RLock()
loop.listeners.Remove(eventID, listener.ID) loop.listeners.Remove(eventID, listener.ID)
loop.mu.RUnlock()
} }
} }
} }

View File

@ -7,7 +7,7 @@ import (
) )
type SourceCollection struct { type SourceCollection struct {
mu sync.Mutex mu sync.RWMutex
values []Source values []Source
} }
@ -44,15 +44,15 @@ func (sc *SourceCollection) Close() error {
} }
func (sc *SourceCollection) Size() int { func (sc *SourceCollection) Size() int {
sc.mu.Lock() sc.mu.RLock()
defer sc.mu.Unlock() defer sc.mu.RUnlock()
return len(sc.values) return len(sc.values)
} }
func (sc *SourceCollection) Get(idx int) (Source, error) { func (sc *SourceCollection) Get(idx int) (Source, error) {
sc.mu.Lock() sc.mu.RLock()
defer sc.mu.Unlock() defer sc.mu.RUnlock()
if len(sc.values) <= idx { if len(sc.values) <= idx {
return nil, core.ErrNotFound return nil, core.ErrNotFound

View File

@ -521,6 +521,9 @@ func (p *HTMLPage) NavigateForward(ctx context.Context, skip values.Int) (values
} }
func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.String) error { func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.String) error {
p.mu.Lock()
defer p.mu.Unlock()
pattern, err := p.urlToRegexp(targetURL) pattern, err := p.urlToRegexp(targetURL)
if err != nil { if err != nil {
@ -535,6 +538,9 @@ func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.Strin
} }
func (p *HTMLPage) WaitForFrameNavigation(ctx context.Context, frame drivers.HTMLDocument, targetURL values.String) error { func (p *HTMLPage) WaitForFrameNavigation(ctx context.Context, frame drivers.HTMLDocument, targetURL values.String) error {
p.mu.Lock()
defer p.mu.Unlock()
current := p.dom.GetMainFrame() current := p.dom.GetMainFrame()
doc, ok := frame.(*dom.HTMLDocument) doc, ok := frame.(*dom.HTMLDocument)
@ -562,10 +568,6 @@ func (p *HTMLPage) WaitForFrameNavigation(ctx context.Context, frame drivers.HTM
return err return err
} }
//if isMain {
//
//}
return p.reloadMainFrame(ctx) return p.reloadMainFrame(ctx)
} }
@ -586,6 +588,12 @@ func (p *HTMLPage) urlToRegexp(targetURL values.String) (*regexp.Regexp, error)
func (p *HTMLPage) reloadMainFrame(ctx context.Context) error { func (p *HTMLPage) reloadMainFrame(ctx context.Context) error {
prev := p.dom.GetMainFrame() prev := p.dom.GetMainFrame()
if prev != nil {
if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil {
p.logger.Error().Err(err).Msg("failed to remove main frame")
}
}
next, err := dom.LoadRootHTMLDocument( next, err := dom.LoadRootHTMLDocument(
ctx, ctx,
p.logger, p.logger,
@ -596,13 +604,9 @@ func (p *HTMLPage) reloadMainFrame(ctx context.Context) error {
) )
if err != nil { if err != nil {
return err p.logger.Error().Err(err).Msg("failed to load a new root document")
}
if prev != nil { return err
if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil {
p.logger.Error().Err(err).Msg("failed to remove main frame")
}
} }
p.dom.SetMainFrame(next) p.dom.SetMainFrame(next)