1
0
mirror of https://github.com/MontFerret/ferret.git synced 2025-04-17 12:06:17 +02:00

Bugfix/#399 navigation (#432)

* Refactored networking

* Some work

* Added event loop

* Renamed EventHandler to Handler

* wip

* Removed console logs

* Added DOMManager

* Refactored frame managment

* Fixes

* Fixed concurrency issues

* Fixed unit tests

* Improved EventLoop api

* Some fixes

* Refactored event loop.

* Improved logic of initial page load

* Cleaned up

* Fixed linting issues

* Fixed dom.Manager.Close

* SOme works

* Fixes

* Removed fmt.Println statements

* Refactored WaitForNavigation

* Removed filter for e2e tests

* Made Cookies Measurable

* Made Cookies KeyedCollection

* Fixes after code review

* Updated e2e tests for iframes

* Fixed iframe lookup in e2e tests

* Added comments
This commit is contained in:
Tim Voronov 2019-12-24 18:47:21 -05:00 committed by GitHub
parent 98b367722b
commit fe7b45df6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 2829 additions and 1705 deletions

View File

@ -58,8 +58,8 @@ jobs:
- stage: e2e - stage: e2e
go: stable go: stable
before_script: before_script:
- docker pull microbox/chromium-headless:75.0.3765.1 - docker pull microbox/chromium-headless:77.0.3844.0
- docker run -d -p 9222:9222 microbox/chromium-headless:75.0.3765.1 - docker run -d -p 9222:9222 microbox/chromium-headless:77.0.3844.0
- docker ps - docker ps
script: script:
- make e2e - make e2e

View File

@ -1,7 +1,14 @@
LET url = @dynamic + "?redirect=/iframe" LET url = @dynamic + "?redirect=/iframe"
LET page = DOCUMENT(url, { driver: 'cdp' }) LET page = DOCUMENT(url, { driver: 'cdp' })
LET doc = page.frames[1] LET frames = (
FOR f IN page.frames
FILTER f.name == "nested"
LIMIT 1
RETURN f
)
LET doc = FIRST(frames)
LET expectedP = TRUE LET expectedP = TRUE
LET actualP = ELEMENT_EXISTS(doc, '.text-center') LET actualP = ELEMENT_EXISTS(doc, '.text-center')

View File

@ -1,6 +1,12 @@
LET url = @dynamic + "?redirect=/iframe&src=/events" LET url = @dynamic + "?redirect=/iframe&src=/events"
LET page = DOCUMENT(url, { driver: 'cdp' }) LET page = DOCUMENT(url, { driver: 'cdp' })
LET doc = page.frames[1] LET frames = (
FOR f IN page.frames
FILTER f.name == "nested"
LIMIT 1
RETURN f
)
LET doc = FIRST(frames)
WAIT_ELEMENT(doc, "#page-events") WAIT_ELEMENT(doc, "#page-events")

View File

@ -1,6 +1,13 @@
LET url = @dynamic + "?redirect=/iframe&src=/forms" LET url = @dynamic + "?redirect=/iframe&src=/forms"
LET page = DOCUMENT(url, true) LET page = DOCUMENT(url, true)
LET doc = page.frames[1]
LET frames = (
FOR f IN page.frames
FILTER f.name == "nested"
LIMIT 1
RETURN f
)
LET doc = FIRST(frames)
WAIT_ELEMENT(doc, "form") WAIT_ELEMENT(doc, "form")

View File

@ -1,6 +1,12 @@
LET url = @dynamic + "?redirect=/iframe&src=/events" LET url = @dynamic + "?redirect=/iframe&src=/events"
LET page = DOCUMENT(url, true) LET page = DOCUMENT(url, true)
LET doc = page.frames[1] LET frames = (
FOR f IN page.frames
FILTER f.name == "nested"
LIMIT 1
RETURN f
)
LET doc = FIRST(frames)
WAIT_ELEMENT(doc, "#page-events") WAIT_ELEMENT(doc, "#page-events")

13
examples/redirects.fql Normal file
View File

@ -0,0 +1,13 @@
LET doc = DOCUMENT("http://waos.ovh/redirect.html", {
driver: 'cdp',
viewport: {
width: 1920,
height: 1080
}
})
CLICK(doc, '.click')
WAIT_NAVIGATION(doc)
RETURN ELEMENT(doc, '.title')

View File

@ -1 +0,0 @@
package cdp

View File

