1
0
mirror of https://github.com/go-task/task.git synced 2025-11-25 22:32:55 +02:00

feat: remote taskfile improvements (cache/expiry) (#2176)

* feat: cache as node, RemoteNode and cache-first approach

* feat: cache expiry

* feat: pass ctx into reader methods instead of timeout

* docs: updated remote taskfiles experiment doc

* feat: use cache if download fails
This commit is contained in:
Pete Davison
2025-04-19 12:12:08 +01:00
committed by GitHub
parent f47f237093
commit a84f09d45f
18 changed files with 579 additions and 353 deletions

View File

@@ -1,72 +0,0 @@
package taskfile
import (
"crypto/sha256"
"fmt"
"os"
"path/filepath"
"strings"
)
type Cache struct {
dir string
}
func NewCache(dir string) (*Cache, error) {
dir = filepath.Join(dir, "remote")
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
return &Cache{
dir: dir,
}, nil
}
func checksum(b []byte) string {
h := sha256.New()
h.Write(b)
return fmt.Sprintf("%x", h.Sum(nil))
}
func (c *Cache) write(node Node, b []byte) error {
return os.WriteFile(c.cacheFilePath(node), b, 0o644)
}
func (c *Cache) read(node Node) ([]byte, error) {
return os.ReadFile(c.cacheFilePath(node))
}
func (c *Cache) writeChecksum(node Node, checksum string) error {
return os.WriteFile(c.checksumFilePath(node), []byte(checksum), 0o644)
}
func (c *Cache) readChecksum(node Node) string {
b, _ := os.ReadFile(c.checksumFilePath(node))
return string(b)
}
func (c *Cache) key(node Node) string {
return strings.TrimRight(checksum([]byte(node.Location())), "=")
}
func (c *Cache) cacheFilePath(node Node) string {
return c.filePath(node, "yaml")
}
func (c *Cache) checksumFilePath(node Node) string {
return c.filePath(node, "checksum")
}
func (c *Cache) filePath(node Node, suffix string) string {
lastDir, filename := node.FilenameAndLastDir()
prefix := filename
// Means it's not "", nor "." nor "/", so it's a valid directory
if len(lastDir) > 1 {
prefix = fmt.Sprintf("%s-%s", lastDir, filename)
}
return filepath.Join(c.dir, fmt.Sprintf("%s.%s.%s", prefix, c.key(node), suffix))
}
func (c *Cache) Clear() error {
return os.RemoveAll(c.dir)
}

View File

@@ -14,14 +14,18 @@ import (
)
type Node interface {
Read(ctx context.Context) ([]byte, error)
Read() ([]byte, error)
Parent() Node
Location() string
Dir() string
Remote() bool
ResolveEntrypoint(entrypoint string) (string, error)
ResolveDir(dir string) (string, error)
FilenameAndLastDir() (string, string)
}
type RemoteNode interface {
Node
ReadContext(ctx context.Context) ([]byte, error)
CacheKey() string
}
func NewRootNode(
@@ -35,35 +39,35 @@ func NewRootNode(
if entrypoint == "-" {
return NewStdinNode(dir)
}
return NewNode(entrypoint, dir, insecure, timeout)
return NewNode(entrypoint, dir, insecure)
}
func NewNode(
entrypoint string,
dir string,
insecure bool,
timeout time.Duration,
opts ...NodeOption,
) (Node, error) {
var node Node
var err error
scheme, err := getScheme(entrypoint)
if err != nil {
return nil, err
}
switch scheme {
case "git":
node, err = NewGitNode(entrypoint, dir, insecure, opts...)
case "http", "https":
node, err = NewHTTPNode(entrypoint, dir, insecure, timeout, opts...)
node, err = NewHTTPNode(entrypoint, dir, insecure, opts...)
default:
node, err = NewFileNode(entrypoint, dir, opts...)
}
if node.Remote() && !experiments.RemoteTaskfiles.Enabled() {
if _, isRemote := node.(RemoteNode); isRemote && !experiments.RemoteTaskfiles.Enabled() {
return nil, errors.New("task: Remote taskfiles are not enabled. You can read more about this experiment and how to enable it at https://taskfile.dev/experiments/remote-taskfiles")
}
return node, err
}
@@ -72,6 +76,7 @@ func getScheme(uri string) (string, error) {
if u == nil {
return "", err
}
if strings.HasSuffix(strings.Split(u.Path, "//")[0], ".git") && (u.Scheme == "git" || u.Scheme == "ssh" || u.Scheme == "https" || u.Scheme == "http") {
return "git", nil
}
@@ -79,6 +84,7 @@ func getScheme(uri string) (string, error) {
if i := strings.Index(uri, "://"); i != -1 {
return uri[:i], nil
}
return "", nil
}

113
taskfile/node_cache.go Normal file
View File

@@ -0,0 +1,113 @@
package taskfile
import (
"crypto/sha256"
"fmt"
"os"
"path/filepath"
"time"
)
const remoteCacheDir = "remote"
type CacheNode struct {
*BaseNode
source RemoteNode
}
func NewCacheNode(source RemoteNode, dir string) *CacheNode {
return &CacheNode{
BaseNode: &BaseNode{
dir: filepath.Join(dir, remoteCacheDir),
},
source: source,
}
}
func (node *CacheNode) Read() ([]byte, error) {
return os.ReadFile(node.Location())
}
func (node *CacheNode) Write(data []byte) error {
if err := node.CreateCacheDir(); err != nil {
return err
}
return os.WriteFile(node.Location(), data, 0o644)
}
func (node *CacheNode) ReadTimestamp() time.Time {
b, err := os.ReadFile(node.timestampPath())
if err != nil {
return time.Time{}.UTC()
}
timestamp, err := time.Parse(time.RFC3339, string(b))
if err != nil {
return time.Time{}.UTC()
}
return timestamp.UTC()
}
func (node *CacheNode) WriteTimestamp(t time.Time) error {
if err := node.CreateCacheDir(); err != nil {
return err
}
return os.WriteFile(node.timestampPath(), []byte(t.Format(time.RFC3339)), 0o644)
}
func (node *CacheNode) ReadChecksum() string {
b, _ := os.ReadFile(node.checksumPath())
return string(b)
}
func (node *CacheNode) WriteChecksum(checksum string) error {
if err := node.CreateCacheDir(); err != nil {
return err
}
return os.WriteFile(node.checksumPath(), []byte(checksum), 0o644)
}
func (node *CacheNode) CreateCacheDir() error {
if err := os.MkdirAll(node.dir, 0o755); err != nil {
return err
}
return nil
}
func (node *CacheNode) ChecksumPrompt(checksum string) string {
cachedChecksum := node.ReadChecksum()
switch {
// If the checksum doesn't exist, prompt the user to continue
case cachedChecksum == "":
return taskfileUntrustedPrompt
// If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue
case cachedChecksum != checksum:
return taskfileChangedPrompt
default:
return ""
}
}
func (node *CacheNode) Location() string {
return node.filePath("yaml")
}
func (node *CacheNode) checksumPath() string {
return node.filePath("checksum")
}
func (node *CacheNode) timestampPath() string {
return node.filePath("timestamp")
}
func (node *CacheNode) filePath(suffix string) string {
return filepath.Join(node.dir, fmt.Sprintf("%s.%s", node.source.CacheKey(), suffix))
}
func checksum(b []byte) string {
h := sha256.New()
h.Write(b)
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -1,7 +1,6 @@
package taskfile
import (
"context"
"io"
"os"
"path/filepath"
@@ -34,11 +33,7 @@ func (node *FileNode) Location() string {
return node.Entrypoint
}
func (node *FileNode) Remote() bool {
return false
}
func (node *FileNode) Read(ctx context.Context) ([]byte, error) {
func (node *FileNode) Read() ([]byte, error) {
f, err := os.Open(node.Location())
if err != nil {
return nil, err
@@ -114,7 +109,3 @@ func (node *FileNode) ResolveDir(dir string) (string, error) {
entrypointDir := filepath.Dir(node.Entrypoint)
return filepathext.SmartJoin(entrypointDir, path), nil
}
func (node *FileNode) FilenameAndLastDir() (string, string) {
return "", filepath.Base(node.Entrypoint)
}

View File

@@ -71,7 +71,11 @@ func (node *GitNode) Remote() bool {
return true
}
func (node *GitNode) Read(_ context.Context) ([]byte, error) {
func (node *GitNode) Read() ([]byte, error) {
return node.ReadContext(context.Background())
}
func (node *GitNode) ReadContext(_ context.Context) ([]byte, error) {
fs := memfs.New()
storer := memory.NewStorage()
_, err := git.Clone(storer, fs, &git.CloneOptions{
@@ -121,6 +125,13 @@ func (node *GitNode) ResolveDir(dir string) (string, error) {
return filepathext.SmartJoin(entrypointDir, path), nil
}
func (node *GitNode) FilenameAndLastDir() (string, string) {
return filepath.Base(node.path), filepath.Base(filepath.Dir(node.path))
func (node *GitNode) CacheKey() string {
checksum := strings.TrimRight(checksum([]byte(node.Location())), "=")
prefix := filepath.Base(filepath.Dir(node.path))
lastDir := filepath.Base(node.path)
// Means it's not "", nor "." nor "/", so it's a valid directory
if len(lastDir) > 1 {
prefix = fmt.Sprintf("%s-%s", lastDir, prefix)
}
return fmt.Sprintf("%s.%s", prefix, checksum)
}

View File

@@ -62,24 +62,21 @@ func TestGitNode_httpsWithDir(t *testing.T) {
assert.Equal(t, "https://github.com/foo/bar.git//directory/common.yml?ref=main", entrypoint)
}
func TestGitNode_FilenameAndDir(t *testing.T) {
func TestGitNode_CacheKey(t *testing.T) {
t.Parallel()
node, err := NewGitNode("https://github.com/foo/bar.git//directory/Taskfile.yml?ref=main", "", false)
assert.NoError(t, err)
filename, dir := node.FilenameAndLastDir()
assert.Equal(t, "Taskfile.yml", filename)
assert.Equal(t, "directory", dir)
key := node.CacheKey()
assert.Equal(t, "Taskfile.yml-directory.f1ddddac425a538870230a3e38fc0cded4ec5da250797b6cab62c82477718fbb", key)
node, err = NewGitNode("https://github.com/foo/bar.git//Taskfile.yml?ref=main", "", false)
assert.NoError(t, err)
filename, dir = node.FilenameAndLastDir()
assert.Equal(t, "Taskfile.yml", filename)
assert.Equal(t, ".", dir)
key = node.CacheKey()
assert.Equal(t, "Taskfile.yml-..39d28c1ff36f973705ae188b991258bbabaffd6d60bcdde9693d157d00d5e3a4", key)
node, err = NewGitNode("https://github.com/foo/bar.git//multiple/directory/Taskfile.yml?ref=main", "", false)
assert.NoError(t, err)
filename, dir = node.FilenameAndLastDir()
assert.Equal(t, "Taskfile.yml", filename)
assert.Equal(t, "directory", dir)
key = node.CacheKey()
assert.Equal(t, "Taskfile.yml-directory.1b6d145e01406dcc6c0aa572e5a5d1333be1ccf2cae96d18296d725d86197d31", key)
}

View File

@@ -2,11 +2,12 @@ package taskfile
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"path/filepath"
"time"
"strings"
"github.com/go-task/task/v3/errors"
"github.com/go-task/task/v3/internal/execext"
@@ -18,14 +19,12 @@ type HTTPNode struct {
*BaseNode
URL *url.URL // stores url pointing actual remote file. (e.g. with Taskfile.yml)
entrypoint string // stores entrypoint url. used for building graph vertices.
timeout time.Duration
}
func NewHTTPNode(
entrypoint string,
dir string,
insecure bool,
timeout time.Duration,
opts ...NodeOption,
) (*HTTPNode, error) {
base := NewBaseNode(dir, opts...)
@@ -41,7 +40,6 @@ func NewHTTPNode(
BaseNode: base,
URL: url,
entrypoint: entrypoint,
timeout: timeout,
}, nil
}
@@ -49,12 +47,12 @@ func (node *HTTPNode) Location() string {
return node.entrypoint
}
func (node *HTTPNode) Remote() bool {
return true
func (node *HTTPNode) Read() ([]byte, error) {
return node.ReadContext(context.Background())
}
func (node *HTTPNode) Read(ctx context.Context) ([]byte, error) {
url, err := RemoteExists(ctx, node.URL, node.timeout)
func (node *HTTPNode) ReadContext(ctx context.Context) ([]byte, error) {
url, err := RemoteExists(ctx, node.URL)
if err != nil {
return nil, err
}
@@ -66,8 +64,8 @@ func (node *HTTPNode) Read(ctx context.Context) ([]byte, error) {
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.URL.String(), Timeout: node.timeout}
if ctx.Err() != nil {
return nil, err
}
return nil, errors.TaskfileFetchFailedError{URI: node.URL.String()}
}
@@ -116,7 +114,14 @@ func (node *HTTPNode) ResolveDir(dir string) (string, error) {
return filepathext.SmartJoin(parent, path), nil
}
func (node *HTTPNode) FilenameAndLastDir() (string, string) {
func (node *HTTPNode) CacheKey() string {
checksum := strings.TrimRight(checksum([]byte(node.Location())), "=")
dir, filename := filepath.Split(node.entrypoint)
return filepath.Base(dir), filename
lastDir := filepath.Base(dir)
prefix := filename
// Means it's not "", nor "." nor "/", so it's a valid directory
if len(lastDir) > 1 {
prefix = fmt.Sprintf("%s-%s", lastDir, filename)
}
return fmt.Sprintf("%s.%s", prefix, checksum)
}

View File

@@ -2,7 +2,6 @@ package taskfile
import (
"bufio"
"context"
"fmt"
"os"
"strings"
@@ -30,7 +29,7 @@ func (node *StdinNode) Remote() bool {
return false
}
func (node *StdinNode) Read(ctx context.Context) ([]byte, error) {
func (node *StdinNode) Read() ([]byte, error) {
var stdin []byte
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
@@ -72,7 +71,3 @@ func (node *StdinNode) ResolveDir(dir string) (string, error) {
return filepathext.SmartJoin(node.Dir(), path), nil
}
func (node *StdinNode) FilenameAndLastDir() (string, string) {
return "", "__stdin__"
}

View File

@@ -39,15 +39,15 @@ type (
// A Reader will recursively read Taskfiles from a given [Node] and build a
// [ast.TaskfileGraph] from them.
Reader struct {
graph *ast.TaskfileGraph
insecure bool
download bool
offline bool
timeout time.Duration
tempDir string
debugFunc DebugFunc
promptFunc PromptFunc
promptMutex sync.Mutex
graph *ast.TaskfileGraph
insecure bool
download bool
offline bool
tempDir string
cacheExpiryDuration time.Duration
debugFunc DebugFunc
promptFunc PromptFunc
promptMutex sync.Mutex
}
)
@@ -55,15 +55,15 @@ type (
// options.
func NewReader(opts ...ReaderOption) *Reader {
r := &Reader{
graph: ast.NewTaskfileGraph(),
insecure: false,
download: false,
offline: false,
timeout: time.Second * 10,
tempDir: os.TempDir(),
debugFunc: nil,
promptFunc: nil,
promptMutex: sync.Mutex{},
graph: ast.NewTaskfileGraph(),
insecure: false,
download: false,
offline: false,
tempDir: os.TempDir(),
cacheExpiryDuration: 0,
debugFunc: nil,
promptFunc: nil,
promptMutex: sync.Mutex{},
}
r.Options(opts...)
return r
@@ -119,20 +119,6 @@ func (o *offlineOption) ApplyToReader(r *Reader) {
r.offline = o.offline
}
// WithTimeout sets the [Reader]'s timeout for fetching remote taskfiles. By
// default, the timeout is set to 10 seconds.
func WithTimeout(timeout time.Duration) ReaderOption {
return &timeoutOption{timeout: timeout}
}
type timeoutOption struct {
timeout time.Duration
}
func (o *timeoutOption) ApplyToReader(r *Reader) {
r.timeout = o.timeout
}
// WithTempDir sets the temporary directory that will be used by the [Reader].
// By default, the reader uses [os.TempDir].
func WithTempDir(tempDir string) ReaderOption {
@@ -147,6 +133,20 @@ func (o *tempDirOption) ApplyToReader(r *Reader) {
r.tempDir = o.tempDir
}
// WithCacheExpiryDuration sets the duration after which the cache is considered
// expired. By default, the cache is considered expired after 24 hours.
func WithCacheExpiryDuration(duration time.Duration) ReaderOption {
return &cacheExpiryDurationOption{duration: duration}
}
type cacheExpiryDurationOption struct {
duration time.Duration
}
func (o *cacheExpiryDurationOption) ApplyToReader(r *Reader) {
r.cacheExpiryDuration = o.duration
}
// WithDebugFunc sets the debug function to be used by the [Reader]. If set,
// this function will be called with debug messages. This can be useful if the
// caller wants to log debug messages from the [Reader]. By default, no debug
@@ -186,8 +186,8 @@ func (o *promptFuncOption) ApplyToReader(r *Reader) {
// through any [ast.Includes] it finds, reading each included Taskfile and
// building an [ast.TaskfileGraph] as it goes. If any errors occur, they will be
// returned immediately.
func (r *Reader) Read(node Node) (*ast.TaskfileGraph, error) {
if err := r.include(node); err != nil {
func (r *Reader) Read(ctx context.Context, node Node) (*ast.TaskfileGraph, error) {
if err := r.include(ctx, node); err != nil {
return nil, err
}
return r.graph, nil
@@ -206,7 +206,7 @@ func (r *Reader) promptf(format string, a ...any) error {
return nil
}
func (r *Reader) include(node Node) error {
func (r *Reader) include(ctx context.Context, node Node) error {
// Create a new vertex for the Taskfile
vertex := &ast.TaskfileVertex{
URI: node.Location(),
@@ -224,7 +224,7 @@ func (r *Reader) include(node Node) error {
// Read and parse the Taskfile from the file and add it to the vertex
var err error
vertex.Taskfile, err = r.readNode(node)
vertex.Taskfile, err = r.readNode(ctx, node)
if err != nil {
return err
}
@@ -265,7 +265,7 @@ func (r *Reader) include(node Node) error {
return err
}
includeNode, err := NewNode(entrypoint, include.Dir, r.insecure, r.timeout,
includeNode, err := NewNode(entrypoint, include.Dir, r.insecure,
WithParent(node),
)
if err != nil {
@@ -276,7 +276,7 @@ func (r *Reader) include(node Node) error {
}
// Recurse into the included Taskfile
if err := r.include(includeNode); err != nil {
if err := r.include(ctx, includeNode); err != nil {
return err
}
@@ -316,8 +316,8 @@ func (r *Reader) include(node Node) error {
return g.Wait()
}
func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {
b, err := r.loadNodeContent(node)
func (r *Reader) readNode(ctx context.Context, node Node) (*ast.Taskfile, error) {
b, err := r.readNodeContent(ctx, node)
if err != nil {
return nil, err
}
@@ -358,72 +358,79 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {
return &tf, nil
}
func (r *Reader) loadNodeContent(node Node) ([]byte, error) {
if !node.Remote() {
ctx, cf := context.WithTimeout(context.Background(), r.timeout)
defer cf()
return node.Read(ctx)
func (r *Reader) readNodeContent(ctx context.Context, node Node) ([]byte, error) {
if node, isRemote := node.(RemoteNode); isRemote {
return r.readRemoteNodeContent(ctx, node)
}
return node.Read()
}
func (r *Reader) readRemoteNodeContent(ctx context.Context, node RemoteNode) ([]byte, error) {
cache := NewCacheNode(node, r.tempDir)
now := time.Now().UTC()
timestamp := cache.ReadTimestamp()
expiry := timestamp.Add(r.cacheExpiryDuration)
cacheValid := now.Before(expiry)
var cacheFound bool
r.debugf("checking cache for %q in %q\n", node.Location(), cache.Location())
cachedBytes, err := cache.Read()
switch {
// If the cache doesn't exist, we need to download the file
case errors.Is(err, os.ErrNotExist):
r.debugf("no cache found\n")
// If we couldn't find a cached copy, and we are offline, we can't do anything
if r.offline {
return nil, &errors.TaskfileCacheNotFoundError{
URI: node.Location(),
}
}
// If the cache is expired
case !cacheValid:
r.debugf("cache expired at %s\n", expiry.Format(time.RFC3339))
cacheFound = true
// If we can't fetch a fresh copy, we should use the cache anyway
if r.offline {
r.debugf("in offline mode, using expired cache\n")
return cachedBytes, nil
}
// Some other error
case err != nil:
return nil, err
// Found valid cache
default:
r.debugf("cache found\n")
// Not being forced to redownload, return cache
if !r.download {
return cachedBytes, nil
}
cacheFound = true
}
cache, err := NewCache(r.tempDir)
// Try to read the remote file
r.debugf("downloading remote file: %s\n", node.Location())
downloadedBytes, err := node.ReadContext(ctx)
if err != nil {
// If the context timed out or was cancelled, but we found a cached version, use that
if ctx.Err() != nil && cacheFound {
if cacheValid {
r.debugf("failed to fetch remote file: %s: using cache\n", ctx.Err().Error())
} else {
r.debugf("failed to fetch remote file: %s: using expired cache\n", ctx.Err().Error())
}
return cachedBytes, nil
}
return nil, err
}
if r.offline {
// In offline mode try to use cached copy
cached, err := cache.read(node)
if errors.Is(err, os.ErrNotExist) {
return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()}
} else if err != nil {
return nil, err
}
r.debugf("task: [%s] Fetched cached copy\n", node.Location())
return cached, nil
}
ctx, cf := context.WithTimeout(context.Background(), r.timeout)
defer cf()
b, err := node.Read(ctx)
if errors.Is(err, &errors.TaskfileNetworkTimeoutError{}) {
// If we timed out then we likely have a network issue
// If a download was requested, then we can't use a cached copy
if r.download {
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout}
}
// Search for any cached copies
cached, err := cache.read(node)
if errors.Is(err, os.ErrNotExist) {
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true}
} else if err != nil {
return nil, err
}
r.debugf("task: [%s] Network timeout. Fetched cached copy\n", node.Location())
return cached, nil
} else if err != nil {
return nil, err
}
r.debugf("task: [%s] Fetched remote copy\n", node.Location())
// Get the checksums
checksum := checksum(b)
cachedChecksum := cache.readChecksum(node)
var prompt string
if cachedChecksum == "" {
// If the checksum doesn't exist, prompt the user to continue
prompt = taskfileUntrustedPrompt
} else if checksum != cachedChecksum {
// If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue
prompt = taskfileChangedPrompt
}
r.debugf("found remote file at %q\n", node.Location())
checksum := checksum(downloadedBytes)
prompt := cache.ChecksumPrompt(checksum)
// Prompt the user if required
if prompt != "" {
if err := func() error {
r.promptMutex.Lock()
@@ -432,18 +439,23 @@ func (r *Reader) loadNodeContent(node Node) ([]byte, error) {
}(); err != nil {
return nil, &errors.TaskfileNotTrustedError{URI: node.Location()}
}
// Store the checksum
if err := cache.writeChecksum(node, checksum); err != nil {
return nil, err
}
// Cache the file
r.debugf("task: [%s] Caching downloaded file\n", node.Location())
if err = cache.write(node, b); err != nil {
return nil, err
}
}
return b, nil
// Store the checksum
if err := cache.WriteChecksum(checksum); err != nil {
return nil, err
}
// Store the timestamp
if err := cache.WriteTimestamp(now); err != nil {
return nil, err
}
// Cache the file
r.debugf("caching %q to %q\n", node.Location(), cache.Location())
if err = cache.Write(downloadedBytes); err != nil {
return nil, err
}
return downloadedBytes, nil
}

View File

@@ -2,13 +2,13 @@ package taskfile
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"time"
"github.com/go-task/task/v3/errors"
"github.com/go-task/task/v3/internal/filepathext"
@@ -40,7 +40,7 @@ var (
// at the given URL with any of the default Taskfile files names. If any of
// these match a file, the first matching path will be returned. If no files are
// found, an error will be returned.
func RemoteExists(ctx context.Context, u *url.URL, timeout time.Duration) (*url.URL, error) {
func RemoteExists(ctx context.Context, u *url.URL) (*url.URL, error) {
// Create a new HEAD request for the given URL to check if the resource exists
req, err := http.NewRequestWithContext(ctx, "HEAD", u.String(), nil)
if err != nil {
@@ -50,8 +50,8 @@ func RemoteExists(ctx context.Context, u *url.URL, timeout time.Duration) (*url.
// Request the given URL
resp, err := http.DefaultClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return nil, &errors.TaskfileNetworkTimeoutError{URI: u.String(), Timeout: timeout}
if ctx.Err() != nil {
return nil, fmt.Errorf("checking remote file: %w", ctx.Err())
}
return nil, errors.TaskfileFetchFailedError{URI: u.String()}
}