mirror of
https://github.com/go-task/task.git
synced 2025-11-25 22:32:55 +02:00
feat: remote taskfile improvements (cache/expiry) (#2176)
* feat: cache as node, RemoteNode and cache-first approach * feat: cache expiry * feat: pass ctx into reader methods instead of timeout * docs: updated remote taskfiles experiment doc * feat: use cache if download fails
This commit is contained in:
@@ -1,72 +0,0 @@
|
||||
package taskfile
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
func NewCache(dir string) (*Cache, error) {
|
||||
dir = filepath.Join(dir, "remote")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Cache{
|
||||
dir: dir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func checksum(b []byte) string {
|
||||
h := sha256.New()
|
||||
h.Write(b)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
func (c *Cache) write(node Node, b []byte) error {
|
||||
return os.WriteFile(c.cacheFilePath(node), b, 0o644)
|
||||
}
|
||||
|
||||
func (c *Cache) read(node Node) ([]byte, error) {
|
||||
return os.ReadFile(c.cacheFilePath(node))
|
||||
}
|
||||
|
||||
func (c *Cache) writeChecksum(node Node, checksum string) error {
|
||||
return os.WriteFile(c.checksumFilePath(node), []byte(checksum), 0o644)
|
||||
}
|
||||
|
||||
func (c *Cache) readChecksum(node Node) string {
|
||||
b, _ := os.ReadFile(c.checksumFilePath(node))
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (c *Cache) key(node Node) string {
|
||||
return strings.TrimRight(checksum([]byte(node.Location())), "=")
|
||||
}
|
||||
|
||||
func (c *Cache) cacheFilePath(node Node) string {
|
||||
return c.filePath(node, "yaml")
|
||||
}
|
||||
|
||||
func (c *Cache) checksumFilePath(node Node) string {
|
||||
return c.filePath(node, "checksum")
|
||||
}
|
||||
|
||||
func (c *Cache) filePath(node Node, suffix string) string {
|
||||
lastDir, filename := node.FilenameAndLastDir()
|
||||
prefix := filename
|
||||
// Means it's not "", nor "." nor "/", so it's a valid directory
|
||||
if len(lastDir) > 1 {
|
||||
prefix = fmt.Sprintf("%s-%s", lastDir, filename)
|
||||
}
|
||||
return filepath.Join(c.dir, fmt.Sprintf("%s.%s.%s", prefix, c.key(node), suffix))
|
||||
}
|
||||
|
||||
func (c *Cache) Clear() error {
|
||||
return os.RemoveAll(c.dir)
|
||||
}
|
||||
@@ -14,14 +14,18 @@ import (
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
Read(ctx context.Context) ([]byte, error)
|
||||
Read() ([]byte, error)
|
||||
Parent() Node
|
||||
Location() string
|
||||
Dir() string
|
||||
Remote() bool
|
||||
ResolveEntrypoint(entrypoint string) (string, error)
|
||||
ResolveDir(dir string) (string, error)
|
||||
FilenameAndLastDir() (string, string)
|
||||
}
|
||||
|
||||
type RemoteNode interface {
|
||||
Node
|
||||
ReadContext(ctx context.Context) ([]byte, error)
|
||||
CacheKey() string
|
||||
}
|
||||
|
||||
func NewRootNode(
|
||||
@@ -35,35 +39,35 @@ func NewRootNode(
|
||||
if entrypoint == "-" {
|
||||
return NewStdinNode(dir)
|
||||
}
|
||||
return NewNode(entrypoint, dir, insecure, timeout)
|
||||
return NewNode(entrypoint, dir, insecure)
|
||||
}
|
||||
|
||||
func NewNode(
|
||||
entrypoint string,
|
||||
dir string,
|
||||
insecure bool,
|
||||
timeout time.Duration,
|
||||
opts ...NodeOption,
|
||||
) (Node, error) {
|
||||
var node Node
|
||||
var err error
|
||||
|
||||
scheme, err := getScheme(entrypoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch scheme {
|
||||
case "git":
|
||||
node, err = NewGitNode(entrypoint, dir, insecure, opts...)
|
||||
case "http", "https":
|
||||
node, err = NewHTTPNode(entrypoint, dir, insecure, timeout, opts...)
|
||||
node, err = NewHTTPNode(entrypoint, dir, insecure, opts...)
|
||||
default:
|
||||
node, err = NewFileNode(entrypoint, dir, opts...)
|
||||
|
||||
}
|
||||
|
||||
if node.Remote() && !experiments.RemoteTaskfiles.Enabled() {
|
||||
if _, isRemote := node.(RemoteNode); isRemote && !experiments.RemoteTaskfiles.Enabled() {
|
||||
return nil, errors.New("task: Remote taskfiles are not enabled. You can read more about this experiment and how to enable it at https://taskfile.dev/experiments/remote-taskfiles")
|
||||
}
|
||||
|
||||
return node, err
|
||||
}
|
||||
|
||||
@@ -72,6 +76,7 @@ func getScheme(uri string) (string, error) {
|
||||
if u == nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.Split(u.Path, "//")[0], ".git") && (u.Scheme == "git" || u.Scheme == "ssh" || u.Scheme == "https" || u.Scheme == "http") {
|
||||
return "git", nil
|
||||
}
|
||||
@@ -79,6 +84,7 @@ func getScheme(uri string) (string, error) {
|
||||
if i := strings.Index(uri, "://"); i != -1 {
|
||||
return uri[:i], nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
|
||||
113
taskfile/node_cache.go
Normal file
113
taskfile/node_cache.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package taskfile
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
const remoteCacheDir = "remote"
|
||||
|
||||
type CacheNode struct {
|
||||
*BaseNode
|
||||
source RemoteNode
|
||||
}
|
||||
|
||||
func NewCacheNode(source RemoteNode, dir string) *CacheNode {
|
||||
return &CacheNode{
|
||||
BaseNode: &BaseNode{
|
||||
dir: filepath.Join(dir, remoteCacheDir),
|
||||
},
|
||||
source: source,
|
||||
}
|
||||
}
|
||||
|
||||
func (node *CacheNode) Read() ([]byte, error) {
|
||||
return os.ReadFile(node.Location())
|
||||
}
|
||||
|
||||
func (node *CacheNode) Write(data []byte) error {
|
||||
if err := node.CreateCacheDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(node.Location(), data, 0o644)
|
||||
}
|
||||
|
||||
func (node *CacheNode) ReadTimestamp() time.Time {
|
||||
b, err := os.ReadFile(node.timestampPath())
|
||||
if err != nil {
|
||||
return time.Time{}.UTC()
|
||||
}
|
||||
timestamp, err := time.Parse(time.RFC3339, string(b))
|
||||
if err != nil {
|
||||
return time.Time{}.UTC()
|
||||
}
|
||||
return timestamp.UTC()
|
||||
}
|
||||
|
||||
func (node *CacheNode) WriteTimestamp(t time.Time) error {
|
||||
if err := node.CreateCacheDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(node.timestampPath(), []byte(t.Format(time.RFC3339)), 0o644)
|
||||
}
|
||||
|
||||
func (node *CacheNode) ReadChecksum() string {
|
||||
b, _ := os.ReadFile(node.checksumPath())
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (node *CacheNode) WriteChecksum(checksum string) error {
|
||||
if err := node.CreateCacheDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(node.checksumPath(), []byte(checksum), 0o644)
|
||||
}
|
||||
|
||||
func (node *CacheNode) CreateCacheDir() error {
|
||||
if err := os.MkdirAll(node.dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (node *CacheNode) ChecksumPrompt(checksum string) string {
|
||||
cachedChecksum := node.ReadChecksum()
|
||||
switch {
|
||||
|
||||
// If the checksum doesn't exist, prompt the user to continue
|
||||
case cachedChecksum == "":
|
||||
return taskfileUntrustedPrompt
|
||||
|
||||
// If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue
|
||||
case cachedChecksum != checksum:
|
||||
return taskfileChangedPrompt
|
||||
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (node *CacheNode) Location() string {
|
||||
return node.filePath("yaml")
|
||||
}
|
||||
|
||||
func (node *CacheNode) checksumPath() string {
|
||||
return node.filePath("checksum")
|
||||
}
|
||||
|
||||
func (node *CacheNode) timestampPath() string {
|
||||
return node.filePath("timestamp")
|
||||
}
|
||||
|
||||
func (node *CacheNode) filePath(suffix string) string {
|
||||
return filepath.Join(node.dir, fmt.Sprintf("%s.%s", node.source.CacheKey(), suffix))
|
||||
}
|
||||
|
||||
func checksum(b []byte) string {
|
||||
h := sha256.New()
|
||||
h.Write(b)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package taskfile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -34,11 +33,7 @@ func (node *FileNode) Location() string {
|
||||
return node.Entrypoint
|
||||
}
|
||||
|
||||
func (node *FileNode) Remote() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (node *FileNode) Read(ctx context.Context) ([]byte, error) {
|
||||
func (node *FileNode) Read() ([]byte, error) {
|
||||
f, err := os.Open(node.Location())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -114,7 +109,3 @@ func (node *FileNode) ResolveDir(dir string) (string, error) {
|
||||
entrypointDir := filepath.Dir(node.Entrypoint)
|
||||
return filepathext.SmartJoin(entrypointDir, path), nil
|
||||
}
|
||||
|
||||
func (node *FileNode) FilenameAndLastDir() (string, string) {
|
||||
return "", filepath.Base(node.Entrypoint)
|
||||
}
|
||||
|
||||
@@ -71,7 +71,11 @@ func (node *GitNode) Remote() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (node *GitNode) Read(_ context.Context) ([]byte, error) {
|
||||
func (node *GitNode) Read() ([]byte, error) {
|
||||
return node.ReadContext(context.Background())
|
||||
}
|
||||
|
||||
func (node *GitNode) ReadContext(_ context.Context) ([]byte, error) {
|
||||
fs := memfs.New()
|
||||
storer := memory.NewStorage()
|
||||
_, err := git.Clone(storer, fs, &git.CloneOptions{
|
||||
@@ -121,6 +125,13 @@ func (node *GitNode) ResolveDir(dir string) (string, error) {
|
||||
return filepathext.SmartJoin(entrypointDir, path), nil
|
||||
}
|
||||
|
||||
func (node *GitNode) FilenameAndLastDir() (string, string) {
|
||||
return filepath.Base(node.path), filepath.Base(filepath.Dir(node.path))
|
||||
func (node *GitNode) CacheKey() string {
|
||||
checksum := strings.TrimRight(checksum([]byte(node.Location())), "=")
|
||||
prefix := filepath.Base(filepath.Dir(node.path))
|
||||
lastDir := filepath.Base(node.path)
|
||||
// Means it's not "", nor "." nor "/", so it's a valid directory
|
||||
if len(lastDir) > 1 {
|
||||
prefix = fmt.Sprintf("%s-%s", lastDir, prefix)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", prefix, checksum)
|
||||
}
|
||||
|
||||
@@ -62,24 +62,21 @@ func TestGitNode_httpsWithDir(t *testing.T) {
|
||||
assert.Equal(t, "https://github.com/foo/bar.git//directory/common.yml?ref=main", entrypoint)
|
||||
}
|
||||
|
||||
func TestGitNode_FilenameAndDir(t *testing.T) {
|
||||
func TestGitNode_CacheKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
node, err := NewGitNode("https://github.com/foo/bar.git//directory/Taskfile.yml?ref=main", "", false)
|
||||
assert.NoError(t, err)
|
||||
filename, dir := node.FilenameAndLastDir()
|
||||
assert.Equal(t, "Taskfile.yml", filename)
|
||||
assert.Equal(t, "directory", dir)
|
||||
key := node.CacheKey()
|
||||
assert.Equal(t, "Taskfile.yml-directory.f1ddddac425a538870230a3e38fc0cded4ec5da250797b6cab62c82477718fbb", key)
|
||||
|
||||
node, err = NewGitNode("https://github.com/foo/bar.git//Taskfile.yml?ref=main", "", false)
|
||||
assert.NoError(t, err)
|
||||
filename, dir = node.FilenameAndLastDir()
|
||||
assert.Equal(t, "Taskfile.yml", filename)
|
||||
assert.Equal(t, ".", dir)
|
||||
key = node.CacheKey()
|
||||
assert.Equal(t, "Taskfile.yml-..39d28c1ff36f973705ae188b991258bbabaffd6d60bcdde9693d157d00d5e3a4", key)
|
||||
|
||||
node, err = NewGitNode("https://github.com/foo/bar.git//multiple/directory/Taskfile.yml?ref=main", "", false)
|
||||
assert.NoError(t, err)
|
||||
filename, dir = node.FilenameAndLastDir()
|
||||
assert.Equal(t, "Taskfile.yml", filename)
|
||||
assert.Equal(t, "directory", dir)
|
||||
key = node.CacheKey()
|
||||
assert.Equal(t, "Taskfile.yml-directory.1b6d145e01406dcc6c0aa572e5a5d1333be1ccf2cae96d18296d725d86197d31", key)
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package taskfile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"time"
|
||||
"strings"
|
||||
|
||||
"github.com/go-task/task/v3/errors"
|
||||
"github.com/go-task/task/v3/internal/execext"
|
||||
@@ -18,14 +19,12 @@ type HTTPNode struct {
|
||||
*BaseNode
|
||||
URL *url.URL // stores url pointing actual remote file. (e.g. with Taskfile.yml)
|
||||
entrypoint string // stores entrypoint url. used for building graph vertices.
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewHTTPNode(
|
||||
entrypoint string,
|
||||
dir string,
|
||||
insecure bool,
|
||||
timeout time.Duration,
|
||||
opts ...NodeOption,
|
||||
) (*HTTPNode, error) {
|
||||
base := NewBaseNode(dir, opts...)
|
||||
@@ -41,7 +40,6 @@ func NewHTTPNode(
|
||||
BaseNode: base,
|
||||
URL: url,
|
||||
entrypoint: entrypoint,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -49,12 +47,12 @@ func (node *HTTPNode) Location() string {
|
||||
return node.entrypoint
|
||||
}
|
||||
|
||||
func (node *HTTPNode) Remote() bool {
|
||||
return true
|
||||
func (node *HTTPNode) Read() ([]byte, error) {
|
||||
return node.ReadContext(context.Background())
|
||||
}
|
||||
|
||||
func (node *HTTPNode) Read(ctx context.Context) ([]byte, error) {
|
||||
url, err := RemoteExists(ctx, node.URL, node.timeout)
|
||||
func (node *HTTPNode) ReadContext(ctx context.Context) ([]byte, error) {
|
||||
url, err := RemoteExists(ctx, node.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,8 +64,8 @@ func (node *HTTPNode) Read(ctx context.Context) ([]byte, error) {
|
||||
|
||||
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.URL.String(), Timeout: node.timeout}
|
||||
if ctx.Err() != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errors.TaskfileFetchFailedError{URI: node.URL.String()}
|
||||
}
|
||||
@@ -116,7 +114,14 @@ func (node *HTTPNode) ResolveDir(dir string) (string, error) {
|
||||
return filepathext.SmartJoin(parent, path), nil
|
||||
}
|
||||
|
||||
func (node *HTTPNode) FilenameAndLastDir() (string, string) {
|
||||
func (node *HTTPNode) CacheKey() string {
|
||||
checksum := strings.TrimRight(checksum([]byte(node.Location())), "=")
|
||||
dir, filename := filepath.Split(node.entrypoint)
|
||||
return filepath.Base(dir), filename
|
||||
lastDir := filepath.Base(dir)
|
||||
prefix := filename
|
||||
// Means it's not "", nor "." nor "/", so it's a valid directory
|
||||
if len(lastDir) > 1 {
|
||||
prefix = fmt.Sprintf("%s-%s", lastDir, filename)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", prefix, checksum)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package taskfile
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -30,7 +29,7 @@ func (node *StdinNode) Remote() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (node *StdinNode) Read(ctx context.Context) ([]byte, error) {
|
||||
func (node *StdinNode) Read() ([]byte, error) {
|
||||
var stdin []byte
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for scanner.Scan() {
|
||||
@@ -72,7 +71,3 @@ func (node *StdinNode) ResolveDir(dir string) (string, error) {
|
||||
|
||||
return filepathext.SmartJoin(node.Dir(), path), nil
|
||||
}
|
||||
|
||||
func (node *StdinNode) FilenameAndLastDir() (string, string) {
|
||||
return "", "__stdin__"
|
||||
}
|
||||
|
||||
@@ -39,15 +39,15 @@ type (
|
||||
// A Reader will recursively read Taskfiles from a given [Node] and build a
|
||||
// [ast.TaskfileGraph] from them.
|
||||
Reader struct {
|
||||
graph *ast.TaskfileGraph
|
||||
insecure bool
|
||||
download bool
|
||||
offline bool
|
||||
timeout time.Duration
|
||||
tempDir string
|
||||
debugFunc DebugFunc
|
||||
promptFunc PromptFunc
|
||||
promptMutex sync.Mutex
|
||||
graph *ast.TaskfileGraph
|
||||
insecure bool
|
||||
download bool
|
||||
offline bool
|
||||
tempDir string
|
||||
cacheExpiryDuration time.Duration
|
||||
debugFunc DebugFunc
|
||||
promptFunc PromptFunc
|
||||
promptMutex sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
@@ -55,15 +55,15 @@ type (
|
||||
// options.
|
||||
func NewReader(opts ...ReaderOption) *Reader {
|
||||
r := &Reader{
|
||||
graph: ast.NewTaskfileGraph(),
|
||||
insecure: false,
|
||||
download: false,
|
||||
offline: false,
|
||||
timeout: time.Second * 10,
|
||||
tempDir: os.TempDir(),
|
||||
debugFunc: nil,
|
||||
promptFunc: nil,
|
||||
promptMutex: sync.Mutex{},
|
||||
graph: ast.NewTaskfileGraph(),
|
||||
insecure: false,
|
||||
download: false,
|
||||
offline: false,
|
||||
tempDir: os.TempDir(),
|
||||
cacheExpiryDuration: 0,
|
||||
debugFunc: nil,
|
||||
promptFunc: nil,
|
||||
promptMutex: sync.Mutex{},
|
||||
}
|
||||
r.Options(opts...)
|
||||
return r
|
||||
@@ -119,20 +119,6 @@ func (o *offlineOption) ApplyToReader(r *Reader) {
|
||||
r.offline = o.offline
|
||||
}
|
||||
|
||||
// WithTimeout sets the [Reader]'s timeout for fetching remote taskfiles. By
|
||||
// default, the timeout is set to 10 seconds.
|
||||
func WithTimeout(timeout time.Duration) ReaderOption {
|
||||
return &timeoutOption{timeout: timeout}
|
||||
}
|
||||
|
||||
type timeoutOption struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func (o *timeoutOption) ApplyToReader(r *Reader) {
|
||||
r.timeout = o.timeout
|
||||
}
|
||||
|
||||
// WithTempDir sets the temporary directory that will be used by the [Reader].
|
||||
// By default, the reader uses [os.TempDir].
|
||||
func WithTempDir(tempDir string) ReaderOption {
|
||||
@@ -147,6 +133,20 @@ func (o *tempDirOption) ApplyToReader(r *Reader) {
|
||||
r.tempDir = o.tempDir
|
||||
}
|
||||
|
||||
// WithCacheExpiryDuration sets the duration after which the cache is considered
|
||||
// expired. By default, the cache is considered expired after 24 hours.
|
||||
func WithCacheExpiryDuration(duration time.Duration) ReaderOption {
|
||||
return &cacheExpiryDurationOption{duration: duration}
|
||||
}
|
||||
|
||||
type cacheExpiryDurationOption struct {
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func (o *cacheExpiryDurationOption) ApplyToReader(r *Reader) {
|
||||
r.cacheExpiryDuration = o.duration
|
||||
}
|
||||
|
||||
// WithDebugFunc sets the debug function to be used by the [Reader]. If set,
|
||||
// this function will be called with debug messages. This can be useful if the
|
||||
// caller wants to log debug messages from the [Reader]. By default, no debug
|
||||
@@ -186,8 +186,8 @@ func (o *promptFuncOption) ApplyToReader(r *Reader) {
|
||||
// through any [ast.Includes] it finds, reading each included Taskfile and
|
||||
// building an [ast.TaskfileGraph] as it goes. If any errors occur, they will be
|
||||
// returned immediately.
|
||||
func (r *Reader) Read(node Node) (*ast.TaskfileGraph, error) {
|
||||
if err := r.include(node); err != nil {
|
||||
func (r *Reader) Read(ctx context.Context, node Node) (*ast.TaskfileGraph, error) {
|
||||
if err := r.include(ctx, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.graph, nil
|
||||
@@ -206,7 +206,7 @@ func (r *Reader) promptf(format string, a ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) include(node Node) error {
|
||||
func (r *Reader) include(ctx context.Context, node Node) error {
|
||||
// Create a new vertex for the Taskfile
|
||||
vertex := &ast.TaskfileVertex{
|
||||
URI: node.Location(),
|
||||
@@ -224,7 +224,7 @@ func (r *Reader) include(node Node) error {
|
||||
|
||||
// Read and parse the Taskfile from the file and add it to the vertex
|
||||
var err error
|
||||
vertex.Taskfile, err = r.readNode(node)
|
||||
vertex.Taskfile, err = r.readNode(ctx, node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -265,7 +265,7 @@ func (r *Reader) include(node Node) error {
|
||||
return err
|
||||
}
|
||||
|
||||
includeNode, err := NewNode(entrypoint, include.Dir, r.insecure, r.timeout,
|
||||
includeNode, err := NewNode(entrypoint, include.Dir, r.insecure,
|
||||
WithParent(node),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -276,7 +276,7 @@ func (r *Reader) include(node Node) error {
|
||||
}
|
||||
|
||||
// Recurse into the included Taskfile
|
||||
if err := r.include(includeNode); err != nil {
|
||||
if err := r.include(ctx, includeNode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -316,8 +316,8 @@ func (r *Reader) include(node Node) error {
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {
|
||||
b, err := r.loadNodeContent(node)
|
||||
func (r *Reader) readNode(ctx context.Context, node Node) (*ast.Taskfile, error) {
|
||||
b, err := r.readNodeContent(ctx, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -358,72 +358,79 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {
|
||||
return &tf, nil
|
||||
}
|
||||
|
||||
func (r *Reader) loadNodeContent(node Node) ([]byte, error) {
|
||||
if !node.Remote() {
|
||||
ctx, cf := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cf()
|
||||
return node.Read(ctx)
|
||||
func (r *Reader) readNodeContent(ctx context.Context, node Node) ([]byte, error) {
|
||||
if node, isRemote := node.(RemoteNode); isRemote {
|
||||
return r.readRemoteNodeContent(ctx, node)
|
||||
}
|
||||
return node.Read()
|
||||
}
|
||||
|
||||
func (r *Reader) readRemoteNodeContent(ctx context.Context, node RemoteNode) ([]byte, error) {
|
||||
cache := NewCacheNode(node, r.tempDir)
|
||||
now := time.Now().UTC()
|
||||
timestamp := cache.ReadTimestamp()
|
||||
expiry := timestamp.Add(r.cacheExpiryDuration)
|
||||
cacheValid := now.Before(expiry)
|
||||
var cacheFound bool
|
||||
|
||||
r.debugf("checking cache for %q in %q\n", node.Location(), cache.Location())
|
||||
cachedBytes, err := cache.Read()
|
||||
switch {
|
||||
// If the cache doesn't exist, we need to download the file
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
r.debugf("no cache found\n")
|
||||
// If we couldn't find a cached copy, and we are offline, we can't do anything
|
||||
if r.offline {
|
||||
return nil, &errors.TaskfileCacheNotFoundError{
|
||||
URI: node.Location(),
|
||||
}
|
||||
}
|
||||
|
||||
// If the cache is expired
|
||||
case !cacheValid:
|
||||
r.debugf("cache expired at %s\n", expiry.Format(time.RFC3339))
|
||||
cacheFound = true
|
||||
// If we can't fetch a fresh copy, we should use the cache anyway
|
||||
if r.offline {
|
||||
r.debugf("in offline mode, using expired cache\n")
|
||||
return cachedBytes, nil
|
||||
}
|
||||
|
||||
// Some other error
|
||||
case err != nil:
|
||||
return nil, err
|
||||
|
||||
// Found valid cache
|
||||
default:
|
||||
r.debugf("cache found\n")
|
||||
// Not being forced to redownload, return cache
|
||||
if !r.download {
|
||||
return cachedBytes, nil
|
||||
}
|
||||
cacheFound = true
|
||||
}
|
||||
|
||||
cache, err := NewCache(r.tempDir)
|
||||
// Try to read the remote file
|
||||
r.debugf("downloading remote file: %s\n", node.Location())
|
||||
downloadedBytes, err := node.ReadContext(ctx)
|
||||
if err != nil {
|
||||
// If the context timed out or was cancelled, but we found a cached version, use that
|
||||
if ctx.Err() != nil && cacheFound {
|
||||
if cacheValid {
|
||||
r.debugf("failed to fetch remote file: %s: using cache\n", ctx.Err().Error())
|
||||
} else {
|
||||
r.debugf("failed to fetch remote file: %s: using expired cache\n", ctx.Err().Error())
|
||||
}
|
||||
return cachedBytes, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if r.offline {
|
||||
// In offline mode try to use cached copy
|
||||
cached, err := cache.read(node)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.debugf("task: [%s] Fetched cached copy\n", node.Location())
|
||||
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
ctx, cf := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cf()
|
||||
|
||||
b, err := node.Read(ctx)
|
||||
if errors.Is(err, &errors.TaskfileNetworkTimeoutError{}) {
|
||||
// If we timed out then we likely have a network issue
|
||||
|
||||
// If a download was requested, then we can't use a cached copy
|
||||
if r.download {
|
||||
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout}
|
||||
}
|
||||
|
||||
// Search for any cached copies
|
||||
cached, err := cache.read(node)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.debugf("task: [%s] Network timeout. Fetched cached copy\n", node.Location())
|
||||
|
||||
return cached, nil
|
||||
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.debugf("task: [%s] Fetched remote copy\n", node.Location())
|
||||
|
||||
// Get the checksums
|
||||
checksum := checksum(b)
|
||||
cachedChecksum := cache.readChecksum(node)
|
||||
|
||||
var prompt string
|
||||
if cachedChecksum == "" {
|
||||
// If the checksum doesn't exist, prompt the user to continue
|
||||
prompt = taskfileUntrustedPrompt
|
||||
} else if checksum != cachedChecksum {
|
||||
// If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue
|
||||
prompt = taskfileChangedPrompt
|
||||
}
|
||||
r.debugf("found remote file at %q\n", node.Location())
|
||||
checksum := checksum(downloadedBytes)
|
||||
prompt := cache.ChecksumPrompt(checksum)
|
||||
|
||||
// Prompt the user if required
|
||||
if prompt != "" {
|
||||
if err := func() error {
|
||||
r.promptMutex.Lock()
|
||||
@@ -432,18 +439,23 @@ func (r *Reader) loadNodeContent(node Node) ([]byte, error) {
|
||||
}(); err != nil {
|
||||
return nil, &errors.TaskfileNotTrustedError{URI: node.Location()}
|
||||
}
|
||||
|
||||
// Store the checksum
|
||||
if err := cache.writeChecksum(node, checksum); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the file
|
||||
r.debugf("task: [%s] Caching downloaded file\n", node.Location())
|
||||
if err = cache.write(node, b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return b, nil
|
||||
// Store the checksum
|
||||
if err := cache.WriteChecksum(checksum); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the timestamp
|
||||
if err := cache.WriteTimestamp(now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the file
|
||||
r.debugf("caching %q to %q\n", node.Location(), cache.Location())
|
||||
if err = cache.Write(downloadedBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return downloadedBytes, nil
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@ package taskfile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-task/task/v3/errors"
|
||||
"github.com/go-task/task/v3/internal/filepathext"
|
||||
@@ -40,7 +40,7 @@ var (
|
||||
// at the given URL with any of the default Taskfile files names. If any of
|
||||
// these match a file, the first matching path will be returned. If no files are
|
||||
// found, an error will be returned.
|
||||
func RemoteExists(ctx context.Context, u *url.URL, timeout time.Duration) (*url.URL, error) {
|
||||
func RemoteExists(ctx context.Context, u *url.URL) (*url.URL, error) {
|
||||
// Create a new HEAD request for the given URL to check if the resource exists
|
||||
req, err := http.NewRequestWithContext(ctx, "HEAD", u.String(), nil)
|
||||
if err != nil {
|
||||
@@ -50,8 +50,8 @@ func RemoteExists(ctx context.Context, u *url.URL, timeout time.Duration) (*url.
|
||||
// Request the given URL
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
return nil, &errors.TaskfileNetworkTimeoutError{URI: u.String(), Timeout: timeout}
|
||||
if ctx.Err() != nil {
|
||||
return nil, fmt.Errorf("checking remote file: %w", ctx.Err())
|
||||
}
|
||||
return nil, errors.TaskfileFetchFailedError{URI: u.String()}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user