From fe7b45df6e5069a6e3b3ce4f15b2129d8b591b07 Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Tue, 24 Dec 2019 18:47:21 -0500 Subject: [PATCH] Bugfix/#399 navigation (#432) * Refactored networking * Some work * Added event loop * Renamed EventHandler to Handler * wip * Removed console logs * Added DOMManager * Refactored frame managment * Fixes * Fixed concurrency issues * Fixed unit tests * Improved EventLoop api * Some fixes * Refactored event loop. * Improved logic of initial page load * Cleaned up * Fixed linting issues * Fixed dom.Manager.Close * SOme works * Fixes * Removed fmt.Println statements * Refactored WaitForNavigation * Removed filter for e2e tests * Made Cookies Measurable * Made Cookies KeyedCollection * Fixes after code review * Updated e2e tests for iframes * Fixed iframe lookup in e2e tests * Added comments --- .travis.yml | 4 +- .../dynamic/doc/iframes/element_exists.fql | 9 +- e2e/tests/dynamic/doc/iframes/hover.fql | 8 +- e2e/tests/dynamic/doc/iframes/input.fql | 9 +- e2e/tests/dynamic/doc/iframes/wait_class.fql | 8 +- examples/redirects.fql | 13 + pkg/drivers/cdp/document_test.go | 1 - pkg/drivers/cdp/{ => dom}/document.go | 165 ++---- pkg/drivers/cdp/dom/document_test.go | 1 + pkg/drivers/cdp/{ => dom}/element.go | 160 +++--- pkg/drivers/cdp/dom/helpers.go | 287 ++++++++++ pkg/drivers/cdp/dom/manager.go | 511 +++++++++++++++++ pkg/drivers/cdp/events/broker.go | 289 ---------- pkg/drivers/cdp/events/broker_test.go | 322 ----------- pkg/drivers/cdp/events/helpers.go | 122 +--- pkg/drivers/cdp/events/listener.go | 38 ++ pkg/drivers/cdp/events/listeners.go | 74 +++ pkg/drivers/cdp/events/loop.go | 165 ++++++ pkg/drivers/cdp/events/loop_test.go | 426 ++++++++++++++ pkg/drivers/cdp/events/noop.go | 31 + pkg/drivers/cdp/events/source.go | 74 +++ pkg/drivers/cdp/events/sources.go | 82 +++ pkg/drivers/cdp/helpers.go | 392 +++---------- pkg/drivers/cdp/network/events.go | 7 + pkg/drivers/cdp/network/helpers.go | 83 +++ pkg/drivers/cdp/network/manager.go | 340 +++++++++++ pkg/drivers/cdp/page.go | 532 +++++------------- pkg/drivers/common/errors.go | 14 +- pkg/drivers/common/getter.go | 27 +- pkg/drivers/cookie.go | 2 - pkg/drivers/cookies.go | 183 ++++++ pkg/drivers/http/document.go | 4 +- pkg/drivers/http/page.go | 22 +- pkg/drivers/type.go | 8 +- pkg/drivers/value.go | 14 +- pkg/runtime/core/errors.go | 4 +- pkg/runtime/program.go | 3 +- pkg/stdlib/html/cookie_del.go | 17 +- pkg/stdlib/html/cookie_get.go | 14 +- pkg/stdlib/html/cookie_set.go | 6 +- pkg/stdlib/html/wait_navigation.go | 63 ++- 41 files changed, 2829 insertions(+), 1705 deletions(-) create mode 100644 examples/redirects.fql delete mode 100644 pkg/drivers/cdp/document_test.go rename pkg/drivers/cdp/{ => dom}/document.go (76%) create mode 100644 pkg/drivers/cdp/dom/document_test.go rename pkg/drivers/cdp/{ => dom}/element.go (91%) create mode 100644 pkg/drivers/cdp/dom/helpers.go create mode 100644 pkg/drivers/cdp/dom/manager.go delete mode 100644 pkg/drivers/cdp/events/broker.go delete mode 100644 pkg/drivers/cdp/events/broker_test.go create mode 100644 pkg/drivers/cdp/events/listener.go create mode 100644 pkg/drivers/cdp/events/listeners.go create mode 100644 pkg/drivers/cdp/events/loop.go create mode 100644 pkg/drivers/cdp/events/loop_test.go create mode 100644 pkg/drivers/cdp/events/noop.go create mode 100644 pkg/drivers/cdp/events/source.go create mode 100644 pkg/drivers/cdp/events/sources.go create mode 100644 pkg/drivers/cdp/network/events.go create mode 100644 pkg/drivers/cdp/network/helpers.go create mode 100644 pkg/drivers/cdp/network/manager.go create mode 100644 pkg/drivers/cookies.go diff --git a/.travis.yml b/.travis.yml index 3c4b098a..cf1b1652 100644 --- a/.travis.yml +++ b/.travis.yml @@ -58,8 +58,8 @@ jobs: - stage: e2e go: stable before_script: - - docker pull microbox/chromium-headless:75.0.3765.1 - - docker run -d -p 9222:9222 microbox/chromium-headless:75.0.3765.1 + - docker pull microbox/chromium-headless:77.0.3844.0 + - docker run -d -p 9222:9222 microbox/chromium-headless:77.0.3844.0 - docker ps script: - make e2e diff --git a/e2e/tests/dynamic/doc/iframes/element_exists.fql b/e2e/tests/dynamic/doc/iframes/element_exists.fql index 0cdeee07..6763e74a 100644 --- a/e2e/tests/dynamic/doc/iframes/element_exists.fql +++ b/e2e/tests/dynamic/doc/iframes/element_exists.fql @@ -1,7 +1,14 @@ LET url = @dynamic + "?redirect=/iframe" LET page = DOCUMENT(url, { driver: 'cdp' }) -LET doc = page.frames[1] +LET frames = ( + FOR f IN page.frames + FILTER f.name == "nested" + LIMIT 1 + RETURN f +) + +LET doc = FIRST(frames) LET expectedP = TRUE LET actualP = ELEMENT_EXISTS(doc, '.text-center') diff --git a/e2e/tests/dynamic/doc/iframes/hover.fql b/e2e/tests/dynamic/doc/iframes/hover.fql index 1e5e83c4..9f71b6b3 100644 --- a/e2e/tests/dynamic/doc/iframes/hover.fql +++ b/e2e/tests/dynamic/doc/iframes/hover.fql @@ -1,6 +1,12 @@ LET url = @dynamic + "?redirect=/iframe&src=/events" LET page = DOCUMENT(url, { driver: 'cdp' }) -LET doc = page.frames[1] +LET frames = ( + FOR f IN page.frames + FILTER f.name == "nested" + LIMIT 1 + RETURN f +) +LET doc = FIRST(frames) WAIT_ELEMENT(doc, "#page-events") diff --git a/e2e/tests/dynamic/doc/iframes/input.fql b/e2e/tests/dynamic/doc/iframes/input.fql index ab0bd989..f4004bcd 100644 --- a/e2e/tests/dynamic/doc/iframes/input.fql +++ b/e2e/tests/dynamic/doc/iframes/input.fql @@ -1,6 +1,13 @@ LET url = @dynamic + "?redirect=/iframe&src=/forms" LET page = DOCUMENT(url, true) -LET doc = page.frames[1] + +LET frames = ( + FOR f IN page.frames + FILTER f.name == "nested" + LIMIT 1 + RETURN f +) +LET doc = FIRST(frames) WAIT_ELEMENT(doc, "form") diff --git a/e2e/tests/dynamic/doc/iframes/wait_class.fql b/e2e/tests/dynamic/doc/iframes/wait_class.fql index 38d8e84f..97d152f1 100644 --- a/e2e/tests/dynamic/doc/iframes/wait_class.fql +++ b/e2e/tests/dynamic/doc/iframes/wait_class.fql @@ -1,6 +1,12 @@ LET url = @dynamic + "?redirect=/iframe&src=/events" LET page = DOCUMENT(url, true) -LET doc = page.frames[1] +LET frames = ( + FOR f IN page.frames + FILTER f.name == "nested" + LIMIT 1 + RETURN f +) +LET doc = FIRST(frames) WAIT_ELEMENT(doc, "#page-events") diff --git a/examples/redirects.fql b/examples/redirects.fql new file mode 100644 index 00000000..96ebbfbb --- /dev/null +++ b/examples/redirects.fql @@ -0,0 +1,13 @@ +LET doc = DOCUMENT("http://waos.ovh/redirect.html", { + driver: 'cdp', + viewport: { + width: 1920, + height: 1080 + } +}) + +CLICK(doc, '.click') + +WAIT_NAVIGATION(doc) + +RETURN ELEMENT(doc, '.title') \ No newline at end of file diff --git a/pkg/drivers/cdp/document_test.go b/pkg/drivers/cdp/document_test.go deleted file mode 100644 index d0d13314..00000000 --- a/pkg/drivers/cdp/document_test.go +++ /dev/null @@ -1 +0,0 @@ -package cdp diff --git a/pkg/drivers/cdp/document.go b/pkg/drivers/cdp/dom/document.go similarity index 76% rename from pkg/drivers/cdp/document.go rename to pkg/drivers/cdp/dom/document.go index c915d6b9..bf82ced1 100644 --- a/pkg/drivers/cdp/document.go +++ b/pkg/drivers/cdp/dom/document.go @@ -1,4 +1,4 @@ -package cdp +package dom import ( "context" @@ -23,22 +23,20 @@ import ( ) type HTMLDocument struct { - logger *zerolog.Logger - client *cdp.Client - events *events.EventBroker - input *input.Manager - exec *eval.ExecutionContext - frames page.FrameTree - element *HTMLElement - parent *HTMLDocument - children *common.LazyValue + logger *zerolog.Logger + client *cdp.Client + dom *Manager + input *input.Manager + exec *eval.ExecutionContext + frameTree page.FrameTree + element *HTMLElement } func LoadRootHTMLDocument( ctx context.Context, logger *zerolog.Logger, client *cdp.Client, - events *events.EventBroker, + domManager *Manager, mouse *input.Mouse, keyboard *input.Keyboard, ) (*HTMLDocument, error) { @@ -64,13 +62,12 @@ func LoadRootHTMLDocument( ctx, logger, client, - events, + domManager, mouse, keyboard, gdRepl.Root, ftRepl.FrameTree, worldRepl.ExecutionContextID, - nil, ) } @@ -78,22 +75,21 @@ func LoadHTMLDocument( ctx context.Context, logger *zerolog.Logger, client *cdp.Client, - events *events.EventBroker, + domManager *Manager, mouse *input.Mouse, keyboard *input.Keyboard, node dom.Node, - tree page.FrameTree, + frameTree page.FrameTree, execID runtime.ExecutionContextID, - parent *HTMLDocument, ) (*HTMLDocument, error) { - exec := eval.NewExecutionContext(client, tree.Frame, execID) + exec := eval.NewExecutionContext(client, frameTree.Frame, execID) inputManager := input.NewManager(client, exec, keyboard, mouse) rootElement, err := LoadHTMLElement( ctx, logger, client, - events, + domManager, inputManager, exec, node.NodeID, @@ -106,35 +102,31 @@ func LoadHTMLDocument( return NewHTMLDocument( logger, client, - events, + domManager, inputManager, exec, rootElement, - tree, - parent, + frameTree, ), nil } func NewHTMLDocument( logger *zerolog.Logger, client *cdp.Client, - events *events.EventBroker, + domManager *Manager, input *input.Manager, exec *eval.ExecutionContext, rootElement *HTMLElement, frames page.FrameTree, - parent *HTMLDocument, ) *HTMLDocument { doc := new(HTMLDocument) doc.logger = logger doc.client = client - doc.events = events + doc.dom = domManager doc.input = input doc.exec = exec doc.element = rootElement - doc.frames = frames - doc.parent = parent - doc.children = common.NewLazyValue(doc.loadChildren) + doc.frameTree = frames return doc } @@ -148,7 +140,7 @@ func (doc *HTMLDocument) Type() core.Type { } func (doc *HTMLDocument) String() string { - return doc.frames.Frame.URL + return doc.frameTree.Frame.URL } func (doc *HTMLDocument) Unwrap() interface{} { @@ -160,8 +152,8 @@ func (doc *HTMLDocument) Hash() uint64 { h.Write([]byte(doc.Type().String())) h.Write([]byte(":")) - h.Write([]byte(doc.frames.Frame.ID)) - h.Write([]byte(doc.frames.Frame.URL)) + h.Write([]byte(doc.frameTree.Frame.ID)) + h.Write([]byte(doc.frameTree.Frame.URL)) return h.Sum64() } @@ -175,7 +167,7 @@ func (doc *HTMLDocument) Compare(other core.Value) int64 { case drivers.HTMLDocumentType: other := other.(drivers.HTMLDocument) - return values.NewString(doc.frames.Frame.URL).Compare(other.GetURL()) + return values.NewString(doc.frameTree.Frame.URL).Compare(other.GetURL()) default: return drivers.Compare(doc.Type(), other.Type()) } @@ -194,41 +186,11 @@ func (doc *HTMLDocument) SetIn(ctx context.Context, path []core.Value, value cor } func (doc *HTMLDocument) Close() error { - errs := make([]error, 0, 5) + return doc.element.Close() +} - if doc.children.Ready() { - val, err := doc.children.Read(context.Background()) - - if err == nil { - arr := val.(*values.Array) - - arr.ForEach(func(value core.Value, _ int) bool { - doc := value.(drivers.HTMLDocument) - - err := doc.Close() - - if err != nil { - errs = append(errs, errors.Wrapf(err, "failed to close nested document: %s", doc.GetURL())) - } - - return true - }) - } else { - errs = append(errs, err) - } - } - - err := doc.element.Close() - - if err != nil { - errs = append(errs, err) - } - - if len(errs) == 0 { - return nil - } - - return core.Errors(errs...) +func (doc *HTMLDocument) Frame() page.FrameTree { + return doc.frameTree } func (doc *HTMLDocument) IsDetached() values.Boolean { @@ -280,25 +242,37 @@ func (doc *HTMLDocument) GetTitle() values.String { } func (doc *HTMLDocument) GetName() values.String { - if doc.frames.Frame.Name != nil { - return values.NewString(*doc.frames.Frame.Name) + if doc.frameTree.Frame.Name != nil { + return values.NewString(*doc.frameTree.Frame.Name) } return values.EmptyString } -func (doc *HTMLDocument) GetParentDocument() drivers.HTMLDocument { - return doc.parent +func (doc *HTMLDocument) GetParentDocument(ctx context.Context) (drivers.HTMLDocument, error) { + if doc.frameTree.Frame.ParentID == nil { + return nil, nil + } + + return doc.dom.GetFrameNode(ctx, *doc.frameTree.Frame.ParentID) } func (doc *HTMLDocument) GetChildDocuments(ctx context.Context) (*values.Array, error) { - children, err := doc.children.Read(ctx) + arr := values.NewArray(len(doc.frameTree.ChildFrames)) - if err != nil { - return values.NewArray(0), errors.Wrap(err, "failed to load child documents") + for _, childFrame := range doc.frameTree.ChildFrames { + frame, err := doc.dom.GetFrameNode(ctx, childFrame.Frame.ID) + + if err != nil { + return nil, err + } + + if frame != nil { + arr.Push(frame) + } } - return children.Copy().(*values.Array), nil + return arr, nil } func (doc *HTMLDocument) XPath(ctx context.Context, expression values.String) (core.Value, error) { @@ -314,7 +288,7 @@ func (doc *HTMLDocument) GetElement() drivers.HTMLElement { } func (doc *HTMLDocument) GetURL() values.String { - return values.NewString(doc.frames.Frame.URL) + return values.NewString(doc.frameTree.Frame.URL) } func (doc *HTMLDocument) MoveMouseByXY(ctx context.Context, x, y values.Float) error { @@ -484,48 +458,13 @@ func (doc *HTMLDocument) ScrollByXY(ctx context.Context, x, y values.Float) erro return doc.input.ScrollByXY(ctx, float64(x), float64(y)) } -func (doc *HTMLDocument) loadChildren(ctx context.Context) (value core.Value, e error) { - children := values.NewArray(len(doc.frames.ChildFrames)) - - if len(doc.frames.ChildFrames) > 0 { - for _, cf := range doc.frames.ChildFrames { - cfNode, cfExecID, err := resolveFrame(ctx, doc.client, cf.Frame) - - if err != nil { - return nil, errors.Wrap(err, "failed to resolve frame node") - } - - cfDocument, err := LoadHTMLDocument( - ctx, - doc.logger, - doc.client, - doc.events, - doc.input.Mouse(), - doc.input.Keyboard(), - cfNode, - cf, - cfExecID, - doc, - ) - - if err != nil { - return nil, errors.Wrap(err, "failed to load frame document") - } - - children.Push(cfDocument) - } - } - - return children, nil -} - func (doc *HTMLDocument) logError(err error) *zerolog.Event { return doc.logger. Error(). Timestamp(). - Str("url", string(doc.frames.Frame.URL)). - Str("securityOrigin", string(doc.frames.Frame.SecurityOrigin)). - Str("mimeType", string(doc.frames.Frame.MimeType)). - Str("frameID", string(doc.frames.Frame.ID)). + Str("url", doc.frameTree.Frame.URL). + Str("securityOrigin", doc.frameTree.Frame.SecurityOrigin). + Str("mimeType", doc.frameTree.Frame.MimeType). + Str("frameID", string(doc.frameTree.Frame.ID)). Err(err) } diff --git a/pkg/drivers/cdp/dom/document_test.go b/pkg/drivers/cdp/dom/document_test.go new file mode 100644 index 00000000..56e941a4 --- /dev/null +++ b/pkg/drivers/cdp/dom/document_test.go @@ -0,0 +1 @@ +package dom diff --git a/pkg/drivers/cdp/element.go b/pkg/drivers/cdp/dom/element.go similarity index 91% rename from pkg/drivers/cdp/element.go rename to pkg/drivers/cdp/dom/element.go index 0be18442..9ea16ca8 100644 --- a/pkg/drivers/cdp/element.go +++ b/pkg/drivers/cdp/dom/element.go @@ -1,4 +1,4 @@ -package cdp +package dom import ( "context" @@ -36,11 +36,20 @@ type ( ObjectID runtime.RemoteObjectID } + elementListeners struct { + pageReload events.ListenerID + attrModified events.ListenerID + attrRemoved events.ListenerID + childNodeCountUpdated events.ListenerID + childNodeInserted events.ListenerID + childNodeRemoved events.ListenerID + } + HTMLElement struct { mu sync.Mutex logger *zerolog.Logger client *cdp.Client - events *events.EventBroker + dom *Manager input *input.Manager exec *eval.ExecutionContext connected values.Boolean @@ -53,6 +62,7 @@ type ( style *common.LazyValue children []HTMLElementIdentity loadedChildren *common.LazyValue + listeners *elementListeners } ) @@ -60,7 +70,7 @@ func LoadHTMLElement( ctx context.Context, logger *zerolog.Logger, client *cdp.Client, - broker *events.EventBroker, + domManager *Manager, input *input.Manager, exec *eval.ExecutionContext, nodeID dom.NodeID, @@ -86,7 +96,7 @@ func LoadHTMLElement( ctx, logger, client, - broker, + domManager, input, exec, HTMLElementIdentity{ @@ -100,7 +110,7 @@ func LoadHTMLElementWithID( ctx context.Context, logger *zerolog.Logger, client *cdp.Client, - broker *events.EventBroker, + domManager *Manager, input *input.Manager, exec *eval.ExecutionContext, id HTMLElementIdentity, @@ -120,7 +130,7 @@ func LoadHTMLElementWithID( return NewHTMLElement( logger, client, - broker, + domManager, input, exec, id, @@ -133,7 +143,7 @@ func LoadHTMLElementWithID( func NewHTMLElement( logger *zerolog.Logger, client *cdp.Client, - broker *events.EventBroker, + domManager *Manager, input *input.Manager, exec *eval.ExecutionContext, id HTMLElementIdentity, @@ -144,7 +154,7 @@ func NewHTMLElement( el := new(HTMLElement) el.logger = logger el.client = client - el.events = broker + el.dom = domManager el.input = input el.exec = exec el.connected = values.True @@ -157,13 +167,14 @@ func NewHTMLElement( el.style = common.NewLazyValue(el.parseStyle) el.loadedChildren = common.NewLazyValue(el.loadChildren) el.children = children - - broker.AddEventListener(events.EventReload, el.handlePageReload) - broker.AddEventListener(events.EventAttrModified, el.handleAttrModified) - broker.AddEventListener(events.EventAttrRemoved, el.handleAttrRemoved) - broker.AddEventListener(events.EventChildNodeCountUpdated, el.handleChildrenCountChanged) - broker.AddEventListener(events.EventChildNodeInserted, el.handleChildInserted) - broker.AddEventListener(events.EventChildNodeRemoved, el.handleChildRemoved) + el.listeners = &elementListeners{ + pageReload: domManager.AddDocumentUpdatedListener(el.handlePageReload), + attrModified: domManager.AddAttrModifiedListener(el.handleAttrModified), + attrRemoved: domManager.AddAttrRemovedListener(el.handleAttrRemoved), + childNodeCountUpdated: domManager.AddChildNodeCountUpdatedListener(el.handleChildrenCountChanged), + childNodeInserted: domManager.AddChildNodeInsertedListener(el.handleChildInserted), + childNodeRemoved: domManager.AddChildNodeRemovedListener(el.handleChildRemoved), + } return el } @@ -178,12 +189,13 @@ func (el *HTMLElement) Close() error { } el.connected = values.False - el.events.RemoveEventListener(events.EventReload, el.handlePageReload) - el.events.RemoveEventListener(events.EventAttrModified, el.handleAttrModified) - el.events.RemoveEventListener(events.EventAttrRemoved, el.handleAttrRemoved) - el.events.RemoveEventListener(events.EventChildNodeCountUpdated, el.handleChildrenCountChanged) - el.events.RemoveEventListener(events.EventChildNodeInserted, el.handleChildInserted) - el.events.RemoveEventListener(events.EventChildNodeRemoved, el.handleChildRemoved) + + el.dom.RemoveReloadListener(el.listeners.pageReload) + el.dom.RemoveAttrModifiedListener(el.listeners.attrModified) + el.dom.RemoveAttrRemovedListener(el.listeners.attrRemoved) + el.dom.RemoveChildNodeCountUpdatedListener(el.listeners.childNodeCountUpdated) + el.dom.RemoveChildNodeInsertedListener(el.listeners.childNodeInserted) + el.dom.RemoveChildNodeRemovedListener(el.listeners.childNodeRemoved) return nil } @@ -472,7 +484,15 @@ func (el *HTMLElement) QuerySelector(ctx context.Context, selector values.String return values.None, nil } - res, err := LoadHTMLElement(ctx, el.logger, el.client, el.events, el.input, el.exec, found.NodeID) + res, err := LoadHTMLElement( + ctx, + el.logger, + el.client, + el.dom, + el.input, + el.exec, + found.NodeID, + ) if err != nil { return values.None, nil @@ -504,7 +524,15 @@ func (el *HTMLElement) QuerySelectorAll(ctx context.Context, selector values.Str continue } - childEl, err := LoadHTMLElement(ctx, el.logger, el.client, el.events, el.input, el.exec, id) + childEl, err := LoadHTMLElement( + ctx, + el.logger, + el.client, + el.dom, + el.input, + el.exec, + id, + ) if err != nil { // close elements that are already loaded, but won't be used because of the error @@ -609,7 +637,7 @@ func (el *HTMLElement) XPath(ctx context.Context, expression values.String) (res ctx, el.logger, el.client, - el.events, + el.dom, el.input, el.exec, HTMLElementIdentity{ @@ -641,7 +669,7 @@ func (el *HTMLElement) XPath(ctx context.Context, expression values.String) (res ctx, el.logger, el.client, - el.events, + el.dom, el.input, el.exec, HTMLElementIdentity{ @@ -1155,7 +1183,7 @@ func (el *HTMLElement) loadChildren(ctx context.Context) (core.Value, error) { ctx, el.logger, el.client, - el.events, + el.dom, el.input, el.exec, childID.NodeID, @@ -1191,20 +1219,13 @@ func (el *HTMLElement) parseStyle(ctx context.Context) (core.Value, error) { return common.DeserializeStyles(value.(values.String)) } -func (el *HTMLElement) handlePageReload(_ context.Context, _ interface{}) { +func (el *HTMLElement) handlePageReload(_ context.Context) { el.Close() } -func (el *HTMLElement) handleAttrModified(ctx context.Context, message interface{}) { - reply, ok := message.(*dom.AttributeModifiedReply) - - // well.... - if !ok { - return - } - +func (el *HTMLElement) handleAttrModified(ctx context.Context, nodeID dom.NodeID, name, value string) { // it's not for this el - if reply.NodeID != el.id.NodeID { + if nodeID != el.id.NodeID { return } @@ -1225,7 +1246,7 @@ func (el *HTMLElement) handleAttrModified(ctx context.Context, message interface return } - if reply.Name == "style" { + if name == "style" { el.style.Reset() } @@ -1235,20 +1256,13 @@ func (el *HTMLElement) handleAttrModified(ctx context.Context, message interface return } - attrs.Set(values.NewString(reply.Name), values.NewString(reply.Value)) + attrs.Set(values.NewString(name), values.NewString(value)) }) } -func (el *HTMLElement) handleAttrRemoved(ctx context.Context, message interface{}) { - reply, ok := message.(*dom.AttributeRemovedReply) - - // well.... - if !ok { - return - } - +func (el *HTMLElement) handleAttrRemoved(ctx context.Context, nodeID dom.NodeID, name string) { // it's not for this el - if reply.NodeID != el.id.NodeID { + if nodeID != el.id.NodeID { return } @@ -1269,7 +1283,7 @@ func (el *HTMLElement) handleAttrRemoved(ctx context.Context, message interface{ return } - if reply.Name == "style" { + if name == "style" { el.style.Reset() } @@ -1279,18 +1293,12 @@ func (el *HTMLElement) handleAttrRemoved(ctx context.Context, message interface{ return } - attrs.Remove(values.NewString(reply.Name)) + attrs.Remove(values.NewString(name)) }) } -func (el *HTMLElement) handleChildrenCountChanged(ctx context.Context, message interface{}) { - reply, ok := message.(*dom.ChildNodeCountUpdatedReply) - - if !ok { - return - } - - if reply.NodeID != el.id.NodeID { +func (el *HTMLElement) handleChildrenCountChanged(ctx context.Context, nodeID dom.NodeID, _ int) { + if nodeID != el.id.NodeID { return } @@ -1315,20 +1323,14 @@ func (el *HTMLElement) handleChildrenCountChanged(ctx context.Context, message i el.children = createChildrenArray(node.Node.Children) } -func (el *HTMLElement) handleChildInserted(ctx context.Context, message interface{}) { - reply, ok := message.(*dom.ChildNodeInsertedReply) - - if !ok { - return - } - - if reply.ParentNodeID != el.id.NodeID { +func (el *HTMLElement) handleChildInserted(ctx context.Context, parentNodeID, prevNodeID dom.NodeID, node dom.Node) { + if parentNodeID != el.id.NodeID { return } targetIDx := -1 - prevID := reply.PreviousNodeID - nextID := reply.Node.NodeID + prevID := prevNodeID + nextID := node.NodeID if el.IsDetached() { return @@ -1349,7 +1351,7 @@ func (el *HTMLElement) handleChildInserted(ctx context.Context, message interfac } nextIdentity := HTMLElementIdentity{ - NodeID: reply.Node.NodeID, + NodeID: nextID, } arr := el.children @@ -1361,7 +1363,15 @@ func (el *HTMLElement) handleChildInserted(ctx context.Context, message interfac el.loadedChildren.Mutate(ctx, func(v core.Value, _ error) { loadedArr := v.(*values.Array) - loadedEl, err := LoadHTMLElement(ctx, el.logger, el.client, el.events, el.input, el.exec, nextID) + loadedEl, err := LoadHTMLElement( + ctx, + el.logger, + el.client, + el.dom, + el.input, + el.exec, + nextID, + ) if err != nil { el.logError(err).Msg("failed to load an inserted element") @@ -1376,19 +1386,13 @@ func (el *HTMLElement) handleChildInserted(ctx context.Context, message interfac }) } -func (el *HTMLElement) handleChildRemoved(ctx context.Context, message interface{}) { - reply, ok := message.(*dom.ChildNodeRemovedReply) - - if !ok { - return - } - - if reply.ParentNodeID != el.id.NodeID { +func (el *HTMLElement) handleChildRemoved(ctx context.Context, nodeID, prevNodeID dom.NodeID) { + if nodeID != el.id.NodeID { return } targetIDx := -1 - targetID := reply.NodeID + targetID := prevNodeID if el.IsDetached() { return diff --git a/pkg/drivers/cdp/dom/helpers.go b/pkg/drivers/cdp/dom/helpers.go new file mode 100644 index 00000000..7a5a8d9c --- /dev/null +++ b/pkg/drivers/cdp/dom/helpers.go @@ -0,0 +1,287 @@ +package dom + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "golang.org/x/net/html" + "strings" + "time" + + "github.com/MontFerret/ferret/pkg/drivers/cdp/eval" + "github.com/MontFerret/ferret/pkg/drivers/cdp/templates" + "github.com/MontFerret/ferret/pkg/drivers/common" + "github.com/MontFerret/ferret/pkg/runtime/values" + + "github.com/PuerkitoBio/goquery" + "github.com/mafredri/cdp" + "github.com/mafredri/cdp/protocol/dom" + "github.com/mafredri/cdp/protocol/page" + "github.com/mafredri/cdp/protocol/runtime" +) + +var emptyExpires = time.Time{} + +// parseAttrs is a helper function that parses a given interleaved array of node attribute names and values, +// and returns an object that represents attribute keys and values. +func parseAttrs(attrs []string) *values.Object { + var attr values.String + + res := values.NewObject() + + for _, el := range attrs { + el = strings.TrimSpace(el) + str := values.NewString(el) + + if common.IsAttribute(el) { + attr = str + res.Set(str, values.EmptyString) + } else { + current, ok := res.Get(attr) + + if ok { + if current.String() != "" { + res.Set(attr, current.(values.String).Concat(values.SpaceString).Concat(str)) + } else { + res.Set(attr, str) + } + } + } + } + + return res +} + +func setInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerHTML values.String) error { + var objID *runtime.RemoteObjectID + + if id.ObjectID != "" { + objID = &id.ObjectID + } else { + repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) + + if err != nil { + return err + } + + if repl.Object.ObjectID == nil { + return errors.New("unable to resolve node") + } + + objID = repl.Object.ObjectID + } + + b, err := json.Marshal(innerHTML.String()) + + if err != nil { + return err + } + + err = exec.EvalWithArguments(ctx, templates.SetInnerHTML(), + runtime.CallArgument{ + ObjectID: objID, + }, + runtime.CallArgument{ + Value: json.RawMessage(b), + }, + ) + + return err +} + +func getInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) { + // not a document + if nodeType != html.DocumentNode { + var objID runtime.RemoteObjectID + + if id.ObjectID != "" { + objID = id.ObjectID + } else { + repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) + + if err != nil { + return "", err + } + + if repl.Object.ObjectID == nil { + return "", errors.New("unable to resolve node") + } + + objID = *repl.Object.ObjectID + } + + res, err := exec.ReadProperty(ctx, objID, "innerHTML") + + if err != nil { + return "", err + } + + return values.NewString(res.String()), nil + } + + repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerHTML") + + if err != nil { + return "", err + } + + return values.NewString(repl.String()), nil +} + +func setInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerText values.String) error { + var objID *runtime.RemoteObjectID + + if id.ObjectID != "" { + objID = &id.ObjectID + } else { + repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) + + if err != nil { + return err + } + + if repl.Object.ObjectID == nil { + return errors.New("unable to resolve node") + } + + objID = repl.Object.ObjectID + } + + b, err := json.Marshal(innerText.String()) + + if err != nil { + return err + } + + err = exec.EvalWithArguments(ctx, templates.SetInnerText(), + runtime.CallArgument{ + ObjectID: objID, + }, + runtime.CallArgument{ + Value: json.RawMessage(b), + }, + ) + + return err +} + +func getInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) { + // not a document + if nodeType != html.DocumentNode { + var objID runtime.RemoteObjectID + + if id.ObjectID != "" { + objID = id.ObjectID + } else { + repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) + + if err != nil { + return "", err + } + + if repl.Object.ObjectID == nil { + return "", errors.New("unable to resolve node") + } + + objID = *repl.Object.ObjectID + } + + res, err := exec.ReadProperty(ctx, objID, "innerText") + + if err != nil { + return "", err + } + + return values.NewString(res.String()), err + } + + repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerText") + + if err != nil { + return "", err + } + + return values.NewString(repl.String()), nil +} + +func parseInnerText(innerHTML string) (values.String, error) { + buff := bytes.NewBuffer([]byte(innerHTML)) + + parsed, err := goquery.NewDocumentFromReader(buff) + + if err != nil { + return values.EmptyString, err + } + + return values.NewString(parsed.Text()), nil +} + +func createChildrenArray(nodes []dom.Node) []HTMLElementIdentity { + children := make([]HTMLElementIdentity, len(nodes)) + + for idx, child := range nodes { + child := child + children[idx] = HTMLElementIdentity{ + NodeID: child.NodeID, + } + } + + return children +} + +func resolveFrame(ctx context.Context, client *cdp.Client, frameID page.FrameID) (dom.Node, runtime.ExecutionContextID, error) { + worldRepl, err := client.Page.CreateIsolatedWorld(ctx, page.NewCreateIsolatedWorldArgs(frameID)) + + if err != nil { + return dom.Node{}, -1, err + } + + evalRes, err := client.Runtime.Evaluate( + ctx, + runtime.NewEvaluateArgs(eval.PrepareEval("return document")). + SetContextID(worldRepl.ExecutionContextID), + ) + + if err != nil { + return dom.Node{}, -1, err + } + + if evalRes.ExceptionDetails != nil { + exception := *evalRes.ExceptionDetails + + return dom.Node{}, -1, errors.New(exception.Text) + } + + if evalRes.Result.ObjectID == nil { + return dom.Node{}, -1, errors.New("failed to resolve frame document") + } + + req, err := client.DOM.RequestNode(ctx, dom.NewRequestNodeArgs(*evalRes.Result.ObjectID)) + + if err != nil { + return dom.Node{}, -1, err + } + + if req.NodeID == 0 { + return dom.Node{}, -1, errors.New("framed document is resolved with empty node id") + } + + desc, err := client.DOM.DescribeNode( + ctx, + dom. + NewDescribeNodeArgs(). + SetNodeID(req.NodeID). + SetDepth(1), + ) + + if err != nil { + return dom.Node{}, -1, err + } + + // Returned node, by some reason, does not contain the NodeID + // So, we have to set it manually + desc.Node.NodeID = req.NodeID + + return desc.Node, worldRepl.ExecutionContextID, nil +} diff --git a/pkg/drivers/cdp/dom/manager.go b/pkg/drivers/cdp/dom/manager.go new file mode 100644 index 00000000..2e677514 --- /dev/null +++ b/pkg/drivers/cdp/dom/manager.go @@ -0,0 +1,511 @@ +package dom + +import ( + "context" + "github.com/MontFerret/ferret/pkg/drivers/cdp/input" + "github.com/MontFerret/ferret/pkg/drivers/common" + "github.com/MontFerret/ferret/pkg/runtime/core" + "github.com/MontFerret/ferret/pkg/runtime/values" + "github.com/mafredri/cdp/protocol/page" + "github.com/pkg/errors" + "github.com/rs/zerolog" + "io" + "sync" + + "github.com/mafredri/cdp" + "github.com/mafredri/cdp/protocol/dom" + "github.com/mafredri/cdp/rpcc" + + "github.com/MontFerret/ferret/pkg/drivers/cdp/events" +) + +var ( + eventDocumentUpdated = events.New("doc_updated") + eventAttrModified = events.New("attr_modified") + eventAttrRemoved = events.New("attr_removed") + eventChildNodeCountUpdated = events.New("child_count_updated") + eventChildNodeInserted = events.New("child_inserted") + eventChildNodeRemoved = events.New("child_removed") +) + +type ( + DocumentUpdatedListener func(ctx context.Context) + + AttrModifiedListener func(ctx context.Context, nodeID dom.NodeID, name, value string) + + AttrRemovedListener func(ctx context.Context, nodeID dom.NodeID, name string) + + ChildNodeCountUpdatedListener func(ctx context.Context, nodeID dom.NodeID, count int) + + ChildNodeInsertedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID, node dom.Node) + + ChildNodeRemovedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID) + + Frame struct { + tree page.FrameTree + node *HTMLDocument + ready bool + } + + Manager struct { + mu sync.Mutex + logger *zerolog.Logger + client *cdp.Client + events *events.Loop + mouse *input.Mouse + keyboard *input.Keyboard + mainFrame page.FrameID + frames map[page.FrameID]Frame + cancel context.CancelFunc + } +) + +// a dirty workaround to let pass the vet test +func createContext() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) +} + +func New( + logger *zerolog.Logger, + client *cdp.Client, + eventLoop *events.Loop, + mouse *input.Mouse, + keyboard *input.Keyboard, +) (manager *Manager, err error) { + ctx, cancel := createContext() + + closers := make([]io.Closer, 0, 10) + + defer func() { + if err != nil { + common.CloseAll(logger, closers, "failed to close a DOM event stream") + } + }() + + onContentReady, err := client.Page.DOMContentEventFired(ctx) + + if err != nil { + return nil, err + } + + closers = append(closers, onContentReady) + + onDocUpdated, err := client.DOM.DocumentUpdated(ctx) + + if err != nil { + return nil, err + } + + closers = append(closers, onDocUpdated) + + onAttrModified, err := client.DOM.AttributeModified(ctx) + + if err != nil { + return nil, err + } + + closers = append(closers, onAttrModified) + + onAttrRemoved, err := client.DOM.AttributeRemoved(ctx) + + if err != nil { + return nil, err + } + + closers = append(closers, onAttrRemoved) + + onChildCountUpdated, err := client.DOM.ChildNodeCountUpdated(ctx) + + if err != nil { + return nil, err + } + + closers = append(closers, onChildCountUpdated) + + onChildNodeInserted, err := client.DOM.ChildNodeInserted(ctx) + + if err != nil { + return nil, err + } + + closers = append(closers, onChildNodeInserted) + + onChildNodeRemoved, err := client.DOM.ChildNodeRemoved(ctx) + + if err != nil { + return nil, err + } + + closers = append(closers, onChildNodeRemoved) + + eventLoop.AddSource(events.NewSource(eventDocumentUpdated, onDocUpdated, func(stream rpcc.Stream) (i interface{}, e error) { + return stream.(dom.DocumentUpdatedClient).Recv() + })) + + eventLoop.AddSource(events.NewSource(eventAttrModified, onAttrModified, func(stream rpcc.Stream) (i interface{}, e error) { + return stream.(dom.AttributeModifiedClient).Recv() + })) + + eventLoop.AddSource(events.NewSource(eventAttrRemoved, onAttrRemoved, func(stream rpcc.Stream) (i interface{}, e error) { + return stream.(dom.AttributeRemovedClient).Recv() + })) + + eventLoop.AddSource(events.NewSource(eventChildNodeCountUpdated, onChildCountUpdated, func(stream rpcc.Stream) (i interface{}, e error) { + return stream.(dom.ChildNodeCountUpdatedClient).Recv() + })) + + eventLoop.AddSource(events.NewSource(eventChildNodeInserted, onChildNodeInserted, func(stream rpcc.Stream) (i interface{}, e error) { + return stream.(dom.ChildNodeInsertedClient).Recv() + })) + + eventLoop.AddSource(events.NewSource(eventChildNodeRemoved, onChildNodeRemoved, func(stream rpcc.Stream) (i interface{}, e error) { + return stream.(dom.ChildNodeRemovedClient).Recv() + })) + + manager = new(Manager) + manager.logger = logger + manager.client = client + manager.events = eventLoop + manager.mouse = mouse + manager.keyboard = keyboard + manager.frames = make(map[page.FrameID]Frame) + manager.cancel = cancel + + return manager, nil +} + +func (m *Manager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + m.cancel() + m.cancel = nil + } + + errs := make([]error, 0, len(m.frames)) + + for _, f := range m.frames { + // if initialized + if f.node != nil { + if err := f.node.Close(); err != nil { + errs = append(errs, err) + } + } + } + + if len(errs) > 0 { + return core.Errors(errs...) + } + + return nil +} + +func (m *Manager) GetMainFrame() *HTMLDocument { + m.mu.Lock() + defer m.mu.Unlock() + + if m.mainFrame == "" { + return nil + } + + mainFrame, exists := m.frames[m.mainFrame] + + if exists { + return mainFrame.node + } + + return nil +} + +func (m *Manager) SetMainFrame(doc *HTMLDocument) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.mainFrame != "" { + if err := m.removeFrameRecursivelyInternal(m.mainFrame); err != nil { + m.logger.Error().Err(err).Msg("failed to close previous main frame") + } + } + + m.mainFrame = doc.frameTree.Frame.ID + + m.addPreloadedFrame(doc) +} + +func (m *Manager) AddFrame(frame page.FrameTree) { + m.mu.Lock() + defer m.mu.Unlock() + + m.addFrameInternal(frame) +} + +func (m *Manager) RemoveFrame(frameID page.FrameID) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.removeFrameInternal(frameID) +} + +func (m *Manager) RemoveFrameRecursively(frameID page.FrameID) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.removeFrameRecursivelyInternal(frameID) +} + +func (m *Manager) RemoveFramesByParentID(parentFrameID page.FrameID) error { + m.mu.Lock() + defer m.mu.Unlock() + + frame, found := m.frames[parentFrameID] + + if !found { + return errors.New("frame not found") + } + + for _, child := range frame.tree.ChildFrames { + if err := m.removeFrameRecursivelyInternal(child.Frame.ID); err != nil { + return err + } + } + + return nil +} + +func (m *Manager) GetFrameNode(ctx context.Context, frameID page.FrameID) (*HTMLDocument, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.getFrameInternal(ctx, frameID) +} + +func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.FrameTree, error) { + m.mu.Lock() + defer m.mu.Unlock() + + frame, found := m.frames[frameID] + + if !found { + return page.FrameTree{}, core.ErrNotFound + } + + return frame.tree, nil +} + +func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) { + m.mu.Lock() + defer m.mu.Unlock() + + arr := values.NewArray(len(m.frames)) + + for _, f := range m.frames { + doc, err := m.getFrameInternal(ctx, f.tree.Frame.ID) + + if err != nil { + return nil, err + } + + arr.Push(doc) + } + + return arr, nil +} + +func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) events.ListenerID { + return m.events.AddListener(eventDocumentUpdated, func(ctx context.Context, _ interface{}) bool { + listener(ctx) + + return true + }) +} + +func (m *Manager) RemoveReloadListener(listenerID events.ListenerID) { + m.events.RemoveListener(eventDocumentUpdated, listenerID) +} + +func (m *Manager) AddAttrModifiedListener(listener AttrModifiedListener) events.ListenerID { + return m.events.AddListener(eventAttrModified, func(ctx context.Context, message interface{}) bool { + reply := message.(*dom.AttributeModifiedReply) + + listener(ctx, reply.NodeID, reply.Name, reply.Value) + + return true + }) +} + +func (m *Manager) RemoveAttrModifiedListener(listenerID events.ListenerID) { + m.events.RemoveListener(eventAttrModified, listenerID) +} + +func (m *Manager) AddAttrRemovedListener(listener AttrRemovedListener) events.ListenerID { + return m.events.AddListener(eventAttrRemoved, func(ctx context.Context, message interface{}) bool { + reply := message.(*dom.AttributeRemovedReply) + + listener(ctx, reply.NodeID, reply.Name) + + return true + }) +} + +func (m *Manager) RemoveAttrRemovedListener(listenerID events.ListenerID) { + m.events.RemoveListener(eventAttrRemoved, listenerID) +} + +func (m *Manager) AddChildNodeCountUpdatedListener(listener ChildNodeCountUpdatedListener) events.ListenerID { + return m.events.AddListener(eventChildNodeCountUpdated, func(ctx context.Context, message interface{}) bool { + reply := message.(*dom.ChildNodeCountUpdatedReply) + + listener(ctx, reply.NodeID, reply.ChildNodeCount) + + return true + }) +} + +func (m *Manager) RemoveChildNodeCountUpdatedListener(listenerID events.ListenerID) { + m.events.RemoveListener(eventChildNodeCountUpdated, listenerID) +} + +func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListener) events.ListenerID { + return m.events.AddListener(eventChildNodeInserted, func(ctx context.Context, message interface{}) bool { + reply := message.(*dom.ChildNodeInsertedReply) + + listener(ctx, reply.ParentNodeID, reply.PreviousNodeID, reply.Node) + + return true + }) +} + +func (m *Manager) RemoveChildNodeInsertedListener(listenerID events.ListenerID) { + m.events.RemoveListener(eventChildNodeInserted, listenerID) +} + +func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener) events.ListenerID { + return m.events.AddListener(eventChildNodeRemoved, func(ctx context.Context, message interface{}) bool { + reply := message.(*dom.ChildNodeRemovedReply) + + listener(ctx, reply.ParentNodeID, reply.NodeID) + + return true + }) +} + +func (m *Manager) RemoveChildNodeRemovedListener(listenerID events.ListenerID) { + m.events.RemoveListener(eventChildNodeRemoved, listenerID) +} + +func (m *Manager) WaitForDOMReady(ctx context.Context) error { + onContentReady, err := m.client.Page.DOMContentEventFired(ctx) + + if err != nil { + return err + } + + defer func() { + if err := onContentReady.Close(); err != nil { + m.logger.Error().Err(err).Msg("failed to close DOM content ready stream event") + } + }() + + _, err = onContentReady.Recv() + + return err +} + +func (m *Manager) addFrameInternal(frame page.FrameTree) { + m.frames[frame.Frame.ID] = Frame{ + tree: frame, + node: nil, + } + + for _, child := range frame.ChildFrames { + m.addFrameInternal(child) + } +} + +func (m *Manager) addPreloadedFrame(doc *HTMLDocument) { + m.frames[doc.frameTree.Frame.ID] = Frame{ + tree: doc.frameTree, + node: doc, + } + + for _, child := range doc.frameTree.ChildFrames { + m.addFrameInternal(child) + } +} + +func (m *Manager) getFrameInternal(ctx context.Context, frameID page.FrameID) (*HTMLDocument, error) { + frame, found := m.frames[frameID] + + if !found { + return nil, core.ErrNotFound + } + + // frame is initialized + if frame.node != nil { + return frame.node, nil + } + + // the frames is not loaded yet + node, execID, err := resolveFrame(ctx, m.client, frameID) + + if err != nil { + return nil, errors.Wrap(err, "failed to resolve frame node") + } + + doc, err := LoadHTMLDocument( + ctx, + m.logger, + m.client, + m, + m.mouse, + m.keyboard, + node, + frame.tree, + execID, + ) + + if err != nil { + return nil, errors.Wrap(err, "failed to load frame document") + } + + frame.node = doc + + return doc, nil +} + +func (m *Manager) removeFrameInternal(frameID page.FrameID) error { + current, exists := m.frames[frameID] + + if !exists { + return core.Error(core.ErrNotFound, "frame") + } + + delete(m.frames, frameID) + + if frameID == m.mainFrame { + m.mainFrame = "" + } + + if current.node == nil { + return nil + } + + return current.node.Close() +} + +func (m *Manager) removeFrameRecursivelyInternal(frameID page.FrameID) error { + parent, exists := m.frames[frameID] + + if !exists { + return core.Error(core.ErrNotFound, "frame") + } + + for _, child := range parent.tree.ChildFrames { + if err := m.removeFrameRecursivelyInternal(child.Frame.ID); err != nil { + return err + } + } + + return m.removeFrameInternal(frameID) +} diff --git a/pkg/drivers/cdp/events/broker.go b/pkg/drivers/cdp/events/broker.go deleted file mode 100644 index 475fbdfa..00000000 --- a/pkg/drivers/cdp/events/broker.go +++ /dev/null @@ -1,289 +0,0 @@ -package events - -import ( - "context" - "reflect" - "sync" - - "github.com/mafredri/cdp/protocol/dom" - "github.com/mafredri/cdp/protocol/page" - - "github.com/MontFerret/ferret/pkg/runtime/core" -) - -type ( - Event int - - EventListener func(ctx context.Context, message interface{}) - - EventBroker struct { - mu sync.Mutex - listeners map[Event][]EventListener - cancel context.CancelFunc - onLoad page.LoadEventFiredClient - onReload dom.DocumentUpdatedClient - onAttrModified dom.AttributeModifiedClient - onAttrRemoved dom.AttributeRemovedClient - onChildNodeCountUpdated dom.ChildNodeCountUpdatedClient - onChildNodeInserted dom.ChildNodeInsertedClient - onChildNodeRemoved dom.ChildNodeRemovedClient - } -) - -const ( - //revive:disable-next-line:var-declaration - EventError = Event(iota) - EventLoad - EventReload - EventAttrModified - EventAttrRemoved - EventChildNodeCountUpdated - EventChildNodeInserted - EventChildNodeRemoved -) - -func NewEventBroker( - onLoad page.LoadEventFiredClient, - onReload dom.DocumentUpdatedClient, - onAttrModified dom.AttributeModifiedClient, - onAttrRemoved dom.AttributeRemovedClient, - onChildNodeCountUpdated dom.ChildNodeCountUpdatedClient, - onChildNodeInserted dom.ChildNodeInsertedClient, - onChildNodeRemoved dom.ChildNodeRemovedClient, -) *EventBroker { - broker := new(EventBroker) - broker.listeners = make(map[Event][]EventListener) - broker.onLoad = onLoad - broker.onReload = onReload - broker.onAttrModified = onAttrModified - broker.onAttrRemoved = onAttrRemoved - broker.onChildNodeCountUpdated = onChildNodeCountUpdated - broker.onChildNodeInserted = onChildNodeInserted - broker.onChildNodeRemoved = onChildNodeRemoved - - return broker -} - -func (broker *EventBroker) AddEventListener(event Event, listener EventListener) { - broker.mu.Lock() - defer broker.mu.Unlock() - - listeners, ok := broker.listeners[event] - - if !ok { - listeners = make([]EventListener, 0, 5) - } - - broker.listeners[event] = append(listeners, listener) -} - -func (broker *EventBroker) RemoveEventListener(event Event, listener EventListener) { - broker.mu.Lock() - defer broker.mu.Unlock() - - idx := -1 - - listeners, ok := broker.listeners[event] - - if !ok { - return - } - - listenerPointer := reflect.ValueOf(listener).Pointer() - - for i, l := range listeners { - itemPointer := reflect.ValueOf(l).Pointer() - if itemPointer == listenerPointer { - idx = i - break - } - } - - if idx < 0 { - return - } - - var modifiedListeners []EventListener - - if len(listeners) > 1 { - modifiedListeners = append(listeners[:idx], listeners[idx+1:]...) - } else { - modifiedListeners = make([]EventListener, 0, 5) - } - - broker.listeners[event] = modifiedListeners -} - -func (broker *EventBroker) ListenerCount(event Event) int { - broker.mu.Lock() - defer broker.mu.Unlock() - - listeners, ok := broker.listeners[event] - - if !ok { - return 0 - } - - return len(listeners) -} - -func (broker *EventBroker) Start() error { - broker.mu.Lock() - defer broker.mu.Unlock() - - if broker.cancel != nil { - return core.Error(core.ErrInvalidOperation, "broker is already started") - } - - ctx, cancel := context.WithCancel(context.Background()) - - broker.cancel = cancel - - go broker.runLoop(ctx) - - return nil -} - -func (broker *EventBroker) Stop() error { - broker.mu.Lock() - defer broker.mu.Unlock() - - if broker.cancel == nil { - return core.Error(core.ErrInvalidOperation, "broker is already stopped") - } - - broker.cancel() - broker.cancel = nil - - return nil -} - -func (broker *EventBroker) Close() error { - broker.mu.Lock() - defer broker.mu.Unlock() - - if broker.cancel != nil { - broker.cancel() - broker.cancel = nil - } - - broker.onLoad.Close() - broker.onReload.Close() - broker.onAttrModified.Close() - broker.onAttrRemoved.Close() - broker.onChildNodeCountUpdated.Close() - broker.onChildNodeInserted.Close() - broker.onChildNodeRemoved.Close() - - return nil -} - -func (broker *EventBroker) StopAndClose() error { - err := broker.Stop() - - if err != nil { - return err - } - - return broker.Close() -} - -func (broker *EventBroker) runLoop(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - - case <-broker.onLoad.Ready(): - if ctxDone(ctx) { - return - } - - reply, err := broker.onLoad.Recv() - - broker.emit(ctx, EventLoad, reply, err) - case <-broker.onReload.Ready(): - if ctxDone(ctx) { - return - } - - reply, err := broker.onReload.Recv() - - broker.emit(ctx, EventReload, reply, err) - case <-broker.onAttrModified.Ready(): - if ctxDone(ctx) { - return - } - - reply, err := broker.onAttrModified.Recv() - - broker.emit(ctx, EventAttrModified, reply, err) - case <-broker.onAttrRemoved.Ready(): - if ctxDone(ctx) { - return - } - - reply, err := broker.onAttrRemoved.Recv() - - broker.emit(ctx, EventAttrRemoved, reply, err) - case <-broker.onChildNodeCountUpdated.Ready(): - if ctxDone(ctx) { - return - } - - reply, err := broker.onChildNodeCountUpdated.Recv() - - broker.emit(ctx, EventChildNodeCountUpdated, reply, err) - case <-broker.onChildNodeInserted.Ready(): - if ctxDone(ctx) { - return - } - - reply, err := broker.onChildNodeInserted.Recv() - - broker.emit(ctx, EventChildNodeInserted, reply, err) - case <-broker.onChildNodeRemoved.Ready(): - if ctxDone(ctx) { - return - } - - reply, err := broker.onChildNodeRemoved.Recv() - - broker.emit(ctx, EventChildNodeRemoved, reply, err) - } - } -} - -func ctxDone(ctx context.Context) bool { - return ctx.Err() == context.Canceled -} - -func (broker *EventBroker) emit(ctx context.Context, event Event, message interface{}, err error) { - if err != nil { - event = EventError - message = err - } - - broker.mu.Lock() - - listeners, ok := broker.listeners[event] - - if !ok { - broker.mu.Unlock() - return - } - - snapshot := make([]EventListener, len(listeners)) - copy(snapshot, listeners) - - broker.mu.Unlock() - - for _, listener := range snapshot { - select { - case <-ctx.Done(): - return - default: - listener(ctx, message) - } - } -} diff --git a/pkg/drivers/cdp/events/broker_test.go b/pkg/drivers/cdp/events/broker_test.go deleted file mode 100644 index a12e16f4..00000000 --- a/pkg/drivers/cdp/events/broker_test.go +++ /dev/null @@ -1,322 +0,0 @@ -package events_test - -import ( - "context" - "github.com/MontFerret/ferret/pkg/drivers/cdp/events" - "github.com/mafredri/cdp/protocol/dom" - "github.com/mafredri/cdp/protocol/page" - . "github.com/smartystreets/goconvey/convey" - "golang.org/x/sync/errgroup" - "sync/atomic" - "testing" - "time" -) - -type ( - TestEventStream struct { - ready chan struct{} - message chan interface{} - } - - TestLoadEventFiredClient struct { - *TestEventStream - } - - TestDocumentUpdatedClient struct { - *TestEventStream - } - - TestAttributeModifiedClient struct { - *TestEventStream - } - - TestAttributeRemovedClient struct { - *TestEventStream - } - - TestChildNodeCountUpdatedClient struct { - *TestEventStream - } - - TestChildNodeInsertedClient struct { - *TestEventStream - } - - TestChildNodeRemovedClient struct { - *TestEventStream - } - - TestBroker struct { - *events.EventBroker - OnLoad *TestLoadEventFiredClient - OnReload *TestDocumentUpdatedClient - OnAttrMod *TestAttributeModifiedClient - OnAttrRem *TestAttributeRemovedClient - OnChildNodeCount *TestChildNodeCountUpdatedClient - OnChildNodeIns *TestChildNodeInsertedClient - OnChildNodeRem *TestChildNodeRemovedClient - } -) - -func NewTestEventStream() *TestEventStream { - es := new(TestEventStream) - es.ready = make(chan struct{}) - es.message = make(chan interface{}) - return es -} - -func (es *TestEventStream) Ready() <-chan struct{} { - return es.ready -} - -func (es *TestEventStream) RecvMsg(i interface{}) error { - // NOT IMPLEMENTED - return nil -} - -func (es *TestEventStream) Close() error { - close(es.message) - close(es.ready) - return nil -} - -func (es *TestEventStream) Emit(msg interface{}) { - es.ready <- struct{}{} - es.message <- msg -} - -func (es *TestLoadEventFiredClient) Recv() (*page.LoadEventFiredReply, error) { - r := <-es.message - reply := r.(*page.LoadEventFiredReply) - - return reply, nil -} - -func (es *TestLoadEventFiredClient) EmitDefault() { - es.TestEventStream.Emit(&page.LoadEventFiredReply{}) -} - -func (es *TestDocumentUpdatedClient) Recv() (*dom.DocumentUpdatedReply, error) { - r := <-es.message - reply := r.(*dom.DocumentUpdatedReply) - - return reply, nil -} - -func (es *TestAttributeModifiedClient) Recv() (*dom.AttributeModifiedReply, error) { - r := <-es.message - reply := r.(*dom.AttributeModifiedReply) - - return reply, nil -} - -func (es *TestAttributeRemovedClient) Recv() (*dom.AttributeRemovedReply, error) { - r := <-es.message - reply := r.(*dom.AttributeRemovedReply) - - return reply, nil -} - -func (es *TestChildNodeCountUpdatedClient) Recv() (*dom.ChildNodeCountUpdatedReply, error) { - r := <-es.message - reply := r.(*dom.ChildNodeCountUpdatedReply) - - return reply, nil -} - -func (es *TestChildNodeInsertedClient) Recv() (*dom.ChildNodeInsertedReply, error) { - r := <-es.message - reply := r.(*dom.ChildNodeInsertedReply) - - return reply, nil -} - -func (es *TestChildNodeRemovedClient) Recv() (*dom.ChildNodeRemovedReply, error) { - r := <-es.message - reply := r.(*dom.ChildNodeRemovedReply) - - return reply, nil -} - -func NewTestEventBroker() *TestBroker { - onLoad := &TestLoadEventFiredClient{NewTestEventStream()} - onReload := &TestDocumentUpdatedClient{NewTestEventStream()} - onAttrMod := &TestAttributeModifiedClient{NewTestEventStream()} - onAttrRem := &TestAttributeRemovedClient{NewTestEventStream()} - onChildCount := &TestChildNodeCountUpdatedClient{NewTestEventStream()} - onChildIns := &TestChildNodeInsertedClient{NewTestEventStream()} - onChildRem := &TestChildNodeRemovedClient{NewTestEventStream()} - - b := events.NewEventBroker( - onLoad, - onReload, - onAttrMod, - onAttrRem, - onChildCount, - onChildIns, - onChildRem, - ) - - return &TestBroker{ - b, - onLoad, - onReload, - onAttrMod, - onAttrRem, - onChildCount, - onChildIns, - onChildRem, - } -} - -func StressTest(h func() error, count int) error { - var err error - - for i := 0; i < count; i++ { - err = h() - - if err != nil { - return err - } - } - - return nil -} - -func StressTestAsync(h func() error, count int) error { - var gr errgroup.Group - - for i := 0; i < count; i++ { - gr.Go(h) - } - - return gr.Wait() -} - -func TestEventBroker(t *testing.T) { - Convey(".AddEventListener", t, func() { - Convey("Should add a new listener when not started", func() { - b := NewTestEventBroker() - - StressTest(func() error { - b.AddEventListener(events.EventLoad, func(ctx context.Context, message interface{}) {}) - - return nil - }, 500) - }) - - Convey("Should add a new listener when started", func() { - b := NewTestEventBroker() - b.Start() - defer b.Stop() - - StressTest(func() error { - b.AddEventListener(events.EventLoad, func(ctx context.Context, message interface{}) {}) - - return nil - }, 500) - }) - }) - - Convey(".RemoveEventListener", t, func() { - Convey("Should remove a listener when not started", func() { - b := NewTestEventBroker() - - StressTest(func() error { - listener := func(ctx context.Context, message interface{}) {} - - b.AddEventListener(events.EventLoad, listener) - b.RemoveEventListener(events.EventLoad, listener) - - So(b.ListenerCount(events.EventLoad), ShouldEqual, 0) - - return nil - }, 500) - }) - - Convey("Should add a new listener when started", func() { - b := NewTestEventBroker() - b.Start() - defer b.Stop() - - StressTest(func() error { - listener := func(ctx context.Context, message interface{}) {} - - b.AddEventListener(events.EventLoad, listener) - - StressTestAsync(func() error { - b.OnLoad.EmitDefault() - - return nil - }, 250) - - b.RemoveEventListener(events.EventLoad, listener) - - So(b.ListenerCount(events.EventLoad), ShouldEqual, 0) - - return nil - }, 250) - - }) - - Convey("Should not call listener once it was removed", func() { - b := NewTestEventBroker() - b.Start() - defer b.Stop() - - counter := 0 - - var listener events.EventListener - - listener = func(ctx context.Context, message interface{}) { - counter++ - - b.RemoveEventListener(events.EventLoad, listener) - } - - b.AddEventListener(events.EventLoad, listener) - b.OnLoad.Emit(&page.LoadEventFiredReply{}) - - time.Sleep(time.Duration(10) * time.Millisecond) - - StressTestAsync(func() error { - b.OnLoad.Emit(&page.LoadEventFiredReply{}) - - return nil - }, 250) - - So(b.ListenerCount(events.EventLoad), ShouldEqual, 0) - So(counter, ShouldEqual, 1) - }) - }) - - Convey(".Stop", t, func() { - Convey("Should stop emitting events", func() { - b := NewTestEventBroker() - b.Start() - - var counter int64 - - b.AddEventListener(events.EventLoad, func(ctx context.Context, message interface{}) { - atomic.AddInt64(&counter, 1) - b.Stop() - }) - - b.OnLoad.EmitDefault() - - time.Sleep(time.Duration(5) * time.Millisecond) - - go func() { - b.OnLoad.EmitDefault() - }() - - go func() { - b.OnLoad.EmitDefault() - }() - - time.Sleep(time.Duration(5) * time.Millisecond) - - So(atomic.LoadInt64(&counter), ShouldEqual, 1) - }) - }) -} diff --git a/pkg/drivers/cdp/events/helpers.go b/pkg/drivers/cdp/events/helpers.go index 58e09770..c44b6439 100644 --- a/pkg/drivers/cdp/events/helpers.go +++ b/pkg/drivers/cdp/events/helpers.go @@ -2,125 +2,17 @@ package events import ( "context" - - "github.com/mafredri/cdp" - "github.com/mafredri/cdp/protocol/dom" - "github.com/mafredri/cdp/protocol/page" - "github.com/pkg/errors" + "hash/fnv" ) -func WaitForLoadEvent(ctx context.Context, client *cdp.Client) error { - loadEventFired, err := client.Page.LoadEventFired(ctx) +func New(name string) ID { + h := fnv.New32a() - if err != nil { - return errors.Wrap(err, "failed to create load event hook") - } + h.Write([]byte(name)) - _, err = loadEventFired.Recv() - - if err != nil { - return err - } - - return loadEventFired.Close() + return ID(h.Sum32()) } -func CreateEventBroker(client *cdp.Client) (*EventBroker, error) { - var err error - var onLoad page.LoadEventFiredClient - var onReload dom.DocumentUpdatedClient - var onAttrModified dom.AttributeModifiedClient - var onAttrRemoved dom.AttributeRemovedClient - var onChildCountUpdated dom.ChildNodeCountUpdatedClient - var onChildNodeInserted dom.ChildNodeInsertedClient - var onChildNodeRemoved dom.ChildNodeRemovedClient - ctx := context.Background() - - onLoad, err = client.Page.LoadEventFired(ctx) - - if err != nil { - return nil, err - } - - onReload, err = client.DOM.DocumentUpdated(ctx) - - if err != nil { - onLoad.Close() - return nil, err - } - - onAttrModified, err = client.DOM.AttributeModified(ctx) - - if err != nil { - onLoad.Close() - onReload.Close() - return nil, err - } - - onAttrRemoved, err = client.DOM.AttributeRemoved(ctx) - - if err != nil { - onLoad.Close() - onReload.Close() - onAttrModified.Close() - return nil, err - } - - onChildCountUpdated, err = client.DOM.ChildNodeCountUpdated(ctx) - - if err != nil { - onLoad.Close() - onReload.Close() - onAttrModified.Close() - onAttrRemoved.Close() - return nil, err - } - - onChildNodeInserted, err = client.DOM.ChildNodeInserted(ctx) - - if err != nil { - onLoad.Close() - onReload.Close() - onAttrModified.Close() - onAttrRemoved.Close() - onChildCountUpdated.Close() - return nil, err - } - - onChildNodeRemoved, err = client.DOM.ChildNodeRemoved(ctx) - - if err != nil { - onLoad.Close() - onReload.Close() - onAttrModified.Close() - onAttrRemoved.Close() - onChildCountUpdated.Close() - onChildNodeInserted.Close() - return nil, err - } - - broker := NewEventBroker( - onLoad, - onReload, - onAttrModified, - onAttrRemoved, - onChildCountUpdated, - onChildNodeInserted, - onChildNodeRemoved, - ) - - err = broker.Start() - - if err != nil { - onLoad.Close() - onReload.Close() - onAttrModified.Close() - onAttrRemoved.Close() - onChildCountUpdated.Close() - onChildNodeInserted.Close() - onChildNodeRemoved.Close() - return nil, err - } - - return broker, nil +func isCtxDone(ctx context.Context) bool { + return ctx.Err() == context.Canceled } diff --git a/pkg/drivers/cdp/events/listener.go b/pkg/drivers/cdp/events/listener.go new file mode 100644 index 00000000..e7748f77 --- /dev/null +++ b/pkg/drivers/cdp/events/listener.go @@ -0,0 +1,38 @@ +package events + +import "context" + +type ( + // Handler represents a function that is called when a particular event occurs + // Returned boolean value indicates whether the handler needs to be called again + // False value indicated that it needs to be removed and never called again + Handler func(ctx context.Context, message interface{}) bool + + // ListenerID is an internal listener ID that can be used to unsubscribe from a particular event + ListenerID int + + // Listener is an internal listener representation + Listener struct { + ID ListenerID + EventID ID + Handler Handler + } +) + +// Always returns a handler wrapper that always gets executed by an event loop +func Always(fn func(ctx context.Context, message interface{})) Handler { + return func(ctx context.Context, message interface{}) bool { + fn(ctx, message) + + return true + } +} + +// Once returns a handler wrapper that gets executed only once by an event loop +func Once(fn func(ctx context.Context, message interface{})) Handler { + return func(ctx context.Context, message interface{}) bool { + fn(ctx, message) + + return false + } +} diff --git a/pkg/drivers/cdp/events/listeners.go b/pkg/drivers/cdp/events/listeners.go new file mode 100644 index 00000000..ec19995d --- /dev/null +++ b/pkg/drivers/cdp/events/listeners.go @@ -0,0 +1,74 @@ +package events + +import "sync" + +type ListenerCollection struct { + mu sync.Mutex + values map[ID]map[ListenerID]Listener +} + +func NewListenerCollection() *ListenerCollection { + lc := new(ListenerCollection) + lc.values = make(map[ID]map[ListenerID]Listener) + + return lc +} + +func (lc *ListenerCollection) Size(eventID ID) int { + lc.mu.Lock() + defer lc.mu.Unlock() + + bucket, exists := lc.values[eventID] + + if !exists { + return 0 + } + + return len(bucket) +} + +func (lc *ListenerCollection) Add(listener Listener) { + lc.mu.Lock() + defer lc.mu.Unlock() + + bucket, exists := lc.values[listener.EventID] + + if !exists { + bucket = make(map[ListenerID]Listener) + lc.values[listener.EventID] = bucket + } + + bucket[listener.ID] = listener +} + +func (lc *ListenerCollection) Remove(eventID ID, listenerID ListenerID) { + lc.mu.Lock() + defer lc.mu.Unlock() + + bucket, exists := lc.values[eventID] + + if !exists { + return + } + + delete(bucket, listenerID) +} + +func (lc *ListenerCollection) Values(eventID ID) []Listener { + lc.mu.Lock() + defer lc.mu.Unlock() + + bucket, exists := lc.values[eventID] + + if !exists { + return []Listener{} + } + + snapshot := make([]Listener, 0, len(bucket)) + + for _, listener := range bucket { + snapshot = append(snapshot, listener) + } + + return snapshot +} diff --git a/pkg/drivers/cdp/events/loop.go b/pkg/drivers/cdp/events/loop.go new file mode 100644 index 00000000..0af193c9 --- /dev/null +++ b/pkg/drivers/cdp/events/loop.go @@ -0,0 +1,165 @@ +package events + +import ( + "context" + "math/rand" + "sync" +) + +type Loop struct { + mu sync.Mutex + cancel context.CancelFunc + sources *SourceCollection + listeners *ListenerCollection +} + +func NewLoop() *Loop { + loop := new(Loop) + loop.sources = NewSourceCollection() + loop.listeners = NewListenerCollection() + + return loop +} + +func (loop *Loop) Start() *Loop { + loop.mu.Lock() + defer loop.mu.Unlock() + + if loop.cancel != nil { + return loop + } + + loopCtx, cancel := context.WithCancel(context.Background()) + + loop.cancel = cancel + + go loop.run(loopCtx) + + return loop +} + +func (loop *Loop) Stop() *Loop { + loop.mu.Lock() + defer loop.mu.Unlock() + + if loop.cancel == nil { + return loop + } + + loop.cancel() + loop.cancel = nil + + return loop +} + +func (loop *Loop) Close() error { + loop.mu.Lock() + defer loop.mu.Unlock() + + if loop.cancel != nil { + loop.cancel() + loop.cancel = nil + } + + return loop.sources.Close() +} + +func (loop *Loop) AddSource(source Source) { + loop.sources.Add(source) +} + +func (loop *Loop) RemoveSource(source Source) { + loop.sources.Remove(source) +} + +func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID { + listener := Listener{ + ID: ListenerID(rand.Int()), + EventID: eventID, + Handler: handler, + } + + loop.listeners.Add(listener) + + return listener.ID +} + +func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) { + loop.listeners.Remove(eventID, listenerID) +} + +// run starts running an event loop. +// It constantly iterates over each event source. +// Additionally to that, on each iteration it checks the command channel in order to perform add/remove listener/source operations. +func (loop *Loop) run(ctx context.Context) { + size := loop.sources.Size() + counter := -1 + + // in case event array is empty + // we use this mock noop event source to simplify the logic + noop := newNoopSource() + + for { + counter++ + + if counter >= size { + // reset the counter + size = loop.sources.Size() + counter = 0 + } + + var source Source + + if size > 0 { + found, err := loop.sources.Get(counter) + + if err == nil { + source = found + } else { + // might be removed + source = noop + // force to reset counter + counter = size + } + } else { + source = noop + } + + // commands have higher priority + select { + case <-ctx.Done(): + return + case <-source.Ready(): + if isCtxDone(ctx) { + return + } + + event, err := source.Recv() + + loop.emit(ctx, event.ID, event.Data, err) + default: + continue + } + } +} + +func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err error) { + if err != nil { + eventID = Error + message = err + } + + snapshot := loop.listeners.Values(eventID) + + for _, listener := range snapshot { + select { + case <-ctx.Done(): + return + default: + // if returned false, it means the loops should call the handler anymore + if !listener.Handler(ctx, message) { + loop.listeners.Remove(eventID, listener.ID) + } + } + } +} diff --git a/pkg/drivers/cdp/events/loop_test.go b/pkg/drivers/cdp/events/loop_test.go new file mode 100644 index 00000000..368c2801 --- /dev/null +++ b/pkg/drivers/cdp/events/loop_test.go @@ -0,0 +1,426 @@ +package events_test + +import ( + "context" + "github.com/MontFerret/ferret/pkg/drivers/cdp/events" + "github.com/mafredri/cdp/protocol/dom" + "github.com/mafredri/cdp/protocol/page" + "github.com/mafredri/cdp/rpcc" + . "github.com/smartystreets/goconvey/convey" + "sync" + "testing" + "time" +) + +type ( + TestEventStream struct { + ready chan struct{} + message chan interface{} + } + + TestLoadEventFiredClient struct { + *TestEventStream + } + + TestDocumentUpdatedClient struct { + *TestEventStream + } + + TestAttributeModifiedClient struct { + *TestEventStream + } + + TestAttributeRemovedClient struct { + *TestEventStream + } + + TestChildNodeCountUpdatedClient struct { + *TestEventStream + } + + TestChildNodeInsertedClient struct { + *TestEventStream + } + + TestChildNodeRemovedClient struct { + *TestEventStream + } +) + +var TestEvent = events.New("test_event") + +func NewTestEventStream() *TestEventStream { + es := new(TestEventStream) + es.ready = make(chan struct{}) + es.message = make(chan interface{}) + return es +} + +func (es *TestEventStream) Ready() <-chan struct{} { + return es.ready +} + +func (es *TestEventStream) RecvMsg(i interface{}) error { + // NOT IMPLEMENTED + return nil +} + +func (es *TestEventStream) Close() error { + close(es.message) + close(es.ready) + return nil +} + +func (es *TestEventStream) Emit(msg interface{}) { + es.ready <- struct{}{} + es.message <- msg +} + +func (es *TestLoadEventFiredClient) Recv() (*page.LoadEventFiredReply, error) { + r := <-es.message + reply := r.(*page.LoadEventFiredReply) + + return reply, nil +} + +func (es *TestLoadEventFiredClient) EmitDefault() { + es.TestEventStream.Emit(&page.LoadEventFiredReply{}) +} + +func (es *TestDocumentUpdatedClient) Recv() (*dom.DocumentUpdatedReply, error) { + r := <-es.message + reply := r.(*dom.DocumentUpdatedReply) + + return reply, nil +} + +func (es *TestAttributeModifiedClient) Recv() (*dom.AttributeModifiedReply, error) { + r := <-es.message + reply := r.(*dom.AttributeModifiedReply) + + return reply, nil +} + +func (es *TestAttributeRemovedClient) Recv() (*dom.AttributeRemovedReply, error) { + r := <-es.message + reply := r.(*dom.AttributeRemovedReply) + + return reply, nil +} + +func (es *TestChildNodeCountUpdatedClient) Recv() (*dom.ChildNodeCountUpdatedReply, error) { + r := <-es.message + reply := r.(*dom.ChildNodeCountUpdatedReply) + + return reply, nil +} + +func (es *TestChildNodeInsertedClient) Recv() (*dom.ChildNodeInsertedReply, error) { + r := <-es.message + reply := r.(*dom.ChildNodeInsertedReply) + + return reply, nil +} + +func (es *TestChildNodeRemovedClient) Recv() (*dom.ChildNodeRemovedReply, error) { + r := <-es.message + reply := r.(*dom.ChildNodeRemovedReply) + + return reply, nil +} + +func wait() { + time.Sleep(time.Duration(50) * time.Millisecond) +} + +type Counter struct { + mu sync.Mutex + value int64 +} + +func NewCounter() *Counter { + return new(Counter) +} + +func (c *Counter) Increase() *Counter { + c.mu.Lock() + defer c.mu.Unlock() + + c.value++ + + return c +} + +func (c *Counter) Decrease() *Counter { + c.mu.Lock() + defer c.mu.Unlock() + + c.value-- + + return c +} + +func (c *Counter) Value() int64 { + c.mu.Lock() + defer c.mu.Unlock() + + return c.value +} + +func TestLoop(t *testing.T) { + Convey(".AddListener", t, func() { + Convey("Should add a new listener", func() { + loop := events.NewLoop() + counter := NewCounter() + + onLoad := &TestLoadEventFiredClient{NewTestEventStream()} + src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) { + return onLoad.Recv() + }) + + loop.AddSource(src) + + loop.Start() + defer loop.Stop() + + onLoad.EmitDefault() + + wait() + + So(counter.Value(), ShouldEqual, 0) + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + counter.Increase() + })) + + wait() + + onLoad.EmitDefault() + + wait() + + So(counter.Value(), ShouldEqual, 1) + }) + }) + + Convey(".RemoveListener", t, func() { + Convey("Should remove a listener", func() { + Convey("Should add a new listener", func() { + loop := events.NewLoop() + counter := NewCounter() + + onLoad := &TestLoadEventFiredClient{NewTestEventStream()} + src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) { + return onLoad.Recv() + }) + + loop.AddSource(src) + id := loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + counter.Increase() + })) + + loop.Start() + defer loop.Stop() + + onLoad.EmitDefault() + + wait() + + So(counter.Value(), ShouldEqual, 1) + + wait() + + loop.RemoveListener(TestEvent, id) + + wait() + + onLoad.EmitDefault() + + wait() + + So(counter.Value(), ShouldEqual, 1) + }) + }) + }) + + Convey(".AddSource", t, func() { + Convey("Should add a new event source when not started", func() { + loop := events.NewLoop() + counter := NewCounter() + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + counter.Increase() + })) + + loop.Start() + defer loop.Stop() + + onLoad := &TestLoadEventFiredClient{NewTestEventStream()} + + go func() { + onLoad.EmitDefault() + }() + + wait() + + So(counter.Value(), ShouldEqual, 0) + + src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) { + return onLoad.Recv() + }) + + loop.AddSource(src) + + wait() + + So(counter.Value(), ShouldEqual, 1) + }) + }) + + Convey(".RemoveSource", t, func() { + Convey("Should remove a source", func() { + loop := events.NewLoop() + counter := NewCounter() + + loop.Start() + defer loop.Stop() + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + counter.Increase() + })) + + onLoad := &TestLoadEventFiredClient{NewTestEventStream()} + src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) { + return onLoad.Recv() + }) + + loop.AddSource(src) + + wait() + + onLoad.EmitDefault() + + wait() + + So(counter.Value(), ShouldEqual, 1) + + loop.RemoveSource(src) + + wait() + + go func() { + onLoad.EmitDefault() + }() + + wait() + + So(counter.Value(), ShouldEqual, 1) + }) + }) + + Convey("Should not call listener once it was removed", t, func() { + loop := events.NewLoop() + onEvent := make(chan struct{}) + + counter := NewCounter() + + id := loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + counter.Increase() + + onEvent <- struct{}{} + })) + + go func() { + <-onEvent + + loop.RemoveListener(TestEvent, id) + }() + + onLoad := &TestLoadEventFiredClient{NewTestEventStream()} + + loop.AddSource(events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) { + return onLoad.Recv() + })) + + loop.Start() + defer loop.Stop() + + time.Sleep(time.Duration(100) * time.Millisecond) + + onLoad.Emit(&page.LoadEventFiredReply{}) + + time.Sleep(time.Duration(10) * time.Millisecond) + + So(counter.Value(), ShouldEqual, 1) + }) +} + +func BenchmarkLoop_AddListenerSync(b *testing.B) { + loop := events.NewLoop() + + for n := 0; n < b.N; n++ { + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {})) + } +} + +func BenchmarkLoop_AddListenerAsync(b *testing.B) { + loop := events.NewLoop() + loop.Start() + defer loop.Stop() + + for n := 0; n < b.N; n++ { + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {})) + } +} + +func BenchmarkLoop_AddListenerAsync2(b *testing.B) { + loop := events.NewLoop() + loop.Start() + defer loop.Stop() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {})) + } + }) +} + +func BenchmarkLoop_Start(b *testing.B) { + loop := events.NewLoop() + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + + })) + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + + })) + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + + })) + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + + })) + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + + })) + + loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) { + + })) + + onLoad := &TestLoadEventFiredClient{NewTestEventStream()} + + loop.AddSource(events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) { + return onLoad.Recv() + })) + + loop.Start() + defer loop.Stop() + + for n := 0; n < b.N; n++ { + onLoad.Emit(&page.LoadEventFiredReply{}) + } +} diff --git a/pkg/drivers/cdp/events/noop.go b/pkg/drivers/cdp/events/noop.go new file mode 100644 index 00000000..40786449 --- /dev/null +++ b/pkg/drivers/cdp/events/noop.go @@ -0,0 +1,31 @@ +package events + +import ( + "github.com/MontFerret/ferret/pkg/runtime/core" +) + +type noopEvent struct { + c chan struct{} +} + +func newNoopSource() Source { + return noopEvent{ + c: make(chan struct{}), + } +} + +func (n noopEvent) Ready() <-chan struct{} { + return n.c +} + +func (n noopEvent) RecvMsg(_ interface{}) error { + return core.ErrNotSupported +} + +func (n noopEvent) Close() error { + return nil +} + +func (n noopEvent) Recv() (Event, error) { + return Event{}, core.ErrNotSupported +} diff --git a/pkg/drivers/cdp/events/source.go b/pkg/drivers/cdp/events/source.go new file mode 100644 index 00000000..9c827340 --- /dev/null +++ b/pkg/drivers/cdp/events/source.go @@ -0,0 +1,74 @@ +package events + +import ( + "github.com/mafredri/cdp/rpcc" +) + +type ( + // ID represents a unique event ID + ID int + + // Event represents a system event that is returned from an event source + Event struct { + ID ID + Data interface{} + } + + // Source represents a custom source of system events + Source interface { + rpcc.Stream + Recv() (Event, error) + } + + // GenericSource represents a helper struct for generating custom event sources + GenericSource struct { + eventID ID + stream rpcc.Stream + recv func(stream rpcc.Stream) (interface{}, error) + } +) + +var ( + Error = New("error") +) + +// NewSource create a new custom event source +// eventID - is a unique event ID +// stream - is a custom event stream +// recv - is a value conversion function +func NewSource( + eventID ID, + stream rpcc.Stream, + recv func(stream rpcc.Stream) (interface{}, error), +) Source { + return &GenericSource{eventID, stream, recv} +} + +func (src *GenericSource) EventID() ID { + return src.eventID +} + +func (src *GenericSource) Ready() <-chan struct{} { + return src.stream.Ready() +} + +func (src *GenericSource) RecvMsg(m interface{}) error { + return src.stream.RecvMsg(m) +} + +func (src *GenericSource) Close() error { + return src.stream.Close() +} + +func (src *GenericSource) Recv() (Event, error) { + data, err := src.recv(src.stream) + + if err != nil { + return Event{}, err + } + + return Event{ + ID: src.eventID, + Data: data, + }, nil +} diff --git a/pkg/drivers/cdp/events/sources.go b/pkg/drivers/cdp/events/sources.go new file mode 100644 index 00000000..b6c7f6f7 --- /dev/null +++ b/pkg/drivers/cdp/events/sources.go @@ -0,0 +1,82 @@ +package events + +import ( + "github.com/MontFerret/ferret/pkg/runtime/core" + "sync" +) + +type SourceCollection struct { + mu sync.Mutex + values []Source +} + +func NewSourceCollection() *SourceCollection { + sc := new(SourceCollection) + sc.values = make([]Source, 0, 10) + + return sc +} + +func (sc *SourceCollection) Close() error { + sc.mu.Lock() + defer sc.mu.Unlock() + + errs := make([]error, 0, len(sc.values)) + + for _, e := range sc.values { + if err := e.Close(); err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return core.Errors(errs...) + } + + return nil +} + +func (sc *SourceCollection) Size() int { + sc.mu.Lock() + defer sc.mu.Unlock() + + return len(sc.values) +} + +func (sc *SourceCollection) Get(idx int) (Source, error) { + sc.mu.Lock() + defer sc.mu.Unlock() + + if len(sc.values) <= idx { + return nil, core.ErrNotFound + } + + return sc.values[idx], nil +} + +func (sc *SourceCollection) Add(source Source) { + sc.mu.Lock() + defer sc.mu.Unlock() + + sc.values = append(sc.values, source) +} + +func (sc *SourceCollection) Remove(source Source) bool { + sc.mu.Lock() + defer sc.mu.Unlock() + + idx := -1 + + for i, current := range sc.values { + if current == source { + idx = i + break + } + } + + if idx > -1 { + sc.values = append(sc.values[:idx], sc.values[idx+1:]...) + } + + return idx > -1 +} diff --git a/pkg/drivers/cdp/helpers.go b/pkg/drivers/cdp/helpers.go index abd3405d..cde02e13 100644 --- a/pkg/drivers/cdp/helpers.go +++ b/pkg/drivers/cdp/helpers.go @@ -1,31 +1,16 @@ package cdp import ( - "bytes" "context" - "encoding/json" - "errors" - "github.com/MontFerret/ferret/pkg/drivers/cdp/templates" - "golang.org/x/net/html" - "strings" - "time" - "github.com/MontFerret/ferret/pkg/drivers" - "github.com/MontFerret/ferret/pkg/drivers/cdp/eval" "github.com/MontFerret/ferret/pkg/drivers/common" - "github.com/MontFerret/ferret/pkg/runtime/values" - - "github.com/PuerkitoBio/goquery" "github.com/mafredri/cdp" - "github.com/mafredri/cdp/protocol/dom" + "github.com/mafredri/cdp/protocol/emulation" "github.com/mafredri/cdp/protocol/network" "github.com/mafredri/cdp/protocol/page" - "github.com/mafredri/cdp/protocol/runtime" "golang.org/x/sync/errgroup" ) -var emptyExpires = time.Time{} - type ( batchFunc = func() error ) @@ -40,336 +25,87 @@ func runBatch(funcs ...batchFunc) error { return eg.Wait() } -func parseAttrs(attrs []string) *values.Object { - var attr values.String - - res := values.NewObject() - - for _, el := range attrs { - el = strings.TrimSpace(el) - str := values.NewString(el) - - if common.IsAttribute(el) { - attr = str - res.Set(str, values.EmptyString) - } else { - current, ok := res.Get(attr) - - if ok { - if current.String() != "" { - res.Set(attr, current.(values.String).Concat(values.SpaceString).Concat(str)) - } else { - res.Set(attr, str) - } - } - } - } - - return res -} - -func setInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerHTML values.String) error { - var objID *runtime.RemoteObjectID - - if id.ObjectID != "" { - objID = &id.ObjectID - } else { - repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) - - if err != nil { - return err - } - - if repl.Object.ObjectID == nil { - return errors.New("unable to resolve node") - } - - objID = repl.Object.ObjectID - } - - b, err := json.Marshal(innerHTML.String()) - - if err != nil { +func enableFeatures(ctx context.Context, client *cdp.Client, params drivers.Params) error { + if err := client.Page.Enable(ctx); err != nil { return err } - err = exec.EvalWithArguments(ctx, templates.SetInnerHTML(), - runtime.CallArgument{ - ObjectID: objID, + return runBatch( + func() error { + return client.Page.SetLifecycleEventsEnabled( + ctx, + page.NewSetLifecycleEventsEnabledArgs(true), + ) }, - runtime.CallArgument{ - Value: json.RawMessage(b), + + func() error { + return client.DOM.Enable(ctx) }, - ) - return err -} + func() error { + return client.Runtime.Enable(ctx) + }, -func getInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) { - // not a document - if nodeType != html.DocumentNode { - var objID runtime.RemoteObjectID + func() error { + ua := common.GetUserAgent(params.UserAgent) - if id.ObjectID != "" { - objID = id.ObjectID - } else { - repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) + //logger. + // Debug(). + // Timestamp(). + // Str("user-agent", ua). + // Msg("using User-Agent") - if err != nil { - return "", err + // do not use custom user agent + if ua == "" { + return nil } - if repl.Object.ObjectID == nil { - return "", errors.New("unable to resolve node") + return client.Emulation.SetUserAgentOverride( + ctx, + emulation.NewSetUserAgentOverrideArgs(ua), + ) + }, + + func() error { + return client.Network.Enable(ctx, network.NewEnableArgs()) + }, + + func() error { + return client.Page.SetBypassCSP(ctx, page.NewSetBypassCSPArgs(true)) + }, + + func() error { + if params.Viewport == nil { + return nil } - objID = *repl.Object.ObjectID - } + orientation := emulation.ScreenOrientation{} - res, err := exec.ReadProperty(ctx, objID, "innerHTML") + if !params.Viewport.Landscape { + orientation.Type = "portraitPrimary" + orientation.Angle = 0 + } else { + orientation.Type = "landscapePrimary" + orientation.Angle = 90 + } - if err != nil { - return "", err - } + scaleFactor := params.Viewport.ScaleFactor - return values.NewString(res.String()), nil - } + if scaleFactor <= 0 { + scaleFactor = 1 + } - repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerHTML") + deviceArgs := emulation.NewSetDeviceMetricsOverrideArgs( + params.Viewport.Width, + params.Viewport.Height, + scaleFactor, + params.Viewport.Mobile, + ).SetScreenOrientation(orientation) - if err != nil { - return "", err - } - - return values.NewString(repl.String()), nil -} - -func setInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerText values.String) error { - var objID *runtime.RemoteObjectID - - if id.ObjectID != "" { - objID = &id.ObjectID - } else { - repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) - - if err != nil { - return err - } - - if repl.Object.ObjectID == nil { - return errors.New("unable to resolve node") - } - - objID = repl.Object.ObjectID - } - - b, err := json.Marshal(innerText.String()) - - if err != nil { - return err - } - - err = exec.EvalWithArguments(ctx, templates.SetInnerText(), - runtime.CallArgument{ - ObjectID: objID, - }, - runtime.CallArgument{ - Value: json.RawMessage(b), + return client.Emulation.SetDeviceMetricsOverride( + ctx, + deviceArgs, + ) }, ) - - return err -} - -func getInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) { - // not a document - if nodeType != html.DocumentNode { - var objID runtime.RemoteObjectID - - if id.ObjectID != "" { - objID = id.ObjectID - } else { - repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) - - if err != nil { - return "", err - } - - if repl.Object.ObjectID == nil { - return "", errors.New("unable to resolve node") - } - - objID = *repl.Object.ObjectID - } - - res, err := exec.ReadProperty(ctx, objID, "innerText") - - if err != nil { - return "", err - } - - return values.NewString(res.String()), err - } - - repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerText") - - if err != nil { - return "", err - } - - return values.NewString(repl.String()), nil -} - -func parseInnerText(innerHTML string) (values.String, error) { - buff := bytes.NewBuffer([]byte(innerHTML)) - - parsed, err := goquery.NewDocumentFromReader(buff) - - if err != nil { - return values.EmptyString, err - } - - return values.NewString(parsed.Text()), nil -} - -func createChildrenArray(nodes []dom.Node) []HTMLElementIdentity { - children := make([]HTMLElementIdentity, len(nodes)) - - for idx, child := range nodes { - child := child - children[idx] = HTMLElementIdentity{ - NodeID: child.NodeID, - } - } - - return children -} - -func fromDriverCookie(url string, cookie drivers.HTTPCookie) network.CookieParam { - sameSite := network.CookieSameSiteNotSet - - switch cookie.SameSite { - case drivers.SameSiteLaxMode: - sameSite = network.CookieSameSiteLax - case drivers.SameSiteStrictMode: - sameSite = network.CookieSameSiteStrict - } - - if cookie.Expires == emptyExpires { - cookie.Expires = time.Now().Add(time.Duration(24) + time.Hour) - } - - normalizedURL := normalizeCookieURL(url) - - return network.CookieParam{ - URL: &normalizedURL, - Name: cookie.Name, - Value: cookie.Value, - Secure: &cookie.Secure, - Path: &cookie.Path, - Domain: &cookie.Domain, - HTTPOnly: &cookie.HTTPOnly, - SameSite: sameSite, - Expires: network.TimeSinceEpoch(cookie.Expires.Unix()), - } -} - -func fromDriverCookieDelete(url string, cookie drivers.HTTPCookie) *network.DeleteCookiesArgs { - normalizedURL := normalizeCookieURL(url) - - return &network.DeleteCookiesArgs{ - URL: &normalizedURL, - Name: cookie.Name, - Path: &cookie.Path, - Domain: &cookie.Domain, - } -} - -func toDriverCookie(c network.Cookie) drivers.HTTPCookie { - sameSite := drivers.SameSiteDefaultMode - - switch c.SameSite { - case network.CookieSameSiteLax: - sameSite = drivers.SameSiteLaxMode - case network.CookieSameSiteStrict: - sameSite = drivers.SameSiteStrictMode - } - - return drivers.HTTPCookie{ - Name: c.Name, - Value: c.Value, - Path: c.Path, - Domain: c.Domain, - Expires: time.Unix(int64(c.Expires), 0), - SameSite: sameSite, - Secure: c.Secure, - HTTPOnly: c.HTTPOnly, - } -} - -func normalizeCookieURL(url string) string { - const httpPrefix = "http://" - const httpsPrefix = "https://" - - if strings.HasPrefix(url, httpPrefix) || strings.HasPrefix(url, httpsPrefix) { - return url - } - - return httpPrefix + url -} - -func resolveFrame(ctx context.Context, client *cdp.Client, frame page.Frame) (dom.Node, runtime.ExecutionContextID, error) { - worldRepl, err := client.Page.CreateIsolatedWorld(ctx, page.NewCreateIsolatedWorldArgs(frame.ID)) - - if err != nil { - return dom.Node{}, -1, err - } - - evalRes, err := client.Runtime.Evaluate( - ctx, - runtime.NewEvaluateArgs(eval.PrepareEval("return document")). - SetContextID(worldRepl.ExecutionContextID), - ) - - if err != nil { - return dom.Node{}, -1, err - } - - if evalRes.ExceptionDetails != nil { - exception := *evalRes.ExceptionDetails - - return dom.Node{}, -1, errors.New(exception.Text) - } - - if evalRes.Result.ObjectID == nil { - return dom.Node{}, -1, errors.New("failed to resolve frame document") - } - - req, err := client.DOM.RequestNode(ctx, dom.NewRequestNodeArgs(*evalRes.Result.ObjectID)) - - if err != nil { - return dom.Node{}, -1, err - } - - if req.NodeID == 0 { - return dom.Node{}, -1, errors.New("framed document is resolved with empty node id") - } - - desc, err := client.DOM.DescribeNode( - ctx, - dom. - NewDescribeNodeArgs(). - SetNodeID(req.NodeID). - SetDepth(1), - ) - - if err != nil { - return dom.Node{}, -1, err - } - - // Returned node, by some reason, does not contain the NodeID - // So, we have to set it manually - desc.Node.NodeID = req.NodeID - - return desc.Node, worldRepl.ExecutionContextID, nil } diff --git a/pkg/drivers/cdp/network/events.go b/pkg/drivers/cdp/network/events.go new file mode 100644 index 00000000..f3afd524 --- /dev/null +++ b/pkg/drivers/cdp/network/events.go @@ -0,0 +1,7 @@ +package network + +import "github.com/MontFerret/ferret/pkg/drivers/cdp/events" + +var ( + eventFrameLoad = events.New("frame_load") +) diff --git a/pkg/drivers/cdp/network/helpers.go b/pkg/drivers/cdp/network/helpers.go new file mode 100644 index 00000000..7f59a810 --- /dev/null +++ b/pkg/drivers/cdp/network/helpers.go @@ -0,0 +1,83 @@ +package network + +import ( + "github.com/MontFerret/ferret/pkg/drivers" + "github.com/mafredri/cdp/protocol/network" + "strings" + "time" +) + +var emptyExpires = time.Time{} + +func fromDriverCookie(url string, cookie drivers.HTTPCookie) network.CookieParam { + sameSite := network.CookieSameSiteNotSet + + switch cookie.SameSite { + case drivers.SameSiteLaxMode: + sameSite = network.CookieSameSiteLax + case drivers.SameSiteStrictMode: + sameSite = network.CookieSameSiteStrict + } + + if cookie.Expires == emptyExpires { + cookie.Expires = time.Now().Add(time.Duration(24) + time.Hour) + } + + normalizedURL := normalizeCookieURL(url) + + return network.CookieParam{ + URL: &normalizedURL, + Name: cookie.Name, + Value: cookie.Value, + Secure: &cookie.Secure, + Path: &cookie.Path, + Domain: &cookie.Domain, + HTTPOnly: &cookie.HTTPOnly, + SameSite: sameSite, + Expires: network.TimeSinceEpoch(cookie.Expires.Unix()), + } +} + +func fromDriverCookieDelete(url string, cookie drivers.HTTPCookie) *network.DeleteCookiesArgs { + normalizedURL := normalizeCookieURL(url) + + return &network.DeleteCookiesArgs{ + URL: &normalizedURL, + Name: cookie.Name, + Path: &cookie.Path, + Domain: &cookie.Domain, + } +} + +func toDriverCookie(c network.Cookie) drivers.HTTPCookie { + sameSite := drivers.SameSiteDefaultMode + + switch c.SameSite { + case network.CookieSameSiteLax: + sameSite = drivers.SameSiteLaxMode + case network.CookieSameSiteStrict: + sameSite = drivers.SameSiteStrictMode + } + + return drivers.HTTPCookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: time.Unix(int64(c.Expires), 0), + SameSite: sameSite, + Secure: c.Secure, + HTTPOnly: c.HTTPOnly, + } +} + +func normalizeCookieURL(url string) string { + const httpPrefix = "http://" + const httpsPrefix = "https://" + + if strings.HasPrefix(url, httpPrefix) || strings.HasPrefix(url, httpsPrefix) { + return url + } + + return httpPrefix + url +} diff --git a/pkg/drivers/cdp/network/manager.go b/pkg/drivers/cdp/network/manager.go new file mode 100644 index 00000000..3176bb32 --- /dev/null +++ b/pkg/drivers/cdp/network/manager.go @@ -0,0 +1,340 @@ +package network + +import ( + "context" + "encoding/json" + "regexp" + "sync" + + "github.com/mafredri/cdp" + "github.com/mafredri/cdp/protocol/network" + "github.com/mafredri/cdp/protocol/page" + "github.com/mafredri/cdp/rpcc" + "github.com/pkg/errors" + "github.com/rs/zerolog" + + "github.com/MontFerret/ferret/pkg/drivers" + "github.com/MontFerret/ferret/pkg/drivers/cdp/events" + "github.com/MontFerret/ferret/pkg/runtime/core" + "github.com/MontFerret/ferret/pkg/runtime/values" +) + +const BlankPageURL = "about:blank" + +type ( + FrameLoadedListener = func(ctx context.Context, frame page.Frame) + + Manager struct { + mu sync.Mutex + logger *zerolog.Logger + client *cdp.Client + headers drivers.HTTPHeaders + eventLoop *events.Loop + cancel context.CancelFunc + listeners []FrameLoadedListener + } +) + +func New( + logger *zerolog.Logger, + client *cdp.Client, + eventLoop *events.Loop, +) (*Manager, error) { + ctx, cancel := context.WithCancel(context.Background()) + + m := new(Manager) + m.logger = logger + m.client = client + m.headers = make(drivers.HTTPHeaders) + m.eventLoop = eventLoop + m.cancel = cancel + + frameNavigatedStream, err := m.client.Page.FrameNavigated(ctx) + + if err != nil { + return nil, err + } + + m.eventLoop.AddSource(events.NewSource(eventFrameLoad, frameNavigatedStream, func(stream rpcc.Stream) (interface{}, error) { + return stream.(page.FrameNavigatedClient).Recv() + })) + + return m, nil +} + +func (m *Manager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + m.cancel() + m.cancel = nil + } + + return nil +} + +func (m *Manager) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) { + repl, err := m.client.Network.GetAllCookies(ctx) + + if err != nil { + return nil, errors.Wrap(err, "failed to get cookies") + } + + cookies := make(drivers.HTTPCookies) + + if repl.Cookies == nil { + return cookies, nil + } + + for _, c := range repl.Cookies { + cookies[c.Name] = toDriverCookie(c) + } + + return cookies, nil +} + +func (m *Manager) SetCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error { + m.mu.Lock() + defer m.mu.Unlock() + + if len(cookies) == 0 { + return nil + } + + params := make([]network.CookieParam, 0, len(cookies)) + + for _, c := range cookies { + params = append(params, fromDriverCookie(url, c)) + } + + return m.client.Network.SetCookies(ctx, network.NewSetCookiesArgs(params)) +} + +func (m *Manager) DeleteCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error { + m.mu.Lock() + defer m.mu.Unlock() + + if len(cookies) == 0 { + return nil + } + + var err error + + for _, c := range cookies { + err = m.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(url, c)) + + if err != nil { + break + } + } + + return err +} + +func (m *Manager) GetHeaders(_ context.Context) (drivers.HTTPHeaders, error) { + copied := make(drivers.HTTPHeaders) + + for k, v := range m.headers { + copied[k] = v + } + + return copied, nil +} + +func (m *Manager) SetHeaders(ctx context.Context, headers drivers.HTTPHeaders) error { + m.mu.Lock() + defer m.mu.Unlock() + + if len(headers) == 0 { + return nil + } + + m.headers = headers + + j, err := json.Marshal(headers) + + if err != nil { + return errors.Wrap(err, "failed to marshal headers") + } + + err = m.client.Network.SetExtraHTTPHeaders( + ctx, + network.NewSetExtraHTTPHeadersArgs(j), + ) + + if err != nil { + return errors.Wrap(err, "failed to set headers") + } + + return nil +} + +func (m *Manager) Navigate(ctx context.Context, url values.String) error { + m.mu.Lock() + defer m.mu.Unlock() + + if url == "" { + url = BlankPageURL + } + + urlStr := url.String() + + repl, err := m.client.Page.Navigate(ctx, page.NewNavigateArgs(urlStr)) + + if err != nil { + return err + } + + if repl.ErrorText != nil { + return errors.New(*repl.ErrorText) + } + + return m.WaitForNavigation(ctx, nil) +} + +func (m *Manager) NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error) { + m.mu.Lock() + defer m.mu.Unlock() + + history, err := m.client.Page.GetNavigationHistory(ctx) + + if err != nil { + return values.False, err + } + + length := len(history.Entries) + lastIndex := length - 1 + + // nowhere to go forward + if history.CurrentIndex == lastIndex { + return values.False, nil + } + + if skip < 1 { + skip = 1 + } + + to := int(skip) + history.CurrentIndex + + if to > lastIndex { + // TODO: Return error? + return values.False, nil + } + + entry := history.Entries[to] + err = m.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(entry.ID)) + + if err != nil { + return values.False, err + } + + err = m.WaitForNavigation(ctx, nil) + + if err != nil { + return values.False, err + } + + return values.True, nil +} + +func (m *Manager) NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error) { + m.mu.Lock() + defer m.mu.Unlock() + + history, err := m.client.Page.GetNavigationHistory(ctx) + + if err != nil { + return values.False, err + } + + // we are in the beginning + if history.CurrentIndex == 0 { + return values.False, nil + } + + if skip < 1 { + skip = 1 + } + + to := history.CurrentIndex - int(skip) + + if to < 0 { + // TODO: Return error? + return values.False, nil + } + + entry := history.Entries[to] + err = m.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(entry.ID)) + + if err != nil { + return values.False, err + } + + err = m.WaitForNavigation(ctx, nil) + + if err != nil { + return values.False, err + } + + return values.True, nil +} + +func (m *Manager) WaitForNavigation(ctx context.Context, pattern *regexp.Regexp) error { + return m.WaitForFrameNavigation(ctx, "", pattern) +} + +func (m *Manager) WaitForFrameNavigation(ctx context.Context, frameID page.FrameID, urlPattern *regexp.Regexp) error { + onEvent := make(chan struct{}) + + defer func() { + close(onEvent) + }() + + m.eventLoop.AddListener(eventFrameLoad, func(_ context.Context, message interface{}) bool { + repl := message.(*page.FrameNavigatedReply) + + var matched bool + + // if frameID is empty string or equals to the current one + if len(frameID) == 0 || repl.Frame.ID == frameID { + // if a URL pattern is provided + if urlPattern != nil { + matched = urlPattern.Match([]byte(repl.Frame.URL)) + } else { + // otherwise just notify + matched = true + } + } + + if matched { + if ctx.Err() == nil { + onEvent <- struct{}{} + } + } + + // if not matched - continue listening + return !matched + }) + + select { + case <-onEvent: + return nil + case <-ctx.Done(): + return core.ErrTimeout + } +} + +func (m *Manager) AddFrameLoadedListener(listener FrameLoadedListener) events.ListenerID { + return m.eventLoop.AddListener(eventFrameLoad, func(ctx context.Context, message interface{}) bool { + repl := message.(*page.FrameNavigatedReply) + + listener(ctx, repl.Frame) + + return true + }) +} + +func (m *Manager) RemoveFrameLoadedListener(id events.ListenerID) { + m.eventLoop.RemoveListener(eventFrameLoad, id) +} diff --git a/pkg/drivers/cdp/page.go b/pkg/drivers/cdp/page.go index 20d0fe71..b351c0a7 100644 --- a/pkg/drivers/cdp/page.go +++ b/pkg/drivers/cdp/page.go @@ -2,21 +2,22 @@ package cdp import ( "context" - "encoding/json" + "github.com/MontFerret/ferret/pkg/drivers/cdp/dom" + "github.com/pkg/errors" "hash/fnv" + "io" + "regexp" "sync" "github.com/mafredri/cdp" - "github.com/mafredri/cdp/protocol/emulation" - "github.com/mafredri/cdp/protocol/network" "github.com/mafredri/cdp/protocol/page" "github.com/mafredri/cdp/rpcc" - "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/MontFerret/ferret/pkg/drivers" "github.com/MontFerret/ferret/pkg/drivers/cdp/events" "github.com/MontFerret/ferret/pkg/drivers/cdp/input" + net "github.com/MontFerret/ferret/pkg/drivers/cdp/network" "github.com/MontFerret/ferret/pkg/drivers/common" "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/logging" @@ -29,26 +30,18 @@ type HTMLPage struct { logger *zerolog.Logger conn *rpcc.Conn client *cdp.Client - events *events.EventBroker + events *events.Loop + network *net.Manager + dom *dom.Manager mouse *input.Mouse keyboard *input.Keyboard - document *common.AtomicValue - frames *common.LazyValue -} - -func handleLoadError(logger *zerolog.Logger, client *cdp.Client) { - err := client.Page.Close(context.Background()) - - if err != nil { - logger.Warn().Timestamp().Err(err).Msg("failed to close document on load error") - } } func LoadHTMLPage( ctx context.Context, conn *rpcc.Conn, params drivers.Params, -) (*HTMLPage, error) { +) (p *HTMLPage, err error) { logger := logging.FromContext(ctx) if conn == nil { @@ -57,217 +50,104 @@ func LoadHTMLPage( client := cdp.NewClient(conn) - if err := client.Page.Enable(ctx); err != nil { + if err := enableFeatures(ctx, client, params); err != nil { return nil, err } - err := runBatch( - func() error { - return client.Page.SetLifecycleEventsEnabled( - ctx, - page.NewSetLifecycleEventsEnabledArgs(true), - ) - }, + closers := make([]io.Closer, 0, 2) - func() error { - return client.DOM.Enable(ctx) - }, + defer func() { + if err != nil { + common.CloseAll(logger, closers, "failed to close a Page resource") + } + }() - func() error { - return client.Runtime.Enable(ctx) - }, + eventLoop := events.NewLoop() + closers = append(closers, eventLoop) - func() error { - ua := common.GetUserAgent(params.UserAgent) + netManager, err := net.New(logger, client, eventLoop) - logger. - Debug(). - Timestamp(). - Str("user-agent", ua). - Msg("using User-Agent") + if err != nil { + return nil, err + } - // do not use custom user agent - if ua == "" { - return nil - } + err = netManager.SetCookies(ctx, params.URL, params.Cookies) - return client.Emulation.SetUserAgentOverride( - ctx, - emulation.NewSetUserAgentOverrideArgs(ua), - ) - }, + if err != nil { + return nil, err + } - func() error { - return client.Network.Enable(ctx, network.NewEnableArgs()) - }, + err = netManager.SetHeaders(ctx, params.Headers) - func() error { - return client.Page.SetBypassCSP(ctx, page.NewSetBypassCSPArgs(true)) - }, + if err != nil { + return nil, err + } - func() error { - if params.Viewport == nil { - return nil - } + eventLoop.Start() - orientation := emulation.ScreenOrientation{} + mouse := input.NewMouse(client) + keyboard := input.NewKeyboard(client) - if !params.Viewport.Landscape { - orientation.Type = "portraitPrimary" - orientation.Angle = 0 - } else { - orientation.Type = "landscapePrimary" - orientation.Angle = 90 - } - - scaleFactor := params.Viewport.ScaleFactor - - if scaleFactor <= 0 { - scaleFactor = 1 - } - - deviceArgs := emulation.NewSetDeviceMetricsOverrideArgs( - params.Viewport.Width, - params.Viewport.Height, - scaleFactor, - params.Viewport.Mobile, - ).SetScreenOrientation(orientation) - - return client.Emulation.SetDeviceMetricsOverride( - ctx, - deviceArgs, - ) - }, + domManager, err := dom.New( + logger, + client, + eventLoop, + mouse, + keyboard, ) if err != nil { return nil, err } - if len(params.Cookies) > 0 { - cookies := make([]network.CookieParam, 0, len(params.Cookies)) + closers = append(closers, domManager) - for _, c := range params.Cookies { - cookies = append(cookies, fromDriverCookie(params.URL, c)) - - logger. - Debug(). - Timestamp(). - Str("cookie", c.Name). - Msg("set cookie") - } - - err = client.Network.SetCookies( - ctx, - network.NewSetCookiesArgs(cookies), - ) - - if err != nil { - return nil, errors.Wrap(err, "failed to set cookies") - } - } - - if len(params.Headers) > 0 { - j, err := json.Marshal(params.Headers) - - if err != nil { - return nil, err - } - - for k := range params.Headers { - logger. - Debug(). - Timestamp(). - Str("header", k). - Msg("set header") - } - - err = client.Network.SetExtraHTTPHeaders( - ctx, - network.NewSetExtraHTTPHeadersArgs(network.Headers(j)), - ) - - if err != nil { - return nil, errors.Wrap(err, "failed to set headers") - } - } - - if params.URL != BlankPageURL && params.URL != "" { - repl, err := client.Page.Navigate(ctx, page.NewNavigateArgs(params.URL)) - - if err != nil { - handleLoadError(logger, client) - - return nil, errors.Wrap(err, "failed to load the page") - } - - if repl.ErrorText != nil { - handleLoadError(logger, client) - - return nil, errors.Wrapf(errors.New(*repl.ErrorText), "failed to load the page: %s", params.URL) - } - - err = events.WaitForLoadEvent(ctx, client) - - if err != nil { - handleLoadError(logger, client) - - return nil, errors.Wrap(err, "failed to load the page") - } - } - - broker, err := events.CreateEventBroker(client) - - if err != nil { - handleLoadError(logger, client) - return nil, errors.Wrap(err, "failed to create event events") - } - - mouse := input.NewMouse(client) - keyboard := input.NewKeyboard(client) - - doc, err := LoadRootHTMLDocument(ctx, logger, client, broker, mouse, keyboard) - - if err != nil { - broker.StopAndClose() - handleLoadError(logger, client) - - return nil, errors.Wrap(err, "failed to load root element") - } - - return NewHTMLPage( + p = NewHTMLPage( logger, conn, client, - broker, + eventLoop, + netManager, + domManager, mouse, keyboard, - doc, - ), nil + ) + + if params.URL != BlankPageURL && params.URL != "" { + err = p.Navigate(ctx, values.NewString(params.URL)) + } else { + err = p.loadMainFrame(ctx) + } + + if err != nil { + return p, err + } + + return p, nil } func NewHTMLPage( logger *zerolog.Logger, conn *rpcc.Conn, client *cdp.Client, - broker *events.EventBroker, + eventLoop *events.Loop, + netManager *net.Manager, + domManager *dom.Manager, mouse *input.Mouse, keyboard *input.Keyboard, - document *HTMLDocument, ) *HTMLPage { p := new(HTMLPage) p.closed = values.False p.logger = logger p.conn = conn p.client = client - p.events = broker + p.events = eventLoop + p.network = netManager + p.dom = domManager p.mouse = mouse p.keyboard = keyboard - p.document = common.NewAtomicValue(document) - p.frames = common.NewLazyValue(p.unfoldFrames) - broker.AddEventListener(events.EventLoad, p.handlePageLoad) - broker.AddEventListener(events.EventError, p.handleError) + eventLoop.AddListener(events.Error, events.Always(p.handleError)) return p } @@ -343,36 +223,36 @@ func (p *HTMLPage) Close() error { defer p.mu.Unlock() p.closed = values.True - err := p.events.Stop() doc := p.getCurrentDocument() + err := p.events.Stop().Close() if err != nil { p.logger.Warn(). Timestamp(). Str("url", doc.GetURL().String()). Err(err). - Msg("failed to stop event events") + Msg("failed to stop event loop") } - err = p.events.Close() + err = p.dom.Close() if err != nil { p.logger.Warn(). Timestamp(). Str("url", doc.GetURL().String()). Err(err). - Msg("failed to close event events") + Msg("failed to close dom manager") } - err = doc.Close() + err = p.network.Close() if err != nil { p.logger.Warn(). Timestamp(). Str("url", doc.GetURL().String()). Err(err). - Msg("failed to close root document") + Msg("failed to close network manager") } err = p.client.Page.Close(context.Background()) @@ -404,87 +284,48 @@ func (p *HTMLPage) GetMainFrame() drivers.HTMLDocument { } func (p *HTMLPage) GetFrames(ctx context.Context) (*values.Array, error) { - res, err := p.frames.Read(ctx) + p.mu.Lock() + defer p.mu.Unlock() - if err != nil { - return nil, err - } - - return res.(*values.Array).Clone().(*values.Array), nil + return p.dom.GetFrameNodes(ctx) } func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, error) { p.mu.Lock() defer p.mu.Unlock() - res, err := p.frames.Read(ctx) + frames, err := p.dom.GetFrameNodes(ctx) if err != nil { - return nil, err + return values.None, err } - return res.(*values.Array).Get(idx), nil + return frames.Get(idx), nil } -func (p *HTMLPage) GetCookies(ctx context.Context) (*values.Array, error) { +func (p *HTMLPage) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) { p.mu.Lock() defer p.mu.Unlock() - repl, err := p.client.Network.GetAllCookies(ctx) - - if err != nil { - return values.NewArray(0), err - } - - if repl.Cookies == nil { - return values.NewArray(0), nil - } - - cookies := values.NewArray(len(repl.Cookies)) - - for _, c := range repl.Cookies { - cookies.Push(toDriverCookie(c)) - } - - return cookies, nil + return p.network.GetCookies(ctx) } -func (p *HTMLPage) SetCookies(ctx context.Context, cookies ...drivers.HTTPCookie) error { +func (p *HTMLPage) SetCookies(ctx context.Context, cookies drivers.HTTPCookies) error { p.mu.Lock() defer p.mu.Unlock() - if len(cookies) == 0 { - return nil - } - - params := make([]network.CookieParam, 0, len(cookies)) - - for _, c := range cookies { - params = append(params, fromDriverCookie(p.getCurrentDocument().GetURL().String(), c)) - } - - return p.client.Network.SetCookies(ctx, network.NewSetCookiesArgs(params)) + return p.network.SetCookies(ctx, p.getCurrentDocument().GetURL().String(), cookies) } -func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies ...drivers.HTTPCookie) error { +func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies drivers.HTTPCookies) error { p.mu.Lock() defer p.mu.Unlock() - if len(cookies) == 0 { - return nil - } + return p.network.DeleteCookies(ctx, p.getCurrentDocument().GetURL().String(), cookies) +} - var err error - - for _, c := range cookies { - err = p.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(p.getCurrentDocument().GetURL().String(), c)) - - if err != nil { - break - } - } - - return err +func (p *HTMLPage) GetResponse(_ context.Context) (*drivers.HTTPResponse, error) { + return nil, core.ErrNotSupported } func (p *HTMLPage) PrintToPDF(ctx context.Context, params drivers.PDFParams) (values.Binary, error) { @@ -607,162 +448,107 @@ func (p *HTMLPage) Navigate(ctx context.Context, url values.String) error { p.mu.Lock() defer p.mu.Unlock() - if url == "" { - url = BlankPageURL - } - - repl, err := p.client.Page.Navigate(ctx, page.NewNavigateArgs(url.String())) - - if err != nil { + if err := p.network.Navigate(ctx, url); err != nil { return err } - if repl.ErrorText != nil { - return errors.New(*repl.ErrorText) - } - - return p.WaitForNavigation(ctx) + return p.reloadMainFrame(ctx) } func (p *HTMLPage) NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error) { p.mu.Lock() defer p.mu.Unlock() - history, err := p.client.Page.GetNavigationHistory(ctx) + ret, err := p.network.NavigateBack(ctx, skip) if err != nil { return values.False, err } - // we are in the beginning - if history.CurrentIndex == 0 { - return values.False, nil - } - - if skip < 1 { - skip = 1 - } - - to := history.CurrentIndex - int(skip) - - if to < 0 { - // TODO: Return error? - return values.False, nil - } - - prev := history.Entries[to] - err = p.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(prev.ID)) - - if err != nil { - return values.False, err - } - - err = p.WaitForNavigation(ctx) - - if err != nil { - return values.False, err - } - - return values.True, nil + return ret, p.reloadMainFrame(ctx) } func (p *HTMLPage) NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error) { p.mu.Lock() defer p.mu.Unlock() - history, err := p.client.Page.GetNavigationHistory(ctx) + ret, err := p.network.NavigateForward(ctx, skip) if err != nil { return values.False, err } - length := len(history.Entries) - lastIndex := length - 1 - - // nowhere to go forward - if history.CurrentIndex == lastIndex { - return values.False, nil - } - - if skip < 1 { - skip = 1 - } - - to := int(skip) + history.CurrentIndex - - if to > lastIndex { - // TODO: Return error? - return values.False, nil - } - - next := history.Entries[to] - err = p.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(next.ID)) - - if err != nil { - return values.False, err - } - - err = p.WaitForNavigation(ctx) - - if err != nil { - return values.False, err - } - - return values.True, nil + return ret, p.reloadMainFrame(ctx) } -func (p *HTMLPage) WaitForNavigation(ctx context.Context) error { - onEvent := make(chan struct{}) - var once sync.Once - listener := func(_ context.Context, _ interface{}) { - once.Do(func() { - close(onEvent) - }) +func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.String) error { + var pattern *regexp.Regexp + + if targetURL != "" { + r, err := regexp.Compile(targetURL.String()) + + if err != nil { + return errors.Wrap(err, "invalid URL pattern") + } + + pattern = r } - defer p.events.RemoveEventListener(events.EventLoad, listener) - - p.events.AddEventListener(events.EventLoad, listener) - - select { - case <-onEvent: - return nil - case <-ctx.Done(): - return core.ErrTimeout + if err := p.network.WaitForNavigation(ctx, pattern); err != nil { + return err } + + return p.reloadMainFrame(ctx) } -func (p *HTMLPage) handlePageLoad(ctx context.Context, _ interface{}) { - err := p.document.Write(func(current core.Value) (core.Value, error) { - nextDoc, err := LoadRootHTMLDocument(ctx, p.logger, p.client, p.events, p.mouse, p.keyboard) +func (p *HTMLPage) reloadMainFrame(ctx context.Context) error { + if err := p.dom.WaitForDOMReady(ctx); err != nil { + return err + } - if err != nil { - return values.None, err - } + prev := p.dom.GetMainFrame() - // close the prev document - currentDoc := current.(*HTMLDocument) - err = currentDoc.Close() - - if err != nil { - p.logger.Warn(). - Timestamp(). - Err(err). - Msgf("failed to close root document: %s", currentDoc.GetURL()) - } - - // reset all loaded frames - p.frames.Reset() - - return nextDoc, nil - }) + next, err := dom.LoadRootHTMLDocument( + ctx, + p.logger, + p.client, + p.dom, + p.mouse, + p.keyboard, + ) if err != nil { - p.logger.Warn(). - Timestamp(). - Err(err). - Msg("failed to load new root document after page load") + return err } + + if prev != nil { + if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil { + p.logger.Error().Err(err).Msg("failed to remove main frame") + } + } + + p.dom.SetMainFrame(next) + + return nil +} + +func (p *HTMLPage) loadMainFrame(ctx context.Context) error { + next, err := dom.LoadRootHTMLDocument( + ctx, + p.logger, + p.client, + p.dom, + p.mouse, + p.keyboard, + ) + + if err != nil { + return err + } + + p.dom.SetMainFrame(next) + + return nil } func (p *HTMLPage) handleError(_ context.Context, val interface{}) { @@ -778,22 +564,6 @@ func (p *HTMLPage) handleError(_ context.Context, val interface{}) { Msg("unexpected error") } -func (p *HTMLPage) getCurrentDocument() *HTMLDocument { - return p.document.Read().(*HTMLDocument) -} - -func (p *HTMLPage) unfoldFrames(ctx context.Context) (core.Value, error) { - res := values.NewArray(10) - - err := common.CollectFrames(ctx, res, p.getCurrentDocument()) - - if err != nil { - return nil, err - } - - return res, nil -} - -func (p *HTMLPage) GetResponse(_ context.Context) (*drivers.HTTPResponse, error) { - return nil, core.ErrNotSupported +func (p *HTMLPage) getCurrentDocument() *dom.HTMLDocument { + return p.dom.GetMainFrame() } diff --git a/pkg/drivers/common/errors.go b/pkg/drivers/common/errors.go index 663ee113..e6ad741e 100644 --- a/pkg/drivers/common/errors.go +++ b/pkg/drivers/common/errors.go @@ -1,8 +1,20 @@ package common -import "github.com/MontFerret/ferret/pkg/runtime/core" +import ( + "github.com/MontFerret/ferret/pkg/runtime/core" + "github.com/rs/zerolog" + "io" +) var ( ErrReadOnly = core.Error(core.ErrInvalidOperation, "read only") ErrInvalidPath = core.Error(core.ErrInvalidOperation, "invalid path") ) + +func CloseAll(logger *zerolog.Logger, closers []io.Closer, msg string) { + for _, closer := range closers { + if err := closer.Close(); err != nil { + logger.Error().Err(err).Msg(msg) + } + } +} diff --git a/pkg/drivers/common/getter.go b/pkg/drivers/common/getter.go index 2d136802..5a0dfe3e 100644 --- a/pkg/drivers/common/getter.go +++ b/pkg/drivers/common/getter.go @@ -63,22 +63,17 @@ func GetInPage(ctx context.Context, page drivers.HTMLPage, path []core.Value) (c case "url", "URL": return page.GetMainFrame().GetURL(), nil case "cookies": + cookies, err := page.GetCookies(ctx) + + if err != nil { + return values.None, err + } + if len(path) == 1 { - return page.GetCookies(ctx) + return cookies, nil } - switch idx := path[1].(type) { - case values.Int: - cookies, err := page.GetCookies(ctx) - - if err != nil { - return values.None, err - } - - return cookies.Get(idx), nil - default: - return values.None, core.TypeError(idx.Type(), types.Int) - } + return cookies.GetIn(ctx, path[1:]) case "isClosed": return page.IsClosed(), nil case "title": @@ -109,7 +104,11 @@ func GetInDocument(ctx context.Context, doc drivers.HTMLDocument, path []core.Va case "title": return doc.GetTitle(), nil case "parent": - parent := doc.GetParentDocument() + parent, err := doc.GetParentDocument(ctx) + + if err != nil { + return values.None, err + } if parent == nil { return values.None, nil diff --git a/pkg/drivers/cookie.go b/pkg/drivers/cookie.go index 7a65bac6..c2e642d3 100644 --- a/pkg/drivers/cookie.go +++ b/pkg/drivers/cookie.go @@ -30,8 +30,6 @@ type ( HTTPOnly bool SameSite SameSite } - - HTTPCookies map[string]HTTPCookie ) const ( diff --git a/pkg/drivers/cookies.go b/pkg/drivers/cookies.go new file mode 100644 index 00000000..2251d1d8 --- /dev/null +++ b/pkg/drivers/cookies.go @@ -0,0 +1,183 @@ +package drivers + +import ( + "context" + "encoding/binary" + "encoding/json" + "github.com/MontFerret/ferret/pkg/runtime/values" + "github.com/MontFerret/ferret/pkg/runtime/values/types" + "hash/fnv" + "sort" + + "github.com/MontFerret/ferret/pkg/runtime/core" +) + +type HTTPCookies map[string]HTTPCookie + +func NewHTTPCookies() HTTPCookies { + return make(HTTPCookies) +} + +func (c HTTPCookies) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]HTTPCookie(c)) +} + +func (c HTTPCookies) Type() core.Type { + return HTTPCookiesType +} + +func (c HTTPCookies) String() string { + j, err := c.MarshalJSON() + + if err != nil { + return "{}" + } + + return string(j) +} + +func (c HTTPCookies) Compare(other core.Value) int64 { + if other.Type() != HTTPCookiesType { + return Compare(HTTPCookiesType, other.Type()) + } + + oc := other.(HTTPCookies) + + switch { + case len(c) > len(oc): + return 1 + case len(c) < len(oc): + return -1 + } + + for name := range c { + cEl, cExists := c.Get(values.NewString(name)) + + if !cExists { + return -1 + } + + ocEl, ocExists := oc.Get(values.NewString(name)) + + if !ocExists { + return 1 + } + + c := cEl.Compare(ocEl) + + if c != 0 { + return c + } + } + + return 0 +} + +func (c HTTPCookies) Unwrap() interface{} { + return map[string]HTTPCookie(c) +} + +func (c HTTPCookies) Hash() uint64 { + hash := fnv.New64a() + + hash.Write([]byte(c.Type().String())) + hash.Write([]byte(":")) + hash.Write([]byte("{")) + + keys := make([]string, 0, len(c)) + + for key := range c { + keys = append(keys, key) + } + + // order does not really matter + // but it will give us a consistent hash sum + sort.Strings(keys) + endIndex := len(keys) - 1 + + for idx, key := range keys { + hash.Write([]byte(key)) + hash.Write([]byte(":")) + + el := c[key] + + bytes := make([]byte, 8) + binary.LittleEndian.PutUint64(bytes, el.Hash()) + + hash.Write(bytes) + + if idx != endIndex { + hash.Write([]byte(",")) + } + } + + hash.Write([]byte("}")) + + return hash.Sum64() +} + +func (c HTTPCookies) Copy() core.Value { + copied := make(HTTPCookies) + + for k, v := range c { + copied[k] = v + } + + return copied +} + +func (c HTTPCookies) Length() values.Int { + return values.NewInt(len(c)) +} + +func (c HTTPCookies) Keys() []values.String { + keys := make([]values.String, 0, len(c)) + + for k := range c { + keys = append(keys, values.NewString(k)) + } + + return keys +} + +func (c HTTPCookies) Get(key values.String) (core.Value, values.Boolean) { + value, found := c[key.String()] + + if found { + return value, values.True + } + + return values.None, values.False +} + +func (c HTTPCookies) Set(key values.String, value core.Value) { + if cookie, ok := value.(HTTPCookie); ok { + c[key.String()] = cookie + } +} + +func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { + if len(path) == 0 { + return values.None, nil + } + + segment := path[0] + + err := core.ValidateType(segment, types.String) + + if err != nil { + return values.None, err + } + + cookie, found := c[segment.String()] + + if found { + if len(path) == 1 { + return cookie, nil + } + + return values.GetIn(ctx, cookie, path[1:]) + } + + return values.None, nil +} diff --git a/pkg/drivers/http/document.go b/pkg/drivers/http/document.go index beb2ad37..1a98dd38 100644 --- a/pkg/drivers/http/document.go +++ b/pkg/drivers/http/document.go @@ -203,8 +203,8 @@ func (doc *HTMLDocument) GetName() values.String { return "" } -func (doc *HTMLDocument) GetParentDocument() drivers.HTMLDocument { - return doc.parent +func (doc *HTMLDocument) GetParentDocument(_ context.Context) (drivers.HTMLDocument, error) { + return doc.parent, nil } func (doc *HTMLDocument) ScrollTop(_ context.Context) error { diff --git a/pkg/drivers/http/page.go b/pkg/drivers/http/page.go index 58806b30..ba92a116 100644 --- a/pkg/drivers/http/page.go +++ b/pkg/drivers/http/page.go @@ -168,29 +168,25 @@ func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, er return p.frames.Get(idx), nil } -func (p *HTMLPage) GetCookies(_ context.Context) (*values.Array, error) { - if p.cookies == nil { - return values.NewArray(0), nil +func (p *HTMLPage) GetCookies(_ context.Context) (drivers.HTTPCookies, error) { + res := make(drivers.HTTPCookies) + + for n, v := range p.cookies { + res[n] = v } - arr := values.NewArray(len(p.cookies)) - - for _, c := range p.cookies { - arr.Push(c) - } - - return arr, nil + return res, nil } func (p *HTMLPage) GetResponse(_ context.Context) (*drivers.HTTPResponse, error) { return p.response, nil } -func (p *HTMLPage) SetCookies(_ context.Context, _ ...drivers.HTTPCookie) error { +func (p *HTMLPage) SetCookies(_ context.Context, _ drivers.HTTPCookies) error { return core.ErrNotSupported } -func (p *HTMLPage) DeleteCookies(_ context.Context, _ ...drivers.HTTPCookie) error { +func (p *HTMLPage) DeleteCookies(_ context.Context, _ drivers.HTTPCookies) error { return core.ErrNotSupported } @@ -202,7 +198,7 @@ func (p *HTMLPage) CaptureScreenshot(_ context.Context, _ drivers.ScreenshotPara return nil, core.ErrNotSupported } -func (p *HTMLPage) WaitForNavigation(_ context.Context) error { +func (p *HTMLPage) WaitForNavigation(_ context.Context, _ values.String) error { return core.ErrNotSupported } diff --git a/pkg/drivers/type.go b/pkg/drivers/type.go index aa282d95..d2cdcdae 100644 --- a/pkg/drivers/type.go +++ b/pkg/drivers/type.go @@ -6,6 +6,7 @@ var ( HTTPResponseType = core.NewType("HTTPResponse") HTTPHeaderType = core.NewType("HTTPHeaders") HTTPCookieType = core.NewType("HTTPCookie") + HTTPCookiesType = core.NewType("HTTPCookies") HTMLElementType = core.NewType("HTMLElement") HTMLDocumentType = core.NewType("HTMLDocument") HTMLPageType = core.NewType("HTMLPageType") @@ -15,9 +16,10 @@ var ( var typeComparisonTable = map[core.Type]uint64{ HTTPHeaderType: 0, HTTPCookieType: 1, - HTMLElementType: 2, - HTMLDocumentType: 3, - HTMLPageType: 4, + HTTPCookiesType: 2, + HTMLElementType: 3, + HTMLDocumentType: 4, + HTMLPageType: 5, } func Compare(first, second core.Type) int64 { diff --git a/pkg/drivers/value.go b/pkg/drivers/value.go index a5a1cb48..26bd7c34 100644 --- a/pkg/drivers/value.go +++ b/pkg/drivers/value.go @@ -143,7 +143,7 @@ type ( GetName() values.String - GetParentDocument() HTMLDocument + GetParentDocument(ctx context.Context) (HTMLDocument, error) GetChildDocuments(ctx context.Context) (*values.Array, error) @@ -192,25 +192,25 @@ type ( GetFrame(ctx context.Context, idx values.Int) (core.Value, error) - GetCookies(ctx context.Context) (*values.Array, error) + GetCookies(ctx context.Context) (HTTPCookies, error) - SetCookies(ctx context.Context, cookies ...HTTPCookie) error + SetCookies(ctx context.Context, cookies HTTPCookies) error - DeleteCookies(ctx context.Context, cookies ...HTTPCookie) error + DeleteCookies(ctx context.Context, cookies HTTPCookies) error + + GetResponse(ctx context.Context) (*HTTPResponse, error) PrintToPDF(ctx context.Context, params PDFParams) (values.Binary, error) CaptureScreenshot(ctx context.Context, params ScreenshotParams) (values.Binary, error) - WaitForNavigation(ctx context.Context) error + WaitForNavigation(ctx context.Context, targetURL values.String) error Navigate(ctx context.Context, url values.String) error NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error) NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error) - - GetResponse(ctx context.Context) (*HTTPResponse, error) } ) diff --git a/pkg/runtime/core/errors.go b/pkg/runtime/core/errors.go index 92cf0456..2c732e5d 100644 --- a/pkg/runtime/core/errors.go +++ b/pkg/runtime/core/errors.go @@ -61,7 +61,9 @@ func Errors(err ...error) error { message := "" for _, e := range err { - message += ": " + e.Error() + if e != nil { + message += ": " + e.Error() + } } return errors.New(message) diff --git a/pkg/runtime/program.go b/pkg/runtime/program.go index dd4391bd..d1efa725 100644 --- a/pkg/runtime/program.go +++ b/pkg/runtime/program.go @@ -5,10 +5,11 @@ import ( "runtime" "strings" + "github.com/pkg/errors" + "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/logging" "github.com/MontFerret/ferret/pkg/runtime/values" - "github.com/pkg/errors" ) type Program struct { diff --git a/pkg/stdlib/html/cookie_del.go b/pkg/stdlib/html/cookie_del.go index fb9563b8..ab489ffa 100644 --- a/pkg/stdlib/html/cookie_del.go +++ b/pkg/stdlib/html/cookie_del.go @@ -26,8 +26,8 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) { } inputs := args[1:] - var currentCookies *values.Array - cookies := make([]drivers.HTTPCookie, 0, len(inputs)) + var currentCookies drivers.HTTPCookies + cookies := make(drivers.HTTPCookies) for _, c := range inputs { switch cookie := c.(type) { @@ -42,23 +42,18 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) { currentCookies = current } - found, isFound := currentCookies.Find(func(value core.Value, _ int) bool { - cv := value.(drivers.HTTPCookie) - - return cv.Name == cookie.String() - }) + found, isFound := currentCookies[cookie.String()] if isFound { - cookies = append(cookies, found.(drivers.HTTPCookie)) + cookies[cookie.String()] = found } case drivers.HTTPCookie: - cookies = append(cookies, cookie) - + cookies[cookie.Name] = cookie default: return values.None, core.TypeError(c.Type(), types.String, drivers.HTTPCookieType) } } - return values.None, page.DeleteCookies(ctx, cookies...) + return values.None, page.DeleteCookies(ctx, cookies) } diff --git a/pkg/stdlib/html/cookie_get.go b/pkg/stdlib/html/cookie_get.go index abad5daf..60603011 100644 --- a/pkg/stdlib/html/cookie_get.go +++ b/pkg/stdlib/html/cookie_get.go @@ -39,15 +39,11 @@ func CookieGet(ctx context.Context, args ...core.Value) (core.Value, error) { return values.None, err } - found, _ := cookies.Find(func(value core.Value, _ int) bool { - cookie, ok := value.(drivers.HTTPCookie) + cookie, found := cookies[name.String()] - if !ok { - return ok - } + if found { + return cookie, nil + } - return cookie.Name == name.String() - }) - - return found, nil + return values.None, nil } diff --git a/pkg/stdlib/html/cookie_set.go b/pkg/stdlib/html/cookie_set.go index 8f4db730..3c17c09c 100644 --- a/pkg/stdlib/html/cookie_set.go +++ b/pkg/stdlib/html/cookie_set.go @@ -24,7 +24,7 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) { return values.None, err } - cookies := make([]drivers.HTTPCookie, 0, len(args)-1) + cookies := make(drivers.HTTPCookies) for _, c := range args[1:] { cookie, err := parseCookie(c) @@ -33,8 +33,8 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) { return values.None, err } - cookies = append(cookies, cookie) + cookies[cookie.Name] = cookie } - return values.None, page.SetCookies(ctx, cookies...) + return values.None, page.SetCookies(ctx, cookies) } diff --git a/pkg/stdlib/html/wait_navigation.go b/pkg/stdlib/html/wait_navigation.go index 9c2e2c50..aa6424f3 100644 --- a/pkg/stdlib/html/wait_navigation.go +++ b/pkg/stdlib/html/wait_navigation.go @@ -2,6 +2,7 @@ package html import ( "context" + "github.com/pkg/errors" "github.com/MontFerret/ferret/pkg/drivers" "github.com/MontFerret/ferret/pkg/runtime/core" @@ -9,6 +10,11 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/values/types" ) +type WaitNavigationParams struct { + TargetURL values.String + Timeout values.Int +} + // WAIT_NAVIGATION waits for a given page to navigate to a new url. // Stops the execution until the navigation ends or operation times out. // @param page (HTMLPage) - Target page. @@ -26,20 +32,67 @@ func WaitNavigation(ctx context.Context, args ...core.Value) (core.Value, error) return values.None, err } - timeout := values.NewInt(drivers.DefaultWaitTimeout) + var params WaitNavigationParams if len(args) > 1 { - err = core.ValidateType(args[1], types.Int) + p, err := parseWaitNavigationParams(args[1]) if err != nil { return values.None, err } - timeout = args[1].(values.Int) + params = p + } else { + params = defaultWaitNavigationParams() } - ctx, fn := waitTimeout(ctx, timeout) + ctx, fn := waitTimeout(ctx, params.Timeout) defer fn() - return values.None, doc.WaitForNavigation(ctx) + return values.None, doc.WaitForNavigation(ctx, params.TargetURL) +} + +func parseWaitNavigationParams(arg core.Value) (WaitNavigationParams, error) { + params := defaultWaitNavigationParams() + err := core.ValidateType(arg, types.Int, types.Object) + + if err != nil { + return params, err + } + + if arg.Type() == types.Int { + params.Timeout = arg.(values.Int) + + } else { + obj := arg.(*values.Object) + + if v, exists := obj.Get("timeout"); exists { + err := core.ValidateType(v, types.Int) + + if err != nil { + return params, errors.Wrap(err, "navigation parameters: timeout") + } + + params.Timeout = v.(values.Int) + } + + if v, exists := obj.Get("target"); exists { + err := core.ValidateType(v, types.String) + + if err != nil { + return params, errors.Wrap(err, "navigation parameters: url") + } + + params.TargetURL = v.(values.String) + } + } + + return params, nil +} + +func defaultWaitNavigationParams() WaitNavigationParams { + return WaitNavigationParams{ + TargetURL: "", + Timeout: values.NewInt(drivers.DefaultWaitTimeout), + } }