From b28506af1f00a3582c2208da856b6b1e1646b300 Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Tue, 25 Aug 2020 01:25:07 -0400 Subject: [PATCH] Added extra locks --- pkg/drivers/cdp/dom/manager.go | 44 +++++++++++++++++++++++++++++ pkg/drivers/cdp/events/listeners.go | 10 +++---- pkg/drivers/cdp/events/loop.go | 25 ++++++++++++++-- pkg/drivers/cdp/events/sources.go | 10 +++---- pkg/drivers/cdp/page.go | 24 +++++++++------- 5 files changed, 90 insertions(+), 23 deletions(-) diff --git a/pkg/drivers/cdp/dom/manager.go b/pkg/drivers/cdp/dom/manager.go index 0a22b6cb..6c023e94 100644 --- a/pkg/drivers/cdp/dom/manager.go +++ b/pkg/drivers/cdp/dom/manager.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" "io" + "sync" ) var ( @@ -36,6 +37,7 @@ type ( ChildNodeRemovedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID) Manager struct { + mu sync.RWMutex logger *zerolog.Logger client *cdp.Client events *events.Loop @@ -180,6 +182,9 @@ func (m *Manager) Close() error { } func (m *Manager) GetMainFrame() *HTMLDocument { + m.mu.RLock() + defer m.mu.RUnlock() + mainFrameID := m.mainFrame.Get() if mainFrameID == "" { @@ -196,6 +201,9 @@ func (m *Manager) GetMainFrame() *HTMLDocument { } func (m *Manager) SetMainFrame(doc *HTMLDocument) { + m.mu.Lock() + defer m.mu.Unlock() + mainFrameID := m.mainFrame.Get() if mainFrameID != "" { @@ -210,18 +218,30 @@ func (m *Manager) SetMainFrame(doc *HTMLDocument) { } func (m *Manager) AddFrame(frame page.FrameTree) { + m.mu.RLock() + defer m.mu.RUnlock() + m.addFrameInternal(frame) } func (m *Manager) RemoveFrame(frameID page.FrameID) error { + m.mu.RLock() + defer m.mu.RUnlock() + return m.removeFrameInternal(frameID) } func (m *Manager) RemoveFrameRecursively(frameID page.FrameID) error { + m.mu.RLock() + defer m.mu.RUnlock() + return m.removeFrameRecursivelyInternal(frameID) } func (m *Manager) RemoveFramesByParentID(parentFrameID page.FrameID) error { + m.mu.RLock() + defer m.mu.RUnlock() + frame, found := m.frames.Get(parentFrameID) if !found { @@ -242,6 +262,9 @@ func (m *Manager) GetFrameNode(ctx context.Context, frameID page.FrameID) (*HTML } func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.FrameTree, error) { + m.mu.RLock() + defer m.mu.RUnlock() + frame, found := m.frames.Get(frameID) if !found { @@ -252,6 +275,9 @@ func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.Fr } func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) { + m.mu.RLock() + defer m.mu.RUnlock() + arr := values.NewArray(m.frames.Length()) for _, f := range m.frames.ToSlice() { @@ -268,6 +294,9 @@ func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) { } func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) events.ListenerID { + m.mu.RLock() + defer m.mu.RUnlock() + return m.events.AddListener(eventDocumentUpdated, func(ctx context.Context, _ interface{}) bool { listener(ctx) @@ -276,10 +305,16 @@ func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) e } func (m *Manager) RemoveReloadListener(listenerID events.ListenerID) { + m.mu.RLock() + defer m.mu.RUnlock() + m.events.RemoveListener(eventDocumentUpdated, listenerID) } func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListener) events.ListenerID { + m.mu.RLock() + defer m.mu.RUnlock() + return m.events.AddListener(eventChildNodeInserted, func(ctx context.Context, message interface{}) bool { reply := message.(*dom.ChildNodeInsertedReply) @@ -290,10 +325,16 @@ func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListene } func (m *Manager) RemoveChildNodeInsertedListener(listenerID events.ListenerID) { + m.mu.RLock() + defer m.mu.RUnlock() + m.events.RemoveListener(eventChildNodeInserted, listenerID) } func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener) events.ListenerID { + m.mu.RLock() + defer m.mu.RUnlock() + return m.events.AddListener(eventChildNodeRemoved, func(ctx context.Context, message interface{}) bool { reply := message.(*dom.ChildNodeRemovedReply) @@ -304,6 +345,9 @@ func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener) } func (m *Manager) RemoveChildNodeRemovedListener(listenerID events.ListenerID) { + m.mu.RLock() + defer m.mu.RUnlock() + m.events.RemoveListener(eventChildNodeRemoved, listenerID) } diff --git a/pkg/drivers/cdp/events/listeners.go b/pkg/drivers/cdp/events/listeners.go index ec19995d..b9b355e7 100644 --- a/pkg/drivers/cdp/events/listeners.go +++ b/pkg/drivers/cdp/events/listeners.go @@ -3,7 +3,7 @@ package events import "sync" type ListenerCollection struct { - mu sync.Mutex + mu sync.RWMutex values map[ID]map[ListenerID]Listener } @@ -15,8 +15,8 @@ func NewListenerCollection() *ListenerCollection { } func (lc *ListenerCollection) Size(eventID ID) int { - lc.mu.Lock() - defer lc.mu.Unlock() + lc.mu.RLock() + defer lc.mu.RUnlock() bucket, exists := lc.values[eventID] @@ -55,8 +55,8 @@ func (lc *ListenerCollection) Remove(eventID ID, listenerID ListenerID) { } func (lc *ListenerCollection) Values(eventID ID) []Listener { - lc.mu.Lock() - defer lc.mu.Unlock() + lc.mu.RLock() + defer lc.mu.RUnlock() bucket, exists := lc.values[eventID] diff --git a/pkg/drivers/cdp/events/loop.go b/pkg/drivers/cdp/events/loop.go index fed688ff..e94b4dea 100644 --- a/pkg/drivers/cdp/events/loop.go +++ b/pkg/drivers/cdp/events/loop.go @@ -3,9 +3,11 @@ package events import ( "context" "math/rand" + "sync" ) type Loop struct { + mu sync.RWMutex sources *SourceCollection listeners *ListenerCollection } @@ -23,14 +25,23 @@ func (loop *Loop) Run(ctx context.Context) { } func (loop *Loop) AddSource(source Source) { + loop.mu.RLock() + defer loop.mu.RUnlock() + loop.sources.Add(source) } func (loop *Loop) RemoveSource(source Source) { + loop.mu.RLock() + defer loop.mu.RUnlock() + loop.sources.Remove(source) } func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID { + loop.mu.RLock() + defer loop.mu.RUnlock() + listener := Listener{ ID: ListenerID(rand.Int()), EventID: eventID, @@ -43,6 +54,9 @@ func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID { } func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) { + loop.mu.RLock() + defer loop.mu.RUnlock() + loop.listeners.Remove(eventID, listenerID) } @@ -50,7 +64,8 @@ func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) { // It constantly iterates over each event source. // Additionally to that, on each iteration it checks the command channel in order to perform add/remove listener/source operations. func (loop *Loop) run(ctx context.Context) { - size := loop.sources.Size() + sources := loop.sources + size := sources.Size() counter := -1 // in case event array is empty @@ -62,14 +77,14 @@ func (loop *Loop) run(ctx context.Context) { if counter >= size { // reset the counter - size = loop.sources.Size() + size = sources.Size() counter = 0 } var source Source if size > 0 { - found, err := loop.sources.Get(counter) + found, err := sources.Get(counter) if err == nil { source = found @@ -106,7 +121,9 @@ func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err message = err } + loop.mu.RLock() snapshot := loop.listeners.Values(eventID) + loop.mu.RUnlock() for _, listener := range snapshot { select { @@ -115,7 +132,9 @@ func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err default: // if returned false, it means the loops should call the handler anymore if !listener.Handler(ctx, message) { + loop.mu.RLock() loop.listeners.Remove(eventID, listener.ID) + loop.mu.RUnlock() } } } diff --git a/pkg/drivers/cdp/events/sources.go b/pkg/drivers/cdp/events/sources.go index 12e4a890..f69a0f76 100644 --- a/pkg/drivers/cdp/events/sources.go +++ b/pkg/drivers/cdp/events/sources.go @@ -7,7 +7,7 @@ import ( ) type SourceCollection struct { - mu sync.Mutex + mu sync.RWMutex values []Source } @@ -44,15 +44,15 @@ func (sc *SourceCollection) Close() error { } func (sc *SourceCollection) Size() int { - sc.mu.Lock() - defer sc.mu.Unlock() + sc.mu.RLock() + defer sc.mu.RUnlock() return len(sc.values) } func (sc *SourceCollection) Get(idx int) (Source, error) { - sc.mu.Lock() - defer sc.mu.Unlock() + sc.mu.RLock() + defer sc.mu.RUnlock() if len(sc.values) <= idx { return nil, core.ErrNotFound diff --git a/pkg/drivers/cdp/page.go b/pkg/drivers/cdp/page.go index c0cb1aa7..f9d8e889 100644 --- a/pkg/drivers/cdp/page.go +++ b/pkg/drivers/cdp/page.go @@ -521,6 +521,9 @@ func (p *HTMLPage) NavigateForward(ctx context.Context, skip values.Int) (values } func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.String) error { + p.mu.Lock() + defer p.mu.Unlock() + pattern, err := p.urlToRegexp(targetURL) if err != nil { @@ -535,6 +538,9 @@ func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.Strin } func (p *HTMLPage) WaitForFrameNavigation(ctx context.Context, frame drivers.HTMLDocument, targetURL values.String) error { + p.mu.Lock() + defer p.mu.Unlock() + current := p.dom.GetMainFrame() doc, ok := frame.(*dom.HTMLDocument) @@ -562,10 +568,6 @@ func (p *HTMLPage) WaitForFrameNavigation(ctx context.Context, frame drivers.HTM return err } - //if isMain { - // - //} - return p.reloadMainFrame(ctx) } @@ -586,6 +588,12 @@ func (p *HTMLPage) urlToRegexp(targetURL values.String) (*regexp.Regexp, error) func (p *HTMLPage) reloadMainFrame(ctx context.Context) error { prev := p.dom.GetMainFrame() + if prev != nil { + if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil { + p.logger.Error().Err(err).Msg("failed to remove main frame") + } + } + next, err := dom.LoadRootHTMLDocument( ctx, p.logger, @@ -596,13 +604,9 @@ func (p *HTMLPage) reloadMainFrame(ctx context.Context) error { ) if err != nil { - return err - } + p.logger.Error().Err(err).Msg("failed to load a new root document") - if prev != nil { - if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil { - p.logger.Error().Err(err).Msg("failed to remove main frame") - } + return err } p.dom.SetMainFrame(next)