diff --git a/.gitignore b/.gitignore index 45b74c09..ada278bd 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out +# Graphvis files +*.gv + # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 .glide/ diff --git a/go.mod b/go.mod index ffedea45..4d26524e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/Masterminds/semver/v3 v3.2.1 github.com/davecgh/go-spew v1.1.1 + github.com/dominikbraun/graph v0.23.0 github.com/fatih/color v1.16.0 github.com/go-task/slim-sprig/v3 v3.0.0 github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum index 205e44e4..97e941ec 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dominikbraun/graph v0.23.0 h1:TdZB4pPqCLFxYhdyMFb1TBdFxp8XLcJfTTBQucVPgCo= +github.com/dominikbraun/graph v0.23.0/go.mod h1:yOjYyogZLY1LSG9E33JWZJiq5k83Qy2C6POAuiViluc= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= diff --git a/setup.go b/setup.go index 8bebdbc7..c690cbe9 100644 --- a/setup.go +++ b/setup.go @@ -63,8 +63,7 @@ func (e *Executor) getRootNode() (taskfile.Node, error) { } func (e *Executor) readTaskfile(node taskfile.Node) error { - var err error - e.Taskfile, err = taskfile.Read( + reader := taskfile.NewReader( node, e.Insecure, e.Download, @@ -73,9 +72,13 @@ func (e *Executor) readTaskfile(node taskfile.Node) error { e.TempDir, e.Logger, ) + graph, err := reader.Read() if err != nil { return err } + if err := graph.Visualize("./taskfile-dag.gv"); err != nil { + return err + } return nil } diff --git a/taskfile/ast/graph.go b/taskfile/ast/graph.go new file mode 100644 index 00000000..3e30faa6 --- /dev/null +++ b/taskfile/ast/graph.go @@ -0,0 +1,41 @@ +package ast + +import ( + "os" + + "github.com/dominikbraun/graph" + "github.com/dominikbraun/graph/draw" +) + +type TaskfileGraph struct { + graph.Graph[string, *TaskfileVertex] +} + +// A TaskfileVertex is a vertex on the Taskfile DAG. +type TaskfileVertex struct { + URI string + Taskfile *Taskfile +} + +func taskfileHash(vertex *TaskfileVertex) string { + return vertex.URI +} + +func NewTaskfileGraph() *TaskfileGraph { + return &TaskfileGraph{ + graph.New(taskfileHash, + graph.Directed(), + graph.PreventCycles(), + graph.Rooted(), + ), + } +} + +func (r *TaskfileGraph) Visualize(filename string) error { + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + return draw.DOT(r.Graph, f) +} diff --git a/taskfile/reader.go b/taskfile/reader.go index af04637c..b194ef08 100644 --- a/taskfile/reader.go +++ b/taskfile/reader.go @@ -6,6 +6,8 @@ import ( "os" "time" + "github.com/dominikbraun/graph" + "golang.org/x/sync/errgroup" "gopkg.in/yaml.v3" "github.com/go-task/task/v3/errors" @@ -24,32 +26,83 @@ Continue?` Continue?` ) -// Read reads a Read for a given directory -// Uses current dir when dir is left empty. Uses Read.yml -// or Read.yaml when entrypoint is left empty -func Read( +// A Reader will recursively read Taskfiles from a given source using a directed +// acyclic graph (DAG). +type Reader struct { + graph *ast.TaskfileGraph + node Node + insecure bool + download bool + offline bool + timeout time.Duration + tempDir string + logger *logger.Logger +} + +func NewReader( node Node, insecure bool, download bool, offline bool, timeout time.Duration, tempDir string, - l *logger.Logger, -) (*ast.Taskfile, error) { - var _taskfile func(Node) (*ast.Taskfile, error) - _taskfile = func(node Node) (*ast.Taskfile, error) { - tf, err := readTaskfile(node, download, offline, timeout, tempDir, l) - if err != nil { - return nil, err - } + logger *logger.Logger, +) *Reader { + return &Reader{ + graph: ast.NewTaskfileGraph(), + node: node, + insecure: insecure, + download: download, + offline: offline, + timeout: timeout, + tempDir: tempDir, + logger: logger, + } +} - // Check that the Taskfile is set and has a schema version - if tf == nil || tf.Version == nil { - return nil, &errors.TaskfileVersionCheckError{URI: node.Location()} - } +func (r *Reader) Read() (*ast.TaskfileGraph, error) { + // Recursively loop through each Taskfile, adding vertices/edges to the graph + if err := r.include(r.node); err != nil { + return nil, err + } - err = tf.Includes.Range(func(namespace string, include ast.Include) error { - cache := &templater.Cache{Vars: tf.Vars} + return r.graph, nil +} + +func (r *Reader) include(node Node) error { + // Create a new vertex for the Taskfile + vertex := &ast.TaskfileVertex{ + URI: node.Location(), + Taskfile: nil, + } + + // Add the included Taskfile to the DAG + // If the vertex already exists, we return early since its Taskfile has + // already been read and its children explored + if err := r.graph.AddVertex(vertex); err == graph.ErrVertexAlreadyExists { + return nil + } else if err != nil { + return err + } + + // Read and parse the Taskfile from the file and add it to the vertex + var err error + vertex.Taskfile, err = r.readNode(node) + if err != nil { + if node.Optional() { + return nil + } + return err + } + + // Create an error group to wait for all included Taskfiles to be read + var g errgroup.Group + + // Loop over each included taskfile + vertex.Taskfile.Includes.Range(func(namespace string, include ast.Include) error { + // Start a goroutine to process each included Taskfile + g.Go(func() error { + cache := &templater.Cache{Vars: vertex.Taskfile.Vars} include = ast.Include{ Namespace: include.Namespace, Taskfile: templater.Replace(include.Taskfile, cache), @@ -74,117 +127,53 @@ func Read( return err } - includeReaderNode, err := NewNode(l, entrypoint, dir, insecure, timeout, + includeNode, err := NewNode(r.logger, entrypoint, dir, r.insecure, r.timeout, WithParent(node), WithOptional(include.Optional), ) if err != nil { - if include.Optional { - return nil - } return err } - if err := checkCircularIncludes(includeReaderNode); err != nil { + // Recurse into the included Taskfile + if err := r.include(includeNode); err != nil { return err } - includedTaskfile, err := _taskfile(includeReaderNode) - if err != nil { - if include.Optional { - return nil - } - return err - } - - if len(includedTaskfile.Dotenv) > 0 { - return ErrIncludedTaskfilesCantHaveDotenvs - } - - if include.AdvancedImport { - // nolint: errcheck - includedTaskfile.Vars.Range(func(k string, v ast.Var) error { - o := v - o.Dir = dir - includedTaskfile.Vars.Set(k, o) - return nil - }) - // nolint: errcheck - includedTaskfile.Env.Range(func(k string, v ast.Var) error { - o := v - o.Dir = dir - includedTaskfile.Env.Set(k, o) - return nil - }) - - for _, task := range includedTaskfile.Tasks.Values() { - task.Dir = filepathext.SmartJoin(dir, task.Dir) - if task.IncludeVars == nil { - task.IncludeVars = &ast.Vars{} - } - task.IncludeVars.Merge(include.Vars) - task.IncludedTaskfileVars = includedTaskfile.Vars - } - } - - if err = tf.Merge(includedTaskfile, &include); err != nil { - return err - } - - return nil + // Create an edge between the Taskfiles + return r.graph.AddEdge(node.Location(), includeNode.Location()) }) - if err != nil { - return nil, err - } + return nil + }) - for _, task := range tf.Tasks.Values() { - // If the task is not defined, create a new one - if task == nil { - task = &ast.Task{} - } - // Set the location of the taskfile for each task - if task.Location.Taskfile == "" { - task.Location.Taskfile = tf.Location - } - } - - return tf, nil - } - return _taskfile(node) + // Wait for all the go routines to finish + return g.Wait() } -func readTaskfile( - node Node, - download, - offline bool, - timeout time.Duration, - tempDir string, - l *logger.Logger, -) (*ast.Taskfile, error) { +func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { var b []byte var err error var cache *Cache if node.Remote() { - cache, err = NewCache(tempDir) + cache, err = NewCache(r.tempDir) if err != nil { return nil, err } } // If the file is remote and we're in offline mode, check if we have a cached copy - if node.Remote() && offline { + if node.Remote() && r.offline { if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) { return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()} } else if err != nil { return nil, err } - l.VerboseOutf(logger.Magenta, "task: [%s] Fetched cached copy\n", node.Location()) - + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched cached copy\n", node.Location()) } else { downloaded := false - ctx, cf := context.WithTimeout(context.Background(), timeout) + ctx, cf := context.WithTimeout(context.Background(), r.timeout) defer cf() // Read the file @@ -192,16 +181,16 @@ func readTaskfile( // If we timed out then we likely have a network issue if node.Remote() && errors.Is(ctx.Err(), context.DeadlineExceeded) { // If a download was requested, then we can't use a cached copy - if download { - return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: timeout} + if r.download { + return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout} } // Search for any cached copies if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) { - return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: timeout, CheckedCache: true} + return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true} } else if err != nil { return nil, err } - l.VerboseOutf(logger.Magenta, "task: [%s] Network timeout. Fetched cached copy\n", node.Location()) + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Network timeout. Fetched cached copy\n", node.Location()) } else if err != nil { return nil, err } else { @@ -210,7 +199,7 @@ func readTaskfile( // If the node was remote, we need to check the checksum if node.Remote() && downloaded { - l.VerboseOutf(logger.Magenta, "task: [%s] Fetched remote copy\n", node.Location()) + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched remote copy\n", node.Location()) // Get the checksums checksum := checksum(b) @@ -224,8 +213,8 @@ func readTaskfile( // If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue prompt = fmt.Sprintf(taskfileChangedPrompt, node.Location()) } - if prompt != "" { - if err := l.Prompt(logger.Yellow, prompt, "n", "y", "yes"); err != nil { + if prompt == "" { + if err := r.logger.Prompt(logger.Yellow, prompt, "n", "y", "yes"); err != nil { return nil, &errors.TaskfileNotTrustedError{URI: node.Location()} } } @@ -237,7 +226,7 @@ func readTaskfile( return nil, err } // Cache the file - l.VerboseOutf(logger.Magenta, "task: [%s] Caching downloaded file\n", node.Location()) + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Caching downloaded file\n", node.Location()) if err = cache.write(node, b); err != nil { return nil, err } @@ -253,25 +242,3 @@ func readTaskfile( return &t, nil } - -func checkCircularIncludes(node Node) error { - if node == nil { - return errors.New("task: failed to check for include cycle: node was nil") - } - if node.Parent() == nil { - return errors.New("task: failed to check for include cycle: node.Parent was nil") - } - curNode := node - location := node.Location() - for curNode.Parent() != nil { - curNode = curNode.Parent() - curLocation := curNode.Location() - if curLocation == location { - return fmt.Errorf("task: include cycle detected between %s <--> %s", - curLocation, - node.Parent().Location(), - ) - } - } - return nil -}