@ -1,4 +1,4 @@
package cdp package dom
import ( import (
"context" "context"
@ -23,22 +23,20 @@ import (
) )
type HTMLDocument struct { type HTMLDocument struct {
logger *zerolog.Logger logger *zerolog.Logger
client *cdp.Client client *cdp.Client
events *events.EventBroker dom *Manager
input *input.Manager input *input.Manager
exec *eval.ExecutionContext exec *eval.ExecutionContext
frames page.FrameTree frameTree page.FrameTree
element *HTMLElement element *HTMLElement
parent *HTMLDocument
children *common.LazyValue
} }
func LoadRootHTMLDocument( func LoadRootHTMLDocument(
ctx context.Context, ctx context.Context,
logger *zerolog.Logger, logger *zerolog.Logger,
client *cdp.Client, client *cdp.Client,
events *events.EventBroker, domManager *Manager,
mouse *input.Mouse, mouse *input.Mouse,
keyboard *input.Keyboard, keyboard *input.Keyboard,
) (*HTMLDocument, error) { ) (*HTMLDocument, error) {
@ -64,13 +62,12 @@ func LoadRootHTMLDocument(
ctx, ctx,
logger, logger,
client, client,
events, domManager,
mouse, mouse,
keyboard, keyboard,
gdRepl.Root, gdRepl.Root,
ftRepl.FrameTree, ftRepl.FrameTree,
worldRepl.ExecutionContextID, worldRepl.ExecutionContextID,
nil,
) )
} }
@ -78,22 +75,21 @@ func LoadHTMLDocument(
ctx context.Context, ctx context.Context,
logger *zerolog.Logger, logger *zerolog.Logger,
client *cdp.Client, client *cdp.Client,
events *events.EventBroker, domManager *Manager,
mouse *input.Mouse, mouse *input.Mouse,
keyboard *input.Keyboard, keyboard *input.Keyboard,
node dom.Node, node dom.Node,
tree page.FrameTree, frameTree page.FrameTree,
execID runtime.ExecutionContextID, execID runtime.ExecutionContextID,
parent *HTMLDocument,
) (*HTMLDocument, error) { ) (*HTMLDocument, error) {
exec := eval.NewExecutionContext(client, tree.Frame, execID) exec := eval.NewExecutionContext(client, frameTree.Frame, execID)
inputManager := input.NewManager(client, exec, keyboard, mouse) inputManager := input.NewManager(client, exec, keyboard, mouse)
rootElement, err := LoadHTMLElement( rootElement, err := LoadHTMLElement(
ctx, ctx,
logger, logger,
client, client,
events, domManager,
inputManager, inputManager,
exec, exec,
node.NodeID, node.NodeID,
@ -106,35 +102,31 @@ func LoadHTMLDocument(
return NewHTMLDocument( return NewHTMLDocument(
logger, logger,
client, client,
events, domManager,
inputManager, inputManager,
exec, exec,
rootElement, rootElement,
tree, frameTree,
parent,
), nil ), nil
} }
func NewHTMLDocument( func NewHTMLDocument(
logger *zerolog.Logger, logger *zerolog.Logger,
client *cdp.Client, client *cdp.Client,
events *events.EventBroker, domManager *Manager,
input *input.Manager, input *input.Manager,
exec *eval.ExecutionContext, exec *eval.ExecutionContext,
rootElement *HTMLElement, rootElement *HTMLElement,
frames page.FrameTree, frames page.FrameTree,
parent *HTMLDocument,
) *HTMLDocument { ) *HTMLDocument {
doc := new(HTMLDocument) doc := new(HTMLDocument)
doc.logger = logger doc.logger = logger
doc.client = client doc.client = client
doc.events = events doc.dom = domManager
doc.input = input doc.input = input
doc.exec = exec doc.exec = exec
doc.element = rootElement doc.element = rootElement
doc.frames = frames doc.frameTree = frames
doc.parent = parent
doc.children = common.NewLazyValue(doc.loadChildren)
return doc return doc
} }
@ -148,7 +140,7 @@ func (doc *HTMLDocument) Type() core.Type {
} }
func (doc *HTMLDocument) String() string { func (doc *HTMLDocument) String() string {
return doc.frames.Frame.URL return doc.frameTree.Frame.URL
} }
func (doc *HTMLDocument) Unwrap() interface{} { func (doc *HTMLDocument) Unwrap() interface{} {
@ -160,8 +152,8 @@ func (doc *HTMLDocument) Hash() uint64 {
h.Write([]byte(doc.Type().String())) h.Write([]byte(doc.Type().String()))
h.Write([]byte(":")) h.Write([]byte(":"))
h.Write([]byte(doc.frames.Frame.ID)) h.Write([]byte(doc.frameTree.Frame.ID))
h.Write([]byte(doc.frames.Frame.URL)) h.Write([]byte(doc.frameTree.Frame.URL))
return h.Sum64() return h.Sum64()
} }
@ -175,7 +167,7 @@ func (doc *HTMLDocument) Compare(other core.Value) int64 {
case drivers.HTMLDocumentType: case drivers.HTMLDocumentType:
other := other.(drivers.HTMLDocument) other := other.(drivers.HTMLDocument)
return values.NewString(doc.frames.Frame.URL).Compare(other.GetURL()) return values.NewString(doc.frameTree.Frame.URL).Compare(other.GetURL())
default: default:
return drivers.Compare(doc.Type(), other.Type()) return drivers.Compare(doc.Type(), other.Type())
} }
@ -194,41 +186,11 @@ func (doc *HTMLDocument) SetIn(ctx context.Context, path []core.Value, value cor
} }
func (doc *HTMLDocument) Close() error { func (doc *HTMLDocument) Close() error {
errs := make([]error, 0, 5) return doc.element.Close()
}
if doc.children.Ready() { func (doc *HTMLDocument) Frame() page.FrameTree {
val, err := doc.children.Read(context.Background()) return doc.frameTree
if err == nil {
arr := val.(*values.Array)
arr.ForEach(func(value core.Value, _ int) bool {
doc := value.(drivers.HTMLDocument)
err := doc.Close()
if err != nil {
errs = append(errs, errors.Wrapf(err, "failed to close nested document: %s", doc.GetURL()))
}
return true
})
} else {
errs = append(errs, err)
}
}
err := doc.element.Close()
if err != nil {
errs = append(errs, err)
}
if len(errs) == 0 {
return nil
}
return core.Errors(errs...)
} }
func (doc *HTMLDocument) IsDetached() values.Boolean { func (doc *HTMLDocument) IsDetached() values.Boolean {
@ -280,25 +242,37 @@ func (doc *HTMLDocument) GetTitle() values.String {
} }
func (doc *HTMLDocument) GetName() values.String { func (doc *HTMLDocument) GetName() values.String {
if doc.frames.Frame.Name != nil { if doc.frameTree.Frame.Name != nil {
return values.NewString(*doc.frames.Frame.Name) return values.NewString(*doc.frameTree.Frame.Name)
} }
return values.EmptyString return values.EmptyString
} }
func (doc *HTMLDocument) GetParentDocument() drivers.HTMLDocument { func (doc *HTMLDocument) GetParentDocument(ctx context.Context) (drivers.HTMLDocument, error) {
return doc.parent if doc.frameTree.Frame.ParentID == nil {
return nil, nil
}
return doc.dom.GetFrameNode(ctx, *doc.frameTree.Frame.ParentID)
} }
func (doc *HTMLDocument) GetChildDocuments(ctx context.Context) (*values.Array, error) { func (doc *HTMLDocument) GetChildDocuments(ctx context.Context) (*values.Array, error) {
children, err := doc.children.Read(ctx) arr := values.NewArray(len(doc.frameTree.ChildFrames))
if err != nil { for _, childFrame := range doc.frameTree.ChildFrames {
return values.NewArray(0), errors.Wrap(err, "failed to load child documents") frame, err := doc.dom.GetFrameNode(ctx, childFrame.Frame.ID)
if err != nil {
return nil, err
}
if frame != nil {
arr.Push(frame)
}
} }
return children.Copy().(*values.Array), nil return arr, nil
} }
func (doc *HTMLDocument) XPath(ctx context.Context, expression values.String) (core.Value, error) { func (doc *HTMLDocument) XPath(ctx context.Context, expression values.String) (core.Value, error) {
@ -314,7 +288,7 @@ func (doc *HTMLDocument) GetElement() drivers.HTMLElement {
} }
func (doc *HTMLDocument) GetURL() values.String { func (doc *HTMLDocument) GetURL() values.String {
return values.NewString(doc.frames.Frame.URL) return values.NewString(doc.frameTree.Frame.URL)
} }
func (doc *HTMLDocument) MoveMouseByXY(ctx context.Context, x, y values.Float) error { func (doc *HTMLDocument) MoveMouseByXY(ctx context.Context, x, y values.Float) error {
@ -484,48 +458,13 @@ func (doc *HTMLDocument) ScrollByXY(ctx context.Context, x, y values.Float) erro
return doc.input.ScrollByXY(ctx, float64(x), float64(y)) return doc.input.ScrollByXY(ctx, float64(x), float64(y))
} }
func (doc *HTMLDocument) loadChildren(ctx context.Context) (value core.Value, e error) {
children := values.NewArray(len(doc.frames.ChildFrames))
if len(doc.frames.ChildFrames) > 0 {
for _, cf := range doc.frames.ChildFrames {
cfNode, cfExecID, err := resolveFrame(ctx, doc.client, cf.Frame)
if err != nil {
return nil, errors.Wrap(err, "failed to resolve frame node")
}
cfDocument, err := LoadHTMLDocument(
ctx,
doc.logger,
doc.client,
doc.events,
doc.input.Mouse(),
doc.input.Keyboard(),
cfNode,
cf,
cfExecID,
doc,
)
if err != nil {
return nil, errors.Wrap(err, "failed to load frame document")
}
children.Push(cfDocument)
}
}
return children, nil
}
func (doc *HTMLDocument) logError(err error) *zerolog.Event { func (doc *HTMLDocument) logError(err error) *zerolog.Event {
return doc.logger. return doc.logger.
Error(). Error().
Timestamp(). Timestamp().
Str("url", string(doc.frames.Frame.URL)). Str("url", doc.frameTree.Frame.URL).
Str("securityOrigin", string(doc.frames.Frame.SecurityOrigin)). Str("securityOrigin", doc.frameTree.Frame.SecurityOrigin).
Str("mimeType", string(doc.frames.Frame.MimeType)). Str("mimeType", doc.frameTree.Frame.MimeType).
Str("frameID", string(doc.frames.Frame.ID)). Str("frameID", string(doc.frameTree.Frame.ID)).
Err(err) Err(err)
} }

View File

@ -0,0 +1 @@
package dom

View File

@ -1,4 +1,4 @@
package cdp package dom
import ( import (
"context" "context"
@ -36,11 +36,20 @@ type (
ObjectID runtime.RemoteObjectID ObjectID runtime.RemoteObjectID
} }
elementListeners struct {
pageReload events.ListenerID
attrModified events.ListenerID
attrRemoved events.ListenerID
childNodeCountUpdated events.ListenerID
childNodeInserted events.ListenerID
childNodeRemoved events.ListenerID
}
HTMLElement struct { HTMLElement struct {
mu sync.Mutex mu sync.Mutex
logger *zerolog.Logger logger *zerolog.Logger
client *cdp.Client client *cdp.Client
events *events.EventBroker dom *Manager
input *input.Manager input *input.Manager
exec *eval.ExecutionContext exec *eval.ExecutionContext
connected values.Boolean connected values.Boolean
@ -53,6 +62,7 @@ type (
style *common.LazyValue style *common.LazyValue
children []HTMLElementIdentity children []HTMLElementIdentity
loadedChildren *common.LazyValue loadedChildren *common.LazyValue
listeners *elementListeners
} }
) )
@ -60,7 +70,7 @@ func LoadHTMLElement(
ctx context.Context, ctx context.Context,
logger *zerolog.Logger, logger *zerolog.Logger,
client *cdp.Client, client *cdp.Client,
broker *events.EventBroker, domManager *Manager,
input *input.Manager, input *input.Manager,
exec *eval.ExecutionContext, exec *eval.ExecutionContext,
nodeID dom.NodeID, nodeID dom.NodeID,
@ -86,7 +96,7 @@ func LoadHTMLElement(
ctx, ctx,
logger, logger,
client, client,
broker, domManager,
input, input,
exec, exec,
HTMLElementIdentity{ HTMLElementIdentity{
@ -100,7 +110,7 @@ func LoadHTMLElementWithID(
ctx context.Context, ctx context.Context,
logger *zerolog.Logger, logger *zerolog.Logger,
client *cdp.Client, client *cdp.Client,
broker *events.EventBroker, domManager *Manager,
input *input.Manager, input *input.Manager,
exec *eval.ExecutionContext, exec *eval.ExecutionContext,
id HTMLElementIdentity, id HTMLElementIdentity,
@ -120,7 +130,7 @@ func LoadHTMLElementWithID(
return NewHTMLElement( return NewHTMLElement(
logger, logger,
client, client,
broker, domManager,
input, input,
exec, exec,
id, id,
@ -133,7 +143,7 @@ func LoadHTMLElementWithID(
func NewHTMLElement( func NewHTMLElement(
logger *zerolog.Logger, logger *zerolog.Logger,
client *cdp.Client, client *cdp.Client,
broker *events.EventBroker, domManager *Manager,
input *input.Manager, input *input.Manager,
exec *eval.ExecutionContext, exec *eval.ExecutionContext,
id HTMLElementIdentity, id HTMLElementIdentity,
@ -144,7 +154,7 @@ func NewHTMLElement(
el := new(HTMLElement) el := new(HTMLElement)
el.logger = logger el.logger = logger
el.client = client el.client = client
el.events = broker el.dom = domManager
el.input = input el.input = input
el.exec = exec el.exec = exec
el.connected = values.True el.connected = values.True
@ -157,13 +167,14 @@ func NewHTMLElement(
el.style = common.NewLazyValue(el.parseStyle) el.style = common.NewLazyValue(el.parseStyle)
el.loadedChildren = common.NewLazyValue(el.loadChildren) el.loadedChildren = common.NewLazyValue(el.loadChildren)
el.children = children el.children = children
el.listeners = &elementListeners{
broker.AddEventListener(events.EventReload, el.handlePageReload) pageReload: domManager.AddDocumentUpdatedListener(el.handlePageReload),
broker.AddEventListener(events.EventAttrModified, el.handleAttrModified) attrModified: domManager.AddAttrModifiedListener(el.handleAttrModified),
broker.AddEventListener(events.EventAttrRemoved, el.handleAttrRemoved) attrRemoved: domManager.AddAttrRemovedListener(el.handleAttrRemoved),
broker.AddEventListener(events.EventChildNodeCountUpdated, el.handleChildrenCountChanged) childNodeCountUpdated: domManager.AddChildNodeCountUpdatedListener(el.handleChildrenCountChanged),
broker.AddEventListener(events.EventChildNodeInserted, el.handleChildInserted) childNodeInserted: domManager.AddChildNodeInsertedListener(el.handleChildInserted),
broker.AddEventListener(events.EventChildNodeRemoved, el.handleChildRemoved) childNodeRemoved: domManager.AddChildNodeRemovedListener(el.handleChildRemoved),
}
return el return el
} }
@ -178,12 +189,13 @@ func (el *HTMLElement) Close() error {
} }
el.connected = values.False el.connected = values.False
el.events.RemoveEventListener(events.EventReload, el.handlePageReload)
el.events.RemoveEventListener(events.EventAttrModified, el.handleAttrModified) el.dom.RemoveReloadListener(el.listeners.pageReload)
el.events.RemoveEventListener(events.EventAttrRemoved, el.handleAttrRemoved) el.dom.RemoveAttrModifiedListener(el.listeners.attrModified)
el.events.RemoveEventListener(events.EventChildNodeCountUpdated, el.handleChildrenCountChanged) el.dom.RemoveAttrRemovedListener(el.listeners.attrRemoved)
el.events.RemoveEventListener(events.EventChildNodeInserted, el.handleChildInserted) el.dom.RemoveChildNodeCountUpdatedListener(el.listeners.childNodeCountUpdated)
el.events.RemoveEventListener(events.EventChildNodeRemoved, el.handleChildRemoved) el.dom.RemoveChildNodeInsertedListener(el.listeners.childNodeInserted)
el.dom.RemoveChildNodeRemovedListener(el.listeners.childNodeRemoved)
return nil return nil
} }
@ -472,7 +484,15 @@ func (el *HTMLElement) QuerySelector(ctx context.Context, selector values.String
return values.None, nil return values.None, nil
} }
res, err := LoadHTMLElement(ctx, el.logger, el.client, el.events, el.input, el.exec, found.NodeID) res, err := LoadHTMLElement(
ctx,
el.logger,
el.client,
el.dom,
el.input,
el.exec,
found.NodeID,
)
if err != nil { if err != nil {
return values.None, nil return values.None, nil
@ -504,7 +524,15 @@ func (el *HTMLElement) QuerySelectorAll(ctx context.Context, selector values.Str
continue continue
} }
childEl, err := LoadHTMLElement(ctx, el.logger, el.client, el.events, el.input, el.exec, id) childEl, err := LoadHTMLElement(
ctx,
el.logger,
el.client,
el.dom,
el.input,
el.exec,
id,
)
if err != nil { if err != nil {
// close elements that are already loaded, but won't be used because of the error // close elements that are already loaded, but won't be used because of the error
@ -609,7 +637,7 @@ func (el *HTMLElement) XPath(ctx context.Context, expression values.String) (res
ctx, ctx,
el.logger, el.logger,
el.client, el.client,
el.events, el.dom,
el.input, el.input,
el.exec, el.exec,
HTMLElementIdentity{ HTMLElementIdentity{
@ -641,7 +669,7 @@ func (el *HTMLElement) XPath(ctx context.Context, expression values.String) (res
ctx, ctx,
el.logger, el.logger,
el.client, el.client,
el.events, el.dom,
el.input, el.input,
el.exec, el.exec,
HTMLElementIdentity{ HTMLElementIdentity{
@ -1155,7 +1183,7 @@ func (el *HTMLElement) loadChildren(ctx context.Context) (core.Value, error) {
ctx, ctx,
el.logger, el.logger,
el.client, el.client,
el.events, el.dom,
el.input, el.input,
el.exec, el.exec,
childID.NodeID, childID.NodeID,
@ -1191,20 +1219,13 @@ func (el *HTMLElement) parseStyle(ctx context.Context) (core.Value, error) {
return common.DeserializeStyles(value.(values.String)) return common.DeserializeStyles(value.(values.String))
} }
func (el *HTMLElement) handlePageReload(_ context.Context, _ interface{}) { func (el *HTMLElement) handlePageReload(_ context.Context) {
el.Close() el.Close()
} }
func (el *HTMLElement) handleAttrModified(ctx context.Context, message interface{}) { func (el *HTMLElement) handleAttrModified(ctx context.Context, nodeID dom.NodeID, name, value string) {
reply, ok := message.(*dom.AttributeModifiedReply)
// well....
if !ok {
return
}
// it's not for this el // it's not for this el
if reply.NodeID != el.id.NodeID { if nodeID != el.id.NodeID {
return return
} }
@ -1225,7 +1246,7 @@ func (el *HTMLElement) handleAttrModified(ctx context.Context, message interface
return return
} }
if reply.Name == "style" { if name == "style" {
el.style.Reset() el.style.Reset()
} }
@ -1235,20 +1256,13 @@ func (el *HTMLElement) handleAttrModified(ctx context.Context, message interface
return return
} }
attrs.Set(values.NewString(reply.Name), values.NewString(reply.Value)) attrs.Set(values.NewString(name), values.NewString(value))
}) })
} }
func (el *HTMLElement) handleAttrRemoved(ctx context.Context, message interface{}) { func (el *HTMLElement) handleAttrRemoved(ctx context.Context, nodeID dom.NodeID, name string) {
reply, ok := message.(*dom.AttributeRemovedReply)
// well....
if !ok {
return
}
// it's not for this el // it's not for this el
if reply.NodeID != el.id.NodeID { if nodeID != el.id.NodeID {
return return
} }
@ -1269,7 +1283,7 @@ func (el *HTMLElement) handleAttrRemoved(ctx context.Context, message interface{
return return
} }
if reply.Name == "style" { if name == "style" {
el.style.Reset() el.style.Reset()
} }
@ -1279,18 +1293,12 @@ func (el *HTMLElement) handleAttrRemoved(ctx context.Context, message interface{
return return
} }
attrs.Remove(values.NewString(reply.Name)) attrs.Remove(values.NewString(name))
}) })
} }
func (el *HTMLElement) handleChildrenCountChanged(ctx context.Context, message interface{}) { func (el *HTMLElement) handleChildrenCountChanged(ctx context.Context, nodeID dom.NodeID, _ int) {
reply, ok := message.(*dom.ChildNodeCountUpdatedReply) if nodeID != el.id.NodeID {
if !ok {
return
}
if reply.NodeID != el.id.NodeID {
return return
} }
@ -1315,20 +1323,14 @@ func (el *HTMLElement) handleChildrenCountChanged(ctx context.Context, message i
el.children = createChildrenArray(node.Node.Children) el.children = createChildrenArray(node.Node.Children)
} }
func (el *HTMLElement) handleChildInserted(ctx context.Context, message interface{}) { func (el *HTMLElement) handleChildInserted(ctx context.Context, parentNodeID, prevNodeID dom.NodeID, node dom.Node) {
reply, ok := message.(*dom.ChildNodeInsertedReply) if parentNodeID != el.id.NodeID {
if !ok {
return
}
if reply.ParentNodeID != el.id.NodeID {
return return
} }
targetIDx := -1 targetIDx := -1
prevID := reply.PreviousNodeID prevID := prevNodeID
nextID := reply.Node.NodeID nextID := node.NodeID
if el.IsDetached() { if el.IsDetached() {
return return
@ -1349,7 +1351,7 @@ func (el *HTMLElement) handleChildInserted(ctx context.Context, message interfac
} }
nextIdentity := HTMLElementIdentity{ nextIdentity := HTMLElementIdentity{
NodeID: reply.Node.NodeID, NodeID: nextID,
} }
arr := el.children arr := el.children
@ -1361,7 +1363,15 @@ func (el *HTMLElement) handleChildInserted(ctx context.Context, message interfac
el.loadedChildren.Mutate(ctx, func(v core.Value, _ error) { el.loadedChildren.Mutate(ctx, func(v core.Value, _ error) {
loadedArr := v.(*values.Array) loadedArr := v.(*values.Array)
loadedEl, err := LoadHTMLElement(ctx, el.logger, el.client, el.events, el.input, el.exec, nextID) loadedEl, err := LoadHTMLElement(
ctx,
el.logger,
el.client,
el.dom,
el.input,
el.exec,
nextID,
)
if err != nil { if err != nil {
el.logError(err).Msg("failed to load an inserted element") el.logError(err).Msg("failed to load an inserted element")
@ -1376,19 +1386,13 @@ func (el *HTMLElement) handleChildInserted(ctx context.Context, message interfac
}) })
} }
func (el *HTMLElement) handleChildRemoved(ctx context.Context, message interface{}) { func (el *HTMLElement) handleChildRemoved(ctx context.Context, nodeID, prevNodeID dom.NodeID) {
reply, ok := message.(*dom.ChildNodeRemovedReply) if nodeID != el.id.NodeID {
if !ok {
return
}
if reply.ParentNodeID != el.id.NodeID {
return return
} }
targetIDx := -1 targetIDx := -1
targetID := reply.NodeID targetID := prevNodeID
if el.IsDetached() { if el.IsDetached() {
return return

View File

@ -0,0 +1,287 @@
package dom
import (
"bytes"
"context"
"encoding/json"
"errors"
"golang.org/x/net/html"
"strings"
"time"
"github.com/MontFerret/ferret/pkg/drivers/cdp/eval"
"github.com/MontFerret/ferret/pkg/drivers/cdp/templates"
"github.com/MontFerret/ferret/pkg/drivers/common"
"github.com/MontFerret/ferret/pkg/runtime/values"
"github.com/PuerkitoBio/goquery"
"github.com/mafredri/cdp"
"github.com/mafredri/cdp/protocol/dom"
"github.com/mafredri/cdp/protocol/page"
"github.com/mafredri/cdp/protocol/runtime"
)
var emptyExpires = time.Time{}
// parseAttrs is a helper function that parses a given interleaved array of node attribute names and values,
// and returns an object that represents attribute keys and values.
func parseAttrs(attrs []string) *values.Object {
var attr values.String
res := values.NewObject()
for _, el := range attrs {
el = strings.TrimSpace(el)
str := values.NewString(el)
if common.IsAttribute(el) {
attr = str
res.Set(str, values.EmptyString)
} else {
current, ok := res.Get(attr)
if ok {
if current.String() != "" {
res.Set(attr, current.(values.String).Concat(values.SpaceString).Concat(str))
} else {
res.Set(attr, str)
}
}
}
}
return res
}
func setInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerHTML values.String) error {
var objID *runtime.RemoteObjectID
if id.ObjectID != "" {
objID = &id.ObjectID
} else {
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID))
if err != nil {
return err
}
if repl.Object.ObjectID == nil {
return errors.New("unable to resolve node")
}
objID = repl.Object.ObjectID
}
b, err := json.Marshal(innerHTML.String())
if err != nil {
return err
}
err = exec.EvalWithArguments(ctx, templates.SetInnerHTML(),
runtime.CallArgument{
ObjectID: objID,
},
runtime.CallArgument{
Value: json.RawMessage(b),
},
)
return err
}
func getInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) {
// not a document
if nodeType != html.DocumentNode {
var objID runtime.RemoteObjectID
if id.ObjectID != "" {
objID = id.ObjectID
} else {
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID))
if err != nil {
return "", err
}
if repl.Object.ObjectID == nil {
return "", errors.New("unable to resolve node")
}
objID = *repl.Object.ObjectID
}
res, err := exec.ReadProperty(ctx, objID, "innerHTML")
if err != nil {
return "", err
}
return values.NewString(res.String()), nil
}
repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerHTML")
if err != nil {
return "", err
}
return values.NewString(repl.String()), nil
}
func setInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerText values.String) error {
var objID *runtime.RemoteObjectID
if id.ObjectID != "" {
objID = &id.ObjectID
} else {
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID))
if err != nil {
return err
}
if repl.Object.ObjectID == nil {
return errors.New("unable to resolve node")
}
objID = repl.Object.ObjectID
}
b, err := json.Marshal(innerText.String())
if err != nil {
return err
}
err = exec.EvalWithArguments(ctx, templates.SetInnerText(),
runtime.CallArgument{
ObjectID: objID,
},
runtime.CallArgument{
Value: json.RawMessage(b),
},
)
return err
}
func getInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) {
// not a document
if nodeType != html.DocumentNode {
var objID runtime.RemoteObjectID
if id.ObjectID != "" {
objID = id.ObjectID
} else {
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID))
if err != nil {
return "", err
}
if repl.Object.ObjectID == nil {
return "", errors.New("unable to resolve node")
}
objID = *repl.Object.ObjectID
}
res, err := exec.ReadProperty(ctx, objID, "innerText")
if err != nil {
return "", err
}
return values.NewString(res.String()), err
}
repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerText")
if err != nil {
return "", err
}
return values.NewString(repl.String()), nil
}
func parseInnerText(innerHTML string) (values.String, error) {
buff := bytes.NewBuffer([]byte(innerHTML))
parsed, err := goquery.NewDocumentFromReader(buff)
if err != nil {
return values.EmptyString, err
}
return values.NewString(parsed.Text()), nil
}
func createChildrenArray(nodes []dom.Node) []HTMLElementIdentity {
children := make([]HTMLElementIdentity, len(nodes))
for idx, child := range nodes {
child := child
children[idx] = HTMLElementIdentity{
NodeID: child.NodeID,
}
}
return children
}
func resolveFrame(ctx context.Context, client *cdp.Client, frameID page.FrameID) (dom.Node, runtime.ExecutionContextID, error) {
worldRepl, err := client.Page.CreateIsolatedWorld(ctx, page.NewCreateIsolatedWorldArgs(frameID))
if err != nil {
return dom.Node{}, -1, err
}
evalRes, err := client.Runtime.Evaluate(
ctx,
runtime.NewEvaluateArgs(eval.PrepareEval("return document")).
SetContextID(worldRepl.ExecutionContextID),
)
if err != nil {
return dom.Node{}, -1, err
}
if evalRes.ExceptionDetails != nil {
exception := *evalRes.ExceptionDetails
return dom.Node{}, -1, errors.New(exception.Text)
}
if evalRes.Result.ObjectID == nil {
return dom.Node{}, -1, errors.New("failed to resolve frame document")
}
req, err := client.DOM.RequestNode(ctx, dom.NewRequestNodeArgs(*evalRes.Result.ObjectID))
if err != nil {
return dom.Node{}, -1, err
}
if req.NodeID == 0 {
return dom.Node{}, -1, errors.New("framed document is resolved with empty node id")
}
desc, err := client.DOM.DescribeNode(
ctx,
dom.
NewDescribeNodeArgs().
SetNodeID(req.NodeID).
SetDepth(1),
)
if err != nil {
return dom.Node{}, -1, err
}
// Returned node, by some reason, does not contain the NodeID
// So, we have to set it manually
desc.Node.NodeID = req.NodeID
return desc.Node, worldRepl.ExecutionContextID, nil
}

View File

@ -0,0 +1,511 @@
package dom
import (
"context"
"github.com/MontFerret/ferret/pkg/drivers/cdp/input"
"github.com/MontFerret/ferret/pkg/drivers/common"
"github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/values"
"github.com/mafredri/cdp/protocol/page"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"io"
"sync"
"github.com/mafredri/cdp"
"github.com/mafredri/cdp/protocol/dom"
"github.com/mafredri/cdp/rpcc"
"github.com/MontFerret/ferret/pkg/drivers/cdp/events"
)
var (
eventDocumentUpdated = events.New("doc_updated")
eventAttrModified = events.New("attr_modified")
eventAttrRemoved = events.New("attr_removed")
eventChildNodeCountUpdated = events.New("child_count_updated")
eventChildNodeInserted = events.New("child_inserted")
eventChildNodeRemoved = events.New("child_removed")
)
type (
DocumentUpdatedListener func(ctx context.Context)
AttrModifiedListener func(ctx context.Context, nodeID dom.NodeID, name, value string)
AttrRemovedListener func(ctx context.Context, nodeID dom.NodeID, name string)
ChildNodeCountUpdatedListener func(ctx context.Context, nodeID dom.NodeID, count int)
ChildNodeInsertedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID, node dom.Node)
ChildNodeRemovedListener func(ctx context.Context, nodeID, previousNodeID dom.NodeID)
Frame struct {
tree page.FrameTree
node *HTMLDocument
ready bool
}
Manager struct {
mu sync.Mutex
logger *zerolog.Logger
client *cdp.Client
events *events.Loop
mouse *input.Mouse
keyboard *input.Keyboard
mainFrame page.FrameID
frames map[page.FrameID]Frame
cancel context.CancelFunc
}
)
// a dirty workaround to let pass the vet test
func createContext() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
}
func New(
logger *zerolog.Logger,
client *cdp.Client,
eventLoop *events.Loop,
mouse *input.Mouse,
keyboard *input.Keyboard,
) (manager *Manager, err error) {
ctx, cancel := createContext()
closers := make([]io.Closer, 0, 10)
defer func() {
if err != nil {
common.CloseAll(logger, closers, "failed to close a DOM event stream")
}
}()
onContentReady, err := client.Page.DOMContentEventFired(ctx)
if err != nil {
return nil, err
}
closers = append(closers, onContentReady)
onDocUpdated, err := client.DOM.DocumentUpdated(ctx)
if err != nil {
return nil, err
}
closers = append(closers, onDocUpdated)
onAttrModified, err := client.DOM.AttributeModified(ctx)
if err != nil {
return nil, err
}
closers = append(closers, onAttrModified)
onAttrRemoved, err := client.DOM.AttributeRemoved(ctx)
if err != nil {
return nil, err
}
closers = append(closers, onAttrRemoved)
onChildCountUpdated, err := client.DOM.ChildNodeCountUpdated(ctx)
if err != nil {
return nil, err
}
closers = append(closers, onChildCountUpdated)
onChildNodeInserted, err := client.DOM.ChildNodeInserted(ctx)
if err != nil {
return nil, err
}
closers = append(closers, onChildNodeInserted)
onChildNodeRemoved, err := client.DOM.ChildNodeRemoved(ctx)
if err != nil {
return nil, err
}
closers = append(closers, onChildNodeRemoved)
eventLoop.AddSource(events.NewSource(eventDocumentUpdated, onDocUpdated, func(stream rpcc.Stream) (i interface{}, e error) {
return stream.(dom.DocumentUpdatedClient).Recv()
}))
eventLoop.AddSource(events.NewSource(eventAttrModified, onAttrModified, func(stream rpcc.Stream) (i interface{}, e error) {
return stream.(dom.AttributeModifiedClient).Recv()
}))
eventLoop.AddSource(events.NewSource(eventAttrRemoved, onAttrRemoved, func(stream rpcc.Stream) (i interface{}, e error) {
return stream.(dom.AttributeRemovedClient).Recv()
}))
eventLoop.AddSource(events.NewSource(eventChildNodeCountUpdated, onChildCountUpdated, func(stream rpcc.Stream) (i interface{}, e error) {
return stream.(dom.ChildNodeCountUpdatedClient).Recv()
}))
eventLoop.AddSource(events.NewSource(eventChildNodeInserted, onChildNodeInserted, func(stream rpcc.Stream) (i interface{}, e error) {
return stream.(dom.ChildNodeInsertedClient).Recv()
}))
eventLoop.AddSource(events.NewSource(eventChildNodeRemoved, onChildNodeRemoved, func(stream rpcc.Stream) (i interface{}, e error) {
return stream.(dom.ChildNodeRemovedClient).Recv()
}))
manager = new(Manager)
manager.logger = logger
manager.client = client
manager.events = eventLoop
manager.mouse = mouse
manager.keyboard = keyboard
manager.frames = make(map[page.FrameID]Frame)
manager.cancel = cancel
return manager, nil
}
func (m *Manager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
m.cancel()
m.cancel = nil
}
errs := make([]error, 0, len(m.frames))
for _, f := range m.frames {
// if initialized
if f.node != nil {
if err := f.node.Close(); err != nil {
errs = append(errs, err)
}
}
}
if len(errs) > 0 {
return core.Errors(errs...)
}
return nil
}
func (m *Manager) GetMainFrame() *HTMLDocument {
m.mu.Lock()
defer m.mu.Unlock()
if m.mainFrame == "" {
return nil
}
mainFrame, exists := m.frames[m.mainFrame]
if exists {
return mainFrame.node
}
return nil
}
func (m *Manager) SetMainFrame(doc *HTMLDocument) {
m.mu.Lock()
defer m.mu.Unlock()
if m.mainFrame != "" {
if err := m.removeFrameRecursivelyInternal(m.mainFrame); err != nil {
m.logger.Error().Err(err).Msg("failed to close previous main frame")
}
}
m.mainFrame = doc.frameTree.Frame.ID
m.addPreloadedFrame(doc)
}
func (m *Manager) AddFrame(frame page.FrameTree) {
m.mu.Lock()
defer m.mu.Unlock()
m.addFrameInternal(frame)
}
func (m *Manager) RemoveFrame(frameID page.FrameID) error {
m.mu.Lock()
defer m.mu.Unlock()
return m.removeFrameInternal(frameID)
}
func (m *Manager) RemoveFrameRecursively(frameID page.FrameID) error {
m.mu.Lock()
defer m.mu.Unlock()
return m.removeFrameRecursivelyInternal(frameID)
}
func (m *Manager) RemoveFramesByParentID(parentFrameID page.FrameID) error {
m.mu.Lock()
defer m.mu.Unlock()
frame, found := m.frames[parentFrameID]
if !found {
return errors.New("frame not found")
}
for _, child := range frame.tree.ChildFrames {
if err := m.removeFrameRecursivelyInternal(child.Frame.ID); err != nil {
return err
}
}
return nil
}
func (m *Manager) GetFrameNode(ctx context.Context, frameID page.FrameID) (*HTMLDocument, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.getFrameInternal(ctx, frameID)
}
func (m *Manager) GetFrameTree(_ context.Context, frameID page.FrameID) (page.FrameTree, error) {
m.mu.Lock()
defer m.mu.Unlock()
frame, found := m.frames[frameID]
if !found {
return page.FrameTree{}, core.ErrNotFound
}
return frame.tree, nil
}
func (m *Manager) GetFrameNodes(ctx context.Context) (*values.Array, error) {
m.mu.Lock()
defer m.mu.Unlock()
arr := values.NewArray(len(m.frames))
for _, f := range m.frames {
doc, err := m.getFrameInternal(ctx, f.tree.Frame.ID)
if err != nil {
return nil, err
}
arr.Push(doc)
}
return arr, nil
}
func (m *Manager) AddDocumentUpdatedListener(listener DocumentUpdatedListener) events.ListenerID {
return m.events.AddListener(eventDocumentUpdated, func(ctx context.Context, _ interface{}) bool {
listener(ctx)
return true
})
}
func (m *Manager) RemoveReloadListener(listenerID events.ListenerID) {
m.events.RemoveListener(eventDocumentUpdated, listenerID)
}
func (m *Manager) AddAttrModifiedListener(listener AttrModifiedListener) events.ListenerID {
return m.events.AddListener(eventAttrModified, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.AttributeModifiedReply)
listener(ctx, reply.NodeID, reply.Name, reply.Value)
return true
})
}
func (m *Manager) RemoveAttrModifiedListener(listenerID events.ListenerID) {
m.events.RemoveListener(eventAttrModified, listenerID)
}
func (m *Manager) AddAttrRemovedListener(listener AttrRemovedListener) events.ListenerID {
return m.events.AddListener(eventAttrRemoved, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.AttributeRemovedReply)
listener(ctx, reply.NodeID, reply.Name)
return true
})
}
func (m *Manager) RemoveAttrRemovedListener(listenerID events.ListenerID) {
m.events.RemoveListener(eventAttrRemoved, listenerID)
}
func (m *Manager) AddChildNodeCountUpdatedListener(listener ChildNodeCountUpdatedListener) events.ListenerID {
return m.events.AddListener(eventChildNodeCountUpdated, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.ChildNodeCountUpdatedReply)
listener(ctx, reply.NodeID, reply.ChildNodeCount)
return true
})
}
func (m *Manager) RemoveChildNodeCountUpdatedListener(listenerID events.ListenerID) {
m.events.RemoveListener(eventChildNodeCountUpdated, listenerID)
}
func (m *Manager) AddChildNodeInsertedListener(listener ChildNodeInsertedListener) events.ListenerID {
return m.events.AddListener(eventChildNodeInserted, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.ChildNodeInsertedReply)
listener(ctx, reply.ParentNodeID, reply.PreviousNodeID, reply.Node)
return true
})
}
func (m *Manager) RemoveChildNodeInsertedListener(listenerID events.ListenerID) {
m.events.RemoveListener(eventChildNodeInserted, listenerID)
}
func (m *Manager) AddChildNodeRemovedListener(listener ChildNodeRemovedListener) events.ListenerID {
return m.events.AddListener(eventChildNodeRemoved, func(ctx context.Context, message interface{}) bool {
reply := message.(*dom.ChildNodeRemovedReply)
listener(ctx, reply.ParentNodeID, reply.NodeID)
return true
})
}
func (m *Manager) RemoveChildNodeRemovedListener(listenerID events.ListenerID) {
m.events.RemoveListener(eventChildNodeRemoved, listenerID)
}
func (m *Manager) WaitForDOMReady(ctx context.Context) error {
onContentReady, err := m.client.Page.DOMContentEventFired(ctx)
if err != nil {
return err
}
defer func() {
if err := onContentReady.Close(); err != nil {
m.logger.Error().Err(err).Msg("failed to close DOM content ready stream event")
}
}()
_, err = onContentReady.Recv()
return err
}
func (m *Manager) addFrameInternal(frame page.FrameTree) {
m.frames[frame.Frame.ID] = Frame{
tree: frame,
node: nil,
}
for _, child := range frame.ChildFrames {
m.addFrameInternal(child)
}
}
func (m *Manager) addPreloadedFrame(doc *HTMLDocument) {
m.frames[doc.frameTree.Frame.ID] = Frame{
tree: doc.frameTree,
node: doc,
}
for _, child := range doc.frameTree.ChildFrames {
m.addFrameInternal(child)
}
}
func (m *Manager) getFrameInternal(ctx context.Context, frameID page.FrameID) (*HTMLDocument, error) {
frame, found := m.frames[frameID]
if !found {
return nil, core.ErrNotFound
}
// frame is initialized
if frame.node != nil {
return frame.node, nil
}
// the frames is not loaded yet
node, execID, err := resolveFrame(ctx, m.client, frameID)
if err != nil {
return nil, errors.Wrap(err, "failed to resolve frame node")
}
doc, err := LoadHTMLDocument(
ctx,
m.logger,
m.client,
m,
m.mouse,
m.keyboard,
node,
frame.tree,
execID,
)
if err != nil {
return nil, errors.Wrap(err, "failed to load frame document")
}
frame.node = doc
return doc, nil
}
func (m *Manager) removeFrameInternal(frameID page.FrameID) error {
current, exists := m.frames[frameID]
if !exists {
return core.Error(core.ErrNotFound, "frame")
}
delete(m.frames, frameID)
if frameID == m.mainFrame {
m.mainFrame = ""
}
if current.node == nil {
return nil
}
return current.node.Close()
}
func (m *Manager) removeFrameRecursivelyInternal(frameID page.FrameID) error {
parent, exists := m.frames[frameID]
if !exists {
return core.Error(core.ErrNotFound, "frame")
}
for _, child := range parent.tree.ChildFrames {
if err := m.removeFrameRecursivelyInternal(child.Frame.ID); err != nil {
return err
}
}
return m.removeFrameInternal(frameID)
}

