1
0
mirror of https://github.com/go-task/task.git synced 2025-01-22 05:10:17 +02:00
task/internal/deepcopy/deepcopy.go
Pete Davison 5e9851f42f
Update minimum go version (#1758)
* feat: update minimum version to 1.22

* refactor: use int range iterator

* refactor: loop variables

* refactor: replace slicesext.FirstNonZero with cmp.Or

* refactor: use slices.Concat instead of append

* fix: unused param

* fix: linting
2024-08-14 08:37:05 -05:00

142 lines
3.5 KiB
Go

package deepcopy
import (
"reflect"
)
type Copier[T any] interface {
DeepCopy() T
}
func Slice[T any](orig []T) []T {
if orig == nil {
return nil
}
c := make([]T, len(orig))
for i, v := range orig {
if copyable, ok := any(v).(Copier[T]); ok {
c[i] = copyable.DeepCopy()
} else {
c[i] = v
}
}
return c
}
func Map[K comparable, V any](orig map[K]V) map[K]V {
if orig == nil {
return nil
}
c := make(map[K]V, len(orig))
for k, v := range orig {
if copyable, ok := any(v).(Copier[V]); ok {
c[k] = copyable.DeepCopy()
} else {
c[k] = v
}
}
return c
}
// TraverseStringsFunc runs the given function on every string in the given
// value by traversing it recursively. If the given value is a string, the
// function will run on a copy of the string and return it. If the value is a
// struct, map or a slice, the function will recursively call itself for each
// field or element of the struct, map or slice until all strings inside the
// struct or slice are replaced.
func TraverseStringsFunc[T any](v T, fn func(v string) (string, error)) (T, error) {
original := reflect.ValueOf(v)
if original.Kind() == reflect.Invalid || !original.IsValid() {
return v, nil
}
copy := reflect.New(original.Type()).Elem()
var traverseFunc func(copy, v reflect.Value) error
traverseFunc = func(copy, v reflect.Value) error {
switch v.Kind() {
case reflect.Ptr:
// Unwrap the pointer
originalValue := v.Elem()
// If the pointer is nil, do nothing
if !originalValue.IsValid() {
return nil
}
// Create an empty copy from the original value's type
copy.Set(reflect.New(originalValue.Type()))
// Unwrap the newly created pointer and call traverseFunc recursively
if err := traverseFunc(copy.Elem(), originalValue); err != nil {
return err
}
case reflect.Interface:
// Unwrap the interface
originalValue := v.Elem()
if !originalValue.IsValid() {
return nil
}
// Create an empty copy from the original value's type
copyValue := reflect.New(originalValue.Type()).Elem()
// Unwrap the newly created pointer and call traverseFunc recursively
if err := traverseFunc(copyValue, originalValue); err != nil {
return err
}
copy.Set(copyValue)
case reflect.Struct:
// Loop over each field and call traverseFunc recursively
for i := range v.NumField() {
if err := traverseFunc(copy.Field(i), v.Field(i)); err != nil {
return err
}
}
case reflect.Slice:
// Create an empty copy from the original value's type
copy.Set(reflect.MakeSlice(v.Type(), v.Len(), v.Cap()))
// Loop over each element and call traverseFunc recursively
for i := range v.Len() {
if err := traverseFunc(copy.Index(i), v.Index(i)); err != nil {
return err
}
}
case reflect.Map:
// Create an empty copy from the original value's type
copy.Set(reflect.MakeMap(v.Type()))
// Loop over each key
for _, key := range v.MapKeys() {
// Create a copy of each map index
originalValue := v.MapIndex(key)
if originalValue.IsNil() {
continue
}
copyValue := reflect.New(originalValue.Type()).Elem()
// Call traverseFunc recursively
if err := traverseFunc(copyValue, originalValue); err != nil {
return err
}
copy.SetMapIndex(key, copyValue)
}
case reflect.String:
rv, err := fn(v.String())
if err != nil {
return err
}
copy.Set(reflect.ValueOf(rv))
default:
copy.Set(v)
}
return nil
}
if err := traverseFunc(copy, original); err != nil {
return v, err
}
return copy.Interface().(T), nil
}