1
0
mirror of https://github.com/go-task/task.git synced 2025-06-04 23:38:05 +02:00

Move circular include logic to a separate function

This commit is contained in:
tylermmorton 2022-01-15 23:34:59 -05:00
parent 02e7ff27c7
commit c73a2c8f84
4 changed files with 33 additions and 25 deletions

View File

@ -77,29 +77,18 @@ func Taskfile(readerNode *ReaderNode) (*taskfile.Taskfile, error) {
path = filepath.Join(readerNode.Dir, path) path = filepath.Join(readerNode.Dir, path)
} }
// check for cyclic include references by walking up includeReaderNode := &ReaderNode{
// node tree of parents and comparing paths
var curNode = readerNode
for curNode.Parent != nil {
curNode = curNode.Parent
curPath := filepath.Join(curNode.Dir, curNode.Entrypoint)
if curPath == path {
return fmt.Errorf("include cycle detected between %s <--> %s",
curPath,
filepath.Join(readerNode.Dir, readerNode.Entrypoint),
)
}
}
// if we made it here then there is no cyclic include
readOpts := &ReaderNode{
Dir: filepath.Dir(path), Dir: filepath.Dir(path),
Entrypoint: filepath.Base(path), Entrypoint: filepath.Base(path),
Parent: readerNode, Parent: readerNode,
Optional: includedTask.Optional, Optional: includedTask.Optional,
} }
includedTaskfile, err := Taskfile(readOpts) if err := checkCircularIncludes(includeReaderNode); err != nil {
return err
}
includedTaskfile, err := Taskfile(includeReaderNode)
if err != nil { if err != nil {
if includedTask.Optional { if includedTask.Optional {
return nil return nil
@ -190,3 +179,25 @@ func exists(path string) (string, error) {
return "", fmt.Errorf(`task: No Taskfile found in "%s". Use "task --init" to create a new one`, path) return "", fmt.Errorf(`task: No Taskfile found in "%s". Use "task --init" to create a new one`, path)
} }
func checkCircularIncludes(node *ReaderNode) error {
if node == nil {
return errors.New("failed to check for include cycle: node was nil")
}
if node.Parent == nil {
return errors.New("failed to check for include cycle: node.Parent was nil")
}
var curNode = node
var basePath = filepath.Join(node.Dir, node.Entrypoint)
for curNode.Parent != nil {
curNode = curNode.Parent
curPath := filepath.Join(curNode.Dir, curNode.Entrypoint)
if curPath == basePath {
return fmt.Errorf("include cycle detected between %s <--> %s",
curPath,
filepath.Join(node.Parent.Dir, node.Parent.Entrypoint),
)
}
}
return nil
}

View File

@ -6,7 +6,4 @@ includes:
tasks: tasks:
default: default:
cmds: cmds:
- echo "called_dep" > called_dep.txt - task: one:two:default
level1:
cmds:
- echo "hello level 1"

View File

@ -4,6 +4,6 @@ includes:
'two': ./two/Taskfile.yml 'two': ./two/Taskfile.yml
tasks: tasks:
level2: level1:
cmds: cmds:
- echo "hello level 2" - echo "hello level 1"

View File

@ -1,6 +1,6 @@
version: '3' version: '3'
tasks: tasks:
level3: default:
cmds: cmds:
- echo "hello level 3" - echo "called_dep" > called_dep.txt