View File

@ -1,289 +0,0 @@
package events
import (
"context"
"reflect"
"sync"
"github.com/mafredri/cdp/protocol/dom"
"github.com/mafredri/cdp/protocol/page"
"github.com/MontFerret/ferret/pkg/runtime/core"
)
type (
Event int
EventListener func(ctx context.Context, message interface{})
EventBroker struct {
mu sync.Mutex
listeners map[Event][]EventListener
cancel context.CancelFunc
onLoad page.LoadEventFiredClient
onReload dom.DocumentUpdatedClient
onAttrModified dom.AttributeModifiedClient
onAttrRemoved dom.AttributeRemovedClient
onChildNodeCountUpdated dom.ChildNodeCountUpdatedClient
onChildNodeInserted dom.ChildNodeInsertedClient
onChildNodeRemoved dom.ChildNodeRemovedClient
}
)
const (
//revive:disable-next-line:var-declaration
EventError = Event(iota)
EventLoad
EventReload
EventAttrModified
EventAttrRemoved
EventChildNodeCountUpdated
EventChildNodeInserted
EventChildNodeRemoved
)
func NewEventBroker(
onLoad page.LoadEventFiredClient,
onReload dom.DocumentUpdatedClient,
onAttrModified dom.AttributeModifiedClient,
onAttrRemoved dom.AttributeRemovedClient,
onChildNodeCountUpdated dom.ChildNodeCountUpdatedClient,
onChildNodeInserted dom.ChildNodeInsertedClient,
onChildNodeRemoved dom.ChildNodeRemovedClient,
) *EventBroker {
broker := new(EventBroker)
broker.listeners = make(map[Event][]EventListener)
broker.onLoad = onLoad
broker.onReload = onReload
broker.onAttrModified = onAttrModified
broker.onAttrRemoved = onAttrRemoved
broker.onChildNodeCountUpdated = onChildNodeCountUpdated
broker.onChildNodeInserted = onChildNodeInserted
broker.onChildNodeRemoved = onChildNodeRemoved
return broker
}
func (broker *EventBroker) AddEventListener(event Event, listener EventListener) {
broker.mu.Lock()
defer broker.mu.Unlock()
listeners, ok := broker.listeners[event]
if !ok {
listeners = make([]EventListener, 0, 5)
}
broker.listeners[event] = append(listeners, listener)
}
func (broker *EventBroker) RemoveEventListener(event Event, listener EventListener) {
broker.mu.Lock()
defer broker.mu.Unlock()
idx := -1
listeners, ok := broker.listeners[event]
if !ok {
return
}
listenerPointer := reflect.ValueOf(listener).Pointer()
for i, l := range listeners {
itemPointer := reflect.ValueOf(l).Pointer()
if itemPointer == listenerPointer {
idx = i
break
}
}
if idx < 0 {
return
}
var modifiedListeners []EventListener
if len(listeners) > 1 {
modifiedListeners = append(listeners[:idx], listeners[idx+1:]...)
} else {
modifiedListeners = make([]EventListener, 0, 5)
}
broker.listeners[event] = modifiedListeners
}
func (broker *EventBroker) ListenerCount(event Event) int {
broker.mu.Lock()
defer broker.mu.Unlock()
listeners, ok := broker.listeners[event]
if !ok {
return 0
}
return len(listeners)
}
func (broker *EventBroker) Start() error {
broker.mu.Lock()
defer broker.mu.Unlock()
if broker.cancel != nil {
return core.Error(core.ErrInvalidOperation, "broker is already started")
}
ctx, cancel := context.WithCancel(context.Background())
broker.cancel = cancel
go broker.runLoop(ctx)
return nil
}
func (broker *EventBroker) Stop() error {
broker.mu.Lock()
defer broker.mu.Unlock()
if broker.cancel == nil {
return core.Error(core.ErrInvalidOperation, "broker is already stopped")
}
broker.cancel()
broker.cancel = nil
return nil
}
func (broker *EventBroker) Close() error {
broker.mu.Lock()
defer broker.mu.Unlock()
if broker.cancel != nil {
broker.cancel()
broker.cancel = nil
}
broker.onLoad.Close()
broker.onReload.Close()
broker.onAttrModified.Close()
broker.onAttrRemoved.Close()
broker.onChildNodeCountUpdated.Close()
broker.onChildNodeInserted.Close()
broker.onChildNodeRemoved.Close()
return nil
}
func (broker *EventBroker) StopAndClose() error {
err := broker.Stop()
if err != nil {
return err
}
return broker.Close()
}
func (broker *EventBroker) runLoop(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-broker.onLoad.Ready():
if ctxDone(ctx) {
return
}
reply, err := broker.onLoad.Recv()
broker.emit(ctx, EventLoad, reply, err)
case <-broker.onReload.Ready():
if ctxDone(ctx) {
return
}
reply, err := broker.onReload.Recv()
broker.emit(ctx, EventReload, reply, err)
case <-broker.onAttrModified.Ready():
if ctxDone(ctx) {
return
}
reply, err := broker.onAttrModified.Recv()
broker.emit(ctx, EventAttrModified, reply, err)
case <-broker.onAttrRemoved.Ready():
if ctxDone(ctx) {
return
}
reply, err := broker.onAttrRemoved.Recv()
broker.emit(ctx, EventAttrRemoved, reply, err)
case <-broker.onChildNodeCountUpdated.Ready():
if ctxDone(ctx) {
return
}
reply, err := broker.onChildNodeCountUpdated.Recv()
broker.emit(ctx, EventChildNodeCountUpdated, reply, err)
case <-broker.onChildNodeInserted.Ready():
if ctxDone(ctx) {
return
}
reply, err := broker.onChildNodeInserted.Recv()
broker.emit(ctx, EventChildNodeInserted, reply, err)
case <-broker.onChildNodeRemoved.Ready():
if ctxDone(ctx) {
return
}
reply, err := broker.onChildNodeRemoved.Recv()
broker.emit(ctx, EventChildNodeRemoved, reply, err)
}
}
}
func ctxDone(ctx context.Context) bool {
return ctx.Err() == context.Canceled
}
func (broker *EventBroker) emit(ctx context.Context, event Event, message interface{}, err error) {
if err != nil {
event = EventError
message = err
}
broker.mu.Lock()
listeners, ok := broker.listeners[event]
if !ok {
broker.mu.Unlock()
return
}
snapshot := make([]EventListener, len(listeners))
copy(snapshot, listeners)
broker.mu.Unlock()
for _, listener := range snapshot {
select {
case <-ctx.Done():
return
default:
listener(ctx, message)
}
}
}

