diff --git a/pkg/config/app_config.go b/pkg/config/app_config.go index e7d884dbb..12b377f96 100644 --- a/pkg/config/app_config.go +++ b/pkg/config/app_config.go @@ -5,6 +5,7 @@ import ( "log" "os" "path/filepath" + "reflect" "strings" "time" @@ -237,7 +238,17 @@ func migrateUserConfig(path string, content []byte) ([]byte, error) { // A pure function helper for testing purposes func computeMigratedConfig(path string, content []byte) ([]byte, error) { - changedContent := content + var err error + var rootNode yaml.Node + err = yaml.Unmarshal(content, &rootNode) + if err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + var originalCopy yaml.Node + err = yaml.Unmarshal(content, &originalCopy) + if err != nil { + return nil, fmt.Errorf("failed to parse YAML, but only the second time!?!? How did that happen: %w", err) + } pathsToReplace := []struct { oldPath []string @@ -248,46 +259,52 @@ func computeMigratedConfig(path string, content []byte) ([]byte, error) { {[]string{"gui", "windowSize"}, "screenMode"}, } - var err error for _, pathToReplace := range pathsToReplace { - changedContent, err = yaml_utils.RenameYamlKey(changedContent, pathToReplace.oldPath, pathToReplace.newName) + err := yaml_utils.RenameYamlKey(&rootNode, pathToReplace.oldPath, pathToReplace.newName) if err != nil { return nil, fmt.Errorf("Couldn't migrate config file at `%s` for key %s: %s", path, strings.Join(pathToReplace.oldPath, "."), err) } } - changedContent, err = changeNullKeybindingsToDisabled(changedContent) + err = changeNullKeybindingsToDisabled(&rootNode) if err != nil { return nil, fmt.Errorf("Couldn't migrate config file at `%s`: %s", path, err) } - changedContent, err = changeElementToSequence(changedContent, []string{"git", "commitPrefix"}) + err = changeElementToSequence(&rootNode, []string{"git", "commitPrefix"}) if err != nil { return nil, fmt.Errorf("Couldn't migrate config file at `%s`: %s", path, err) } - changedContent, err = changeCommitPrefixesMap(changedContent) + err = changeCommitPrefixesMap(&rootNode) if err != nil { return nil, fmt.Errorf("Couldn't migrate config file at `%s`: %s", path, err) } + // Add more migrations here... - return changedContent, nil + if !reflect.DeepEqual(rootNode, originalCopy) { + newContent, err := yaml_utils.YamlMarshal(&rootNode) + if err != nil { + return nil, fmt.Errorf("Failed to remarsal!\n %w", err) + } + return newContent, nil + } else { + return content, nil + } } -func changeNullKeybindingsToDisabled(changedContent []byte) ([]byte, error) { - return yaml_utils.Walk(changedContent, func(node *yaml.Node, path string) bool { +func changeNullKeybindingsToDisabled(rootNode *yaml.Node) error { + return yaml_utils.Walk(rootNode, func(node *yaml.Node, path string) { if strings.HasPrefix(path, "keybinding.") && node.Kind == yaml.ScalarNode && node.Tag == "!!null" { node.Value = "" node.Tag = "!!str" - return true } - return false }) } -func changeElementToSequence(changedContent []byte, path []string) ([]byte, error) { - return yaml_utils.TransformNode(changedContent, path, func(node *yaml.Node) (bool, error) { +func changeElementToSequence(rootNode *yaml.Node, path []string) error { + return yaml_utils.TransformNode(rootNode, path, func(node *yaml.Node) error { if node.Kind == yaml.MappingNode { nodeContentCopy := node.Content node.Kind = yaml.SequenceNode @@ -298,15 +315,14 @@ func changeElementToSequence(changedContent []byte, path []string) ([]byte, erro Content: nodeContentCopy, }} - return true, nil + return nil } - return false, nil + return nil }) } -func changeCommitPrefixesMap(changedContent []byte) ([]byte, error) { - return yaml_utils.TransformNode(changedContent, []string{"git", "commitPrefixes"}, func(prefixesNode *yaml.Node) (bool, error) { - changedAnyNodes := false +func changeCommitPrefixesMap(rootNode *yaml.Node) error { + return yaml_utils.TransformNode(rootNode, []string{"git", "commitPrefixes"}, func(prefixesNode *yaml.Node) error { if prefixesNode.Kind == yaml.MappingNode { for _, contentNode := range prefixesNode.Content { if contentNode.Kind == yaml.MappingNode { @@ -318,11 +334,10 @@ func changeCommitPrefixesMap(changedContent []byte) ([]byte, error) { Kind: yaml.MappingNode, Content: nodeContentCopy, }} - changedAnyNodes = true } } } - return changedAnyNodes, nil + return nil }) } diff --git a/pkg/utils/yaml_utils/yaml_utils.go b/pkg/utils/yaml_utils/yaml_utils.go index 6e031413e..956805691 100644 --- a/pkg/utils/yaml_utils/yaml_utils.go +++ b/pkg/utils/yaml_utils/yaml_utils.go @@ -35,7 +35,7 @@ func UpdateYamlValue(yamlBytes []byte, path []string, value string) ([]byte, err } // Convert the updated YAML node back to YAML bytes. - updatedYAMLBytes, err := yamlMarshal(body) + updatedYAMLBytes, err := YamlMarshal(body) if err != nil { return nil, fmt.Errorf("failed to convert YAML node to bytes: %w", err) } @@ -100,147 +100,101 @@ func lookupKey(node *yaml.Node, key string) (*yaml.Node, *yaml.Node) { return nil, nil } -// Walks a yaml document to the specified path, and then applies the transformation to that node. -// -// The transform must return true if it made changes to the node. +// Walks a yaml document from the root node to the specified path, and then applies the transformation to that node. // If the requested path is not defined in the document, no changes are made to the document. -// -// If no changes are made, the original document is returned. -// If changes are made, a newly marshalled document is returned. (This may result in different indentation for all nodes) -func TransformNode(yamlBytes []byte, path []string, transform func(node *yaml.Node) (bool, error)) ([]byte, error) { - // Parse the YAML file. - var node yaml.Node - err := yaml.Unmarshal(yamlBytes, &node) - if err != nil { - return nil, fmt.Errorf("failed to parse YAML: %w", err) - } - +func TransformNode(rootNode *yaml.Node, path []string, transform func(node *yaml.Node) error) error { // Empty document: nothing to do. - if len(node.Content) == 0 { - return yamlBytes, nil + if len(rootNode.Content) == 0 { + return nil } - body := node.Content[0] + body := rootNode.Content[0] - if didTransform, err := transformNode(body, path, transform); err != nil || !didTransform { - return yamlBytes, err + if err := transformNode(body, path, transform); err != nil { + return err } - // Convert the updated YAML node back to YAML bytes. - updatedYAMLBytes, err := yamlMarshal(body) - if err != nil { - return nil, fmt.Errorf("failed to convert YAML node to bytes: %w", err) - } - - return updatedYAMLBytes, nil + return nil } // A recursive function to walk down the tree. See TransformNode for more details. -func transformNode(node *yaml.Node, path []string, transform func(node *yaml.Node) (bool, error)) (bool, error) { +func transformNode(node *yaml.Node, path []string, transform func(node *yaml.Node) error) error { if len(path) == 0 { return transform(node) } keyNode, valueNode := lookupKey(node, path[0]) if keyNode == nil { - return false, nil + return nil } return transformNode(valueNode, path[1:], transform) } -// takes a yaml document in bytes, a path to a key, and a new name for the key. +// Takes the root node of a yaml document, a path to a key, and a new name for the key. // Will rename the key to the new name if it exists, and do nothing otherwise. -func RenameYamlKey(yamlBytes []byte, path []string, newKey string) ([]byte, error) { - // Parse the YAML file. - var node yaml.Node - err := yaml.Unmarshal(yamlBytes, &node) - if err != nil { - return nil, fmt.Errorf("failed to parse YAML: %w for bytes %s", err, string(yamlBytes)) - } - +func RenameYamlKey(rootNode *yaml.Node, path []string, newKey string) error { // Empty document: nothing to do. - if len(node.Content) == 0 { - return yamlBytes, nil + if len(rootNode.Content) == 0 { + return nil } - body := node.Content[0] + body := rootNode.Content[0] - if didRename, err := renameYamlKey(body, path, newKey); err != nil || !didRename { - return yamlBytes, err + if err := renameYamlKey(body, path, newKey); err != nil { + return err } - // Convert the updated YAML node back to YAML bytes. - updatedYAMLBytes, err := yamlMarshal(body) - if err != nil { - return nil, fmt.Errorf("failed to convert YAML node to bytes: %w", err) - } - - return updatedYAMLBytes, nil + return nil } // Recursive function to rename the YAML key. -func renameYamlKey(node *yaml.Node, path []string, newKey string) (bool, error) { +func renameYamlKey(node *yaml.Node, path []string, newKey string) error { if node.Kind != yaml.MappingNode { - return false, errors.New("yaml node in path is not a dictionary") + return errors.New("yaml node in path is not a dictionary") } keyNode, valueNode := lookupKey(node, path[0]) if keyNode == nil { - return false, nil + return nil } // end of path reached: rename key if len(path) == 1 { // Check that new key doesn't exist yet if newKeyNode, _ := lookupKey(node, newKey); newKeyNode != nil { - return false, fmt.Errorf("new key `%s' already exists", newKey) + return fmt.Errorf("new key `%s' already exists", newKey) } keyNode.Value = newKey - return true, nil + return nil } return renameYamlKey(valueNode, path[1:], newKey) } // Traverses a yaml document, calling the callback function for each node. The -// callback is allowed to modify the node in place, in which case it should -// return true. The function returns the original yaml document if none of the -// callbacks returned true, and the modified document otherwise. -func Walk(yamlBytes []byte, callback func(node *yaml.Node, path string) bool) ([]byte, error) { - // Parse the YAML file. - var node yaml.Node - err := yaml.Unmarshal(yamlBytes, &node) - if err != nil { - return nil, fmt.Errorf("failed to parse YAML: %w", err) - } - +// callback is expected to modify the node in place +func Walk(rootNode *yaml.Node, callback func(node *yaml.Node, path string)) error { // Empty document: nothing to do. - if len(node.Content) == 0 { - return yamlBytes, nil + if len(rootNode.Content) == 0 { + return nil } - body := node.Content[0] + body := rootNode.Content[0] - if didChange, err := walk(body, "", callback); err != nil || !didChange { - return yamlBytes, err + if err := walk(body, "", callback); err != nil { + return err } - // Convert the updated YAML node back to YAML bytes. - updatedYAMLBytes, err := yamlMarshal(body) - if err != nil { - return nil, fmt.Errorf("failed to convert YAML node to bytes: %w", err) - } - - return updatedYAMLBytes, nil + return nil } -func walk(node *yaml.Node, path string, callback func(*yaml.Node, string) bool) (bool, error) { - didChange := callback(node, path) +func walk(node *yaml.Node, path string, callback func(*yaml.Node, string)) error { + callback(node, path) switch node.Kind { case yaml.DocumentNode: - return false, errors.New("Unexpected document node in the middle of a yaml tree") + return errors.New("Unexpected document node in the middle of a yaml tree") case yaml.MappingNode: for i := 0; i < len(node.Content); i += 2 { name := node.Content[i].Value @@ -251,31 +205,29 @@ func walk(node *yaml.Node, path string, callback func(*yaml.Node, string) bool) } else { childPath = fmt.Sprintf("%s.%s", path, name) } - didChangeChild, err := walk(childNode, childPath, callback) + err := walk(childNode, childPath, callback) if err != nil { - return false, err + return err } - didChange = didChange || didChangeChild } case yaml.SequenceNode: for i := 0; i < len(node.Content); i++ { childPath := fmt.Sprintf("%s[%d]", path, i) - didChangeChild, err := walk(node.Content[i], childPath, callback) + err := walk(node.Content[i], childPath, callback) if err != nil { - return false, err + return err } - didChange = didChange || didChangeChild } case yaml.ScalarNode: // nothing to do case yaml.AliasNode: - return false, errors.New("Alias nodes are not supported") + return errors.New("Alias nodes are not supported") } - return didChange, nil + return nil } -func yamlMarshal(node *yaml.Node) ([]byte, error) { +func YamlMarshal(node *yaml.Node) ([]byte, error) { var buffer bytes.Buffer encoder := yaml.NewEncoder(&buffer) encoder.SetIndent(2) diff --git a/pkg/utils/yaml_utils/yaml_utils_test.go b/pkg/utils/yaml_utils/yaml_utils_test.go index 21433d32d..e53bb1354 100644 --- a/pkg/utils/yaml_utils/yaml_utils_test.go +++ b/pkg/utils/yaml_utils/yaml_utils_test.go @@ -186,14 +186,16 @@ func TestRenameYamlKey(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - out, actualErr := RenameYamlKey([]byte(test.in), test.path, test.newKey) + node := unmarshalForTest(t, test.in) + actualErr := RenameYamlKey(&node, test.path, test.newKey) if test.expectedErr == "" { assert.NoError(t, actualErr) } else { assert.EqualError(t, actualErr, test.expectedErr) } + out := marshalForTest(t, &node) - assert.Equal(t, test.expectedOut, string(out)) + assert.Equal(t, test.expectedOut, out) }) } } @@ -238,10 +240,10 @@ func TestWalk_paths(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + node := unmarshalForTest(t, test.document) paths := []string{} - _, err := Walk([]byte(test.document), func(node *yaml.Node, path string) bool { + err := Walk(&node, func(node *yaml.Node, path string) { paths = append(paths, path) - return true }) assert.NoError(t, err) @@ -254,48 +256,41 @@ func TestWalk_inPlaceChanges(t *testing.T) { tests := []struct { name string in string - callback func(node *yaml.Node, path string) bool + callback func(node *yaml.Node, path string) expectedOut string }{ { - name: "no change", - in: "x: 5", - callback: func(node *yaml.Node, path string) bool { return false }, - expectedOut: "x: 5", + name: "no change", + in: "x: 5", + callback: func(node *yaml.Node, path string) {}, }, { name: "change value", in: "x: 5\ny: 3", - callback: func(node *yaml.Node, path string) bool { + callback: func(node *yaml.Node, path string) { if path == "x" { node.Value = "7" - return true } - return false }, expectedOut: "x: 7\ny: 3\n", }, { name: "change nested value", in: "x:\n y: 5", - callback: func(node *yaml.Node, path string) bool { + callback: func(node *yaml.Node, path string) { if path == "x.y" { node.Value = "7" - return true } - return false }, expectedOut: "x:\n y: 7\n", }, { name: "change array value", in: "x:\n - y: 5", - callback: func(node *yaml.Node, path string) bool { + callback: func(node *yaml.Node, path string) { if path == "x[0].y" { node.Value = "7" - return true } - return false }, expectedOut: "x:\n - y: 7\n", }, @@ -303,28 +298,34 @@ func TestWalk_inPlaceChanges(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - result, err := Walk([]byte(test.in), test.callback) - + node := unmarshalForTest(t, test.in) + err := Walk(&node, test.callback) assert.NoError(t, err) - assert.Equal(t, test.expectedOut, string(result)) + if test.expectedOut == "" { + unmodifiedOriginal := unmarshalForTest(t, test.in) + assert.Equal(t, unmodifiedOriginal, node) + } else { + result := marshalForTest(t, &node) + assert.Equal(t, test.expectedOut, result) + } }) } } func TestTransformNode(t *testing.T) { - transformIntValueToString := func(node *yaml.Node) (bool, error) { + transformIntValueToString := func(node *yaml.Node) error { if node.Kind == yaml.ScalarNode { if node.ShortTag() == "!!int" { node.Tag = "!!str" - return true, nil + return nil } else if node.ShortTag() == "!!str" { // We have already transformed it, - return false, nil + return nil } else { - return false, fmt.Errorf("Node was of bad type") + return fmt.Errorf("Node was of bad type") } } else { - return false, fmt.Errorf("Node was not a scalar") + return fmt.Errorf("Node was not a scalar") } } @@ -332,15 +333,14 @@ func TestTransformNode(t *testing.T) { name string in string path []string - transform func(node *yaml.Node) (bool, error) + transform func(node *yaml.Node) error expectedOut string }{ { - name: "Path not present", - in: "foo: 1", - path: []string{"bar"}, - transform: transformIntValueToString, - expectedOut: "foo: 1", + name: "Path not present", + in: "foo: 1", + path: []string{"bar"}, + transform: transformIntValueToString, }, { name: "Part of path present", @@ -349,9 +349,6 @@ foo: bar: 2`, path: []string{"foo", "baz"}, transform: transformIntValueToString, - expectedOut: ` -foo: - bar: 2`, }, { name: "Successfully Transforms to string", @@ -371,19 +368,42 @@ foo: bar: "2"`, path: []string{"foo", "bar"}, transform: transformIntValueToString, - expectedOut: ` -foo: - bar: "2"`, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - result, err := TransformNode([]byte(test.in), test.path, test.transform) + node := unmarshalForTest(t, test.in) + err := TransformNode(&node, test.path, test.transform) if err != nil { t.Fatal(err) } - assert.Equal(t, test.expectedOut, string(result)) + if test.expectedOut == "" { + unmodifiedOriginal := unmarshalForTest(t, test.in) + assert.Equal(t, unmodifiedOriginal, node) + } else { + result := marshalForTest(t, &node) + assert.Equal(t, test.expectedOut, result) + } }) } } + +func unmarshalForTest(t *testing.T, input string) yaml.Node { + t.Helper() + var node yaml.Node + err := yaml.Unmarshal([]byte(input), &node) + if err != nil { + t.Fatal(err) + } + return node +} + +func marshalForTest(t *testing.T, node *yaml.Node) string { + t.Helper() + result, err := YamlMarshal(node) + if err != nil { + t.Fatal(err) + } + return string(result) +}