1
0
mirror of https://github.com/MontFerret/ferret.git synced 2025-07-05 00:49:00 +02:00

Bugfix/#597 headers panic (#598)

* Remodeled HTTPHeaders

* Remodeled HTTPCookies
This commit is contained in:
Tim Voronov
2021-03-26 12:01:00 -04:00
committed by GitHub
parent b3adea3622
commit d55bce325c
34 changed files with 724 additions and 345 deletions

View File

@ -15,7 +15,7 @@ COOKIE_DEL(doc, COOKIE_GET(doc, "x-e2e"), "x-e2e-2")
LET cookie1 = COOKIE_GET(doc, "x-e2e") LET cookie1 = COOKIE_GET(doc, "x-e2e")
LET cookie2 = COOKIE_GET(doc, "x-e2e-2") LET cookie2 = COOKIE_GET(doc, "x-e2e-2")
T::EQ(cookie1, "none") T::EQ(cookie1, NONE)
T::EQ(cookie2, "none") T::EQ(cookie2, NONE)
RETURN NONE RETURN NONE

View File

@ -7,6 +7,6 @@ LET cookiesPath = LENGTH(doc.cookies) > 0 ? "ok" : "false"
LET cookie = COOKIE_GET(doc, "x-ferret") LET cookie = COOKIE_GET(doc, "x-ferret")
LET expected = "ok e2e" LET expected = "ok e2e"
T::LEN(doc.cookies T::LEN(doc.cookies, 1)
RETURN T::EQ(cookiesPath + " " + cookie.value, expected) RETURN T::EQ(cookiesPath + " " + cookie.value, expected)

View File

@ -7,8 +7,9 @@ LET doc = DOCUMENT(url, {
}] }]
}) })
LET cookiesPath = LENGTH(doc.cookies) > 0 ? "ok" : "false"
LET cookie = COOKIE_GET(doc, "x-e2e") LET cookie = COOKIE_GET(doc, "x-e2e")
LET expected = "ok test"
RETURN T::EQ(cookiesPath + " " + cookie.value, expected) T::NOT::NONE(cookie)
T::EQ(cookie.value, "test")
RETURN NONE

5
examples/headers.fql Normal file
View File

@ -0,0 +1,5 @@
LET proxy_header = {"Proxy-Authorization": ["Basic e40b7d5eff464a4fb51efed2d1a19a24"]}
LET doc = DOCUMENT("https://google.com", { headers: proxy_header})
RETURN doc

View File

