1
0
mirror of https://github.com/MontFerret/ferret.git synced 2024-12-14 11:23:02 +02:00

Added extra locks

This commit is contained in:
Tim Voronov 2020-08-25 01:25:07 -04:00
parent e283722d37
commit b28506af1f
5 changed files with 90 additions and 23 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/pkg/errors"
"github.com/rs/zerolog"
"io"
"sync"
)
var (
@ -36,6 +37,7 @@ type (
ChildNodeRemovedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID)
Manager struct {
mu sync.RWMutex
logger *zerolog.Logger
client *cdp.Client
events *events.Loop
@ -180,6 +182,9 @@ func (m *Manager) Close() error {
}
func (m *Manager) GetMainFrame() *HTMLDocument {
m.mu.RLock()
defer m.mu.RUnlock()
mainFrameID := m.mainFrame.Get()
if mainFrameID == "" {
@ -196,6 +201,9 @@ func (m *Manager) GetMainFrame() *HTMLDocument {
}
func (m *Manager) SetMainFrame(doc *HTMLDocument) {
m.mu.Lock()
defer m.mu.Unlock()
mainFrameID := m.mainFrame.Get()
if mainFrameID != "" {
@ -210,18 +218,30 @@ func (m *Manager) SetMainFrame(doc *HTMLDocument) {
}
func (m *Manager) AddFrame(frame page.FrameTree) {
m.mu.RLock()
defer m.mu.RUnlock()
m.addFrameInternal(frame)
}
func (m *Manager) RemoveFrame(frameID page.FrameID) error {
m.mu.RLock()
defer m.mu.RUnlock()
return m.removeFrameInternal(frameID)
}
func (m *Manager) RemoveFrameRecursively(frameID page.FrameID) error {
m.mu.RLock()
defer m.mu.RUnlock()
return m.removeFrameRecursivelyInternal(frameID)
}
func (m *Manager) RemoveFramesByParentID(parentFrameID page.FrameID) error {
m.mu.RLock()
defer m.mu.RUnlock()
frame, found := m.frames.Get(parentFrameID)
if !found {
@ -242,6 +262,9 @@ func (m *Manager) GetFrameNode(ctx context.Context, frameID page.FrameID) (*HTML
}
func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.FrameTree, error) {
m.mu.RLock()
defer m.mu.RUnlock()
frame, found := m.frames.Get(frameID)
if !found {
@ -252,6 +275,9 @@ func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.Fr
}
func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) {
m.mu.RLock()
defer m.mu.RUnlock()
arr := values.NewArray(m.frames.Length())
for _, f := range m.frames.ToSlice() {
@ -268,6 +294,9 @@ func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) {
}
func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) events.ListenerID {
m.mu.RLock()
defer m.mu.RUnlock()
return m.events.AddListener(eventDocumentUpdated, func(ctx context.Context, _ interface{}) bool {
listener(ctx)
@ -276,10 +305,16 @@ func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) e
}
func (m *Manager) RemoveReloadListener(listenerID events.ListenerID) {
m.mu.RLock()
defer m.mu.RUnlock()
m.events.RemoveListener(eventDocumentUpdated, listenerID)
}
func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListener) events.ListenerID {
m.mu.RLock()
defer m.mu.RUnlock()
return m.events.AddListener(eventChildNodeInserted, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.ChildNodeInsertedReply)
@ -290,10 +325,16 @@ func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListene
}
func (m *Manager) RemoveChildNodeInsertedListener(listenerID events.ListenerID) {
m.mu.RLock()
defer m.mu.RUnlock()
m.events.RemoveListener(eventChildNodeInserted, listenerID)
}
func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener) events.ListenerID {
m.mu.RLock()
defer m.mu.RUnlock()
return m.events.AddListener(eventChildNodeRemoved, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.ChildNodeRemovedReply)
@ -304,6 +345,9 @@ func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener)
}
func (m *Manager) RemoveChildNodeRemovedListener(listenerID events.ListenerID) {
m.mu.RLock()
defer m.mu.RUnlock()
m.events.RemoveListener(eventChildNodeRemoved, listenerID)
}

View File

@ -3,7 +3,7 @@ package events
import "sync"
type ListenerCollection struct {
mu sync.Mutex
mu sync.RWMutex
values map[ID]map[ListenerID]Listener
}
@ -15,8 +15,8 @@ func NewListenerCollection() *ListenerCollection {
}
func (lc *ListenerCollection) Size(eventID ID) int {
lc.mu.Lock()
defer lc.mu.Unlock()
lc.mu.RLock()
defer lc.mu.RUnlock()
bucket, exists := lc.values[eventID]
@ -55,8 +55,8 @@ func (lc *ListenerCollection) Remove(eventID ID, listenerID ListenerID) {
}
func (lc *ListenerCollection) Values(eventID ID) []Listener {
lc.mu.Lock()
defer lc.mu.Unlock()
lc.mu.RLock()
defer lc.mu.RUnlock()
bucket, exists := lc.values[eventID]

View File

@ -3,9 +3,11 @@ package events
import (
"context"
"math/rand"
"sync"
)
type Loop struct {
mu sync.RWMutex
sources *SourceCollection
listeners *ListenerCollection
}
@ -23,14 +25,23 @@ func (loop *Loop) Run(ctx context.Context) {
}
func (loop *Loop) AddSource(source Source) {
loop.mu.RLock()
defer loop.mu.RUnlock()
loop.sources.Add(source)
}
func (loop *Loop) RemoveSource(source Source) {
loop.mu.RLock()
defer loop.mu.RUnlock()
loop.sources.Remove(source)
}
func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID {
loop.mu.RLock()
defer loop.mu.RUnlock()
listener := Listener{
ID: ListenerID(rand.Int()),
EventID: eventID,
@ -43,6 +54,9 @@ func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID {
}
func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) {
loop.mu.RLock()
defer loop.mu.RUnlock()
loop.listeners.Remove(eventID, listenerID)
}
@ -50,7 +64,8 @@ func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) {
// It constantly iterates over each event source.
// Additionally to that, on each iteration it checks the command channel in order to perform add/remove listener/source operations.
func (loop *Loop) run(ctx context.Context) {
size := loop.sources.Size()
sources := loop.sources
size := sources.Size()
counter := -1
// in case event array is empty
@ -62,14 +77,14 @@ func (loop *Loop) run(ctx context.Context) {
if counter >= size {
// reset the counter
size = loop.sources.Size()
size = sources.Size()
counter = 0
}
var source Source
if size > 0 {
found, err := loop.sources.Get(counter)
found, err := sources.Get(counter)
if err == nil {
source = found
@ -106,7 +121,9 @@ func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err
message = err
}
loop.mu.RLock()
snapshot := loop.listeners.Values(eventID)
loop.mu.RUnlock()
for _, listener := range snapshot {
select {
@ -115,7 +132,9 @@ func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err
default:
// if returned false, it means the loops should call the handler anymore
if !listener.Handler(ctx, message) {
loop.mu.RLock()
loop.listeners.Remove(eventID, listener.ID)
loop.mu.RUnlock()
}
}
}

View File

@ -7,7 +7,7 @@ import (
)
type SourceCollection struct {
mu sync.Mutex
mu sync.RWMutex
values []Source
}
@ -44,15 +44,15 @@ func (sc *SourceCollection) Close() error {
}
func (sc *SourceCollection) Size() int {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.mu.RLock()
defer sc.mu.RUnlock()
return len(sc.values)
}
func (sc *SourceCollection) Get(idx int) (Source, error) {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.mu.RLock()
defer sc.mu.RUnlock()
if len(sc.values) <= idx {
return nil, core.ErrNotFound

View File

@ -521,6 +521,9 @@ func (p *HTMLPage) NavigateForward(ctx context.Context, skip values.Int) (values
}
func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.String) error {
p.mu.Lock()
defer p.mu.Unlock()
pattern, err := p.urlToRegexp(targetURL)
if err != nil {
@ -535,6 +538,9 @@ func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.Strin
}
func (p *HTMLPage) WaitForFrameNavigation(ctx context.Context, frame drivers.HTMLDocument, targetURL values.String) error {
p.mu.Lock()
defer p.mu.Unlock()
current := p.dom.GetMainFrame()
doc, ok := frame.(*dom.HTMLDocument)
@ -562,10 +568,6 @@ func (p *HTMLPage) WaitForFrameNavigation(ctx context.Context, frame drivers.HTM
return err
}
//if isMain {
//
//}
return p.reloadMainFrame(ctx)
}
@ -586,6 +588,12 @@ func (p *HTMLPage) urlToRegexp(targetURL values.String) (*regexp.Regexp, error)
func (p *HTMLPage) reloadMainFrame(ctx context.Context) error {
prev := p.dom.GetMainFrame()
if prev != nil {
if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil {
p.logger.Error().Err(err).Msg("failed to remove main frame")
}
}
next, err := dom.LoadRootHTMLDocument(
ctx,
p.logger,
@ -596,13 +604,9 @@ func (p *HTMLPage) reloadMainFrame(ctx context.Context) error {
)
if err != nil {
return err
}
p.logger.Error().Err(err).Msg("failed to load a new root document")
if prev != nil {
if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil {
p.logger.Error().Err(err).Msg("failed to remove main frame")
}
return err
}
p.dom.SetMainFrame(next)