View File

@ -1,322 +0,0 @@
package events_test
import (
"context"
"github.com/MontFerret/ferret/pkg/drivers/cdp/events"
"github.com/mafredri/cdp/protocol/dom"
"github.com/mafredri/cdp/protocol/page"
. "github.com/smartystreets/goconvey/convey"
"golang.org/x/sync/errgroup"
"sync/atomic"
"testing"
"time"
)
type (
TestEventStream struct {
ready chan struct{}
message chan interface{}
}
TestLoadEventFiredClient struct {
*TestEventStream
}
TestDocumentUpdatedClient struct {
*TestEventStream
}
TestAttributeModifiedClient struct {
*TestEventStream
}
TestAttributeRemovedClient struct {
*TestEventStream
}
TestChildNodeCountUpdatedClient struct {
*TestEventStream
}
TestChildNodeInsertedClient struct {
*TestEventStream
}
TestChildNodeRemovedClient struct {
*TestEventStream
}
TestBroker struct {
*events.EventBroker
OnLoad *TestLoadEventFiredClient
OnReload *TestDocumentUpdatedClient
OnAttrMod *TestAttributeModifiedClient
OnAttrRem *TestAttributeRemovedClient
OnChildNodeCount *TestChildNodeCountUpdatedClient
OnChildNodeIns *TestChildNodeInsertedClient
OnChildNodeRem *TestChildNodeRemovedClient
}
)
func NewTestEventStream() *TestEventStream {
es := new(TestEventStream)
es.ready = make(chan struct{})
es.message = make(chan interface{})
return es
}
func (es *TestEventStream) Ready() <-chan struct{} {
return es.ready
}
func (es *TestEventStream) RecvMsg(i interface{}) error {
// NOT IMPLEMENTED
return nil
}
func (es *TestEventStream) Close() error {
close(es.message)
close(es.ready)
return nil
}
func (es *TestEventStream) Emit(msg interface{}) {
es.ready <- struct{}{}
es.message <- msg
}
func (es *TestLoadEventFiredClient) Recv() (*page.LoadEventFiredReply, error) {
r := <-es.message
reply := r.(*page.LoadEventFiredReply)
return reply, nil
}
func (es *TestLoadEventFiredClient) EmitDefault() {
es.TestEventStream.Emit(&page.LoadEventFiredReply{})
}
func (es *TestDocumentUpdatedClient) Recv() (*dom.DocumentUpdatedReply, error) {
r := <-es.message
reply := r.(*dom.DocumentUpdatedReply)
return reply, nil
}
func (es *TestAttributeModifiedClient) Recv() (*dom.AttributeModifiedReply, error) {
r := <-es.message
reply := r.(*dom.AttributeModifiedReply)
return reply, nil
}
func (es *TestAttributeRemovedClient) Recv() (*dom.AttributeRemovedReply, error) {
r := <-es.message
reply := r.(*dom.AttributeRemovedReply)
return reply, nil
}
func (es *TestChildNodeCountUpdatedClient) Recv() (*dom.ChildNodeCountUpdatedReply, error) {
r := <-es.message
reply := r.(*dom.ChildNodeCountUpdatedReply)
return reply, nil
}
func (es *TestChildNodeInsertedClient) Recv() (*dom.ChildNodeInsertedReply, error) {
r := <-es.message
reply := r.(*dom.ChildNodeInsertedReply)
return reply, nil
}
func (es *TestChildNodeRemovedClient) Recv() (*dom.ChildNodeRemovedReply, error) {
r := <-es.message
reply := r.(*dom.ChildNodeRemovedReply)
return reply, nil
}
func NewTestEventBroker() *TestBroker {
onLoad := &TestLoadEventFiredClient{NewTestEventStream()}
onReload := &TestDocumentUpdatedClient{NewTestEventStream()}
onAttrMod := &TestAttributeModifiedClient{NewTestEventStream()}
onAttrRem := &TestAttributeRemovedClient{NewTestEventStream()}
onChildCount := &TestChildNodeCountUpdatedClient{NewTestEventStream()}
onChildIns := &TestChildNodeInsertedClient{NewTestEventStream()}
onChildRem := &TestChildNodeRemovedClient{NewTestEventStream()}
b := events.NewEventBroker(
onLoad,
onReload,
onAttrMod,
onAttrRem,
onChildCount,
onChildIns,
onChildRem,
)
return &TestBroker{
b,
onLoad,
onReload,
onAttrMod,
onAttrRem,
onChildCount,
onChildIns,
onChildRem,
}
}
func StressTest(h func() error, count int) error {
var err error
for i := 0; i < count; i++ {
err = h()
if err != nil {
return err
}
}
return nil
}
func StressTestAsync(h func() error, count int) error {
var gr errgroup.Group
for i := 0; i < count; i++ {
gr.Go(h)
}
return gr.Wait()
}
func TestEventBroker(t *testing.T) {
Convey(".AddEventListener", t, func() {
Convey("Should add a new listener when not started", func() {
b := NewTestEventBroker()
StressTest(func() error {
b.AddEventListener(events.EventLoad, func(ctx context.Context, message interface{}) {})
return nil
}, 500)
})
Convey("Should add a new listener when started", func() {
b := NewTestEventBroker()
b.Start()
defer b.Stop()
StressTest(func() error {
b.AddEventListener(events.EventLoad, func(ctx context.Context, message interface{}) {})
return nil
}, 500)
})
})
Convey(".RemoveEventListener", t, func() {
Convey("Should remove a listener when not started", func() {
b := NewTestEventBroker()
StressTest(func() error {
listener := func(ctx context.Context, message interface{}) {}
b.AddEventListener(events.EventLoad, listener)
b.RemoveEventListener(events.EventLoad, listener)
So(b.ListenerCount(events.EventLoad), ShouldEqual, 0)
return nil
}, 500)
})
Convey("Should add a new listener when started", func() {
b := NewTestEventBroker()
b.Start()
defer b.Stop()
StressTest(func() error {
listener := func(ctx context.Context, message interface{}) {}
b.AddEventListener(events.EventLoad, listener)
StressTestAsync(func() error {
b.OnLoad.EmitDefault()
return nil
}, 250)
b.RemoveEventListener(events.EventLoad, listener)
So(b.ListenerCount(events.EventLoad), ShouldEqual, 0)
return nil
}, 250)
})
Convey("Should not call listener once it was removed", func() {
b := NewTestEventBroker()
b.Start()
defer b.Stop()
counter := 0
var listener events.EventListener
listener = func(ctx context.Context, message interface{}) {
counter++
b.RemoveEventListener(events.EventLoad, listener)
}
b.AddEventListener(events.EventLoad, listener)
b.OnLoad.Emit(&page.LoadEventFiredReply{})
time.Sleep(time.Duration(10) * time.Millisecond)
StressTestAsync(func() error {
b.OnLoad.Emit(&page.LoadEventFiredReply{})
return nil
}, 250)
So(b.ListenerCount(events.EventLoad), ShouldEqual, 0)
So(counter, ShouldEqual, 1)
})
})
Convey(".Stop", t, func() {
Convey("Should stop emitting events", func() {
b := NewTestEventBroker()
b.Start()
var counter int64
b.AddEventListener(events.EventLoad, func(ctx context.Context, message interface{}) {
atomic.AddInt64(&counter, 1)
b.Stop()
})
b.OnLoad.EmitDefault()
time.Sleep(time.Duration(5) * time.Millisecond)
go func() {
b.OnLoad.EmitDefault()
}()
go func() {
b.OnLoad.EmitDefault()
}()
time.Sleep(time.Duration(5) * time.Millisecond)
So(atomic.LoadInt64(&counter), ShouldEqual, 1)
})
})
}

View File

@ -2,125 +2,17 @@ package events
import ( import (
"context" "context"
"hash/fnv"
"github.com/mafredri/cdp"
"github.com/mafredri/cdp/protocol/dom"
"github.com/mafredri/cdp/protocol/page"
"github.com/pkg/errors"
) )
func WaitForLoadEvent(ctx context.Context, client *cdp.Client) error { func New(name string) ID {
loadEventFired, err := client.Page.LoadEventFired(ctx) h := fnv.New32a()
if err != nil { h.Write([]byte(name))
return errors.Wrap(err, "failed to create load event hook")
}
_, err = loadEventFired.Recv() return ID(h.Sum32())
if err != nil {
return err
}
return loadEventFired.Close()
} }
func CreateEventBroker(client *cdp.Client) (*EventBroker, error) { func isCtxDone(ctx context.Context) bool {
var err error return ctx.Err() == context.Canceled
var onLoad page.LoadEventFiredClient
var onReload dom.DocumentUpdatedClient
var onAttrModified dom.AttributeModifiedClient
var onAttrRemoved dom.AttributeRemovedClient
var onChildCountUpdated dom.ChildNodeCountUpdatedClient
var onChildNodeInserted dom.ChildNodeInsertedClient
var onChildNodeRemoved dom.ChildNodeRemovedClient
ctx := context.Background()
onLoad, err = client.Page.LoadEventFired(ctx)
if err != nil {
return nil, err
}
onReload, err = client.DOM.DocumentUpdated(ctx)
if err != nil {
onLoad.Close()
return nil, err
}
onAttrModified, err = client.DOM.AttributeModified(ctx)
if err != nil {
onLoad.Close()
onReload.Close()
return nil, err
}
onAttrRemoved, err = client.DOM.AttributeRemoved(ctx)
if err != nil {
onLoad.Close()
onReload.Close()
onAttrModified.Close()
return nil, err
}
onChildCountUpdated, err = client.DOM.ChildNodeCountUpdated(ctx)
if err != nil {
onLoad.Close()
onReload.Close()
onAttrModified.Close()
onAttrRemoved.Close()
return nil, err
}
onChildNodeInserted, err = client.DOM.ChildNodeInserted(ctx)
if err != nil {
onLoad.Close()
onReload.Close()
onAttrModified.Close()
onAttrRemoved.Close()
onChildCountUpdated.Close()
return nil, err
}
onChildNodeRemoved, err = client.DOM.ChildNodeRemoved(ctx)
if err != nil {
onLoad.Close()
onReload.Close()
onAttrModified.Close()
onAttrRemoved.Close()
onChildCountUpdated.Close()
onChildNodeInserted.Close()
return nil, err
}
broker := NewEventBroker(
onLoad,
onReload,
onAttrModified,
onAttrRemoved,
onChildCountUpdated,
onChildNodeInserted,
onChildNodeRemoved,
)
err = broker.Start()
if err != nil {
onLoad.Close()
onReload.Close()
onAttrModified.Close()
onAttrRemoved.Close()
onChildCountUpdated.Close()
onChildNodeInserted.Close()
onChildNodeRemoved.Close()
return nil, err
}
return broker, nil
} }

View File

@ -0,0 +1,38 @@
package events
import "context"
type (
// Handler represents a function that is called when a particular event occurs
// Returned boolean value indicates whether the handler needs to be called again
// False value indicated that it needs to be removed and never called again
Handler func(ctx context.Context, message interface{}) bool
// ListenerID is an internal listener ID that can be used to unsubscribe from a particular event
ListenerID int
// Listener is an internal listener representation
Listener struct {
ID ListenerID
EventID ID
Handler Handler
}
)
// Always returns a handler wrapper that always gets executed by an event loop
func Always(fn func(ctx context.Context, message interface{})) Handler {
return func(ctx context.Context, message interface{}) bool {
fn(ctx, message)
return true
}
}
// Once returns a handler wrapper that gets executed only once by an event loop
func Once(fn func(ctx context.Context, message interface{})) Handler {
return func(ctx context.Context, message interface{}) bool {
fn(ctx, message)
return false
}
}

View File

@ -0,0 +1,74 @@
package events
import "sync"
type ListenerCollection struct {
mu sync.Mutex
values map[ID]map[ListenerID]Listener
}
func NewListenerCollection() *ListenerCollection {
lc := new(ListenerCollection)
lc.values = make(map[ID]map[ListenerID]Listener)
return lc
}
func (lc *ListenerCollection) Size(eventID ID) int {
lc.mu.Lock()
defer lc.mu.Unlock()
bucket, exists := lc.values[eventID]
if !exists {
return 0
}
return len(bucket)
}
func (lc *ListenerCollection) Add(listener Listener) {
lc.mu.Lock()
defer lc.mu.Unlock()
bucket, exists := lc.values[listener.EventID]
if !exists {
bucket = make(map[ListenerID]Listener)
lc.values[listener.EventID] = bucket
}
bucket[listener.ID] = listener
}
func (lc *ListenerCollection) Remove(eventID ID, listenerID ListenerID) {
lc.mu.Lock()
defer lc.mu.Unlock()
bucket, exists := lc.values[eventID]
if !exists {
return
}
delete(bucket, listenerID)
}
func (lc *ListenerCollection) Values(eventID ID) []Listener {
lc.mu.Lock()
defer lc.mu.Unlock()
bucket, exists := lc.values[eventID]
if !exists {
return []Listener{}
}
snapshot := make([]Listener, 0, len(bucket))
for _, listener := range bucket {
snapshot = append(snapshot, listener)
}
return snapshot
}

View File