@ -36,7 +36,7 @@ type Driver struct {
func NewDriver(opts ...Option) *Driver { func NewDriver(opts ...Option) *Driver {
drv := new(Driver) drv := new(Driver)
drv.options = newOptions(opts) drv.options = NewOptions(opts)
drv.dev = devtool.New(drv.options.Address) drv.dev = devtool.New(drv.options.Address)
return drv return drv
@ -137,43 +137,11 @@ func (drv *Driver) createConnection(ctx context.Context, keepCookies bool) (*rpc
} }
func (drv *Driver) setDefaultParams(params drivers.Params) drivers.Params { func (drv *Driver) setDefaultParams(params drivers.Params) drivers.Params {
if params.UserAgent == "" {
params.UserAgent = drv.options.UserAgent
}
if params.Viewport == nil { if params.Viewport == nil {
params.Viewport = defaultViewport params.Viewport = defaultViewport
} }
if drv.options.Headers != nil && params.Headers == nil { return drivers.SetDefaultParams(drv.options.Options, params)
params.Headers = make(drivers.HTTPHeaders)
}
// set default headers
for k, v := range drv.options.Headers {
_, exists := params.Headers[k]
// do not override user's set values
if !exists {
params.Headers[k] = v
}
}
if drv.options.Cookies != nil && params.Cookies == nil {
params.Cookies = make(drivers.HTTPCookies)
}
// set default cookies
for k, v := range drv.options.Cookies {
_, exists := params.Cookies[k]
// do not override user's set values
if !exists {
params.Cookies[k] = v
}
}
return params
} }
func (drv *Driver) init(ctx context.Context) error { func (drv *Driver) init(ctx context.Context) error {

View File

@ -51,12 +51,6 @@ func enableFeatures(ctx context.Context, client *cdp.Client, params drivers.Para
func() error { func() error {
ua := common.GetUserAgent(params.UserAgent) ua := common.GetUserAgent(params.UserAgent)
//logger.
// Debug().
// Timestamp().
// Str("user-agent", ua).
// Msg("using User-Agent")
// do not use custom user agent // do not use custom user agent
if ua == "" { if ua == "" {
return nil return nil

View File

@ -34,7 +34,7 @@ type (
mu sync.Mutex mu sync.Mutex
logger *zerolog.Logger logger *zerolog.Logger
client *cdp.Client client *cdp.Client
headers drivers.HTTPHeaders headers *drivers.HTTPHeaders
eventLoop *events.Loop eventLoop *events.Loop
cancel context.CancelFunc cancel context.CancelFunc
responseListenerID events.ListenerID responseListenerID events.ListenerID
@ -53,12 +53,12 @@ func New(
m := new(Manager) m := new(Manager)
m.logger = logger m.logger = logger
m.client = client m.client = client
m.headers = make(drivers.HTTPHeaders) m.headers = drivers.NewHTTPHeaders()
m.eventLoop = events.NewLoop() m.eventLoop = events.NewLoop()
m.cancel = cancel m.cancel = cancel
m.response = new(sync.Map) m.response = new(sync.Map)
if len(options.Cookies) > 0 { if options.Cookies != nil && len(options.Cookies) > 0 {
for url, cookies := range options.Cookies { for url, cookies := range options.Cookies {
if err := m.setCookiesInternal(ctx, url, cookies); err != nil { if err := m.setCookiesInternal(ctx, url, cookies); err != nil {
return nil, err return nil, err
@ -66,7 +66,7 @@ func New(
} }
} }
if len(options.Headers) > 0 { if options.Headers != nil && options.Headers.Length() > 0 {
if err := m.setHeadersInternal(ctx, options.Headers); err != nil { if err := m.setHeadersInternal(ctx, options.Headers); err != nil {
return nil, err return nil, err
} }
@ -104,7 +104,7 @@ func New(
m.responseListenerID = m.eventLoop.AddListener(responseReceived, m.onResponse) m.responseListenerID = m.eventLoop.AddListener(responseReceived, m.onResponse)
if len(options.Filter.Patterns) > 0 { if options.Filter != nil && len(options.Filter.Patterns) > 0 {
el2 := events.NewLoop() el2 := events.NewLoop()
err = m.client.Fetch.Enable(ctx, toFetchArgs(options.Filter.Patterns)) err = m.client.Fetch.Enable(ctx, toFetchArgs(options.Filter.Patterns))
@ -147,87 +147,100 @@ func (m *Manager) Close() error {
return nil return nil
} }
func (m *Manager) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) { func (m *Manager) GetCookies(ctx context.Context) (*drivers.HTTPCookies, error) {
repl, err := m.client.Network.GetAllCookies(ctx) repl, err := m.client.Network.GetAllCookies(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get cookies") return nil, errors.Wrap(err, "failed to get cookies")
} }
cookies := make(drivers.HTTPCookies) cookies := drivers.NewHTTPCookies()
if repl.Cookies == nil { if repl.Cookies == nil {
return cookies, nil return cookies, nil
} }
for _, c := range repl.Cookies { for _, c := range repl.Cookies {
cookies[c.Name] = toDriverCookie(c) cookies.Set(toDriverCookie(c))
} }
return cookies, nil return cookies, nil
} }
func (m *Manager) SetCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error { func (m *Manager) SetCookies(ctx context.Context, url string, cookies *drivers.HTTPCookies) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
return m.setCookiesInternal(ctx, url, cookies) return m.setCookiesInternal(ctx, url, cookies)
} }
func (m *Manager) setCookiesInternal(ctx context.Context, url string, cookies drivers.HTTPCookies) error { func (m *Manager) setCookiesInternal(ctx context.Context, url string, cookies *drivers.HTTPCookies) error {
if len(cookies) == 0 { if cookies == nil {
return errors.Wrap(core.ErrMissedArgument, "cookies")
}
if cookies.Length() == 0 {
return nil return nil
} }
params := make([]network.CookieParam, 0, len(cookies)) params := make([]network.CookieParam, 0, cookies.Length())
for _, c := range cookies { cookies.ForEach(func(value drivers.HTTPCookie, _ values.String) bool {
params = append(params, fromDriverCookie(url, c)) params = append(params, fromDriverCookie(url, value))
}
return true
})
return m.client.Network.SetCookies(ctx, network.NewSetCookiesArgs(params)) return m.client.Network.SetCookies(ctx, network.NewSetCookiesArgs(params))
} }
func (m *Manager) DeleteCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error { func (m *Manager) DeleteCookies(ctx context.Context, url string, cookies *drivers.HTTPCookies) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if len(cookies) == 0 { if cookies == nil {
return errors.Wrap(core.ErrMissedArgument, "cookies")
}
if cookies.Length() == 0 {
return nil return nil
} }
var err error var err error
for _, c := range cookies { cookies.ForEach(func(value drivers.HTTPCookie, _ values.String) bool {
err = m.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(url, c)) err = m.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(url, value))
if err != nil { if err != nil {
break return false
} }
}
return true
})
return err return err
} }
func (m *Manager) GetHeaders(_ context.Context) (drivers.HTTPHeaders, error) { func (m *Manager) GetHeaders(_ context.Context) (*drivers.HTTPHeaders, error) {
copied := make(drivers.HTTPHeaders) m.mu.Lock()
defer m.mu.Unlock()
for k, v := range m.headers { if m.headers == nil {
copied[k] = v return drivers.NewHTTPHeaders(), nil
} }
return copied, nil return m.headers.Clone().(*drivers.HTTPHeaders), nil
} }
func (m *Manager) SetHeaders(ctx context.Context, headers drivers.HTTPHeaders) error { func (m *Manager) SetHeaders(ctx context.Context, headers *drivers.HTTPHeaders) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
return m.setHeadersInternal(ctx, headers) return m.setHeadersInternal(ctx, headers)
} }
func (m *Manager) setHeadersInternal(ctx context.Context, headers drivers.HTTPHeaders) error { func (m *Manager) setHeadersInternal(ctx context.Context, headers *drivers.HTTPHeaders) error {
if len(headers) == 0 { if headers.Length() == 0 {
return nil return nil
} }
@ -461,7 +474,7 @@ func (m *Manager) onResponse(_ context.Context, message interface{}) (out bool)
response := drivers.HTTPResponse{ response := drivers.HTTPResponse{
StatusCode: msg.Response.Status, StatusCode: msg.Response.Status,
Status: msg.Response.StatusText, Status: msg.Response.StatusText,
Headers: make(drivers.HTTPHeaders), Headers: drivers.NewHTTPHeaders(),
} }
deserialized := make(map[string]string) deserialized := make(map[string]string)

View File

@ -6,7 +6,7 @@ import (
) )
type ( type (
Cookies map[string]drivers.HTTPCookies Cookies map[string]*drivers.HTTPCookies
Filter struct { Filter struct {
Patterns []drivers.ResourceFilter Patterns []drivers.ResourceFilter
@ -14,8 +14,8 @@ type (
Options struct { Options struct {
Cookies Cookies Cookies Cookies
Headers drivers.HTTPHeaders Headers *drivers.HTTPHeaders
Filter Filter Filter *Filter
} }
) )

View File

@ -4,13 +4,9 @@ import "github.com/MontFerret/ferret/pkg/drivers"
type ( type (
Options struct { Options struct {
Name string *drivers.Options
Proxy string
UserAgent string
Address string Address string
KeepCookies bool KeepCookies bool
Headers drivers.HTTPHeaders
Cookies drivers.HTTPCookies
} }
Option func(opts *Options) Option func(opts *Options)
@ -18,8 +14,9 @@ type (
const DefaultAddress = "http://127.0.0.1:9222" const DefaultAddress = "http://127.0.0.1:9222"
func newOptions(setters []Option) *Options { func NewOptions(setters []Option) *Options {
opts := new(Options) opts := new(Options)
opts.Options = new(drivers.Options)
opts.Name = DriverName opts.Name = DriverName
opts.Address = DefaultAddress opts.Address = DefaultAddress
@ -40,13 +37,13 @@ func WithAddress(address string) Option {
func WithProxy(address string) Option { func WithProxy(address string) Option {
return func(opts *Options) { return func(opts *Options) {
opts.Proxy = address drivers.WithProxy(address)(opts.Options)
} }
} }
func WithUserAgent(value string) Option { func WithUserAgent(value string) Option {
return func(opts *Options) { return func(opts *Options) {
opts.UserAgent = value drivers.WithUserAgent(value)(opts.Options)
} }
} }
@ -58,50 +55,30 @@ func WithKeepCookies() Option {
func WithCustomName(name string) Option { func WithCustomName(name string) Option {
return func(opts *Options) { return func(opts *Options) {
opts.Name = name drivers.WithCustomName(name)(opts.Options)
} }
} }
func WithHeader(name string, value []string) Option { func WithHeader(name string, header []string) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Headers == nil { drivers.WithHeader(name, header)(opts.Options)
opts.Headers = make(drivers.HTTPHeaders)
}
opts.Headers[name] = value
} }
} }
func WithHeaders(headers drivers.HTTPHeaders) Option { func WithHeaders(headers *drivers.HTTPHeaders) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Headers == nil { drivers.WithHeaders(headers)(opts.Options)
opts.Headers = make(drivers.HTTPHeaders)
}
for k, v := range headers {
opts.Headers[k] = v
}
} }
} }
func WithCookie(cookie drivers.HTTPCookie) Option { func WithCookie(cookie drivers.HTTPCookie) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Cookies == nil { drivers.WithCookie(cookie)(opts.Options)
opts.Cookies = make(drivers.HTTPCookies)
}
opts.Cookies[cookie.Name] = cookie
} }
} }
func WithCookies(cookies []drivers.HTTPCookie) Option { func WithCookies(cookies []drivers.HTTPCookie) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Cookies == nil { drivers.WithCookies(cookies)(opts.Options)
opts.Cookies = make(drivers.HTTPCookies)
}
for _, c := range cookies {
opts.Cookies[c.Name] = c
}
} }
} }

View File

@ -0,0 +1,71 @@
package cdp_test
import (
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"github.com/MontFerret/ferret/pkg/drivers"
"github.com/MontFerret/ferret/pkg/drivers/cdp"
)
func TestNewOptions(t *testing.T) {
Convey("Should create driver options with initial values", t, func() {
opts := cdp.NewOptions([]cdp.Option{})
So(opts.Options, ShouldNotBeNil)
So(opts.Name, ShouldEqual, cdp.DriverName)
So(opts.Address, ShouldEqual, cdp.DefaultAddress)
})
Convey("Should use setters to set values", t, func() {
expectedName := cdp.DriverName + "2"
expectedAddress := "0.0.0.0:9222"
expectedUA := "Mozilla"
expectedProxy := "https://proxy.com"
opts := cdp.NewOptions([]cdp.Option{
cdp.WithCustomName(expectedName),
cdp.WithAddress(expectedAddress),
cdp.WithUserAgent(expectedUA),
cdp.WithProxy(expectedProxy),
cdp.WithKeepCookies(),
cdp.WithCookie(drivers.HTTPCookie{
Name: "Session",
Value: "fsdfsdfs",
Path: "dfsdfsd",
Domain: "sfdsfs",
Expires: time.Time{},
MaxAge: 0,
Secure: false,
HTTPOnly: false,
SameSite: 0,
}),
cdp.WithCookies([]drivers.HTTPCookie{
{
Name: "Use",
Value: "Foos",
Path: "",
Domain: "",
Expires: time.Time{},
MaxAge: 0,
Secure: false,
HTTPOnly: false,
SameSite: 0,
},
}),
cdp.WithHeader("Authorization", []string{"Bearer dfsd7f98sd9fsd9fsd"}),
cdp.WithHeaders(drivers.NewHTTPHeadersWith(map[string][]string{
"x-correlation-id": {"232483833833839"},
})),
})
So(opts.Options, ShouldNotBeNil)
So(opts.Name, ShouldEqual, expectedName)
So(opts.Address, ShouldEqual, expectedAddress)
So(opts.UserAgent, ShouldEqual, expectedUA)
So(opts.Proxy, ShouldEqual, expectedProxy)
So(opts.KeepCookies, ShouldBeTrue)
So(opts.Cookies.Length(), ShouldEqual, 2)
So(opts.Headers.Length(), ShouldEqual, 2)
})
}

View File

@ -2,7 +2,6 @@ package cdp
import ( import (
"context" "context"
"github.com/MontFerret/ferret/pkg/drivers/cdp/templates"
"hash/fnv" "hash/fnv"
"io" "io"
"regexp" "regexp"
@ -18,6 +17,7 @@ import (
"github.com/MontFerret/ferret/pkg/drivers/cdp/dom" "github.com/MontFerret/ferret/pkg/drivers/cdp/dom"
"github.com/MontFerret/ferret/pkg/drivers/cdp/input" "github.com/MontFerret/ferret/pkg/drivers/cdp/input"
net "github.com/MontFerret/ferret/pkg/drivers/cdp/network" net "github.com/MontFerret/ferret/pkg/drivers/cdp/network"
"github.com/MontFerret/ferret/pkg/drivers/cdp/templates"
"github.com/MontFerret/ferret/pkg/drivers/common" "github.com/MontFerret/ferret/pkg/drivers/common"
"github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/logging" "github.com/MontFerret/ferret/pkg/runtime/logging"
@ -73,15 +73,13 @@ func LoadHTMLPage(
Headers: params.Headers, Headers: params.Headers,
} }
if len(params.Cookies) > 0 { if params.Cookies != nil && params.Cookies.Length() > 0 {
netOpts.Cookies = make(map[string]drivers.HTTPCookies) netOpts.Cookies = make(map[string]*drivers.HTTPCookies)
netOpts.Cookies[params.URL] = params.Cookies netOpts.Cookies[params.URL] = params.Cookies
} }
if params.Ignore != nil { if params.Ignore != nil && len(params.Ignore.Resources) > 0 {
if len(params.Ignore.Resources) > 0 { netOpts.Filter.Patterns = params.Ignore.Resources
netOpts.Filter.Patterns = params.Ignore.Resources
}
} }
netManager, err := net.New(logger, client, netOpts) netManager, err := net.New(logger, client, netOpts)
@ -358,21 +356,21 @@ func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, er
return frames.Get(idx), nil return frames.Get(idx), nil
} }
func (p *HTMLPage) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) { func (p *HTMLPage) GetCookies(ctx context.Context) (*drivers.HTTPCookies, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
return p.network.GetCookies(ctx) return p.network.GetCookies(ctx)
} }
func (p *HTMLPage) SetCookies(ctx context.Context, cookies drivers.HTTPCookies) error { func (p *HTMLPage) SetCookies(ctx context.Context, cookies *drivers.HTTPCookies) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
return p.network.SetCookies(ctx, p.getCurrentDocument().GetURL().String(), cookies) return p.network.SetCookies(ctx, p.getCurrentDocument().GetURL().String(), cookies)
} }
func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies drivers.HTTPCookies) error { func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies *drivers.HTTPCookies) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()

View File

@ -13,21 +13,27 @@ import (
"github.com/wI2L/jettison" "github.com/wI2L/jettison"
) )
type HTTPCookies map[string]HTTPCookie type HTTPCookies struct {
values map[string]HTTPCookie
func NewHTTPCookies() HTTPCookies {
return make(HTTPCookies)
} }
func (c HTTPCookies) MarshalJSON() ([]byte, error) { func NewHTTPCookies() *HTTPCookies {
return jettison.MarshalOpts(map[string]HTTPCookie(c), jettison.NoHTMLEscaping()) return NewHTTPCookiesWith(make(map[string]HTTPCookie))
} }
func (c HTTPCookies) Type() core.Type { func NewHTTPCookiesWith(values map[string]HTTPCookie) *HTTPCookies {
return &HTTPCookies{values}
}
func (c *HTTPCookies) MarshalJSON() ([]byte, error) {
return jettison.MarshalOpts(c.values, jettison.NoHTMLEscaping())
}
func (c *HTTPCookies) Type() core.Type {
return HTTPCookiesType return HTTPCookiesType
} }
func (c HTTPCookies) String() string { func (c *HTTPCookies) String() string {
j, err := c.MarshalJSON() j, err := c.MarshalJSON()
if err != nil { if err != nil {
@ -37,21 +43,21 @@ func (c HTTPCookies) String() string {
return string(j) return string(j)
} }
func (c HTTPCookies) Compare(other core.Value) int64 { func (c *HTTPCookies) Compare(other core.Value) int64 {
if other.Type() != HTTPCookiesType { if other.Type() != HTTPCookiesType {
return Compare(HTTPCookiesType, other.Type()) return Compare(HTTPCookiesType, other.Type())
} }
oc := other.(HTTPCookies) oc := other.(*HTTPCookies)
switch { switch {
case len(c) > len(oc): case len(c.values) > len(oc.values):
return 1 return 1
case len(c) < len(oc): case len(c.values) < len(oc.values):
return -1 return -1
} }
for name := range c { for name := range c.values {
cEl, cExists := c.Get(values.NewString(name)) cEl, cExists := c.Get(values.NewString(name))
if !cExists { if !cExists {
@ -74,20 +80,20 @@ func (c HTTPCookies) Compare(other core.Value) int64 {
return 0 return 0
} }
func (c HTTPCookies) Unwrap() interface{} { func (c *HTTPCookies) Unwrap() interface{} {
return map[string]HTTPCookie(c) return c.values
} }
func (c HTTPCookies) Hash() uint64 { func (c *HTTPCookies) Hash() uint64 {
hash := fnv.New64a() hash := fnv.New64a()
hash.Write([]byte(c.Type().String())) hash.Write([]byte(c.Type().String()))
hash.Write([]byte(":")) hash.Write([]byte(":"))
hash.Write([]byte("{")) hash.Write([]byte("{"))
keys := make([]string, 0, len(c)) keys := make([]string, 0, len(c.values))
for key := range c { for key := range c.values {
keys = append(keys, key) keys = append(keys, key)
} }
@ -100,7 +106,7 @@ func (c HTTPCookies) Hash() uint64 {
hash.Write([]byte(key)) hash.Write([]byte(key))
hash.Write([]byte(":")) hash.Write([]byte(":"))
el := c[key] el := c.values[key]
bytes := make([]byte, 8) bytes := make([]byte, 8)
binary.LittleEndian.PutUint64(bytes, el.Hash()) binary.LittleEndian.PutUint64(bytes, el.Hash())
@ -117,47 +123,59 @@ func (c HTTPCookies) Hash() uint64 {
return hash.Sum64() return hash.Sum64()
} }
func (c HTTPCookies) Copy() core.Value { func (c *HTTPCookies) Copy() core.Value {
copied := make(HTTPCookies) return NewHTTPCookiesWith(c.values)
}
for k, v := range c { func (c *HTTPCookies) Clone() core.Cloneable {
copied[k] = v clone := make(map[string]HTTPCookie)
for _, cookie := range c.values {
clone[cookie.Name] = cookie
} }
return copied return NewHTTPCookiesWith(clone)
} }
func (c HTTPCookies) Length() values.Int { func (c *HTTPCookies) Length() values.Int {
return values.NewInt(len(c)) return values.NewInt(len(c.values))
} }
func (c HTTPCookies) Keys() []values.String { func (c *HTTPCookies) Keys() []values.String {
keys := make([]values.String, 0, len(c)) result := make([]values.String, 0, len(c.values))
for k := range c { for k := range c.values {
keys = append(keys, values.NewString(k)) result = append(result, values.NewString(k))
} }
return keys return result
} }
func (c HTTPCookies) Get(key values.String) (core.Value, values.Boolean) { func (c *HTTPCookies) Values() []HTTPCookie {
value, found := c[key.String()] result := make([]HTTPCookie, 0, len(c.values))
for _, v := range c.values {
result = append(result, v)
}
return result
}
func (c *HTTPCookies) Get(key values.String) (HTTPCookie, values.Boolean) {
value, found := c.values[key.String()]
if found { if found {
return value, values.True return value, values.True
} }
return values.None, values.False return HTTPCookie{}, values.False
} }
func (c HTTPCookies) Set(key values.String, value core.Value) { func (c *HTTPCookies) Set(cookie HTTPCookie) {
if cookie, ok := value.(HTTPCookie); ok { c.values[cookie.Name] = cookie
c[key.String()] = cookie
}
} }
func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { func (c *HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, error) {
if len(path) == 0 { if len(path) == 0 {
return values.None, nil return values.None, nil
} }
@ -170,7 +188,7 @@ func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value,
return values.None, err return values.None, err
} }
cookie, found := c[segment.String()] cookie, found := c.values[segment.String()]
if found { if found {
if len(path) == 1 { if len(path) == 1 {
@ -182,3 +200,11 @@ func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value,
return values.None, nil return values.None, nil
} }
func (c *HTTPCookies) ForEach(predicate func(value HTTPCookie, key values.String) bool) {
for key, val := range c.values {
if !predicate(val, values.NewString(key)) {
break
}
}
}

View File

@ -0,0 +1,65 @@
package drivers_test
import (
"fmt"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"github.com/wI2L/jettison"
"github.com/MontFerret/ferret/pkg/drivers"
)
func TestHTTPCookies(t *testing.T) {
Convey("HTTPCookies", t, func() {
Convey(".MarshalJSON", func() {
Convey("Should serialize cookies", func() {
expires := time.Now()
headers := drivers.NewHTTPCookiesWith(map[string]drivers.HTTPCookie{
"Session": {
Name: "Session",
Value: "asdfg",
Path: "/",
Domain: "www.google.com",
Expires: expires,
MaxAge: 0,
Secure: true,
HTTPOnly: true,
SameSite: drivers.SameSiteLaxMode,
},
})
out, err := headers.MarshalJSON()
t, e := expires.MarshalJSON()
So(e, ShouldBeNil)
expected := fmt.Sprintf(`{"Session":{"domain":"www.google.com","expires":%s,"http_only":true,"max_age":0,"name":"Session","path":"/","same_site":"Lax","secure":true,"value":"asdfg"}}`, string(t))
So(err, ShouldBeNil)
So(string(out), ShouldEqual, expected)
})
Convey("Should set proper values", func() {
headers := drivers.NewHTTPCookies()
headers.Set(drivers.HTTPCookie{
Name: "Authorization",
Value: "e40b7d5eff464a4fb51efed2d1a19a24",
Path: "/",
Domain: "www.google.com",
Expires: time.Now(),
MaxAge: 0,
Secure: false,
HTTPOnly: false,
SameSite: 0,
})
_, err := jettison.MarshalOpts(headers, jettison.NoHTMLEscaping())
So(err, ShouldBeNil)
})
})
})
}

View File

@ -11,7 +11,7 @@ type (
ctxKey struct{} ctxKey struct{}
ctxValue struct { ctxValue struct {
opts *options opts *globalOptions
drivers map[string]Driver drivers map[string]Driver
} }
@ -23,7 +23,7 @@ type (
} }
) )
func WithContext(ctx context.Context, drv Driver, opts ...Option) context.Context { func WithContext(ctx context.Context, drv Driver, opts ...GlobalOption) context.Context {
ctx, value := resolveValue(ctx) ctx, value := resolveValue(ctx)
value.drivers[drv.Name()] = drv value.drivers[drv.Name()] = drv
@ -63,7 +63,7 @@ func resolveValue(ctx context.Context) (context.Context, *ctxValue) {
if !ok { if !ok {
value = &ctxValue{ value = &ctxValue{
opts: &options{}, opts: &globalOptions{},
drivers: make(map[string]Driver), drivers: make(map[string]Driver),
} }

View File

@ -17,40 +17,50 @@ import (
) )
// HTTPHeaders HTTP header object // HTTPHeaders HTTP header object
type HTTPHeaders map[string][]string type HTTPHeaders struct {
values map[string][]string
func NewHTTPHeaders(values map[string][]string) HTTPHeaders {
return HTTPHeaders(values)
} }
func (h HTTPHeaders) Type() core.Type { func NewHTTPHeaders() *HTTPHeaders {
return NewHTTPHeadersWith(make(map[string][]string))
}
func NewHTTPHeadersWith(values map[string][]string) *HTTPHeaders {
return &HTTPHeaders{values}
}
func (h *HTTPHeaders) Length() values.Int {
return values.NewInt(len(h.values))
}
func (h *HTTPHeaders) Type() core.Type {
return HTTPHeaderType return HTTPHeaderType
} }
func (h HTTPHeaders) String() string { func (h *HTTPHeaders) String() string {
var buf bytes.Buffer var buf bytes.Buffer
for k := range h { for k := range h.values {
buf.WriteString(fmt.Sprintf("%s=%s;", k, h.Get(k))) buf.WriteString(fmt.Sprintf("%s=%s;", k, h.Get(k)))
} }
return buf.String() return buf.String()
} }
func (h HTTPHeaders) Compare(other core.Value) int64 { func (h *HTTPHeaders) Compare(other core.Value) int64 {
if other.Type() != HTTPHeaderType { if other.Type() != HTTPHeaderType {
return Compare(HTTPHeaderType, other.Type()) return Compare(HTTPHeaderType, other.Type())
} }
oh := other.(HTTPHeaders) oh := other.(*HTTPHeaders)
if len(h) > len(oh) { if len(h.values) > len(oh.values) {
return 1 return 1
} else if len(h) < len(oh) { } else if len(h.values) < len(oh.values) {
return -1 return -1
} }
for k := range h { for k := range h.values {
c := strings.Compare(h.Get(k), oh.Get(k)) c := strings.Compare(h.Get(k), oh.Get(k))
if c != 0 { if c != 0 {
@ -61,20 +71,20 @@ func (h HTTPHeaders) Compare(other core.Value) int64 {
return 0 return 0
} }
func (h HTTPHeaders) Unwrap() interface{} { func (h *HTTPHeaders) Unwrap() interface{} {
return h return h.values
} }
func (h HTTPHeaders) Hash() uint64 { func (h *HTTPHeaders) Hash() uint64 {
hash := fnv.New64a() hash := fnv.New64a()
hash.Write([]byte(h.Type().String())) hash.Write([]byte(h.Type().String()))
hash.Write([]byte(":")) hash.Write([]byte(":"))
hash.Write([]byte("{")) hash.Write([]byte("{"))
keys := make([]string, 0, len(h)) keys := make([]string, 0, len(h.values))
for key := range h { for key := range h.values {
keys = append(keys, key) keys = append(keys, key)
} }
@ -101,18 +111,28 @@ func (h HTTPHeaders) Hash() uint64 {
return hash.Sum64() return hash.Sum64()
} }
func (h HTTPHeaders) Copy() core.Value { func (h *HTTPHeaders) Copy() core.Value {
return *(&h) return &HTTPHeaders{h.values}
} }
func (h HTTPHeaders) MarshalJSON() ([]byte, error) { func (h *HTTPHeaders) Clone() core.Cloneable {
cp := make(map[string][]string)
for k, v := range h.values {
cp[k] = v
}
return &HTTPHeaders{cp}
}
func (h *HTTPHeaders) MarshalJSON() ([]byte, error) {
headers := map[string]string{} headers := map[string]string{}
for key, val := range h { for key, val := range h.values {
headers[key] = strings.Join(val, ", ") headers[key] = strings.Join(val, ", ")
} }
out, err := jettison.MarshalOpts(headers, jettison.NoHTMLEscaping()) out, err := jettison.MarshalOpts(headers)
if err != nil { if err != nil {
return nil, err return nil, err
@ -121,15 +141,25 @@ func (h HTTPHeaders) MarshalJSON() ([]byte, error) {
return out, err return out, err
} }
func (h HTTPHeaders) Set(key, value string) { func (h *HTTPHeaders) Set(key, value string) {
textproto.MIMEHeader(h).Set(key, value) textproto.MIMEHeader(h.values).Set(key, value)
} }
func (h HTTPHeaders) Get(key string) string { func (h *HTTPHeaders) SetArr(key string, value []string) {
return textproto.MIMEHeader(h).Get(key) h.values[key] = value
} }
func (h HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, error) { func (h *HTTPHeaders) Get(key string) string {
_, found := h.values[key]
if !found {
return ""
}
return textproto.MIMEHeader(h.values).Get(key)
}
func (h *HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, error) {
if len(path) == 0 { if len(path) == 0 {
return values.None, nil return values.None, nil
} }
@ -144,3 +174,11 @@ func (h HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, er
return values.NewString(h.Get(segment.String())), nil return values.NewString(h.Get(segment.String())), nil
} }
func (h *HTTPHeaders) ForEach(predicate func(value []string, key string) bool) {
for key, val := range h.values {
if !predicate(val, key) {
break
}
}
}

View File

@ -4,24 +4,35 @@ import (
"testing" "testing"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
"github.com/wI2L/jettison"
"github.com/MontFerret/ferret/pkg/drivers" "github.com/MontFerret/ferret/pkg/drivers"
) )
func TestHTTPHeader(t *testing.T) { func TestHTTPHeaders(t *testing.T) {
Convey("HTTPHeaders", t, func() { Convey("HTTPHeaders", t, func() {
Convey(".MarshalJSON", func() { Convey(".MarshalJSON", func() {
Convey("Should serialize header values", func() { Convey("Should serialize header values", func() {
headers := make(drivers.HTTPHeaders) headers := drivers.NewHTTPHeadersWith(map[string][]string{
"Content-Encoding": []string{"gzip"},
headers["Content-Encoding"] = []string{"gzip"} "Content-Type": []string{"text/html", "charset=utf-8"},
headers["Content-Type"] = []string{"text/html", "charset=utf-8"} })
out, err := headers.MarshalJSON() out, err := headers.MarshalJSON()
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(string(out), ShouldEqual, `{"Content-Encoding":"gzip","Content-Type":"text/html, charset=utf-8"}`) So(string(out), ShouldEqual, `{"Content-Encoding":"gzip","Content-Type":"text/html, charset=utf-8"}`)
}) })
Convey("Should set proper values", func() {
headers := drivers.NewHTTPHeaders()
headers.Set("Authorization", `["Basic e40b7d5eff464a4fb51efed2d1a19a24"]`)
_, err := jettison.MarshalOpts(headers, jettison.NoHTMLEscaping())
So(err, ShouldBeNil)
})
}) })
}) })
} }

View File

@ -2,6 +2,7 @@ package drivers
import ( import (
"github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/core"
"github.com/MontFerret/ferret/pkg/runtime/values"
) )
func ToPage(value core.Value) (HTMLPage, error) { func ToPage(value core.Value) (HTMLPage, error) {
@ -46,3 +47,48 @@ func ToElement(value core.Value) (HTMLElement, error) {
) )
} }
} }
func SetDefaultParams(opts *Options, params Params) Params {
if params.Headers == nil && opts.Headers != nil {
params.Headers = NewHTTPHeaders()
}
// set default headers
if opts.Headers != nil {
opts.Headers.ForEach(func(value []string, key string) bool {
val := params.Headers.Get(key)
// do not override user's set values
if val == "" {
params.Headers.SetArr(key, value)
}
return true
})
}
if params.Cookies == nil && opts.Cookies != nil {
params.Cookies = NewHTTPCookies()
}
// set default cookies
if opts.Cookies != nil {
opts.Cookies.ForEach(func(value HTTPCookie, key values.String) bool {
_, exists := params.Cookies.Get(key)
// do not override user's set values
if !exists {
params.Cookies.Set(value)
}
return true
})
}
// set default user agent
if opts.UserAgent != "" && params.UserAgent == "" {
params.UserAgent = opts.UserAgent
}
return params
}

View File

@ -0,0 +1,41 @@
package drivers_test
import (
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"github.com/MontFerret/ferret/pkg/drivers"
)
func TestSetDefaultParams(t *testing.T) {
Convey("Should take values from Options if not present in Params", t, func() {
opts := &drivers.Options{
Name: "Test",
UserAgent: "Mozilla",
Headers: drivers.NewHTTPHeadersWith(map[string][]string{
"Accept": {"application/json"},
}),
Cookies: drivers.NewHTTPCookiesWith(map[string]drivers.HTTPCookie{
"Session": drivers.HTTPCookie{
Name: "Session",
Value: "fsfsdfsd",
Path: "",
Domain: "",
Expires: time.Time{},
MaxAge: 0,
Secure: false,
HTTPOnly: false,
SameSite: 0,
},
}),
}
params := drivers.SetDefaultParams(opts, drivers.Params{})
So(params.UserAgent, ShouldEqual, opts.UserAgent)
So(params.Headers, ShouldNotBeNil)
So(params.Cookies, ShouldNotBeNil)
})
}

View File

@ -25,7 +25,7 @@ type Driver struct {
func NewDriver(opts ...Option) *Driver { func NewDriver(opts ...Option) *Driver {
drv := new(Driver) drv := new(Driver)
drv.options = newOptions(opts) drv.options = NewOptions(opts)
drv.client = newHTTPClient(drv.options) drv.client = newHTTPClient(drv.options)
drv.client.Concurrency = drv.options.Concurrency drv.client.Concurrency = drv.options.Concurrency
@ -96,63 +96,11 @@ func (drv *Driver) Open(ctx context.Context, params drivers.Params) (drivers.HTM
req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache") req.Header.Set("Pragma", "no-cache")
if drv.options.Headers != nil && params.Headers == nil { params = drivers.SetDefaultParams(drv.options.Options, params)
params.Headers = make(drivers.HTTPHeaders)
}
// Set default headers
for k, v := range drv.options.Headers {
_, exists := params.Headers[k]
// do not override user's set values
if !exists {
params.Headers[k] = v
}
}
for k := range params.Headers {
req.Header.Add(k, params.Headers.Get(k))
logger.
Debug().
Timestamp().
Str("header", k).
Msg("set header")
}
if drv.options.Cookies != nil && params.Cookies == nil {
params.Cookies = make(drivers.HTTPCookies)
}
// set default cookies
for k, v := range drv.options.Cookies {
_, exists := params.Cookies[k]
// do not override user's set values
if !exists {
params.Cookies[k] = v
}
}
for _, c := range params.Cookies {
req.AddCookie(fromDriverCookie(c))
logger.
Debug().
Timestamp().
Str("cookie", c.Name).
Msg("set cookie")
}
req = req.WithContext(ctx) req = req.WithContext(ctx)
var ua string ua := common.GetUserAgent(params.UserAgent)
if params.UserAgent != "" {
ua = common.GetUserAgent(params.UserAgent)
} else {
ua = common.GetUserAgent(drv.options.UserAgent)
}
logger. logger.
Debug(). Debug().
@ -197,7 +145,7 @@ func (drv *Driver) Open(ctx context.Context, params drivers.Params) (drivers.HTM
r := drivers.HTTPResponse{ r := drivers.HTTPResponse{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Status: resp.Status, Status: resp.Status,
Headers: drivers.HTTPHeaders(resp.Header), Headers: drivers.NewHTTPHeadersWith(resp.Header),
} }
return NewHTMLPage(doc, params.URL, r, cookies) return NewHTMLPage(doc, params.URL, r, cookies)

View File

@ -2,6 +2,7 @@ package http
import ( import (
"crypto/tls" "crypto/tls"
"github.com/MontFerret/ferret/pkg/drivers"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
@ -25,14 +26,18 @@ func Test_newHTTPClientWithTransport(t *testing.T) {
{ {
name: "check transport exist with pester.New()", name: "check transport exist with pester.New()",
args: args{options: &Options{ args: args{options: &Options{
Proxy: "http://0.0.0.|", Options: &drivers.Options{
Proxy: "http://0.0.0.|",
},
HTTPTransport: httpTransport, HTTPTransport: httpTransport,
}}, }},
}, },
{ {
name: "check transport exist with pester.NewExtendedClient()", name: "check transport exist with pester.NewExtendedClient()",
args: args{options: &Options{ args: args{options: &Options{
Proxy: "http://0.0.0.0", Options: &drivers.Options{
Proxy: "http://0.0.0.0",
},
HTTPTransport: httpTransport, HTTPTransport: httpTransport,
}}, }},
}, },
@ -69,7 +74,9 @@ func Test_newHTTPClient(t *testing.T) {
convey.Convey("pester.New()", t, func() { convey.Convey("pester.New()", t, func() {
var ( var (
client = newHTTPClient(&Options{ client = newHTTPClient(&Options{
Proxy: "http://0.0.0.|", Options: &drivers.Options{
Proxy: "http://0.0.0.|",
},
}) })
rValue = reflect.ValueOf(client).Elem() rValue = reflect.ValueOf(client).Elem()
@ -85,7 +92,9 @@ func Test_newHTTPClient(t *testing.T) {
convey.Convey("pester.NewExtend()", t, func() { convey.Convey("pester.NewExtend()", t, func() {
var ( var (
client = newHTTPClient(&Options{ client = newHTTPClient(&Options{
Proxy: "http://0.0.0.0", Options: &drivers.Options{
Proxy: "http://0.0.0.0",
},
}) })
rValue = reflect.ValueOf(client).Elem() rValue = reflect.ValueOf(client).Elem()

View File

@ -54,8 +54,8 @@ func outerHTML(s *goquery.Selection) (string, error) {
return buf.String(), nil return buf.String(), nil
} }
func toDriverCookies(cookies []*HTTP.Cookie) (drivers.HTTPCookies, error) { func toDriverCookies(cookies []*HTTP.Cookie) (*drivers.HTTPCookies, error) {
res := make(drivers.HTTPCookies) res := drivers.NewHTTPCookies()
for _, c := range cookies { for _, c := range cookies {
dc, err := toDriverCookie(c) dc, err := toDriverCookie(c)
@ -64,7 +64,7 @@ func toDriverCookies(cookies []*HTTP.Cookie) (drivers.HTTPCookies, error) {
return nil, err return nil, err
} }
res[dc.Name] = dc res.Set(dc)
} }
return res, nil return res, nil

View File

@ -1,11 +1,17 @@
package http package http
import ( import (
"github.com/gobwas/glob"
stdhttp "net/http" stdhttp "net/http"
"github.com/MontFerret/ferret/pkg/drivers" "github.com/gobwas/glob"
"github.com/sethgrid/pester" "github.com/sethgrid/pester"
"github.com/MontFerret/ferret/pkg/drivers"
)
var (
DefaultConcurrency = 3
DefaultMaxRetries = 5
) )
type ( type (
@ -17,25 +23,22 @@ type (
} }
Options struct { Options struct {
Name string *drivers.Options
Backoff pester.BackoffStrategy Backoff pester.BackoffStrategy
MaxRetries int MaxRetries int
Concurrency int Concurrency int
Proxy string
UserAgent string
Headers drivers.HTTPHeaders
Cookies drivers.HTTPCookies
HTTPCodesFilter []compiledStatusCodeFilter HTTPCodesFilter []compiledStatusCodeFilter
HTTPTransport *stdhttp.Transport HTTPTransport *stdhttp.Transport
} }
) )
func newOptions(setters []Option) *Options { func NewOptions(setters []Option) *Options {
opts := new(Options) opts := new(Options)
opts.Options = new(drivers.Options)
opts.Name = DriverName opts.Name = DriverName
opts.Backoff = pester.ExponentialBackoff opts.Backoff = pester.ExponentialBackoff
opts.Concurrency = 3 opts.Concurrency = DefaultConcurrency
opts.MaxRetries = 5 opts.MaxRetries = DefaultMaxRetries
opts.HTTPCodesFilter = make([]compiledStatusCodeFilter, 0, 5) opts.HTTPCodesFilter = make([]compiledStatusCodeFilter, 0, 5)
for _, setter := range setters { for _, setter := range setters {
@ -77,63 +80,43 @@ func WithConcurrency(value int) Option {
func WithProxy(address string) Option { func WithProxy(address string) Option {
return func(opts *Options) { return func(opts *Options) {
opts.Proxy = address drivers.WithProxy(address)(opts.Options)
} }
} }
func WithUserAgent(value string) Option { func WithUserAgent(value string) Option {
return func(opts *Options) { return func(opts *Options) {
opts.UserAgent = value drivers.WithUserAgent(value)(opts.Options)
} }
} }
func WithCustomName(name string) Option { func WithCustomName(name string) Option {
return func(opts *Options) { return func(opts *Options) {
opts.Name = name drivers.WithCustomName(name)(opts.Options)
} }
} }
func WithHeader(name string, value []string) Option { func WithHeader(name string, value []string) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Headers == nil { drivers.WithHeader(name, value)(opts.Options)
opts.Headers = make(drivers.HTTPHeaders)
}
opts.Headers[name] = value
} }
} }
func WithHeaders(headers drivers.HTTPHeaders) Option { func WithHeaders(headers *drivers.HTTPHeaders) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Headers == nil { drivers.WithHeaders(headers)(opts.Options)
opts.Headers = make(drivers.HTTPHeaders)
}
for k, v := range headers {
opts.Headers[k] = v
}
} }
} }
func WithCookie(cookie drivers.HTTPCookie) Option { func WithCookie(cookie drivers.HTTPCookie) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Cookies == nil { drivers.WithCookie(cookie)(opts.Options)
opts.Cookies = make(drivers.HTTPCookies)
}
opts.Cookies[cookie.Name] = cookie
} }
} }
func WithCookies(cookies []drivers.HTTPCookie) Option { func WithCookies(cookies []drivers.HTTPCookie) Option {
return func(opts *Options) { return func(opts *Options) {
if opts.Cookies == nil { drivers.WithCookies(cookies)(opts.Options)
opts.Cookies = make(drivers.HTTPCookies)
}
for _, c := range cookies {
opts.Cookies[c.Name] = c
}
} }
} }

View File

@ -0,0 +1,85 @@
package http_test
import (
stdhttp "net/http"
"testing"
"time"
"github.com/sethgrid/pester"
. "github.com/smartystreets/goconvey/convey"
"github.com/MontFerret/ferret/pkg/drivers"
"github.com/MontFerret/ferret/pkg/drivers/http"
)
func TestNewOptions(t *testing.T) {
Convey("Should create driver options with initial values", t, func() {
opts := http.NewOptions([]http.Option{})
So(opts.Options, ShouldNotBeNil)
So(opts.Name, ShouldEqual, http.DriverName)
So(opts.Backoff, ShouldEqual, pester.ExponentialBackoff)
So(opts.Concurrency, ShouldEqual, http.DefaultConcurrency)
So(opts.MaxRetries, ShouldEqual, http.DefaultMaxRetries)
So(opts.HTTPCodesFilter, ShouldHaveLength, 0)
})
Convey("Should use setters to set values", t, func() {
expectedName := http.DriverName + "2"
expectedUA := "Mozilla"
expectedProxy := "https://proxy.com"
expectedMaxRetries := 2
expectedConcurrency := 10
expectedTransport := &stdhttp.Transport{}
opts := http.NewOptions([]http.Option{
http.WithCustomName(expectedName),
http.WithUserAgent(expectedUA),
http.WithProxy(expectedProxy),
http.WithCookie(drivers.HTTPCookie{
Name: "Session",
Value: "fsdfsdfs",
Path: "dfsdfsd",
Domain: "sfdsfs",
Expires: time.Time{},
MaxAge: 0,
Secure: false,
HTTPOnly: false,
SameSite: 0,
}),
http.WithCookies([]drivers.HTTPCookie{
{
Name: "Use",
Value: "Foos",
Path: "",
Domain: "",
Expires: time.Time{},
MaxAge: 0,
Secure: false,
HTTPOnly: false,
SameSite: 0,
},
}),
http.WithHeader("Authorization", []string{"Bearer dfsd7f98sd9fsd9fsd"}),
http.WithHeaders(drivers.NewHTTPHeadersWith(map[string][]string{
"x-correlation-id": {"232483833833839"},
})),
http.WithDefaultBackoff(),
http.WithMaxRetries(expectedMaxRetries),
http.WithConcurrency(expectedConcurrency),
http.WithAllowedHTTPCode(401),
http.WithAllowedHTTPCodes([]int{403, 404}),
http.WithCustomTransport(expectedTransport),
})
So(opts.Options, ShouldNotBeNil)
So(opts.Name, ShouldEqual, expectedName)
So(opts.UserAgent, ShouldEqual, expectedUA)
So(opts.Proxy, ShouldEqual, expectedProxy)
So(opts.Cookies.Length(), ShouldEqual, 2)
So(opts.Headers.Length(), ShouldEqual, 2)
So(opts.Backoff, ShouldEqual, pester.DefaultBackoff)
So(opts.MaxRetries, ShouldEqual, expectedMaxRetries)
So(opts.Concurrency, ShouldEqual, expectedConcurrency)
So(opts.HTTPCodesFilter, ShouldHaveLength, 3)
So(opts.HTTPTransport, ShouldEqual, expectedTransport)
})
}

View File

@ -14,7 +14,7 @@ import (
type HTMLPage struct { type HTMLPage struct {
document *HTMLDocument document *HTMLDocument
cookies drivers.HTTPCookies cookies *drivers.HTTPCookies
frames *values.Array frames *values.Array
response drivers.HTTPResponse response drivers.HTTPResponse
} }
@ -23,7 +23,7 @@ func NewHTMLPage(
qdoc *goquery.Document, qdoc *goquery.Document,
url string, url string,
response drivers.HTTPResponse, response drivers.HTTPResponse,
cookies drivers.HTTPCookies, cookies *drivers.HTTPCookies,
) (*HTMLPage, error) { ) (*HTMLPage, error) {
doc, err := NewRootHTMLDocument(qdoc, url) doc, err := NewRootHTMLDocument(qdoc, url)
@ -84,10 +84,10 @@ func (p *HTMLPage) Hash() uint64 {
} }
func (p *HTMLPage) Copy() core.Value { func (p *HTMLPage) Copy() core.Value {
cookies := make(drivers.HTTPCookies) var cookies *drivers.HTTPCookies
for k, v := range p.cookies { if p.cookies != nil {
cookies[k] = v cookies = p.cookies.Copy().(*drivers.HTTPCookies)
} }
page, err := NewHTMLPage( page, err := NewHTMLPage(
@ -168,11 +168,15 @@ func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, er
return p.frames.Get(idx), nil return p.frames.Get(idx), nil
} }
func (p *HTMLPage) GetCookies(_ context.Context) (drivers.HTTPCookies, error) { func (p *HTMLPage) GetCookies(_ context.Context) (*drivers.HTTPCookies, error) {
res := make(drivers.HTTPCookies) res := drivers.NewHTTPCookies()
for n, v := range p.cookies { if p.cookies != nil {
res[n] = v p.cookies.ForEach(func(value drivers.HTTPCookie, _ values.String) bool {
res.Set(value)
return true
})
} }
return res, nil return res, nil
@ -182,11 +186,11 @@ func (p *HTMLPage) GetResponse(_ context.Context) (drivers.HTTPResponse, error)
return p.response, nil return p.response, nil
} }
func (p *HTMLPage) SetCookies(_ context.Context, _ drivers.HTTPCookies) error { func (p *HTMLPage) SetCookies(_ context.Context, _ *drivers.HTTPCookies) error {
return core.ErrNotSupported return core.ErrNotSupported
} }
func (p *HTMLPage) DeleteCookies(_ context.Context, _ drivers.HTTPCookies) error { func (p *HTMLPage) DeleteCookies(_ context.Context, _ *drivers.HTTPCookies) error {
return core.ErrNotSupported return core.ErrNotSupported
} }

View File

@ -1,15 +1,89 @@
package drivers package drivers
type ( type (
options struct { globalOptions struct {
defaultDriver string defaultDriver string
} }
Option func(drv Driver, opts *options) GlobalOption func(drv Driver, opts *globalOptions)
Options struct {
Name string
Proxy string
UserAgent string
Headers *HTTPHeaders
Cookies *HTTPCookies
}
Option func(opts *Options)
) )
func AsDefault() Option { func AsDefault() GlobalOption {
return func(drv Driver, opts *options) { return func(drv Driver, opts *globalOptions) {
opts.defaultDriver = drv.Name() opts.defaultDriver = drv.Name()
} }
} }
func WithProxy(address string) Option {
return func(opts *Options) {
opts.Proxy = address
}
}
func WithUserAgent(value string) Option {
return func(opts *Options) {
opts.UserAgent = value
}
}
func WithCustomName(name string) Option {
return func(opts *Options) {
opts.Name = name
}
}
func WithHeader(name string, value []string) Option {
return func(opts *Options) {
if opts.Headers == nil {
opts.Headers = NewHTTPHeaders()
}
opts.Headers.SetArr(name, value)
}
}
func WithHeaders(headers *HTTPHeaders) Option {
return func(opts *Options) {
if opts.Headers == nil {
opts.Headers = NewHTTPHeaders()
}
headers.ForEach(func(value []string, key string) bool {
opts.Headers.SetArr(key, value)
return true
})
}
}
func WithCookie(cookie HTTPCookie) Option {
return func(opts *Options) {
if opts.Cookies == nil {
opts.Cookies = NewHTTPCookies()
}
opts.Cookies.Set(cookie)
}
}
func WithCookies(cookies []HTTPCookie) Option {
return func(opts *Options) {
if opts.Cookies == nil {
opts.Cookies = NewHTTPCookies()
}
for _, c := range cookies {
opts.Cookies.Set(c)
}
}
}

View File

@ -28,8 +28,8 @@ type (
URL string URL string
UserAgent string UserAgent string
KeepCookies bool KeepCookies bool
Cookies HTTPCookies Cookies *HTTPCookies
Headers HTTPHeaders Headers *HTTPHeaders
Viewport *Viewport Viewport *Viewport
Ignore *Ignore Ignore *Ignore
} }
@ -37,8 +37,8 @@ type (
ParseParams struct { ParseParams struct {
Content []byte Content []byte
KeepCookies bool KeepCookies bool
Cookies HTTPCookies Cookies *HTTPCookies
Headers HTTPHeaders Headers *HTTPHeaders
Viewport *Viewport Viewport *Viewport
} }
) )

View File

@ -14,7 +14,7 @@ import (
type HTTPResponse struct { type HTTPResponse struct {
StatusCode int StatusCode int
Status string Status string
Headers HTTPHeaders Headers *HTTPHeaders
} }
func (resp *HTTPResponse) Type() core.Type { func (resp *HTTPResponse) Type() core.Type {
@ -60,9 +60,9 @@ func (resp *HTTPResponse) Hash() uint64 {
// responseMarshal is a structure that repeats HTTPResponse. It allows // responseMarshal is a structure that repeats HTTPResponse. It allows
// easily Marshal the HTTPResponse object. // easily Marshal the HTTPResponse object.
type responseMarshal struct { type responseMarshal struct {
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"`
Status string `json:"status"` Status string `json:"status"`
Headers HTTPHeaders `json:"headers"` Headers *HTTPHeaders `json:"headers"`
} }
func (resp *HTTPResponse) MarshalJSON() ([]byte, error) { func (resp *HTTPResponse) MarshalJSON() ([]byte, error) {

View File

@ -196,11 +196,11 @@ type (
GetFrame(ctx context.Context, idx values.Int) (core.Value, error) GetFrame(ctx context.Context, idx values.Int) (core.Value, error)
GetCookies(ctx context.Context) (HTTPCookies, error) GetCookies(ctx context.Context) (*HTTPCookies, error)
SetCookies(ctx context.Context, cookies HTTPCookies) error SetCookies(ctx context.Context, cookies *HTTPCookies) error
DeleteCookies(ctx context.Context, cookies HTTPCookies) error DeleteCookies(ctx context.Context, cookies *HTTPCookies) error
GetResponse(ctx context.Context) (HTTPResponse, error) GetResponse(ctx context.Context) (HTTPResponse, error)

View File

@ -26,8 +26,8 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) {
} }
inputs := args[1:] inputs := args[1:]
var currentCookies drivers.HTTPCookies var currentCookies *drivers.HTTPCookies
cookies := make(drivers.HTTPCookies) cookies := drivers.NewHTTPCookies()
for _, c := range inputs { for _, c := range inputs {
switch cookie := c.(type) { switch cookie := c.(type) {
@ -42,14 +42,14 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) {
currentCookies = current currentCookies = current
} }
found, isFound := currentCookies[cookie.String()] found, isFound := currentCookies.Get(cookie)
if isFound { if isFound {
cookies[cookie.String()] = found cookies.Set(found)
} }
case drivers.HTTPCookie: case drivers.HTTPCookie:
cookies[cookie.Name] = cookie cookies.Set(cookie)
default: default:
return values.None, core.TypeError(c.Type(), types.String, drivers.HTTPCookieType) return values.None, core.TypeError(c.Type(), types.String, drivers.HTTPCookieType)
} }

View File

@ -40,7 +40,7 @@ func CookieGet(ctx context.Context, args ...core.Value) (core.Value, error) {
return values.None, err return values.None, err
} }
cookie, found := cookies[name.String()] cookie, found := cookies.Get(name)
if found { if found {
return cookie, nil return cookie, nil

View File

@ -24,7 +24,7 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) {
return values.None, err return values.None, err
} }
cookies := make(drivers.HTTPCookies) cookies := drivers.NewHTTPCookies()
for _, c := range args[1:] { for _, c := range args[1:] {
cookie, err := parseCookie(c) cookie, err := parseCookie(c)
@ -33,7 +33,7 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) {
return values.None, err return values.None, err
} }
cookies[cookie.Name] = cookie cookies.Set(cookie)
} }
return values.None, page.SetCookies(ctx, cookies) return values.None, page.SetCookies(ctx, cookies)

View File

@ -168,7 +168,7 @@ func newPageLoadParams(url values.String, arg core.Value) (PageLoadParams, error
res.Cookies = cookies res.Cookies = cookies
default: default:
res.Cookies = make(drivers.HTTPCookies) res.Cookies = drivers.NewHTTPCookies()
} }
} }
@ -220,9 +220,13 @@ func newPageLoadParams(url values.String, arg core.Value) (PageLoadParams, error
return res, nil return res, nil
} }
func parseCookieObject(obj *values.Object) (drivers.HTTPCookies, error) { func parseCookieObject(obj *values.Object) (*drivers.HTTPCookies, error) {
if obj == nil {
return nil, errors.Wrap(core.ErrMissedArgument, "cookies")
}
var err error var err error
res := make(drivers.HTTPCookies) res := drivers.NewHTTPCookies()
obj.ForEach(func(value core.Value, _ string) bool { obj.ForEach(func(value core.Value, _ string) bool {
cookie, e := parseCookie(value) cookie, e := parseCookie(value)
@ -233,7 +237,7 @@ func parseCookieObject(obj *values.Object) (drivers.HTTPCookies, error) {
return false return false
} }
res[cookie.Name] = cookie res.Set(cookie)
return true return true
}) })
@ -241,9 +245,13 @@ func parseCookieObject(obj *values.Object) (drivers.HTTPCookies, error) {
return res, err return res, err
} }
func parseCookieArray(arr *values.Array) (drivers.HTTPCookies, error) { func parseCookieArray(arr *values.Array) (*drivers.HTTPCookies, error) {
if arr == nil {
return nil, errors.Wrap(core.ErrMissedArgument, "cookies")
}
var err error var err error
res := make(drivers.HTTPCookies) res := drivers.NewHTTPCookies()
arr.ForEach(func(value core.Value, _ int) bool { arr.ForEach(func(value core.Value, _ int) bool {
cookie, e := parseCookie(value) cookie, e := parseCookie(value)
@ -254,7 +262,7 @@ func parseCookieArray(arr *values.Array) (drivers.HTTPCookies, error) {
return false return false
} }
res[cookie.Name] = cookie res.Set(cookie)
return true return true
}) })
@ -350,11 +358,25 @@ func parseCookie(value core.Value) (drivers.HTTPCookie, error) {
return cookie, err return cookie, err
} }
func parseHeader(headers *values.Object) drivers.HTTPHeaders { func parseHeader(headers *values.Object) *drivers.HTTPHeaders {
res := make(drivers.HTTPHeaders) res := drivers.NewHTTPHeaders()
headers.ForEach(func(value core.Value, key string) bool { headers.ForEach(func(value core.Value, key string) bool {
res.Set(key, value.String()) if value.Type() == types.Array {
value := value.(*values.Array)
keyValues := make([]string, 0, value.Length())
value.ForEach(func(v core.Value, idx int) bool {
keyValues = append(keyValues, v.String())
return true
})
res.SetArr(key, keyValues)
} else {
res.Set(key, value.String())
}
return true return true
}) })

View File

@ -132,7 +132,7 @@ func parseParseParams(content []byte, arg *values.Object) (ParseParams, error) {
res.Cookies = cookies res.Cookies = cookies
default: default:
res.Cookies = make(drivers.HTTPCookies) res.Cookies = drivers.NewHTTPCookies()
} }
} }