package taskfile import ( "context" "fmt" "os" "sync" "time" "github.com/dominikbraun/graph" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v3" "github.com/go-task/task/v3/errors" "github.com/go-task/task/v3/internal/compiler" "github.com/go-task/task/v3/internal/filepathext" "github.com/go-task/task/v3/internal/logger" "github.com/go-task/task/v3/internal/templater" "github.com/go-task/task/v3/taskfile/ast" ) const ( taskfileUntrustedPrompt = `The task you are attempting to run depends on the remote Taskfile at %q. --- Make sure you trust the source of this Taskfile before continuing --- Continue?` taskfileChangedPrompt = `The Taskfile at %q has changed since you last used it! --- Make sure you trust the source of this Taskfile before continuing --- Continue?` ) // A Reader will recursively read Taskfiles from a given source using a directed // acyclic graph (DAG). type Reader struct { graph *ast.TaskfileGraph node Node insecure bool download bool offline bool timeout time.Duration tempDir string logger *logger.Logger promptMutex sync.Mutex } func NewReader( node Node, insecure bool, download bool, offline bool, timeout time.Duration, tempDir string, logger *logger.Logger, ) *Reader { return &Reader{ graph: ast.NewTaskfileGraph(), node: node, insecure: insecure, download: download, offline: offline, timeout: timeout, tempDir: tempDir, logger: logger, promptMutex: sync.Mutex{}, } } func (r *Reader) Read() (*ast.TaskfileGraph, error) { // Recursively loop through each Taskfile, adding vertices/edges to the graph if err := r.include(r.node); err != nil { return nil, err } return r.graph, nil } func (r *Reader) include(node Node) error { // Create a new vertex for the Taskfile vertex := &ast.TaskfileVertex{ URI: node.Location(), Taskfile: nil, } // Add the included Taskfile to the DAG // If the vertex already exists, we return early since its Taskfile has // already been read and its children explored if err := r.graph.AddVertex(vertex); err == graph.ErrVertexAlreadyExists { return nil } else if err != nil { return err } // Read and parse the Taskfile from the file and add it to the vertex var err error vertex.Taskfile, err = r.readNode(node) if err != nil { return err } // Create an error group to wait for all included Taskfiles to be read var g errgroup.Group // Loop over each included taskfile _ = vertex.Taskfile.Includes.Range(func(namespace string, include *ast.Include) error { vars := compiler.GetEnviron() vars.Merge(vertex.Taskfile.Vars, nil) // Start a goroutine to process each included Taskfile g.Go(func() error { cache := &templater.Cache{Vars: vars} include = &ast.Include{ Namespace: include.Namespace, Taskfile: templater.Replace(include.Taskfile, cache), Dir: templater.Replace(include.Dir, cache), Optional: include.Optional, Internal: include.Internal, Flatten: include.Flatten, Aliases: include.Aliases, AdvancedImport: include.AdvancedImport, Vars: include.Vars, } if err := cache.Err(); err != nil { return err } entrypoint, err := node.ResolveEntrypoint(include.Taskfile) if err != nil { return err } include.Dir, err = node.ResolveDir(include.Dir) if err != nil { return err } includeNode, err := NewNode(r.logger, entrypoint, include.Dir, r.insecure, r.timeout, WithParent(node), ) if err != nil { if include.Optional { return nil } return err } // Recurse into the included Taskfile if err := r.include(includeNode); err != nil { return err } // Create an edge between the Taskfiles r.graph.Lock() defer r.graph.Unlock() edge, err := r.graph.Edge(node.Location(), includeNode.Location()) if err == graph.ErrEdgeNotFound { // If the edge doesn't exist, create it err = r.graph.AddEdge( node.Location(), includeNode.Location(), graph.EdgeData([]*ast.Include{include}), graph.EdgeWeight(1), ) } else { // If the edge already exists edgeData := append(edge.Properties.Data.([]*ast.Include), include) err = r.graph.UpdateEdge( node.Location(), includeNode.Location(), graph.EdgeData(edgeData), graph.EdgeWeight(len(edgeData)), ) } if errors.Is(err, graph.ErrEdgeCreatesCycle) { return errors.TaskfileCycleError{ Source: node.Location(), Destination: includeNode.Location(), } } return err }) return nil }) // Wait for all the go routines to finish return g.Wait() } func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { b, err := r.loadNodeContent(node) if err != nil { return nil, err } var tf ast.Taskfile if err := yaml.Unmarshal(b, &tf); err != nil { // Decode the taskfile and add the file info the any errors taskfileInvalidErr := &errors.TaskfileDecodeError{} if errors.As(err, &taskfileInvalidErr) { return nil, taskfileInvalidErr.WithFileInfo(node.Location(), b, 2) } return nil, &errors.TaskfileInvalidError{URI: filepathext.TryAbsToRel(node.Location()), Err: err} } // Check that the Taskfile is set and has a schema version if tf.Version == nil { return nil, &errors.TaskfileVersionCheckError{URI: node.Location()} } // Set the taskfile/task's locations tf.Location = node.Location() for _, task := range tf.Tasks.Values() { // If the task is not defined, create a new one if task == nil { task = &ast.Task{} } // Set the location of the taskfile for each task if task.Location.Taskfile == "" { task.Location.Taskfile = tf.Location } } return &tf, nil } func (r *Reader) loadNodeContent(node Node) ([]byte, error) { if !node.Remote() { ctx, cf := context.WithTimeout(context.Background(), r.timeout) defer cf() return node.Read(ctx) } cache, err := NewCache(r.tempDir) if err != nil { return nil, err } if r.offline { // In offline mode try to use cached copy cached, err := cache.read(node) if errors.Is(err, os.ErrNotExist) { return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()} } else if err != nil { return nil, err } r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched cached copy\n", node.Location()) return cached, nil } ctx, cf := context.WithTimeout(context.Background(), r.timeout) defer cf() b, err := node.Read(ctx) if errors.Is(err, &errors.TaskfileNetworkTimeoutError{}) { // If we timed out then we likely have a network issue // If a download was requested, then we can't use a cached copy if r.download { return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout} } // Search for any cached copies cached, err := cache.read(node) if errors.Is(err, os.ErrNotExist) { return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true} } else if err != nil { return nil, err } r.logger.VerboseOutf(logger.Magenta, "task: [%s] Network timeout. Fetched cached copy\n", node.Location()) return cached, nil } else if err != nil { return nil, err } r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched remote copy\n", node.Location()) // Get the checksums checksum := checksum(b) cachedChecksum := cache.readChecksum(node) var prompt string if cachedChecksum == "" { // If the checksum doesn't exist, prompt the user to continue prompt = fmt.Sprintf(taskfileUntrustedPrompt, node.Location()) } else if checksum != cachedChecksum { // If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue prompt = fmt.Sprintf(taskfileChangedPrompt, node.Location()) } if prompt != "" { if err := func() error { r.promptMutex.Lock() defer r.promptMutex.Unlock() return r.logger.Prompt(logger.Yellow, prompt, "n", "y", "yes") }(); err != nil { return nil, &errors.TaskfileNotTrustedError{URI: node.Location()} } // Store the checksum if err := cache.writeChecksum(node, checksum); err != nil { return nil, err } // Cache the file r.logger.VerboseOutf(logger.Magenta, "task: [%s] Caching downloaded file\n", node.Location()) if err = cache.write(node, b); err != nil { return nil, err } } return b, nil }