@ -0,0 +1,165 @@
package events
import (
"context"
"math/rand"
"sync"
)
type Loop struct {
mu sync.Mutex
cancel context.CancelFunc
sources *SourceCollection
listeners *ListenerCollection
}
func NewLoop() *Loop {
loop := new(Loop)
loop.sources = NewSourceCollection()
loop.listeners = NewListenerCollection()
return loop
}
func (loop *Loop) Start() *Loop {
loop.mu.Lock()
defer loop.mu.Unlock()
if loop.cancel != nil {
return loop
}
loopCtx, cancel := context.WithCancel(context.Background())
loop.cancel = cancel
go loop.run(loopCtx)
return loop
}
func (loop *Loop) Stop() *Loop {
loop.mu.Lock()
defer loop.mu.Unlock()
if loop.cancel == nil {
return loop
}
loop.cancel()
loop.cancel = nil
return loop
}
func (loop *Loop) Close() error {
loop.mu.Lock()
defer loop.mu.Unlock()
if loop.cancel != nil {
loop.cancel()
loop.cancel = nil
}
return loop.sources.Close()
}
func (loop *Loop) AddSource(source Source) {
loop.sources.Add(source)
}
func (loop *Loop) RemoveSource(source Source) {
loop.sources.Remove(source)
}
func (loop *Loop) AddListener(eventID ID, handler Handler) ListenerID {
listener := Listener{
ID: ListenerID(rand.Int()),
EventID: eventID,
Handler: handler,
}
loop.listeners.Add(listener)
return listener.ID
}
func (loop *Loop) RemoveListener(eventID ID, listenerID ListenerID) {
loop.listeners.Remove(eventID, listenerID)
}
// run starts running an event loop.
// It constantly iterates over each event source.
// Additionally to that, on each iteration it checks the command channel in order to perform add/remove listener/source operations.
func (loop *Loop) run(ctx context.Context) {
size := loop.sources.Size()
counter := -1
// in case event array is empty
// we use this mock noop event source to simplify the logic
noop := newNoopSource()
for {
counter++
if counter >= size {
// reset the counter
size = loop.sources.Size()
counter = 0
}
var source Source
if size > 0 {
found, err := loop.sources.Get(counter)
if err == nil {
source = found
} else {
// might be removed
source = noop
// force to reset counter
counter = size
}
} else {
source = noop
}
// commands have higher priority
select {
case <-ctx.Done():
return
case <-source.Ready():
if isCtxDone(ctx) {
return
}
event, err := source.Recv()
loop.emit(ctx, event.ID, event.Data, err)
default:
continue
}
}
}
func (loop *Loop) emit(ctx context.Context, eventID ID, message interface{}, err error) {
if err != nil {
eventID = Error
message = err
}
snapshot := loop.listeners.Values(eventID)
for _, listener := range snapshot {
select {
case <-ctx.Done():
return
default:
// if returned false, it means the loops should call the handler anymore
if !listener.Handler(ctx, message) {
loop.listeners.Remove(eventID, listener.ID)
}
}
}
}

View File

@ -0,0 +1,426 @@
package events_test
import (
"context"
"github.com/MontFerret/ferret/pkg/drivers/cdp/events"
"github.com/mafredri/cdp/protocol/dom"
"github.com/mafredri/cdp/protocol/page"
"github.com/mafredri/cdp/rpcc"
. "github.com/smartystreets/goconvey/convey"
"sync"
"testing"
"time"
)
type (
TestEventStream struct {
ready chan struct{}
message chan interface{}
}
TestLoadEventFiredClient struct {
*TestEventStream
}
TestDocumentUpdatedClient struct {
*TestEventStream
}
TestAttributeModifiedClient struct {
*TestEventStream
}
TestAttributeRemovedClient struct {
*TestEventStream
}
TestChildNodeCountUpdatedClient struct {
*TestEventStream
}
TestChildNodeInsertedClient struct {
*TestEventStream
}
TestChildNodeRemovedClient struct {
*TestEventStream
}
)
var TestEvent = events.New("test_event")
func NewTestEventStream() *TestEventStream {
es := new(TestEventStream)
es.ready = make(chan struct{})
es.message = make(chan interface{})
return es
}
func (es *TestEventStream) Ready() <-chan struct{} {
return es.ready
}
func (es *TestEventStream) RecvMsg(i interface{}) error {
// NOT IMPLEMENTED
return nil
}
func (es *TestEventStream) Close() error {
close(es.message)
close(es.ready)
return nil
}
func (es *TestEventStream) Emit(msg interface{}) {
es.ready <- struct{}{}
es.message <- msg
}
func (es *TestLoadEventFiredClient) Recv() (*page.LoadEventFiredReply, error) {
r := <-es.message
reply := r.(*page.LoadEventFiredReply)
return reply, nil
}
func (es *TestLoadEventFiredClient) EmitDefault() {
es.TestEventStream.Emit(&page.LoadEventFiredReply{})
}
func (es *TestDocumentUpdatedClient) Recv() (*dom.DocumentUpdatedReply, error) {
r := <-es.message
reply := r.(*dom.DocumentUpdatedReply)
return reply, nil
}
func (es *TestAttributeModifiedClient) Recv() (*dom.AttributeModifiedReply, error) {
r := <-es.message
reply := r.(*dom.AttributeModifiedReply)
return reply, nil
}
func (es *TestAttributeRemovedClient) Recv() (*dom.AttributeRemovedReply, error) {
r := <-es.message
reply := r.(*dom.AttributeRemovedReply)
return reply, nil
}
func (es *TestChildNodeCountUpdatedClient) Recv() (*dom.ChildNodeCountUpdatedReply, error) {
r := <-es.message
reply := r.(*dom.ChildNodeCountUpdatedReply)
return reply, nil
}
func (es *TestChildNodeInsertedClient) Recv() (*dom.ChildNodeInsertedReply, error) {
r := <-es.message
reply := r.(*dom.ChildNodeInsertedReply)
return reply, nil
}
func (es *TestChildNodeRemovedClient) Recv() (*dom.ChildNodeRemovedReply, error) {
r := <-es.message
reply := r.(*dom.ChildNodeRemovedReply)
return reply, nil
}
func wait() {
time.Sleep(time.Duration(50) * time.Millisecond)
}
type Counter struct {
mu sync.Mutex
value int64
}
func NewCounter() *Counter {
return new(Counter)
}
func (c *Counter) Increase() *Counter {
c.mu.Lock()
defer c.mu.Unlock()
c.value++
return c
}
func (c *Counter) Decrease() *Counter {
c.mu.Lock()
defer c.mu.Unlock()
c.value--
return c
}
func (c *Counter) Value() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.value
}
func TestLoop(t *testing.T) {
Convey(".AddListener", t, func() {
Convey("Should add a new listener", func() {
loop := events.NewLoop()
counter := NewCounter()
onLoad := &TestLoadEventFiredClient{NewTestEventStream()}
src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) {
return onLoad.Recv()
})
loop.AddSource(src)
loop.Start()
defer loop.Stop()
onLoad.EmitDefault()
wait()
So(counter.Value(), ShouldEqual, 0)
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
counter.Increase()
}))
wait()
onLoad.EmitDefault()
wait()
So(counter.Value(), ShouldEqual, 1)
})
})
Convey(".RemoveListener", t, func() {
Convey("Should remove a listener", func() {
Convey("Should add a new listener", func() {
loop := events.NewLoop()
counter := NewCounter()
onLoad := &TestLoadEventFiredClient{NewTestEventStream()}
src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) {
return onLoad.Recv()
})
loop.AddSource(src)
id := loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
counter.Increase()
}))
loop.Start()
defer loop.Stop()
onLoad.EmitDefault()
wait()
So(counter.Value(), ShouldEqual, 1)
wait()
loop.RemoveListener(TestEvent, id)
wait()
onLoad.EmitDefault()
wait()
So(counter.Value(), ShouldEqual, 1)
})
})
})
Convey(".AddSource", t, func() {
Convey("Should add a new event source when not started", func() {
loop := events.NewLoop()
counter := NewCounter()
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
counter.Increase()
}))
loop.Start()
defer loop.Stop()
onLoad := &TestLoadEventFiredClient{NewTestEventStream()}
go func() {
onLoad.EmitDefault()
}()
wait()
So(counter.Value(), ShouldEqual, 0)
src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) {
return onLoad.Recv()
})
loop.AddSource(src)
wait()
So(counter.Value(), ShouldEqual, 1)
})
})
Convey(".RemoveSource", t, func() {
Convey("Should remove a source", func() {
loop := events.NewLoop()
counter := NewCounter()
loop.Start()
defer loop.Stop()
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
counter.Increase()
}))
onLoad := &TestLoadEventFiredClient{NewTestEventStream()}
src := events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) {
return onLoad.Recv()
})
loop.AddSource(src)
wait()
onLoad.EmitDefault()
wait()
So(counter.Value(), ShouldEqual, 1)
loop.RemoveSource(src)
wait()
go func() {
onLoad.EmitDefault()
}()
wait()
So(counter.Value(), ShouldEqual, 1)
})
})
Convey("Should not call listener once it was removed", t, func() {
loop := events.NewLoop()
onEvent := make(chan struct{})
counter := NewCounter()
id := loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
counter.Increase()
onEvent <- struct{}{}
}))
go func() {
<-onEvent
loop.RemoveListener(TestEvent, id)
}()
onLoad := &TestLoadEventFiredClient{NewTestEventStream()}
loop.AddSource(events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) {
return onLoad.Recv()
}))
loop.Start()
defer loop.Stop()
time.Sleep(time.Duration(100) * time.Millisecond)
onLoad.Emit(&page.LoadEventFiredReply{})
time.Sleep(time.Duration(10) * time.Millisecond)
So(counter.Value(), ShouldEqual, 1)
})
}
func BenchmarkLoop_AddListenerSync(b *testing.B) {
loop := events.NewLoop()
for n := 0; n < b.N; n++ {
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {}))
}
}
func BenchmarkLoop_AddListenerAsync(b *testing.B) {
loop := events.NewLoop()
loop.Start()
defer loop.Stop()
for n := 0; n < b.N; n++ {
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {}))
}
}
func BenchmarkLoop_AddListenerAsync2(b *testing.B) {
loop := events.NewLoop()
loop.Start()
defer loop.Stop()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {}))
}
})
}
func BenchmarkLoop_Start(b *testing.B) {
loop := events.NewLoop()
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
}))
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
}))
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
}))
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
}))
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
}))
loop.AddListener(TestEvent, events.Always(func(ctx context.Context, message interface{}) {
}))
onLoad := &TestLoadEventFiredClient{NewTestEventStream()}
loop.AddSource(events.NewSource(TestEvent, onLoad, func(_ rpcc.Stream) (i interface{}, e error) {
return onLoad.Recv()
}))
loop.Start()
defer loop.Stop()
for n := 0; n < b.N; n++ {
onLoad.Emit(&page.LoadEventFiredReply{})
}
}

View File

@ -0,0 +1,31 @@
package events
import (
"github.com/MontFerret/ferret/pkg/runtime/core"
)
type noopEvent struct {
c chan struct{}
}
func newNoopSource() Source {
return noopEvent{
c: make(chan struct{}),
}
}
func (n noopEvent) Ready() <-chan struct{} {
return n.c
}
func (n noopEvent) RecvMsg(_ interface{}) error {
return core.ErrNotSupported
}
func (n noopEvent) Close() error {
return nil
}
func (n noopEvent) Recv() (Event, error) {
return Event{}, core.ErrNotSupported
}

View File

@ -0,0 +1,74 @@
package events
import (
"github.com/mafredri/cdp/rpcc"
)
type (
// ID represents a unique event ID
ID int
// Event represents a system event that is returned from an event source
Event struct {
ID ID
Data interface{}
}
// Source represents a custom source of system events
Source interface {
rpcc.Stream
Recv() (Event, error)
}
// GenericSource represents a helper struct for generating custom event sources
GenericSource struct {
eventID ID
stream rpcc.Stream
recv func(stream rpcc.Stream) (interface{}, error)
}
)
var (
Error = New("error")
)
// NewSource create a new custom event source
// eventID - is a unique event ID
// stream - is a custom event stream
// recv - is a value conversion function
func NewSource(
eventID ID,
stream rpcc.Stream,
recv func(stream rpcc.Stream) (interface{}, error),
) Source {
return &GenericSource{eventID, stream, recv}
}
func (src *GenericSource) EventID() ID {
return src.eventID
}
func (src *GenericSource) Ready() <-chan struct{} {
return src.stream.Ready()
}
func (src *GenericSource) RecvMsg(m interface{}) error {
return src.stream.RecvMsg(m)
}
func (src *GenericSource) Close() error {
return src.stream.Close()
}
func (src *GenericSource) Recv() (Event, error) {
data, err := src.recv(src.stream)
if err != nil {
return Event{}, err
}
return Event{
ID: src.eventID,
Data: data,
}, nil
}

View File

@ -0,0 +1,82 @@
package events
import (
"github.com/MontFerret/ferret/pkg/runtime/core"
"sync"
)
type SourceCollection struct {
mu sync.Mutex
values []Source
}
func NewSourceCollection() *SourceCollection {
sc := new(SourceCollection)
sc.values = make([]Source, 0, 10)
return sc
}
func (sc *SourceCollection) Close() error {
sc.mu.Lock()
defer sc.mu.Unlock()
errs := make([]error, 0, len(sc.values))
for _, e := range sc.values {
if err := e.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return core.Errors(errs...)
}
return nil
}
func (sc *SourceCollection) Size() int {
sc.mu.Lock()
defer sc.mu.Unlock()
return len(sc.values)
}
func (sc *SourceCollection) Get(idx int) (Source, error) {
sc.mu.Lock()
defer sc.mu.Unlock()
if len(sc.values) <= idx {
return nil, core.ErrNotFound
}
return sc.values[idx], nil
}
func (sc *SourceCollection) Add(source Source) {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.values = append(sc.values, source)
}
func (sc *SourceCollection) Remove(source Source) bool {
sc.mu.Lock()
defer sc.mu.Unlock()
idx := -1
for i, current := range sc.values {
if current == source {
idx = i
break
}
}
if idx > -1 {
sc.values = append(sc.values[:idx], sc.values[idx+1:]...)
}
return idx > -1
}

View File

