From 64b7d3415a9c07f8802a8c5e4de40ab051de6eaf Mon Sep 17 00:00:00 2001 From: Pete Davison Date: Mon, 25 Mar 2024 19:05:21 +0000 Subject: [PATCH] feat: use timeout in RemoteExists function --- setup.go | 2 +- taskfile/node.go | 7 +++++-- taskfile/node_http.go | 17 +++++++++++++++-- taskfile/reader.go | 2 +- taskfile/taskfile.go | 5 +++-- 5 files changed, 25 insertions(+), 8 deletions(-) diff --git a/setup.go b/setup.go index bef2ad37..8bebdbc7 100644 --- a/setup.go +++ b/setup.go @@ -54,7 +54,7 @@ func (e *Executor) Setup() error { } func (e *Executor) getRootNode() (taskfile.Node, error) { - node, err := taskfile.NewRootNode(e.Logger, e.Entrypoint, e.Dir, e.Insecure) + node, err := taskfile.NewRootNode(e.Logger, e.Entrypoint, e.Dir, e.Insecure, e.Timeout) if err != nil { return nil, err } diff --git a/taskfile/node.go b/taskfile/node.go index b05b92cb..4aa22fa6 100644 --- a/taskfile/node.go +++ b/taskfile/node.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/go-task/task/v3/errors" "github.com/go-task/task/v3/internal/experiments" @@ -27,6 +28,7 @@ func NewRootNode( entrypoint string, dir string, insecure bool, + timeout time.Duration, ) (Node, error) { dir = getDefaultDir(entrypoint, dir) // Check if there is something to read on STDIN @@ -34,7 +36,7 @@ func NewRootNode( if (stat.Mode()&os.ModeCharDevice) == 0 && stat.Size() > 0 { return NewStdinNode(dir) } - return NewNode(l, entrypoint, dir, insecure) + return NewNode(l, entrypoint, dir, insecure, timeout) } func NewNode( @@ -42,13 +44,14 @@ func NewNode( entrypoint string, dir string, insecure bool, + timeout time.Duration, opts ...NodeOption, ) (Node, error) { var node Node var err error switch getScheme(entrypoint) { case "http", "https": - node, err = NewHTTPNode(l, entrypoint, dir, insecure, opts...) + node, err = NewHTTPNode(l, entrypoint, dir, insecure, timeout, opts...) default: // If no other scheme matches, we assume it's a file node, err = NewFileNode(l, entrypoint, dir, opts...) diff --git a/taskfile/node_http.go b/taskfile/node_http.go index d7dce27b..5931d10f 100644 --- a/taskfile/node_http.go +++ b/taskfile/node_http.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "path/filepath" + "time" "github.com/go-task/task/v3/errors" "github.com/go-task/task/v3/internal/execext" @@ -19,7 +20,14 @@ type HTTPNode struct { URL *url.URL } -func NewHTTPNode(l *logger.Logger, entrypoint, dir string, insecure bool, opts ...NodeOption) (*HTTPNode, error) { +func NewHTTPNode( + l *logger.Logger, + entrypoint string, + dir string, + insecure bool, + timeout time.Duration, + opts ...NodeOption, +) (*HTTPNode, error) { base := NewBaseNode(dir, opts...) url, err := url.Parse(entrypoint) if err != nil { @@ -28,10 +36,15 @@ func NewHTTPNode(l *logger.Logger, entrypoint, dir string, insecure bool, opts . if url.Scheme == "http" && !insecure { return nil, &errors.TaskfileNotSecureError{URI: entrypoint} } - url, err = RemoteExists(l, url) + ctx, cf := context.WithTimeout(context.Background(), timeout) + defer cf() + url, err = RemoteExists(ctx, l, url) if err != nil { return nil, err } + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return nil, &errors.TaskfileNetworkTimeoutError{URI: url.String(), Timeout: timeout} + } return &HTTPNode{ BaseNode: base, URL: url, diff --git a/taskfile/reader.go b/taskfile/reader.go index 42e64914..af04637c 100644 --- a/taskfile/reader.go +++ b/taskfile/reader.go @@ -74,7 +74,7 @@ func Read( return err } - includeReaderNode, err := NewNode(l, entrypoint, dir, insecure, + includeReaderNode, err := NewNode(l, entrypoint, dir, insecure, timeout, WithParent(node), WithOptional(include.Optional), ) diff --git a/taskfile/taskfile.go b/taskfile/taskfile.go index 3c006dae..499a79a5 100644 --- a/taskfile/taskfile.go +++ b/taskfile/taskfile.go @@ -1,6 +1,7 @@ package taskfile import ( + "context" "net/http" "net/url" "os" @@ -43,7 +44,7 @@ var ( // at the given URL with any of the default Taskfile files names. If any of // these match a file, the first matching path will be returned. If no files are // found, an error will be returned. -func RemoteExists(l *logger.Logger, u *url.URL) (*url.URL, error) { +func RemoteExists(ctx context.Context, l *logger.Logger, u *url.URL) (*url.URL, error) { // Create a new HEAD request for the given URL to check if the resource exists req, err := http.NewRequest("HEAD", u.String(), nil) if err != nil { @@ -51,7 +52,7 @@ func RemoteExists(l *logger.Logger, u *url.URL) (*url.URL, error) { } // Request the given URL - resp, err := http.DefaultClient.Do(req) + resp, err := http.DefaultClient.Do(req.WithContext(ctx)) if err != nil { return nil, errors.TaskfileFetchFailedError{URI: u.String()} }