1
0
mirror of https://github.com/IBM/fp-go.git synced 2025-11-23 22:14:53 +02:00

fix: improve lens generation

Signed-off-by: Dr. Carsten Leue <carsten.leue@de.ibm.com>
This commit is contained in:
Dr. Carsten Leue
2025-11-12 17:28:20 +01:00
parent ca813b673c
commit 6f7ec0768d
5 changed files with 435 additions and 19 deletions

View File

@@ -64,6 +64,7 @@ type fieldInfo struct {
TypeName string
BaseType string // TypeName without leading * for pointer types
IsOptional bool // true if field is a pointer or has json omitempty tag
IsComparable bool // true if the type is comparable (can use ==)
}
// templateData holds data for template rendering
@@ -127,12 +128,19 @@ func Make{{.Name}}RefLenses() {{.Name}}RefLenses {
func(s *{{$.Name}}) O.Option[{{.TypeName}}] { return iso{{.Name}}.Get(s.{{.Name}}) },
func(s *{{$.Name}}, v O.Option[{{.TypeName}}]) *{{$.Name}} { s.{{.Name}} = iso{{.Name}}.ReverseGet(v); return s },
),
{{- else}}
{{- if .IsComparable}}
{{.Name}}: L.MakeLensStrict(
func(s *{{$.Name}}) {{.TypeName}} { return s.{{.Name}} },
func(s *{{$.Name}}, v {{.TypeName}}) *{{$.Name}} { s.{{.Name}} = v; return s },
),
{{- else}}
{{.Name}}: L.MakeLensRef(
func(s *{{$.Name}}) {{.TypeName}} { return s.{{.Name}} },
func(s *{{$.Name}}, v {{.TypeName}}) *{{$.Name}} { s.{{.Name}} = v; return s },
),
{{- end}}
{{- end}}
{{- end}}
}
}
@@ -257,6 +265,111 @@ func isPointerType(expr ast.Expr) bool {
return ok
}
// isComparableType checks if a type expression represents a comparable type.
// Comparable types in Go include:
// - Basic types (bool, numeric types, string)
// - Pointer types
// - Channel types
// - Interface types
// - Structs where all fields are comparable
// - Arrays where the element type is comparable
//
// Non-comparable types include:
// - Slices
// - Maps
// - Functions
func isComparableType(expr ast.Expr) bool {
switch t := expr.(type) {
case *ast.Ident:
// Basic types and named types
// We assume named types are comparable unless they're known non-comparable types
name := t.Name
// Known non-comparable built-in types
if name == "error" {
// error is an interface, which is comparable
return true
}
// Most basic types and named types are comparable
// We can't determine if a custom type is comparable without type checking,
// so we assume it is (conservative approach)
return true
case *ast.StarExpr:
// Pointer types are always comparable
return true
case *ast.ArrayType:
// Arrays are comparable if their element type is comparable
if t.Len == nil {
// This is a slice (no length), slices are not comparable
return false
}
// Fixed-size array, check element type
return isComparableType(t.Elt)
case *ast.MapType:
// Maps are not comparable
return false
case *ast.FuncType:
// Functions are not comparable
return false
case *ast.InterfaceType:
// Interface types are comparable
return true
case *ast.StructType:
// Structs are comparable if all fields are comparable
// We can't easily determine this without full type information,
// so we conservatively return false for struct literals
return false
case *ast.SelectorExpr:
// Qualified identifier (e.g., pkg.Type)
// We can't determine comparability without type information
// Check for known non-comparable types from standard library
if ident, ok := t.X.(*ast.Ident); ok {
pkgName := ident.Name
typeName := t.Sel.Name
// Check for known non-comparable types
if pkgName == "context" && typeName == "Context" {
// context.Context is an interface, which is comparable
return true
}
// For other qualified types, we assume they're comparable
// This is a conservative approach
}
return true
case *ast.IndexExpr, *ast.IndexListExpr:
// Generic types - we can't determine comparability without type information
// For common generic types, we can make educated guesses
var baseExpr ast.Expr
if idx, ok := t.(*ast.IndexExpr); ok {
baseExpr = idx.X
} else if idxList, ok := t.(*ast.IndexListExpr); ok {
baseExpr = idxList.X
}
if sel, ok := baseExpr.(*ast.SelectorExpr); ok {
if ident, ok := sel.X.(*ast.Ident); ok {
pkgName := ident.Name
typeName := sel.Sel.Name
// Check for known non-comparable generic types
if pkgName == "option" && typeName == "Option" {
// Option types are not comparable (they contain a slice internally)
return false
}
if pkgName == "either" && typeName == "Either" {
// Either types are not comparable
return false
}
}
}
// For other generic types, conservatively assume not comparable
return false
case *ast.ChanType:
// Channel types are comparable
return true
default:
// Unknown type, conservatively assume not comparable
return false
}
}
// parseFile parses a Go file and extracts structs with lens annotations
func parseFile(filename string) ([]structInfo, string, error) {
fset := token.NewFileSet()
@@ -331,6 +444,7 @@ func parseFile(filename string) ([]structInfo, string, error) {
typeName := getTypeName(field.Type)
isOptional := false
baseType := typeName
isComparable := false
// Check if field is optional:
// 1. Pointer types are always optional
@@ -344,6 +458,12 @@ func parseFile(filename string) ([]structInfo, string, error) {
isOptional = true
}
// Check if the type is comparable (for non-optional fields)
// For optional fields, we don't need to check since they use LensO
if !isOptional {
isComparable = isComparableType(field.Type)
}
// Extract imports from this field's type
fieldImports := make(map[string]string)
extractImports(field.Type, fieldImports)
@@ -360,6 +480,7 @@ func parseFile(filename string) ([]structInfo, string, error) {
TypeName: typeName,
BaseType: baseType,
IsOptional: isOptional,
IsComparable: isComparable,
})
}
}