@ -1,31 +1,16 @@
package cdp package cdp
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"github.com/MontFerret/ferret/pkg/drivers/cdp/templates"
"golang.org/x/net/html"
"strings"
"time"
"github.com/MontFerret/ferret/pkg/drivers" "github.com/MontFerret/ferret/pkg/drivers"
"github.com/MontFerret/ferret/pkg/drivers/cdp/eval"
"github.com/MontFerret/ferret/pkg/drivers/common" "github.com/MontFerret/ferret/pkg/drivers/common"
"github.com/MontFerret/ferret/pkg/runtime/values"
"github.com/PuerkitoBio/goquery"
"github.com/mafredri/cdp" "github.com/mafredri/cdp"
"github.com/mafredri/cdp/protocol/dom" "github.com/mafredri/cdp/protocol/emulation"
"github.com/mafredri/cdp/protocol/network" "github.com/mafredri/cdp/protocol/network"
"github.com/mafredri/cdp/protocol/page" "github.com/mafredri/cdp/protocol/page"
"github.com/mafredri/cdp/protocol/runtime"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
var emptyExpires = time.Time{}
type ( type (
batchFunc = func() error batchFunc = func() error
) )
@ -40,336 +25,87 @@ func runBatch(funcs ...batchFunc) error {
return eg.Wait() return eg.Wait()
} }
func parseAttrs(attrs []string) *values.Object { func enableFeatures(ctx context.Context, client *cdp.Client, params drivers.Params) error {
var attr values.String if err := client.Page.Enable(ctx); err != nil {
res := values.NewObject()
for _, el := range attrs {
el = strings.TrimSpace(el)
str := values.NewString(el)
if common.IsAttribute(el) {
attr = str
res.Set(str, values.EmptyString)
} else {
current, ok := res.Get(attr)
if ok {
if current.String() != "" {
res.Set(attr, current.(values.String).Concat(values.SpaceString).Concat(str))
} else {
res.Set(attr, str)
}
}
}
}
return res
}
func setInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerHTML values.String) error {
var objID *runtime.RemoteObjectID
if id.ObjectID != "" {
objID = &id.ObjectID
} else {
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID))
if err != nil {
return err
}
if repl.Object.ObjectID == nil {
return errors.New("unable to resolve node")
}
objID = repl.Object.ObjectID
}
b, err := json.Marshal(innerHTML.String())
if err != nil {
return err return err
} }
err = exec.EvalWithArguments(ctx, templates.SetInnerHTML(), return runBatch(
runtime.CallArgument{ func() error {
ObjectID: objID, return client.Page.SetLifecycleEventsEnabled(
ctx,
page.NewSetLifecycleEventsEnabledArgs(true),
)
}, },
runtime.CallArgument{
Value: json.RawMessage(b), func() error {
return client.DOM.Enable(ctx)
}, },
)
return err func() error {
} return client.Runtime.Enable(ctx)
},
func getInnerHTML(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) { func() error {
// not a document ua := common.GetUserAgent(params.UserAgent)
if nodeType != html.DocumentNode {
var objID runtime.RemoteObjectID
if id.ObjectID != "" { //logger.
objID = id.ObjectID // Debug().
} else { // Timestamp().
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID)) // Str("user-agent", ua).
// Msg("using User-Agent")
if err != nil { // do not use custom user agent
return "", err if ua == "" {
return nil
} }
if repl.Object.ObjectID == nil { return client.Emulation.SetUserAgentOverride(
return "", errors.New("unable to resolve node") ctx,
emulation.NewSetUserAgentOverrideArgs(ua),
)
},
func() error {
return client.Network.Enable(ctx, network.NewEnableArgs())
},
func() error {
return client.Page.SetBypassCSP(ctx, page.NewSetBypassCSPArgs(true))
},
func() error {
if params.Viewport == nil {
return nil
} }
objID = *repl.Object.ObjectID orientation := emulation.ScreenOrientation{}
}
res, err := exec.ReadProperty(ctx, objID, "innerHTML") if !params.Viewport.Landscape {
orientation.Type = "portraitPrimary"
orientation.Angle = 0
} else {
orientation.Type = "landscapePrimary"
orientation.Angle = 90
}
if err != nil { scaleFactor := params.Viewport.ScaleFactor
return "", err
}
return values.NewString(res.String()), nil if scaleFactor <= 0 {
} scaleFactor = 1
}
repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerHTML") deviceArgs := emulation.NewSetDeviceMetricsOverrideArgs(
params.Viewport.Width,
params.Viewport.Height,
scaleFactor,
params.Viewport.Mobile,
).SetScreenOrientation(orientation)
if err != nil { return client.Emulation.SetDeviceMetricsOverride(
return "", err ctx,
} deviceArgs,
)
return values.NewString(repl.String()), nil
}
func setInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, innerText values.String) error {
var objID *runtime.RemoteObjectID
if id.ObjectID != "" {
objID = &id.ObjectID
} else {
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID))
if err != nil {
return err
}
if repl.Object.ObjectID == nil {
return errors.New("unable to resolve node")
}
objID = repl.Object.ObjectID
}
b, err := json.Marshal(innerText.String())
if err != nil {
return err
}
err = exec.EvalWithArguments(ctx, templates.SetInnerText(),
runtime.CallArgument{
ObjectID: objID,
},
runtime.CallArgument{
Value: json.RawMessage(b),
}, },
) )
return err
}
func getInnerText(ctx context.Context, client *cdp.Client, exec *eval.ExecutionContext, id HTMLElementIdentity, nodeType html.NodeType) (values.String, error) {
// not a document
if nodeType != html.DocumentNode {
var objID runtime.RemoteObjectID
if id.ObjectID != "" {
objID = id.ObjectID
} else {
repl, err := client.DOM.ResolveNode(ctx, dom.NewResolveNodeArgs().SetNodeID(id.NodeID))
if err != nil {
return "", err
}
if repl.Object.ObjectID == nil {
return "", errors.New("unable to resolve node")
}
objID = *repl.Object.ObjectID
}
res, err := exec.ReadProperty(ctx, objID, "innerText")
if err != nil {
return "", err
}
return values.NewString(res.String()), err
}
repl, err := exec.EvalWithReturnValue(ctx, "return document.documentElement.innerText")
if err != nil {
return "", err
}
return values.NewString(repl.String()), nil
}
func parseInnerText(innerHTML string) (values.String, error) {
buff := bytes.NewBuffer([]byte(innerHTML))
parsed, err := goquery.NewDocumentFromReader(buff)
if err != nil {
return values.EmptyString, err
}
return values.NewString(parsed.Text()), nil
}
func createChildrenArray(nodes []dom.Node) []HTMLElementIdentity {
children := make([]HTMLElementIdentity, len(nodes))
for idx, child := range nodes {
child := child
children[idx] = HTMLElementIdentity{
NodeID: child.NodeID,
}
}
return children
}
func fromDriverCookie(url string, cookie drivers.HTTPCookie) network.CookieParam {
sameSite := network.CookieSameSiteNotSet
switch cookie.SameSite {
case drivers.SameSiteLaxMode:
sameSite = network.CookieSameSiteLax
case drivers.SameSiteStrictMode:
sameSite = network.CookieSameSiteStrict
}
if cookie.Expires == emptyExpires {
cookie.Expires = time.Now().Add(time.Duration(24) + time.Hour)
}
normalizedURL := normalizeCookieURL(url)
return network.CookieParam{
URL: &normalizedURL,
Name: cookie.Name,
Value: cookie.Value,
Secure: &cookie.Secure,
Path: &cookie.Path,
Domain: &cookie.Domain,
HTTPOnly: &cookie.HTTPOnly,
SameSite: sameSite,
Expires: network.TimeSinceEpoch(cookie.Expires.Unix()),
}
}
func fromDriverCookieDelete(url string, cookie drivers.HTTPCookie) *network.DeleteCookiesArgs {
normalizedURL := normalizeCookieURL(url)
return &network.DeleteCookiesArgs{
URL: &normalizedURL,
Name: cookie.Name,
Path: &cookie.Path,
Domain: &cookie.Domain,
}
}
func toDriverCookie(c network.Cookie) drivers.HTTPCookie {
sameSite := drivers.SameSiteDefaultMode
switch c.SameSite {
case network.CookieSameSiteLax:
sameSite = drivers.SameSiteLaxMode
case network.CookieSameSiteStrict:
sameSite = drivers.SameSiteStrictMode
}
return drivers.HTTPCookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: time.Unix(int64(c.Expires), 0),
SameSite: sameSite,
Secure: c.Secure,
HTTPOnly: c.HTTPOnly,
}
}
func normalizeCookieURL(url string) string {
const httpPrefix = "http://"
const httpsPrefix = "https://"
if strings.HasPrefix(url, httpPrefix) || strings.HasPrefix(url, httpsPrefix) {
return url
}
return httpPrefix + url
}
func resolveFrame(ctx context.Context, client *cdp.Client, frame page.Frame) (dom.Node, runtime.ExecutionContextID, error) {
worldRepl, err := client.Page.CreateIsolatedWorld(ctx, page.NewCreateIsolatedWorldArgs(frame.ID))
if err != nil {
return dom.Node{}, -1, err
}
evalRes, err := client.Runtime.Evaluate(
ctx,
runtime.NewEvaluateArgs(eval.PrepareEval("return document")).
SetContextID(worldRepl.ExecutionContextID),
)
if err != nil {
return dom.Node{}, -1, err
}
if evalRes.ExceptionDetails != nil {
exception := *evalRes.ExceptionDetails
return dom.Node{}, -1, errors.New(exception.Text)
}
if evalRes.Result.ObjectID == nil {
return dom.Node{}, -1, errors.New("failed to resolve frame document")
}
req, err := client.DOM.RequestNode(ctx, dom.NewRequestNodeArgs(*evalRes.Result.ObjectID))
if err != nil {
return dom.Node{}, -1, err
}
if req.NodeID == 0 {
return dom.Node{}, -1, errors.New("framed document is resolved with empty node id")
}
desc, err := client.DOM.DescribeNode(
ctx,
dom.
NewDescribeNodeArgs().
SetNodeID(req.NodeID).
SetDepth(1),
)
if err != nil {
return dom.Node{}, -1, err
}
// Returned node, by some reason, does not contain the NodeID
// So, we have to set it manually
desc.Node.NodeID = req.NodeID
return desc.Node, worldRepl.ExecutionContextID, nil
} }

View File

@ -0,0 +1,7 @@
package network
import "github.com/MontFerret/ferret/pkg/drivers/cdp/events"
var (
eventFrameLoad = events.New("frame_load")
)

View File

@ -0,0 +1,83 @@
package network
import (
"github.com/MontFerret/ferret/pkg/drivers"
"github.com/mafredri/cdp/protocol/network"
"strings"
"time"
)
var emptyExpires = time.Time{}
func fromDriverCookie(url string, cookie drivers.HTTPCookie) network.CookieParam {
sameSite := network.CookieSameSiteNotSet
switch cookie.SameSite {
case drivers.SameSiteLaxMode:
sameSite = network.CookieSameSiteLax
case drivers.SameSiteStrictMode:
sameSite = network.CookieSameSiteStrict
}
if cookie.Expires == emptyExpires {
cookie.Expires = time.Now().Add(time.Duration(24) + time.Hour)
}
normalizedURL := normalizeCookieURL(url)
return network.CookieParam{
URL: &normalizedURL,
Name: cookie.Name,
Value: cookie.Value,
Secure: &cookie.Secure,
Path: &cookie.Path,
Domain: &cookie.Domain,
HTTPOnly: &cookie.HTTPOnly,
SameSite: sameSite,
Expires: network.TimeSinceEpoch(cookie.Expires.Unix()),
}
}
func fromDriverCookieDelete(url string, cookie drivers.HTTPCookie) *network.DeleteCookiesArgs {
normalizedURL := normalizeCookieURL(url)
return &network.DeleteCookiesArgs{
URL: &normalizedURL,
Name: cookie.Name,
Path: &cookie.Path,
Domain: &cookie.Domain,
}
}
func toDriverCookie(c network.Cookie) drivers.HTTPCookie {
sameSite := drivers.SameSiteDefaultMode
switch c.SameSite {
case network.CookieSameSiteLax:
sameSite = drivers.SameSiteLaxMode
case network.CookieSameSiteStrict:
sameSite = drivers.SameSiteStrictMode
}
return drivers.HTTPCookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: time.Unix(int64(c.Expires), 0),
SameSite: sameSite,
Secure: c.Secure,
HTTPOnly: c.HTTPOnly,
}
}
func normalizeCookieURL(url string) string {
const httpPrefix = "http://"
const httpsPrefix = "https://"
if strings.HasPrefix(url, httpPrefix) || strings.HasPrefix(url, httpsPrefix) {
return url
}
return httpPrefix + url
}

View File

@ -0,0 +1,340 @@
package network
import (
"context"
"encoding/json"
"regexp"
"sync"
"github.com/mafredri/cdp"
"github.com/mafredri/cdp/protocol/network"
"github.com/mafredri/cdp/protocol/page"
"github.com/mafredri/cdp/rpcc"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/MontFerret/ferret/pkg/drivers"
"github.com/MontFerret/ferret/pkg/drivers/cdp/events"
"github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/values"
)
const BlankPageURL = "about:blank"
type (
FrameLoadedListener = func(ctx context.Context, frame page.Frame)
Manager struct {
mu sync.Mutex
logger *zerolog.Logger
client *cdp.Client
headers drivers.HTTPHeaders
eventLoop *events.Loop
cancel context.CancelFunc
listeners []FrameLoadedListener
}
)
func New(
logger *zerolog.Logger,
client *cdp.Client,
eventLoop *events.Loop,
) (*Manager, error) {
ctx, cancel := context.WithCancel(context.Background())
m := new(Manager)
m.logger = logger
m.client = client
m.headers = make(drivers.HTTPHeaders)
m.eventLoop = eventLoop
m.cancel = cancel
frameNavigatedStream, err := m.client.Page.FrameNavigated(ctx)
if err != nil {
return nil, err
}
m.eventLoop.AddSource(events.NewSource(eventFrameLoad, frameNavigatedStream, func(stream rpcc.Stream) (interface{}, error) {
return stream.(page.FrameNavigatedClient).Recv()
}))
return m, nil
}
func (m *Manager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
m.cancel()
m.cancel = nil
}
return nil
}
func (m *Manager) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) {
repl, err := m.client.Network.GetAllCookies(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get cookies")
}
cookies := make(drivers.HTTPCookies)
if repl.Cookies == nil {
return cookies, nil
}
for _, c := range repl.Cookies {
cookies[c.Name] = toDriverCookie(c)
}
return cookies, nil
}
func (m *Manager) SetCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error {
m.mu.Lock()
defer m.mu.Unlock()
if len(cookies) == 0 {
return nil
}
params := make([]network.CookieParam, 0, len(cookies))
for _, c := range cookies {
params = append(params, fromDriverCookie(url, c))
}
return m.client.Network.SetCookies(ctx, network.NewSetCookiesArgs(params))
}
func (m *Manager) DeleteCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error {
m.mu.Lock()
defer m.mu.Unlock()
if len(cookies) == 0 {
return nil
}
var err error
for _, c := range cookies {
err = m.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(url, c))
if err != nil {
break
}
}
return err
}
func (m *Manager) GetHeaders(_ context.Context) (drivers.HTTPHeaders, error) {
copied := make(drivers.HTTPHeaders)
for k, v := range m.headers {
copied[k] = v
}
return copied, nil
}
func (m *Manager) SetHeaders(ctx context.Context, headers drivers.HTTPHeaders) error {
m.mu.Lock()
defer m.mu.Unlock()
if len(headers) == 0 {
return nil
}
m.headers = headers
j, err := json.Marshal(headers)
if err != nil {
return errors.Wrap(err, "failed to marshal headers")
}
err = m.client.Network.SetExtraHTTPHeaders(
ctx,
network.NewSetExtraHTTPHeadersArgs(j),
)
if err != nil {
return errors.Wrap(err, "failed to set headers")
}
return nil
}
func (m *Manager) Navigate(ctx context.Context, url values.String) error {
m.mu.Lock()
defer m.mu.Unlock()
if url == "" {
url = BlankPageURL
}
urlStr := url.String()
repl, err := m.client.Page.Navigate(ctx, page.NewNavigateArgs(urlStr))
if err != nil {
return err
}
if repl.ErrorText != nil {
return errors.New(*repl.ErrorText)
}
return m.WaitForNavigation(ctx, nil)
}
func (m *Manager) NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error) {
m.mu.Lock()
defer m.mu.Unlock()
history, err := m.client.Page.GetNavigationHistory(ctx)
if err != nil {
return values.False, err
}
length := len(history.Entries)
lastIndex := length - 1
// nowhere to go forward
if history.CurrentIndex == lastIndex {
return values.False, nil
}
if skip < 1 {
skip = 1
}
to := int(skip) + history.CurrentIndex
if to > lastIndex {
// TODO: Return error?
return values.False, nil
}
entry := history.Entries[to]
err = m.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(entry.ID))
if err != nil {
return values.False, err
}
err = m.WaitForNavigation(ctx, nil)
if err != nil {
return values.False, err
}
return values.True, nil
}
func (m *Manager) NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error) {
m.mu.Lock()
defer m.mu.Unlock()
history, err := m.client.Page.GetNavigationHistory(ctx)
if err != nil {
return values.False, err
}
// we are in the beginning
if history.CurrentIndex == 0 {
return values.False, nil
}
if skip < 1 {
skip = 1
}
to := history.CurrentIndex - int(skip)
if to < 0 {
// TODO: Return error?
return values.False, nil
}
entry := history.Entries[to]
err = m.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(entry.ID))
if err != nil {
return values.False, err
}
err = m.WaitForNavigation(ctx, nil)
if err != nil {
return values.False, err
}
return values.True, nil
}
func (m *Manager) WaitForNavigation(ctx context.Context, pattern *regexp.Regexp) error {
return m.WaitForFrameNavigation(ctx, "", pattern)
}
func (m *Manager) WaitForFrameNavigation(ctx context.Context, frameID page.FrameID, urlPattern *regexp.Regexp) error {
onEvent := make(chan struct{})
defer func() {
close(onEvent)
}()
m.eventLoop.AddListener(eventFrameLoad, func(_ context.Context, message interface{}) bool {
repl := message.(*page.FrameNavigatedReply)
var matched bool
// if frameID is empty string or equals to the current one
if len(frameID) == 0 || repl.Frame.ID == frameID {
// if a URL pattern is provided
if urlPattern != nil {
matched = urlPattern.Match([]byte(repl.Frame.URL))
} else {
// otherwise just notify
matched = true
}
}
if matched {
if ctx.Err() == nil {
onEvent <- struct{}{}
}
}
// if not matched - continue listening
return !matched
})
select {
case <-onEvent:
return nil
case <-ctx.Done():
return core.ErrTimeout
}
}
func (m *Manager) AddFrameLoadedListener(listener FrameLoadedListener) events.ListenerID {
return m.eventLoop.AddListener(eventFrameLoad, func(ctx context.Context, message interface{}) bool {
repl := message.(*page.FrameNavigatedReply)
listener(ctx, repl.Frame)
return true
})
}
func (m *Manager) RemoveFrameLoadedListener(id events.ListenerID) {
m.eventLoop.RemoveListener(eventFrameLoad, id)
}

View File

