diff --git a/task_test.go b/task_test.go index 3411c5eb..15d6fd2d 100644 --- a/task_test.go +++ b/task_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/fs" + rand "math/rand/v2" "net/http" "net/http/httptest" "os" @@ -1047,6 +1048,107 @@ func TestIncludesMultiLevel(t *testing.T) { tt.Run(t) } +func TestIncludesRemote(t *testing.T) { + enableExperimentForTest(t, &experiments.RemoteTaskfiles, "1") + + dir := "testdata/includes_remote" + + srv := httptest.NewServer(http.FileServer(http.Dir(dir))) + defer srv.Close() + + tcs := []struct { + firstRemote string + secondRemote string + }{ + { + firstRemote: srv.URL + "/first/Taskfile.yml", + secondRemote: srv.URL + "/first/second/Taskfile.yml", + }, + { + firstRemote: srv.URL + "/first/Taskfile.yml", + secondRemote: "./second/Taskfile.yml", + }, + } + + tasks := []string{ + "first:write-file", + "first:second:write-file", + } + + for i, tc := range tcs { + t.Run(fmt.Sprint(i), func(t *testing.T) { + t.Setenv("FIRST_REMOTE_URL", tc.firstRemote) + t.Setenv("SECOND_REMOTE_URL", tc.secondRemote) + + var buff SyncBuffer + + executors := []struct { + name string + executor *task.Executor + }{ + { + name: "online, always download", + executor: &task.Executor{ + Dir: dir, + Stdout: &buff, + Stderr: &buff, + Timeout: time.Minute, + Insecure: true, + Logger: &logger.Logger{Stdout: &buff, Stderr: &buff, Verbose: true}, + + // Without caching + AssumeYes: true, + Download: true, + }, + }, + { + name: "offline, use cache", + executor: &task.Executor{ + Dir: dir, + Stdout: &buff, + Stderr: &buff, + Timeout: time.Minute, + Insecure: true, + Logger: &logger.Logger{Stdout: &buff, Stderr: &buff, Verbose: true}, + + // With caching + AssumeYes: false, + Download: false, + Offline: true, + }, + }, + } + + for j, e := range executors { + t.Run(fmt.Sprint(j), func(t *testing.T) { + require.NoError(t, e.executor.Setup()) + + for k, task := range tasks { + t.Run(task, func(t *testing.T) { + expectedContent := fmt.Sprint(rand.Int64()) + t.Setenv("CONTENT", expectedContent) + + outputFile := fmt.Sprintf("%d.%d.txt", i, k) + t.Setenv("OUTPUT_FILE", outputFile) + + path := filepath.Join(dir, outputFile) + require.NoError(t, os.RemoveAll(path)) + + require.NoError(t, e.executor.Run(context.Background(), &ast.Call{Task: task})) + + actualContent, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, expectedContent, strings.TrimSpace(string(actualContent))) + }) + } + }) + } + + t.Log("\noutput:\n", buff.buf.String()) + }) + } +} + func TestIncludeCycle(t *testing.T) { const dir = "testdata/includes_cycle" diff --git a/taskfile/reader.go b/taskfile/reader.go index c64f35b1..69f75b1d 100644 --- a/taskfile/reader.go +++ b/taskfile/reader.go @@ -184,92 +184,9 @@ func (r *Reader) include(node Node) error { } func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { - var b []byte - var err error - var cache *Cache - - if node.Remote() { - cache, err = NewCache(r.tempDir) - if err != nil { - return nil, err - } - } - - // If the file is remote and we're in offline mode, check if we have a cached copy - if node.Remote() && r.offline { - if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) { - return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()} - } else if err != nil { - return nil, err - } - r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched cached copy\n", node.Location()) - } else { - - downloaded := false - ctx, cf := context.WithTimeout(context.Background(), r.timeout) - defer cf() - - // Read the file - b, err = node.Read(ctx) - var taskfileNetworkTimeoutError *errors.TaskfileNetworkTimeoutError - // If we timed out then we likely have a network issue - if node.Remote() && errors.As(err, &taskfileNetworkTimeoutError) { - // If a download was requested, then we can't use a cached copy - if r.download { - return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout} - } - // Search for any cached copies - if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) { - return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true} - } else if err != nil { - return nil, err - } - r.logger.VerboseOutf(logger.Magenta, "task: [%s] Network timeout. Fetched cached copy\n", node.Location()) - } else if err != nil { - return nil, err - } else { - downloaded = true - } - - // If the node was remote, we need to check the checksum - if node.Remote() && downloaded { - r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched remote copy\n", node.Location()) - - // Get the checksums - checksum := checksum(b) - cachedChecksum := cache.readChecksum(node) - - var prompt string - if cachedChecksum == "" { - // If the checksum doesn't exist, prompt the user to continue - prompt = fmt.Sprintf(taskfileUntrustedPrompt, node.Location()) - } else if checksum != cachedChecksum { - // If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue - prompt = fmt.Sprintf(taskfileChangedPrompt, node.Location()) - } - - if prompt != "" { - if err := func() error { - r.promptMutex.Lock() - defer r.promptMutex.Unlock() - return r.logger.Prompt(logger.Yellow, prompt, "n", "y", "yes") - }(); err != nil { - return nil, &errors.TaskfileNotTrustedError{URI: node.Location()} - } - } - // If the hash has changed (or is new) - if checksum != cachedChecksum { - // Store the checksum - if err := cache.writeChecksum(node, checksum); err != nil { - return nil, err - } - // Cache the file - r.logger.VerboseOutf(logger.Magenta, "task: [%s] Caching downloaded file\n", node.Location()) - if err = cache.write(node, b); err != nil { - return nil, err - } - } - } + b, err := r.loadNodeContent(node) + if err != nil { + return nil, err } var tf ast.Taskfile @@ -302,3 +219,93 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { return &tf, nil } + +func (r *Reader) loadNodeContent(node Node) ([]byte, error) { + if !node.Remote() { + ctx, cf := context.WithTimeout(context.Background(), r.timeout) + defer cf() + return node.Read(ctx) + } + + cache, err := NewCache(r.tempDir) + if err != nil { + return nil, err + } + + if r.offline { + // In offline mode try to use cached copy + cached, err := cache.read(node) + if errors.Is(err, os.ErrNotExist) { + return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()} + } else if err != nil { + return nil, err + } + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched cached copy\n", node.Location()) + + return cached, nil + } + + ctx, cf := context.WithTimeout(context.Background(), r.timeout) + defer cf() + + b, err := node.Read(ctx) + if errors.Is(err, &errors.TaskfileNetworkTimeoutError{}) { + // If we timed out then we likely have a network issue + + // If a download was requested, then we can't use a cached copy + if r.download { + return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout} + } + + // Search for any cached copies + cached, err := cache.read(node) + if errors.Is(err, os.ErrNotExist) { + return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true} + } else if err != nil { + return nil, err + } + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Network timeout. Fetched cached copy\n", node.Location()) + + return cached, nil + + } else if err != nil { + return nil, err + } + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched remote copy\n", node.Location()) + + // Get the checksums + checksum := checksum(b) + cachedChecksum := cache.readChecksum(node) + + var prompt string + if cachedChecksum == "" { + // If the checksum doesn't exist, prompt the user to continue + prompt = fmt.Sprintf(taskfileUntrustedPrompt, node.Location()) + } else if checksum != cachedChecksum { + // If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue + prompt = fmt.Sprintf(taskfileChangedPrompt, node.Location()) + } + + if prompt != "" { + if err := func() error { + r.promptMutex.Lock() + defer r.promptMutex.Unlock() + return r.logger.Prompt(logger.Yellow, prompt, "n", "y", "yes") + }(); err != nil { + return nil, &errors.TaskfileNotTrustedError{URI: node.Location()} + } + + // Store the checksum + if err := cache.writeChecksum(node, checksum); err != nil { + return nil, err + } + + // Cache the file + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Caching downloaded file\n", node.Location()) + if err = cache.write(node, b); err != nil { + return nil, err + } + } + + return b, nil +} diff --git a/testdata/includes_remote/.gitignore b/testdata/includes_remote/.gitignore new file mode 100644 index 00000000..2211df63 --- /dev/null +++ b/testdata/includes_remote/.gitignore @@ -0,0 +1 @@ +*.txt diff --git a/testdata/includes_remote/Taskfile.yml b/testdata/includes_remote/Taskfile.yml new file mode 100644 index 00000000..b5181022 --- /dev/null +++ b/testdata/includes_remote/Taskfile.yml @@ -0,0 +1,4 @@ +version: '3' + +includes: + first: "{{.FIRST_REMOTE_URL}}" diff --git a/testdata/includes_remote/first/Taskfile.yml b/testdata/includes_remote/first/Taskfile.yml new file mode 100644 index 00000000..d15e50e0 --- /dev/null +++ b/testdata/includes_remote/first/Taskfile.yml @@ -0,0 +1,11 @@ +version: '3' + +includes: + second: "{{.SECOND_REMOTE_URL}}" + +tasks: + write-file: + requires: + vars: [CONTENT, OUTPUT_FILE] + cmd: | + echo "{{.CONTENT}}" > "{{.OUTPUT_FILE}}" diff --git a/testdata/includes_remote/first/second/Taskfile.yml b/testdata/includes_remote/first/second/Taskfile.yml new file mode 100644 index 00000000..1b261e44 --- /dev/null +++ b/testdata/includes_remote/first/second/Taskfile.yml @@ -0,0 +1,8 @@ +version: '3' + +tasks: + write-file: + requires: + vars: [CONTENT, OUTPUT_FILE] + cmd: | + echo "{{.CONTENT}}" > "{{.OUTPUT_FILE}}"