View File

@@ -168,6 +168,91 @@ func TestIsPointerType(t *testing.T) {
}
}
func TestIsComparableType(t *testing.T) {
tests := []struct {
name string
code string
expected bool
}{
{
name: "basic type - string",
code: "type T struct { F string }",
expected: true,
},
{
name: "basic type - int",
code: "type T struct { F int }",
expected: true,
},
{
name: "basic type - bool",
code: "type T struct { F bool }",
expected: true,
},
{
name: "pointer type",
code: "type T struct { F *string }",
expected: true,
},
{
name: "slice type - not comparable",
code: "type T struct { F []string }",
expected: false,
},
{
name: "map type - not comparable",
code: "type T struct { F map[string]int }",
expected: false,
},
{
name: "array type - comparable if element is",
code: "type T struct { F [5]int }",
expected: true,
},
{
name: "interface type",
code: "type T struct { F interface{} }",
expected: true,
},
{
name: "channel type",
code: "type T struct { F chan int }",
expected: true,
},
{
name: "function type - not comparable",
code: "type T struct { F func() }",
expected: false,
},
{
name: "struct literal - conservatively not comparable",
code: "type T struct { F struct{ X int } }",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", "package test\n"+tt.code, 0)
require.NoError(t, err)
var fieldType ast.Expr
ast.Inspect(file, func(n ast.Node) bool {
if field, ok := n.(*ast.Field); ok && len(field.Names) > 0 {
fieldType = field.Type
return false
}
return true
})
require.NotNil(t, fieldType)
result := isComparableType(fieldType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestHasOmitEmpty(t *testing.T) {
tests := []struct {
name string
@@ -337,6 +422,171 @@ type Config struct {
assert.False(t, config.Fields[4].IsOptional, "Required field without omitempty should not be optional")
}
func TestParseFileWithComparableTypes(t *testing.T) {
// Create a temporary test file
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
testCode := `package testpkg
// fp-go:Lens
type TypeTest struct {
Name string
Age int
Pointer *string
Slice []string
Map map[string]int
Channel chan int
}
`
err := os.WriteFile(testFile, []byte(testCode), 0644)
require.NoError(t, err)
// Parse the file
structs, pkg, err := parseFile(testFile)
require.NoError(t, err)
// Verify results
assert.Equal(t, "testpkg", pkg)
assert.Len(t, structs, 1)
// Check TypeTest struct
typeTest := structs[0]
assert.Equal(t, "TypeTest", typeTest.Name)
assert.Len(t, typeTest.Fields, 6)
// Name - string is comparable
assert.Equal(t, "Name", typeTest.Fields[0].Name)
assert.Equal(t, "string", typeTest.Fields[0].TypeName)
assert.False(t, typeTest.Fields[0].IsOptional)
assert.True(t, typeTest.Fields[0].IsComparable, "string should be comparable")
// Age - int is comparable
assert.Equal(t, "Age", typeTest.Fields[1].Name)
assert.Equal(t, "int", typeTest.Fields[1].TypeName)
assert.False(t, typeTest.Fields[1].IsOptional)
assert.True(t, typeTest.Fields[1].IsComparable, "int should be comparable")
// Pointer - pointer is optional, IsComparable not checked for optional fields
assert.Equal(t, "Pointer", typeTest.Fields[2].Name)
assert.Equal(t, "*string", typeTest.Fields[2].TypeName)
assert.True(t, typeTest.Fields[2].IsOptional)
assert.False(t, typeTest.Fields[2].IsComparable, "IsComparable not set for optional fields")
// Slice - not comparable
assert.Equal(t, "Slice", typeTest.Fields[3].Name)
assert.Equal(t, "[]string", typeTest.Fields[3].TypeName)
assert.False(t, typeTest.Fields[3].IsOptional)
assert.False(t, typeTest.Fields[3].IsComparable, "slice should not be comparable")
// Map - not comparable
assert.Equal(t, "Map", typeTest.Fields[4].Name)
assert.Equal(t, "map[string]int", typeTest.Fields[4].TypeName)
assert.False(t, typeTest.Fields[4].IsOptional)
assert.False(t, typeTest.Fields[4].IsComparable, "map should not be comparable")
// Channel - comparable (note: getTypeName returns "any" for channel types, but isComparableType correctly identifies them)
assert.Equal(t, "Channel", typeTest.Fields[5].Name)
assert.Equal(t, "any", typeTest.Fields[5].TypeName) // getTypeName doesn't handle chan types specifically
assert.False(t, typeTest.Fields[5].IsOptional)
assert.True(t, typeTest.Fields[5].IsComparable, "channel should be comparable")
}
func TestLensRefTemplatesWithComparable(t *testing.T) {
s := structInfo{
Name: "TestStruct",
Fields: []fieldInfo{
{Name: "Name", TypeName: "string", IsOptional: false, IsComparable: true},
{Name: "Age", TypeName: "int", IsOptional: false, IsComparable: true},
{Name: "Data", TypeName: "[]byte", IsOptional: false, IsComparable: false},
{Name: "Pointer", TypeName: "*string", IsOptional: true, IsComparable: false},
},
}
// Test constructor template for RefLenses
var constructorBuf bytes.Buffer
err := constructorTmpl.Execute(&constructorBuf, s)
require.NoError(t, err)
constructorStr := constructorBuf.String()
// Check that MakeLensStrict is used for comparable types in RefLenses
assert.Contains(t, constructorStr, "func MakeTestStructRefLenses() TestStructRefLenses")
// Name field - comparable, should use MakeLensStrict
assert.Contains(t, constructorStr, "Name: L.MakeLensStrict(",
"comparable field Name should use MakeLensStrict in RefLenses")
// Age field - comparable, should use MakeLensStrict
assert.Contains(t, constructorStr, "Age: L.MakeLensStrict(",
"comparable field Age should use MakeLensStrict in RefLenses")
// Data field - not comparable, should use MakeLensRef
assert.Contains(t, constructorStr, "Data: L.MakeLensRef(",
"non-comparable field Data should use MakeLensRef in RefLenses")
// Pointer field - optional, should use MakeLensRef
assert.Contains(t, constructorStr, "Pointer: L.MakeLensRef(",
"optional field Pointer should use MakeLensRef in RefLenses")
}
func TestGenerateLensHelpersWithComparable(t *testing.T) {
// Create a temporary directory with test files
tmpDir := t.TempDir()
testCode := `package testpkg
// fp-go:Lens
type TestStruct struct {
Name string
Count int
Data []byte
}
`
testFile := filepath.Join(tmpDir, "test.go")
err := os.WriteFile(testFile, []byte(testCode), 0644)
require.NoError(t, err)
// Generate lens code
outputFile := "gen.go"
err = generateLensHelpers(tmpDir, outputFile, false)
require.NoError(t, err)
// Verify the generated file exists
genPath := filepath.Join(tmpDir, outputFile)
_, err = os.Stat(genPath)
require.NoError(t, err)
// Read and verify the generated content
content, err := os.ReadFile(genPath)
require.NoError(t, err)
contentStr := string(content)
// Check for expected content in RefLenses
assert.Contains(t, contentStr, "MakeTestStructRefLenses")
// Name and Count are comparable, should use MakeLensStrict
assert.Contains(t, contentStr, "L.MakeLensStrict",
"comparable fields should use MakeLensStrict in RefLenses")
// Data is not comparable (slice), should use MakeLensRef
assert.Contains(t, contentStr, "L.MakeLensRef",
"non-comparable fields should use MakeLensRef in RefLenses")
// Verify the pattern appears for Name field (comparable)
namePattern := "Name: L.MakeLensStrict("
assert.Contains(t, contentStr, namePattern,
"Name field should use MakeLensStrict")
// Verify the pattern appears for Data field (not comparable)
dataPattern := "Data: L.MakeLensRef("
assert.Contains(t, contentStr, dataPattern,
"Data field should use MakeLensRef")
}
func TestGenerateLensHelpers(t *testing.T) {
// Create a temporary directory with test files
tmpDir := t.TempDir()

View File

@@ -17,6 +17,8 @@ package lens
import "github.com/IBM/fp-go/v2/optics/lens/option"
//go:generate go run ../../main.go lens --dir . --filename gen_lens.go
// fp-go:Lens
type Person struct {
Name string

View File

@@ -153,3 +153,46 @@ func TestLensComposition(t *testing.T) {
assert.Equal(t, 55, olderCEO.CEO.Age)
assert.Equal(t, 50, company.CEO.Age) // Original unchanged
}
func TestPersonRefLensesIdempotent(t *testing.T) {
// Create a person pointer
person := &Person{
Name: "Alice",
Age: 30,
Email: "alice@example.com",
}
// Create ref lenses
refLenses := MakePersonRefLenses()
// Test that setting the same value returns the identical pointer (idempotent)
// This works because Name, Age, and Email use MakeLensStrict which has equality optimization
// Test Name field - setting same value should return same pointer
sameName := refLenses.Name.Set("Alice")(person)
assert.Same(t, person, sameName, "Setting Name to same value should return identical pointer")
// Test Age field - setting same value should return same pointer
sameAge := refLenses.Age.Set(30)(person)
assert.Same(t, person, sameAge, "Setting Age to same value should return identical pointer")
// Test Email field - setting same value should return same pointer
sameEmail := refLenses.Email.Set("alice@example.com")(person)
assert.Same(t, person, sameEmail, "Setting Email to same value should return identical pointer")
// Test that setting a different value creates a new pointer
differentName := refLenses.Name.Set("Bob")(person)
assert.NotSame(t, person, differentName, "Setting Name to different value should return new pointer")
assert.Equal(t, "Bob", differentName.Name)
assert.Equal(t, "Alice", person.Name, "Original should be unchanged")
differentAge := refLenses.Age.Set(31)(person)
assert.NotSame(t, person, differentAge, "Setting Age to different value should return new pointer")
assert.Equal(t, 31, differentAge.Age)
assert.Equal(t, 30, person.Age, "Original should be unchanged")
differentEmail := refLenses.Email.Set("bob@example.com")(person)
assert.NotSame(t, person, differentEmail, "Setting Email to different value should return new pointer")
assert.Equal(t, "bob@example.com", differentEmail.Email)
assert.Equal(t, "alice@example.com", person.Email, "Original should be unchanged")
}

View File

@@ -2,7 +2,7 @@ package lens
// Code generated by go generate; DO NOT EDIT.
// This file was generated by robots at
// 2025-11-07 16:52:17.4935733 +0100 CET m=+0.003883901
// 2025-11-12 17:16:40.1431921 +0100 CET m=+0.003694701
import (
L "github.com/IBM/fp-go/v2/optics/lens"
@@ -55,15 +55,15 @@ func MakePersonLenses() PersonLenses {
func MakePersonRefLenses() PersonRefLenses {
isoPhone := I.FromZero[*string]()
return PersonRefLenses{
Name: L.MakeLensRef(
Name: L.MakeLensStrict(
func(s *Person) string { return s.Name },
func(s *Person, v string) *Person { s.Name = v; return s },
),
Age: L.MakeLensRef(
Age: L.MakeLensStrict(
func(s *Person) int { return s.Age },
func(s *Person, v int) *Person { s.Age = v; return s },
),
Email: L.MakeLensRef(
Email: L.MakeLensStrict(
func(s *Person) string { return s.Email },
func(s *Person, v string) *Person { s.Email = v; return s },
),
@@ -123,19 +123,19 @@ func MakeAddressLenses() AddressLenses {
func MakeAddressRefLenses() AddressRefLenses {
isoState := I.FromZero[*string]()
return AddressRefLenses{
Street: L.MakeLensRef(
Street: L.MakeLensStrict(
func(s *Address) string { return s.Street },
func(s *Address, v string) *Address { s.Street = v; return s },
),
City: L.MakeLensRef(
City: L.MakeLensStrict(
func(s *Address) string { return s.City },
func(s *Address, v string) *Address { s.City = v; return s },
),
ZipCode: L.MakeLensRef(
ZipCode: L.MakeLensStrict(
func(s *Address) string { return s.ZipCode },
func(s *Address, v string) *Address { s.ZipCode = v; return s },
),
Country: L.MakeLensRef(
Country: L.MakeLensStrict(
func(s *Address) string { return s.Country },
func(s *Address, v string) *Address { s.Country = v; return s },
),
@@ -189,15 +189,15 @@ func MakeCompanyLenses() CompanyLenses {
func MakeCompanyRefLenses() CompanyRefLenses {
isoWebsite := I.FromZero[*string]()
return CompanyRefLenses{
Name: L.MakeLensRef(
Name: L.MakeLensStrict(
func(s *Company) string { return s.Name },
func(s *Company, v string) *Company { s.Name = v; return s },
),
Address: L.MakeLensRef(
Address: L.MakeLensStrict(
func(s *Company) Address { return s.Address },
func(s *Company, v Address) *Company { s.Address = v; return s },
),
CEO: L.MakeLensRef(
CEO: L.MakeLensStrict(
func(s *Company) Person { return s.CEO },
func(s *Company, v Person) *Company { s.CEO = v; return s },
),