@ -2,21 +2,22 @@ package cdp
import ( import (
"context" "context"
"encoding/json" "github.com/MontFerret/ferret/pkg/drivers/cdp/dom"
"github.com/pkg/errors"
"hash/fnv" "hash/fnv"
"io"
"regexp"
"sync" "sync"
"github.com/mafredri/cdp" "github.com/mafredri/cdp"
"github.com/mafredri/cdp/protocol/emulation"
"github.com/mafredri/cdp/protocol/network"
"github.com/mafredri/cdp/protocol/page" "github.com/mafredri/cdp/protocol/page"
"github.com/mafredri/cdp/rpcc" "github.com/mafredri/cdp/rpcc"
"github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/MontFerret/ferret/pkg/drivers" "github.com/MontFerret/ferret/pkg/drivers"
"github.com/MontFerret/ferret/pkg/drivers/cdp/events" "github.com/MontFerret/ferret/pkg/drivers/cdp/events"
"github.com/MontFerret/ferret/pkg/drivers/cdp/input" "github.com/MontFerret/ferret/pkg/drivers/cdp/input"
net "github.com/MontFerret/ferret/pkg/drivers/cdp/network"
"github.com/MontFerret/ferret/pkg/drivers/common" "github.com/MontFerret/ferret/pkg/drivers/common"
"github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/logging" "github.com/MontFerret/ferret/pkg/runtime/logging"
@ -29,26 +30,18 @@ type HTMLPage struct {
logger *zerolog.Logger logger *zerolog.Logger
conn *rpcc.Conn conn *rpcc.Conn
client *cdp.Client client *cdp.Client
events *events.EventBroker events *events.Loop
network *net.Manager
dom *dom.Manager
mouse *input.Mouse mouse *input.Mouse
keyboard *input.Keyboard keyboard *input.Keyboard
document *common.AtomicValue
frames *common.LazyValue
}
func handleLoadError(logger *zerolog.Logger, client *cdp.Client) {
err := client.Page.Close(context.Background())
if err != nil {
logger.Warn().Timestamp().Err(err).Msg("failed to close document on load error")
}
} }
func LoadHTMLPage( func LoadHTMLPage(
ctx context.Context, ctx context.Context,
conn *rpcc.Conn, conn *rpcc.Conn,
params drivers.Params, params drivers.Params,
) (*HTMLPage, error) { ) (p *HTMLPage, err error) {
logger := logging.FromContext(ctx) logger := logging.FromContext(ctx)
if conn == nil { if conn == nil {
@ -57,217 +50,104 @@ func LoadHTMLPage(
client := cdp.NewClient(conn) client := cdp.NewClient(conn)
if err := client.Page.Enable(ctx); err != nil { if err := enableFeatures(ctx, client, params); err != nil {
return nil, err return nil, err
} }
err := runBatch( closers := make([]io.Closer, 0, 2)
func() error {
return client.Page.SetLifecycleEventsEnabled(
ctx,
page.NewSetLifecycleEventsEnabledArgs(true),
)
},
func() error { defer func() {
return client.DOM.Enable(ctx) if err != nil {
}, common.CloseAll(logger, closers, "failed to close a Page resource")
}
}()
func() error { eventLoop := events.NewLoop()
return client.Runtime.Enable(ctx) closers = append(closers, eventLoop)
},
func() error { netManager, err := net.New(logger, client, eventLoop)
ua := common.GetUserAgent(params.UserAgent)
logger. if err != nil {
Debug(). return nil, err
Timestamp(). }
Str("user-agent", ua).
Msg("using User-Agent")
// do not use custom user agent err = netManager.SetCookies(ctx, params.URL, params.Cookies)
if ua == "" {
return nil
}
return client.Emulation.SetUserAgentOverride( if err != nil {
ctx, return nil, err
emulation.NewSetUserAgentOverrideArgs(ua), }
)
},
func() error { err = netManager.SetHeaders(ctx, params.Headers)
return client.Network.Enable(ctx, network.NewEnableArgs())
},
func() error { if err != nil {
return client.Page.SetBypassCSP(ctx, page.NewSetBypassCSPArgs(true)) return nil, err
}, }
func() error { eventLoop.Start()
if params.Viewport == nil {
return nil
}
orientation := emulation.ScreenOrientation{} mouse := input.NewMouse(client)
keyboard := input.NewKeyboard(client)
if !params.Viewport.Landscape { domManager, err := dom.New(
orientation.Type = "portraitPrimary" logger,
orientation.Angle = 0 client,
} else { eventLoop,
orientation.Type = "landscapePrimary" mouse,
orientation.Angle = 90 keyboard,
}
scaleFactor := params.Viewport.ScaleFactor
if scaleFactor <= 0 {
scaleFactor = 1
}
deviceArgs := emulation.NewSetDeviceMetricsOverrideArgs(
params.Viewport.Width,
params.Viewport.Height,
scaleFactor,
params.Viewport.Mobile,
).SetScreenOrientation(orientation)
return client.Emulation.SetDeviceMetricsOverride(
ctx,
deviceArgs,
)
},
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(params.Cookies) > 0 { closers = append(closers, domManager)
cookies := make([]network.CookieParam, 0, len(params.Cookies))
for _, c := range params.Cookies { p = NewHTMLPage(
cookies = append(cookies, fromDriverCookie(params.URL, c))
logger.
Debug().
Timestamp().
Str("cookie", c.Name).
Msg("set cookie")
}
err = client.Network.SetCookies(
ctx,
network.NewSetCookiesArgs(cookies),
)
if err != nil {
return nil, errors.Wrap(err, "failed to set cookies")
}
}
if len(params.Headers) > 0 {
j, err := json.Marshal(params.Headers)
if err != nil {
return nil, err
}
for k := range params.Headers {
logger.
Debug().
Timestamp().
Str("header", k).
Msg("set header")
}
err = client.Network.SetExtraHTTPHeaders(
ctx,
network.NewSetExtraHTTPHeadersArgs(network.Headers(j)),
)
if err != nil {
return nil, errors.Wrap(err, "failed to set headers")
}
}
if params.URL != BlankPageURL && params.URL != "" {
repl, err := client.Page.Navigate(ctx, page.NewNavigateArgs(params.URL))
if err != nil {
handleLoadError(logger, client)
return nil, errors.Wrap(err, "failed to load the page")
}
if repl.ErrorText != nil {
handleLoadError(logger, client)
return nil, errors.Wrapf(errors.New(*repl.ErrorText), "failed to load the page: %s", params.URL)
}
err = events.WaitForLoadEvent(ctx, client)
if err != nil {
handleLoadError(logger, client)
return nil, errors.Wrap(err, "failed to load the page")
}
}
broker, err := events.CreateEventBroker(client)
if err != nil {
handleLoadError(logger, client)
return nil, errors.Wrap(err, "failed to create event events")
}
mouse := input.NewMouse(client)
keyboard := input.NewKeyboard(client)
doc, err := LoadRootHTMLDocument(ctx, logger, client, broker, mouse, keyboard)
if err != nil {
broker.StopAndClose()
handleLoadError(logger, client)
return nil, errors.Wrap(err, "failed to load root element")
}
return NewHTMLPage(
logger, logger,
conn, conn,
client, client,
broker, eventLoop,
netManager,
domManager,
mouse, mouse,
keyboard, keyboard,
doc, )
), nil
if params.URL != BlankPageURL && params.URL != "" {
err = p.Navigate(ctx, values.NewString(params.URL))
} else {
err = p.loadMainFrame(ctx)
}
if err != nil {
return p, err
}
return p, nil
} }
func NewHTMLPage( func NewHTMLPage(
logger *zerolog.Logger, logger *zerolog.Logger,
conn *rpcc.Conn, conn *rpcc.Conn,
client *cdp.Client, client *cdp.Client,
broker *events.EventBroker, eventLoop *events.Loop,
netManager *net.Manager,
domManager *dom.Manager,
mouse *input.Mouse, mouse *input.Mouse,
keyboard *input.Keyboard, keyboard *input.Keyboard,
document *HTMLDocument,
) *HTMLPage { ) *HTMLPage {
p := new(HTMLPage) p := new(HTMLPage)
p.closed = values.False p.closed = values.False
p.logger = logger p.logger = logger
p.conn = conn p.conn = conn
p.client = client p.client = client
p.events = broker p.events = eventLoop
p.network = netManager
p.dom = domManager
p.mouse = mouse p.mouse = mouse
p.keyboard = keyboard p.keyboard = keyboard
p.document = common.NewAtomicValue(document)
p.frames = common.NewLazyValue(p.unfoldFrames)
broker.AddEventListener(events.EventLoad, p.handlePageLoad) eventLoop.AddListener(events.Error, events.Always(p.handleError))
broker.AddEventListener(events.EventError, p.handleError)
return p return p
} }
@ -343,36 +223,36 @@ func (p *HTMLPage) Close() error {
defer p.mu.Unlock() defer p.mu.Unlock()
p.closed = values.True p.closed = values.True
err := p.events.Stop()
doc := p.getCurrentDocument() doc := p.getCurrentDocument()
err := p.events.Stop().Close()
if err != nil { if err != nil {
p.logger.Warn(). p.logger.Warn().
Timestamp(). Timestamp().
Str("url", doc.GetURL().String()). Str("url", doc.GetURL().String()).
Err(err). Err(err).
Msg("failed to stop event events") Msg("failed to stop event loop")
} }
err = p.events.Close() err = p.dom.Close()
if err != nil { if err != nil {
p.logger.Warn(). p.logger.Warn().
Timestamp(). Timestamp().
Str("url", doc.GetURL().String()). Str("url", doc.GetURL().String()).
Err(err). Err(err).
Msg("failed to close event events") Msg("failed to close dom manager")
} }
err = doc.Close() err = p.network.Close()
if err != nil { if err != nil {
p.logger.Warn(). p.logger.Warn().
Timestamp(). Timestamp().
Str("url", doc.GetURL().String()). Str("url", doc.GetURL().String()).
Err(err). Err(err).
Msg("failed to close root document") Msg("failed to close network manager")
} }
err = p.client.Page.Close(context.Background()) err = p.client.Page.Close(context.Background())
@ -404,87 +284,48 @@ func (p *HTMLPage) GetMainFrame() drivers.HTMLDocument {
} }
func (p *HTMLPage) GetFrames(ctx context.Context) (*values.Array, error) { func (p *HTMLPage) GetFrames(ctx context.Context) (*values.Array, error) {
res, err := p.frames.Read(ctx) p.mu.Lock()
defer p.mu.Unlock()
if err != nil { return p.dom.GetFrameNodes(ctx)
return nil, err
}
return res.(*values.Array).Clone().(*values.Array), nil
} }
func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, error) { func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
res, err := p.frames.Read(ctx) frames, err := p.dom.GetFrameNodes(ctx)
if err != nil { if err != nil {
return nil, err return values.None, err
} }
return res.(*values.Array).Get(idx), nil return frames.Get(idx), nil
} }
func (p *HTMLPage) GetCookies(ctx context.Context) (*values.Array, error) { func (p *HTMLPage) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
repl, err := p.client.Network.GetAllCookies(ctx) return p.network.GetCookies(ctx)
if err != nil {
return values.NewArray(0), err
}
if repl.Cookies == nil {
return values.NewArray(0), nil
}
cookies := values.NewArray(len(repl.Cookies))
for _, c := range repl.Cookies {
cookies.Push(toDriverCookie(c))
}
return cookies, nil
} }
func (p *HTMLPage) SetCookies(ctx context.Context, cookies ...drivers.HTTPCookie) error { func (p *HTMLPage) SetCookies(ctx context.Context, cookies drivers.HTTPCookies) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
if len(cookies) == 0 { return p.network.SetCookies(ctx, p.getCurrentDocument().GetURL().String(), cookies)
return nil
}
params := make([]network.CookieParam, 0, len(cookies))
for _, c := range cookies {
params = append(params, fromDriverCookie(p.getCurrentDocument().GetURL().String(), c))
}
return p.client.Network.SetCookies(ctx, network.NewSetCookiesArgs(params))
} }
func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies ...drivers.HTTPCookie) error { func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies drivers.HTTPCookies) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
if len(cookies) == 0 { return p.network.DeleteCookies(ctx, p.getCurrentDocument().GetURL().String(), cookies)
return nil }
}
var err error func (p *HTMLPage) GetResponse(_ context.Context) (*drivers.HTTPResponse, error) {
return nil, core.ErrNotSupported
for _, c := range cookies {
err = p.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(p.getCurrentDocument().GetURL().String(), c))
if err != nil {
break
}
}
return err
} }
func (p *HTMLPage) PrintToPDF(ctx context.Context, params drivers.PDFParams) (values.Binary, error) { func (p *HTMLPage) PrintToPDF(ctx context.Context, params drivers.PDFParams) (values.Binary, error) {
@ -607,162 +448,107 @@ func (p *HTMLPage) Navigate(ctx context.Context, url values.String) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
if url == "" { if err := p.network.Navigate(ctx, url); err != nil {
url = BlankPageURL
}
repl, err := p.client.Page.Navigate(ctx, page.NewNavigateArgs(url.String()))
if err != nil {
return err return err
} }
if repl.ErrorText != nil { return p.reloadMainFrame(ctx)
return errors.New(*repl.ErrorText)
}
return p.WaitForNavigation(ctx)
} }
func (p *HTMLPage) NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error) { func (p *HTMLPage) NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
history, err := p.client.Page.GetNavigationHistory(ctx) ret, err := p.network.NavigateBack(ctx, skip)
if err != nil { if err != nil {
return values.False, err return values.False, err
} }
// we are in the beginning return ret, p.reloadMainFrame(ctx)
if history.CurrentIndex == 0 {
return values.False, nil
}
if skip < 1 {
skip = 1
}
to := history.CurrentIndex - int(skip)
if to < 0 {
// TODO: Return error?
return values.False, nil
}
prev := history.Entries[to]
err = p.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(prev.ID))
if err != nil {
return values.False, err
}
err = p.WaitForNavigation(ctx)
if err != nil {
return values.False, err
}
return values.True, nil
} }
func (p *HTMLPage) NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error) { func (p *HTMLPage) NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
history, err := p.client.Page.GetNavigationHistory(ctx) ret, err := p.network.NavigateForward(ctx, skip)
if err != nil { if err != nil {
return values.False, err return values.False, err
} }
length := len(history.Entries) return ret, p.reloadMainFrame(ctx)
lastIndex := length - 1
// nowhere to go forward
if history.CurrentIndex == lastIndex {
return values.False, nil
}
if skip < 1 {
skip = 1
}
to := int(skip) + history.CurrentIndex
if to > lastIndex {
// TODO: Return error?
return values.False, nil
}
next := history.Entries[to]
err = p.client.Page.NavigateToHistoryEntry(ctx, page.NewNavigateToHistoryEntryArgs(next.ID))
if err != nil {
return values.False, err
}
err = p.WaitForNavigation(ctx)
if err != nil {
return values.False, err
}
return values.True, nil
} }
func (p *HTMLPage) WaitForNavigation(ctx context.Context) error { func (p *HTMLPage) WaitForNavigation(ctx context.Context, targetURL values.String) error {
onEvent := make(chan struct{}) var pattern *regexp.Regexp
var once sync.Once
listener := func(_ context.Context, _ interface{}) { if targetURL != "" {
once.Do(func() { r, err := regexp.Compile(targetURL.String())
close(onEvent)
}) if err != nil {
return errors.Wrap(err, "invalid URL pattern")
}
pattern = r
} }
defer p.events.RemoveEventListener(events.EventLoad, listener) if err := p.network.WaitForNavigation(ctx, pattern); err != nil {
return err
p.events.AddEventListener(events.EventLoad, listener)
select {
case <-onEvent:
return nil
case <-ctx.Done():
return core.ErrTimeout
} }
return p.reloadMainFrame(ctx)
} }
func (p *HTMLPage) handlePageLoad(ctx context.Context, _ interface{}) { func (p *HTMLPage) reloadMainFrame(ctx context.Context) error {
err := p.document.Write(func(current core.Value) (core.Value, error) { if err := p.dom.WaitForDOMReady(ctx); err != nil {
nextDoc, err := LoadRootHTMLDocument(ctx, p.logger, p.client, p.events, p.mouse, p.keyboard) return err
}
if err != nil { prev := p.dom.GetMainFrame()
return values.None, err
}
// close the prev document next, err := dom.LoadRootHTMLDocument(
currentDoc := current.(*HTMLDocument) ctx,
err = currentDoc.Close() p.logger,
p.client,
if err != nil { p.dom,
p.logger.Warn(). p.mouse,
Timestamp(). p.keyboard,
Err(err). )
Msgf("failed to close root document: %s", currentDoc.GetURL())
}
// reset all loaded frames
p.frames.Reset()
return nextDoc, nil
})
if err != nil { if err != nil {
p.logger.Warn(). return err
Timestamp().
Err(err).
Msg("failed to load new root document after page load")
} }
if prev != nil {
if err := p.dom.RemoveFrameRecursively(prev.Frame().Frame.ID); err != nil {
p.logger.Error().Err(err).Msg("failed to remove main frame")
}
}
p.dom.SetMainFrame(next)
return nil
}
func (p *HTMLPage) loadMainFrame(ctx context.Context) error {
next, err := dom.LoadRootHTMLDocument(
ctx,
p.logger,
p.client,
p.dom,
p.mouse,
p.keyboard,
)
if err != nil {
return err
}
p.dom.SetMainFrame(next)
return nil
} }
func (p *HTMLPage) handleError(_ context.Context, val interface{}) { func (p *HTMLPage) handleError(_ context.Context, val interface{}) {
@ -778,22 +564,6 @@ func (p *HTMLPage) handleError(_ context.Context, val interface{}) {
Msg("unexpected error") Msg("unexpected error")
} }
func (p *HTMLPage) getCurrentDocument() *HTMLDocument { func (p *HTMLPage) getCurrentDocument() *dom.HTMLDocument {
return p.document.Read().(*HTMLDocument) return p.dom.GetMainFrame()
}
func (p *HTMLPage) unfoldFrames(ctx context.Context) (core.Value, error) {
res := values.NewArray(10)
err := common.CollectFrames(ctx, res, p.getCurrentDocument())
if err != nil {
return nil, err
}
return res, nil
}
func (p *HTMLPage) GetResponse(_ context.Context) (*drivers.HTTPResponse, error) {
return nil, core.ErrNotSupported
} }

View File

@ -1,8 +1,20 @@
package common package common
import "github.com/MontFerret/ferret/pkg/runtime/core" import (
"github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/rs/zerolog"
"io"
)
var ( var (
ErrReadOnly = core.Error(core.ErrInvalidOperation, "read only") ErrReadOnly = core.Error(core.ErrInvalidOperation, "read only")
ErrInvalidPath = core.Error(core.ErrInvalidOperation, "invalid path") ErrInvalidPath = core.Error(core.ErrInvalidOperation, "invalid path")
) )
func CloseAll(logger *zerolog.Logger, closers []io.Closer, msg string) {
for _, closer := range closers {
if err := closer.Close(); err != nil {
logger.Error().Err(err).Msg(msg)
}
}
}

View File

@ -63,22 +63,17 @@ func GetInPage(ctx context.Context, page drivers.HTMLPage, path []core.Value) (c
case "url", "URL": case "url", "URL":
return page.GetMainFrame().GetURL(), nil return page.GetMainFrame().GetURL(), nil
case "cookies": case "cookies":
cookies, err := page.GetCookies(ctx)
if err != nil {
return values.None, err
}
if len(path) == 1 { if len(path) == 1 {
return page.GetCookies(ctx) return cookies, nil
} }
switch idx := path[1].(type) { return cookies.GetIn(ctx, path[1:])
case values.Int:
cookies, err := page.GetCookies(ctx)
if err != nil {
return values.None, err
}
return cookies.Get(idx), nil
default:
return values.None, core.TypeError(idx.Type(), types.Int)
}
case "isClosed": case "isClosed":
return page.IsClosed(), nil return page.IsClosed(), nil
case "title": case "title":
@ -109,7 +104,11 @@ func GetInDocument(ctx context.Context, doc drivers.HTMLDocument, path []core.Va
case "title": case "title":
return doc.GetTitle(), nil return doc.GetTitle(), nil
case "parent": case "parent":
parent := doc.GetParentDocument() parent, err := doc.GetParentDocument(ctx)
if err != nil {
return values.None, err
}
if parent == nil { if parent == nil {
return values.None, nil return values.None, nil

View File

@ -30,8 +30,6 @@ type (
HTTPOnly bool HTTPOnly bool
SameSite SameSite SameSite SameSite
} }
HTTPCookies map[string]HTTPCookie
) )
const ( const (

183
pkg/drivers/cookies.go Normal file
View File

@ -0,0 +1,183 @@
package drivers
import (
"context"
"encoding/binary"
"encoding/json"
"github.com/MontFerret/ferret/pkg/runtime/values"
"github.com/MontFerret/ferret/pkg/runtime/values/types"
"hash/fnv"
"sort"
"github.com/MontFerret/ferret/pkg/runtime/core"
)
type HTTPCookies map[string]HTTPCookie
func NewHTTPCookies() HTTPCookies {
return make(HTTPCookies)
}
func (c HTTPCookies) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]HTTPCookie(c))
}
func (c HTTPCookies) Type() core.Type {
return HTTPCookiesType
}
func (c HTTPCookies) String() string {
j, err := c.MarshalJSON()
if err != nil {
return "{}"
}
return string(j)
}
func (c HTTPCookies) Compare(other core.Value) int64 {
if other.Type() != HTTPCookiesType {
return Compare(HTTPCookiesType, other.Type())
}
oc := other.(HTTPCookies)
switch {
case len(c) > len(oc):
return 1
case len(c) < len(oc):
return -1
}
for name := range c {
cEl, cExists := c.Get(values.NewString(name))
if !cExists {
return -1
}
ocEl, ocExists := oc.Get(values.NewString(name))
if !ocExists {
return 1
}
c := cEl.Compare(ocEl)
if c != 0 {
return c
}
}
return 0
}
func (c HTTPCookies) Unwrap() interface{} {
return map[string]HTTPCookie(c)
}
func (c HTTPCookies) Hash() uint64 {
hash := fnv.New64a()
hash.Write([]byte(c.Type().String()))
hash.Write([]byte(":"))
hash.Write([]byte("{"))
keys := make([]string, 0, len(c))
for key := range c {
keys = append(keys, key)
}
// order does not really matter
// but it will give us a consistent hash sum
sort.Strings(keys)
endIndex := len(keys) - 1
for idx, key := range keys {
hash.Write([]byte(key))
hash.Write([]byte(":"))
el := c[key]
bytes := make([]byte, 8)
binary.LittleEndian.PutUint64(bytes, el.Hash())
hash.Write(bytes)
if idx != endIndex {
hash.Write([]byte(","))
}
}
hash.Write([]byte("}"))
return hash.Sum64()
}
func (c HTTPCookies) Copy() core.Value {
copied := make(HTTPCookies)
for k, v := range c {
copied[k] = v
}
return copied
}
func (c HTTPCookies) Length() values.Int {
return values.NewInt(len(c))
}
func (c HTTPCookies) Keys() []values.String {
keys := make([]values.String, 0, len(c))
for k := range c {
keys = append(keys, values.NewString(k))
}
return keys
}
func (c HTTPCookies) Get(key values.String) (core.Value, values.Boolean) {
value, found := c[key.String()]
if found {
return value, values.True
}
return values.None, values.False
}
func (c HTTPCookies) Set(key values.String, value core.Value) {
if cookie, ok := value.(HTTPCookie); ok {
c[key.String()] = cookie
}
}
func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, error) {
if len(path) == 0 {
return values.None, nil
}
segment := path[0]
err := core.ValidateType(segment, types.String)
if err != nil {
return values.None, err
}
cookie, found := c[segment.String()]
if found {
if len(path) == 1 {
return cookie, nil
}
return values.GetIn(ctx, cookie, path[1:])
}
return values.None, nil
}

View File

@ -203,8 +203,8 @@ func (doc *HTMLDocument) GetName() values.String {
return "" return ""
} }
func (doc *HTMLDocument) GetParentDocument() drivers.HTMLDocument { func (doc *HTMLDocument) GetParentDocument(_ context.Context) (drivers.HTMLDocument, error) {
return doc.parent return doc.parent, nil
} }
func (doc *HTMLDocument) ScrollTop(_ context.Context) error { func (doc *HTMLDocument) ScrollTop(_ context.Context) error {

View File

@ -168,29 +168,25 @@ func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, er
return p.frames.Get(idx), nil return p.frames.Get(idx), nil
} }
func (p *HTMLPage) GetCookies(_ context.Context) (*values.Array, error) { func (p *HTMLPage) GetCookies(_ context.Context) (drivers.HTTPCookies, error) {
if p.cookies == nil { res := make(drivers.HTTPCookies)
return values.NewArray(0), nil
for n, v := range p.cookies {
res[n] = v
} }
arr := values.NewArray(len(p.cookies)) return res, nil
for _, c := range p.cookies {
arr.Push(c)
}
return arr, nil
} }
func (p *HTMLPage) GetResponse(_ context.Context) (*drivers.HTTPResponse, error) { func (p *HTMLPage) GetResponse(_ context.Context) (*drivers.HTTPResponse, error) {
return p.response, nil return p.response, nil
} }
func (p *HTMLPage) SetCookies(_ context.Context, _ ...drivers.HTTPCookie) error { func (p *HTMLPage) SetCookies(_ context.Context, _ drivers.HTTPCookies) error {
return core.ErrNotSupported return core.ErrNotSupported
} }
func (p *HTMLPage) DeleteCookies(_ context.Context, _ ...drivers.HTTPCookie) error { func (p *HTMLPage) DeleteCookies(_ context.Context, _ drivers.HTTPCookies) error {
return core.ErrNotSupported return core.ErrNotSupported
} }
@ -202,7 +198,7 @@ func (p *HTMLPage) CaptureScreenshot(_ context.Context, _ drivers.ScreenshotPara
return nil, core.ErrNotSupported return nil, core.ErrNotSupported
} }
func (p *HTMLPage) WaitForNavigation(_ context.Context) error { func (p *HTMLPage) WaitForNavigation(_ context.Context, _ values.String) error {
return core.ErrNotSupported return core.ErrNotSupported
} }

View File

@ -6,6 +6,7 @@ var (
HTTPResponseType = core.NewType("HTTPResponse") HTTPResponseType = core.NewType("HTTPResponse")
HTTPHeaderType = core.NewType("HTTPHeaders") HTTPHeaderType = core.NewType("HTTPHeaders")
HTTPCookieType = core.NewType("HTTPCookie") HTTPCookieType = core.NewType("HTTPCookie")
HTTPCookiesType = core.NewType("HTTPCookies")
HTMLElementType = core.NewType("HTMLElement") HTMLElementType = core.NewType("HTMLElement")
HTMLDocumentType = core.NewType("HTMLDocument") HTMLDocumentType = core.NewType("HTMLDocument")
HTMLPageType = core.NewType("HTMLPageType") HTMLPageType = core.NewType("HTMLPageType")
@ -15,9 +16,10 @@ var (
var typeComparisonTable = map[core.Type]uint64{ var typeComparisonTable = map[core.Type]uint64{
HTTPHeaderType: 0, HTTPHeaderType: 0,
HTTPCookieType: 1, HTTPCookieType: 1,
HTMLElementType: 2, HTTPCookiesType: 2,
HTMLDocumentType: 3, HTMLElementType: 3,
HTMLPageType: 4, HTMLDocumentType: 4,
HTMLPageType: 5,
} }
func Compare(first, second core.Type) int64 { func Compare(first, second core.Type) int64 {

View File

@ -143,7 +143,7 @@ type (
GetName() values.String GetName() values.String
GetParentDocument() HTMLDocument GetParentDocument(ctx context.Context) (HTMLDocument, error)
GetChildDocuments(ctx context.Context) (*values.Array, error) GetChildDocuments(ctx context.Context) (*values.Array, error)
@ -192,25 +192,25 @@ type (
GetFrame(ctx context.Context, idx values.Int) (core.Value, error) GetFrame(ctx context.Context, idx values.Int) (core.Value, error)
GetCookies(ctx context.Context) (*values.Array, error) GetCookies(ctx context.Context) (HTTPCookies, error)
SetCookies(ctx context.Context, cookies ...HTTPCookie) error SetCookies(ctx context.Context, cookies HTTPCookies) error
DeleteCookies(ctx context.Context, cookies ...HTTPCookie) error DeleteCookies(ctx context.Context, cookies HTTPCookies) error
GetResponse(ctx context.Context) (*HTTPResponse, error)
PrintToPDF(ctx context.Context, params PDFParams) (values.Binary, error) PrintToPDF(ctx context.Context, params PDFParams) (values.Binary, error)
CaptureScreenshot(ctx context.Context, params ScreenshotParams) (values.Binary, error) CaptureScreenshot(ctx context.Context, params ScreenshotParams) (values.Binary, error)
WaitForNavigation(ctx context.Context) error WaitForNavigation(ctx context.Context, targetURL values.String) error
Navigate(ctx context.Context, url values.String) error Navigate(ctx context.Context, url values.String) error
NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error) NavigateBack(ctx context.Context, skip values.Int) (values.Boolean, error)
NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error) NavigateForward(ctx context.Context, skip values.Int) (values.Boolean, error)
GetResponse(ctx context.Context) (*HTTPResponse, error)
} }
) )

View File

@ -61,7 +61,9 @@ func Errors(err ...error) error {
message := "" message := ""
for _, e := range err { for _, e := range err {
message += ": " + e.Error() if e != nil {
message += ": " + e.Error()
}
} }
return errors.New(message) return errors.New(message)

View File

@ -5,10 +5,11 @@ import (
"runtime" "runtime"
"strings" "strings"
"github.com/pkg/errors"
"github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/logging" "github.com/MontFerret/ferret/pkg/runtime/logging"
"github.com/MontFerret/ferret/pkg/runtime/values" "github.com/MontFerret/ferret/pkg/runtime/values"
"github.com/pkg/errors"
) )
type Program struct { type Program struct {

View File

@ -26,8 +26,8 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) {
} }
inputs := args[1:] inputs := args[1:]
var currentCookies *values.Array var currentCookies drivers.HTTPCookies
cookies := make([]drivers.HTTPCookie, 0, len(inputs)) cookies := make(drivers.HTTPCookies)
for _, c := range inputs { for _, c := range inputs {
switch cookie := c.(type) { switch cookie := c.(type) {
@ -42,23 +42,18 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) {
currentCookies = current currentCookies = current
} }
found, isFound := currentCookies.Find(func(value core.Value, _ int) bool { found, isFound := currentCookies[cookie.String()]
cv := value.(drivers.HTTPCookie)
return cv.Name == cookie.String()
})
if isFound { if isFound {
cookies = append(cookies, found.(drivers.HTTPCookie)) cookies[cookie.String()] = found
} }
case drivers.HTTPCookie: case drivers.HTTPCookie:
cookies = append(cookies, cookie) cookies[cookie.Name] = cookie
default: default:
return values.None, core.TypeError(c.Type(), types.String, drivers.HTTPCookieType) return values.None, core.TypeError(c.Type(), types.String, drivers.HTTPCookieType)
} }
} }
return values.None, page.DeleteCookies(ctx, cookies...) return values.None, page.DeleteCookies(ctx, cookies)
} }

View File

@ -39,15 +39,11 @@ func CookieGet(ctx context.Context, args ...core.Value) (core.Value, error) {
return values.None, err return values.None, err
} }
found, _ := cookies.Find(func(value core.Value, _ int) bool { cookie, found := cookies[name.String()]
cookie, ok := value.(drivers.HTTPCookie)
if !ok { if found {
return ok return cookie, nil
} }
return cookie.Name == name.String() return values.None, nil
})
return found, nil
} }

View File

@ -24,7 +24,7 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) {
return values.None, err return values.None, err
} }
cookies := make([]drivers.HTTPCookie, 0, len(args)-1) cookies := make(drivers.HTTPCookies)
for _, c := range args[1:] { for _, c := range args[1:] {
cookie, err := parseCookie(c) cookie, err := parseCookie(c)
@ -33,8 +33,8 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) {
return values.None, err return values.None, err
} }
cookies = append(cookies, cookie) cookies[cookie.Name] = cookie
} }
return values.None, page.SetCookies(ctx, cookies...) return values.None, page.SetCookies(ctx, cookies)
} }

View File

@ -2,6 +2,7 @@ package html
import ( import (
"context" "context"
"github.com/pkg/errors"
"github.com/MontFerret/ferret/pkg/drivers" "github.com/MontFerret/ferret/pkg/drivers"
"github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/core"
@ -9,6 +10,11 @@ import (
"github.com/MontFerret/ferret/pkg/runtime/values/types" "github.com/MontFerret/ferret/pkg/runtime/values/types"
) )
type WaitNavigationParams struct {
TargetURL values.String
Timeout values.Int
}
// WAIT_NAVIGATION waits for a given page to navigate to a new url. // WAIT_NAVIGATION waits for a given page to navigate to a new url.
// Stops the execution until the navigation ends or operation times out. // Stops the execution until the navigation ends or operation times out.
// @param page (HTMLPage) - Target page. // @param page (HTMLPage) - Target page.
@ -26,20 +32,67 @@ func WaitNavigation(ctx context.Context, args ...core.Value) (core.Value, error)
return values.None, err return values.None, err
} }
timeout := values.NewInt(drivers.DefaultWaitTimeout) var params WaitNavigationParams
if len(args) > 1 { if len(args) > 1 {
err = core.ValidateType(args[1], types.Int) p, err := parseWaitNavigationParams(args[1])
if err != nil { if err != nil {
return values.None, err return values.None, err
} }
timeout = args[1].(values.Int) params = p
} else {
params = defaultWaitNavigationParams()
} }
ctx, fn := waitTimeout(ctx, timeout) ctx, fn := waitTimeout(ctx, params.Timeout)
defer fn() defer fn()
return values.None, doc.WaitForNavigation(ctx) return values.None, doc.WaitForNavigation(ctx, params.TargetURL)
}
func parseWaitNavigationParams(arg core.Value) (WaitNavigationParams, error) {
params := defaultWaitNavigationParams()
err := core.ValidateType(arg, types.Int, types.Object)
if err != nil {
return params, err
}
if arg.Type() == types.Int {
params.Timeout = arg.(values.Int)
} else {
obj := arg.(*values.Object)
if v, exists := obj.Get("timeout"); exists {
err := core.ValidateType(v, types.Int)
if err != nil {
return params, errors.Wrap(err, "navigation parameters: timeout")
}
params.Timeout = v.(values.Int)
}
if v, exists := obj.Get("target"); exists {
err := core.ValidateType(v, types.String)
if err != nil {
return params, errors.Wrap(err, "navigation parameters: url")
}
params.TargetURL = v.(values.String)
}
}
return params, nil
}
func defaultWaitNavigationParams() WaitNavigationParams {
return WaitNavigationParams{
TargetURL: "",
Timeout: values.NewInt(drivers.DefaultWaitTimeout),
}
} }