2025-11-07 16:15:16 +01:00
// Copyright (c) 2023 - 2025 IBM Corp.
// All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cli
import (
"bytes"
"go/ast"
"go/parser"
"go/token"
"log"
"os"
"path/filepath"
"reflect"
"strings"
"text/template"
C "github.com/urfave/cli/v2"
)
const (
keyLensDir = "dir"
keyVerbose = "verbose"
lensAnnotation = "fp-go:Lens"
)
var (
flagLensDir = & C . StringFlag {
Name : keyLensDir ,
Value : "." ,
Usage : "Directory to scan for Go files" ,
}
flagVerbose = & C . BoolFlag {
Name : keyVerbose ,
Aliases : [ ] string { "v" } ,
Value : false ,
Usage : "Enable verbose output" ,
}
)
// structInfo holds information about a struct that needs lens generation
type structInfo struct {
Name string
Fields [ ] fieldInfo
Imports map [ string ] string // package path -> alias
}
// fieldInfo holds information about a struct field
type fieldInfo struct {
2025-11-12 17:28:20 +01:00
Name string
TypeName string
BaseType string // TypeName without leading * for pointer types
IsOptional bool // true if field is a pointer or has json omitempty tag
IsComparable bool // true if the type is comparable (can use ==)
2025-11-07 16:15:16 +01:00
}
// templateData holds data for template rendering
type templateData struct {
PackageName string
Structs [ ] structInfo
}
const lensStructTemplate = `
// {{.Name}}Lenses provides lenses for accessing fields of {{.Name}}
type { { . Name } } Lenses struct {
{ { - range . Fields } }
{ { . Name } } { { if . IsOptional } } LO . LensO [ { { $ . Name } } , { { . TypeName } } ] { { else } } L . Lens [ { { $ . Name } } , { { . TypeName } } ] { { end } }
{ { - end } }
}
// {{.Name}}RefLenses provides lenses for accessing fields of {{.Name}} via a reference to {{.Name}}
type { { . Name } } RefLenses struct {
{ { - range . Fields } }
{ { . Name } } { { if . IsOptional } } LO . LensO [ * { { $ . Name } } , { { . TypeName } } ] { { else } } L . Lens [ * { { $ . Name } } , { { . TypeName } } ] { { end } }
{ { - end } }
}
`
const lensConstructorTemplate = `
// Make{{.Name}}Lenses creates a new {{.Name}}Lenses with lenses for all fields
func Make { { . Name } } Lenses ( ) { { . Name } } Lenses {
{ { - range . Fields } }
{ { - if . IsOptional } }
2025-11-07 17:31:27 +01:00
iso { { . Name } } := I . FromZero [ { { . TypeName } } ] ( )
2025-11-07 16:15:16 +01:00
{ { - end } }
{ { - end } }
return { { . Name } } Lenses {
{ { - range . Fields } }
{ { - if . IsOptional } }
{ { . Name } } : L . MakeLens (
2025-11-07 17:31:27 +01:00
func ( s { { $ . Name } } ) O . Option [ { { . TypeName } } ] { return iso { { . Name } } . Get ( s . { { . Name } } ) } ,
func ( s { { $ . Name } } , v O . Option [ { { . TypeName } } ] ) { { $ . Name } } { s . { { . Name } } = iso { { . Name } } . ReverseGet ( v ) ; return s } ,
2025-11-07 16:15:16 +01:00
) ,
{ { - else } }
{ { . Name } } : L . MakeLens (
func ( s { { $ . Name } } ) { { . TypeName } } { return s . { { . Name } } } ,
func ( s { { $ . Name } } , v { { . TypeName } } ) { { $ . Name } } { s . { { . Name } } = v ; return s } ,
) ,
{ { - end } }
{ { - end } }
}
}
// Make{{.Name}}RefLenses creates a new {{.Name}}RefLenses with lenses for all fields
func Make { { . Name } } RefLenses ( ) { { . Name } } RefLenses {
return { { . Name } } RefLenses {
{ { - range . Fields } }
{ { - if . IsOptional } }
2025-11-12 18:23:57 +01:00
{ { - if . IsComparable } }
{ { . Name } } : LO . FromIso [ * { { $ . Name } } ] ( I . FromZero [ { { . TypeName } } ] ( ) ) ( L . MakeLensStrict (
func ( s * { { $ . Name } } ) { { . TypeName } } { return s . { { . Name } } } ,
func ( s * { { $ . Name } } , v { { . TypeName } } ) * { { $ . Name } } { s . { { . Name } } = v ; return s } ,
) ) ,
{ { - else } }
{ { . Name } } : LO . FromIso [ * { { $ . Name } } ] ( I . FromZero [ { { . TypeName } } ] ( ) ) ( L . MakeLensRef (
func ( s * { { $ . Name } } ) { { . TypeName } } { return s . { { . Name } } } ,
func ( s * { { $ . Name } } , v { { . TypeName } } ) * { { $ . Name } } { s . { { . Name } } = v ; return s } ,
) ) ,
{ { - end } }
2025-11-12 17:28:20 +01:00
{ { - else } }
{ { - if . IsComparable } }
{ { . Name } } : L . MakeLensStrict (
func ( s * { { $ . Name } } ) { { . TypeName } } { return s . { { . Name } } } ,
func ( s * { { $ . Name } } , v { { . TypeName } } ) * { { $ . Name } } { s . { { . Name } } = v ; return s } ,
) ,
2025-11-07 16:15:16 +01:00
{ { - else } }
{ { . Name } } : L . MakeLensRef (
func ( s * { { $ . Name } } ) { { . TypeName } } { return s . { { . Name } } } ,
func ( s * { { $ . Name } } , v { { . TypeName } } ) * { { $ . Name } } { s . { { . Name } } = v ; return s } ,
) ,
{ { - end } }
2025-11-12 17:28:20 +01:00
{ { - end } }
2025-11-07 16:15:16 +01:00
{ { - end } }
}
}
`
var (
structTmpl * template . Template
constructorTmpl * template . Template
)
func init ( ) {
var err error
structTmpl , err = template . New ( "struct" ) . Parse ( lensStructTemplate )
if err != nil {
panic ( err )
}
constructorTmpl , err = template . New ( "constructor" ) . Parse ( lensConstructorTemplate )
if err != nil {
panic ( err )
}
}
// hasLensAnnotation checks if a comment group contains the lens annotation
func hasLensAnnotation ( doc * ast . CommentGroup ) bool {
if doc == nil {
return false
}
for _ , comment := range doc . List {
if strings . Contains ( comment . Text , lensAnnotation ) {
return true
}
}
return false
}
// getTypeName extracts the type name from a field type expression
func getTypeName ( expr ast . Expr ) string {
switch t := expr . ( type ) {
case * ast . Ident :
return t . Name
case * ast . StarExpr :
return "*" + getTypeName ( t . X )
case * ast . ArrayType :
return "[]" + getTypeName ( t . Elt )
case * ast . MapType :
return "map[" + getTypeName ( t . Key ) + "]" + getTypeName ( t . Value )
case * ast . SelectorExpr :
return getTypeName ( t . X ) + "." + t . Sel . Name
case * ast . InterfaceType :
return "interface{}"
case * ast . IndexExpr :
// Generic type with single type parameter (Go 1.18+)
// e.g., Option[string]
return getTypeName ( t . X ) + "[" + getTypeName ( t . Index ) + "]"
case * ast . IndexListExpr :
// Generic type with multiple type parameters (Go 1.18+)
// e.g., Map[string, int]
var params [ ] string
for _ , index := range t . Indices {
params = append ( params , getTypeName ( index ) )
}
return getTypeName ( t . X ) + "[" + strings . Join ( params , ", " ) + "]"
default :
return "any"
}
}
// extractImports extracts package imports from a type expression
// Returns a map of package path -> package name
func extractImports ( expr ast . Expr , imports map [ string ] string ) {
switch t := expr . ( type ) {
case * ast . StarExpr :
extractImports ( t . X , imports )
case * ast . ArrayType :
extractImports ( t . Elt , imports )
case * ast . MapType :
extractImports ( t . Key , imports )
extractImports ( t . Value , imports )
case * ast . SelectorExpr :
// This is a qualified identifier like "option.Option"
if ident , ok := t . X . ( * ast . Ident ) ; ok {
// ident.Name is the package name (e.g., "option")
// We need to track this for import resolution
imports [ ident . Name ] = ident . Name
}
case * ast . IndexExpr :
// Generic type with single type parameter
extractImports ( t . X , imports )
extractImports ( t . Index , imports )
case * ast . IndexListExpr :
// Generic type with multiple type parameters
extractImports ( t . X , imports )
for _ , index := range t . Indices {
extractImports ( index , imports )
}
}
}
// hasOmitEmpty checks if a struct tag contains json omitempty
func hasOmitEmpty ( tag * ast . BasicLit ) bool {
if tag == nil {
return false
}
// Parse the struct tag
tagValue := strings . Trim ( tag . Value , "`" )
structTag := reflect . StructTag ( tagValue )
jsonTag := structTag . Get ( "json" )
// Check if omitempty is present
parts := strings . Split ( jsonTag , "," )
for _ , part := range parts {
if strings . TrimSpace ( part ) == "omitempty" {
return true
}
}
return false
}
// isPointerType checks if a type expression is a pointer
func isPointerType ( expr ast . Expr ) bool {
_ , ok := expr . ( * ast . StarExpr )
return ok
}
2025-11-12 17:28:20 +01:00
// isComparableType checks if a type expression represents a comparable type.
// Comparable types in Go include:
// - Basic types (bool, numeric types, string)
// - Pointer types
// - Channel types
// - Interface types
// - Structs where all fields are comparable
// - Arrays where the element type is comparable
//
// Non-comparable types include:
// - Slices
// - Maps
// - Functions
func isComparableType ( expr ast . Expr ) bool {
switch t := expr . ( type ) {
case * ast . Ident :
// Basic types and named types
// We assume named types are comparable unless they're known non-comparable types
name := t . Name
// Known non-comparable built-in types
if name == "error" {
// error is an interface, which is comparable
return true
}
// Most basic types and named types are comparable
// We can't determine if a custom type is comparable without type checking,
// so we assume it is (conservative approach)
return true
case * ast . StarExpr :
// Pointer types are always comparable
return true
case * ast . ArrayType :
// Arrays are comparable if their element type is comparable
if t . Len == nil {
// This is a slice (no length), slices are not comparable
return false
}
// Fixed-size array, check element type
return isComparableType ( t . Elt )
case * ast . MapType :
// Maps are not comparable
return false
case * ast . FuncType :
// Functions are not comparable
return false
case * ast . InterfaceType :
// Interface types are comparable
return true
case * ast . StructType :
// Structs are comparable if all fields are comparable
// We can't easily determine this without full type information,
// so we conservatively return false for struct literals
return false
case * ast . SelectorExpr :
// Qualified identifier (e.g., pkg.Type)
// We can't determine comparability without type information
// Check for known non-comparable types from standard library
if ident , ok := t . X . ( * ast . Ident ) ; ok {
pkgName := ident . Name
typeName := t . Sel . Name
// Check for known non-comparable types
if pkgName == "context" && typeName == "Context" {
// context.Context is an interface, which is comparable
return true
}
// For other qualified types, we assume they're comparable
// This is a conservative approach
}
return true
case * ast . IndexExpr , * ast . IndexListExpr :
// Generic types - we can't determine comparability without type information
// For common generic types, we can make educated guesses
var baseExpr ast . Expr
if idx , ok := t . ( * ast . IndexExpr ) ; ok {
baseExpr = idx . X
} else if idxList , ok := t . ( * ast . IndexListExpr ) ; ok {
baseExpr = idxList . X
}
if sel , ok := baseExpr . ( * ast . SelectorExpr ) ; ok {
if ident , ok := sel . X . ( * ast . Ident ) ; ok {
pkgName := ident . Name
typeName := sel . Sel . Name
// Check for known non-comparable generic types
if pkgName == "option" && typeName == "Option" {
// Option types are not comparable (they contain a slice internally)
return false
}
if pkgName == "either" && typeName == "Either" {
// Either types are not comparable
return false
}
}
}
// For other generic types, conservatively assume not comparable
return false
case * ast . ChanType :
// Channel types are comparable
return true
default :
// Unknown type, conservatively assume not comparable
return false
}
}
2025-11-07 16:15:16 +01:00
// parseFile parses a Go file and extracts structs with lens annotations
func parseFile ( filename string ) ( [ ] structInfo , string , error ) {
fset := token . NewFileSet ( )
node , err := parser . ParseFile ( fset , filename , nil , parser . ParseComments )
if err != nil {
return nil , "" , err
}
var structs [ ] structInfo
packageName := node . Name . Name
// Build import map: package name -> import path
fileImports := make ( map [ string ] string )
for _ , imp := range node . Imports {
path := strings . Trim ( imp . Path . Value , ` " ` )
var name string
if imp . Name != nil {
name = imp . Name . Name
} else {
// Extract package name from path (last component)
parts := strings . Split ( path , "/" )
name = parts [ len ( parts ) - 1 ]
}
fileImports [ name ] = path
}
// First pass: collect all GenDecls with their doc comments
declMap := make ( map [ * ast . TypeSpec ] * ast . CommentGroup )
ast . Inspect ( node , func ( n ast . Node ) bool {
if gd , ok := n . ( * ast . GenDecl ) ; ok {
for _ , spec := range gd . Specs {
if ts , ok := spec . ( * ast . TypeSpec ) ; ok {
declMap [ ts ] = gd . Doc
}
}
}
return true
} )
// Second pass: process type specs
ast . Inspect ( node , func ( n ast . Node ) bool {
// Look for type declarations
typeSpec , ok := n . ( * ast . TypeSpec )
if ! ok {
return true
}
// Check if it's a struct type
structType , ok := typeSpec . Type . ( * ast . StructType )
if ! ok {
return true
}
// Get the doc comment from our map
doc := declMap [ typeSpec ]
if ! hasLensAnnotation ( doc ) {
return true
}
// Extract field information and collect imports
var fields [ ] fieldInfo
structImports := make ( map [ string ] string )
for _ , field := range structType . Fields . List {
if len ( field . Names ) == 0 {
// Embedded field, skip for now
continue
}
for _ , name := range field . Names {
// Only export lenses for exported fields
if name . IsExported ( ) {
typeName := getTypeName ( field . Type )
isOptional := false
baseType := typeName
2025-11-12 17:28:20 +01:00
isComparable := false
2025-11-07 16:15:16 +01:00
2025-11-07 17:31:27 +01:00
// Check if field is optional:
// 1. Pointer types are always optional
// 2. Non-pointer types with json omitempty tag are optional
2025-11-07 16:15:16 +01:00
if isPointerType ( field . Type ) {
isOptional = true
// Strip leading * for base type
baseType = strings . TrimPrefix ( typeName , "*" )
2025-11-07 17:31:27 +01:00
} else if hasOmitEmpty ( field . Tag ) {
// Non-pointer type with omitempty is also optional
isOptional = true
2025-11-07 16:15:16 +01:00
}
2025-11-12 17:28:20 +01:00
// Check if the type is comparable (for non-optional fields)
// For optional fields, we don't need to check since they use LensO
2025-11-12 18:23:57 +01:00
isComparable = isComparableType ( field . Type )
2025-11-12 17:28:20 +01:00
2025-11-07 16:15:16 +01:00
// Extract imports from this field's type
fieldImports := make ( map [ string ] string )
extractImports ( field . Type , fieldImports )
// Resolve package names to full import paths
for pkgName := range fieldImports {
if importPath , ok := fileImports [ pkgName ] ; ok {
structImports [ importPath ] = pkgName
}
}
fields = append ( fields , fieldInfo {
2025-11-12 17:28:20 +01:00
Name : name . Name ,
TypeName : typeName ,
BaseType : baseType ,
IsOptional : isOptional ,
IsComparable : isComparable ,
2025-11-07 16:15:16 +01:00
} )
}
}
}
if len ( fields ) > 0 {
structs = append ( structs , structInfo {
Name : typeSpec . Name . Name ,
Fields : fields ,
Imports : structImports ,
} )
}
return true
} )
return structs , packageName , nil
}
// generateLensHelpers scans a directory for Go files and generates lens code
func generateLensHelpers ( dir , filename string , verbose bool ) error {
// Get absolute path
absDir , err := filepath . Abs ( dir )
if err != nil {
return err
}
if verbose {
log . Printf ( "Scanning directory: %s" , absDir )
}
// Find all Go files in the directory
files , err := filepath . Glob ( filepath . Join ( absDir , "*.go" ) )
if err != nil {
return err
}
if verbose {
log . Printf ( "Found %d Go files" , len ( files ) )
}
// Parse all files and collect structs
var allStructs [ ] structInfo
var packageName string
for _ , file := range files {
// Skip generated files and test files
if strings . HasSuffix ( file , "_test.go" ) || strings . Contains ( file , "gen.go" ) {
if verbose {
log . Printf ( "Skipping file: %s" , filepath . Base ( file ) )
}
continue
}
if verbose {
log . Printf ( "Parsing file: %s" , filepath . Base ( file ) )
}
structs , pkg , err := parseFile ( file )
if err != nil {
log . Printf ( "Warning: failed to parse %s: %v" , file , err )
continue
}
if verbose && len ( structs ) > 0 {
log . Printf ( "Found %d annotated struct(s) in %s" , len ( structs ) , filepath . Base ( file ) )
for _ , s := range structs {
log . Printf ( " - %s (%d fields)" , s . Name , len ( s . Fields ) )
}
}
if packageName == "" {
packageName = pkg
}
allStructs = append ( allStructs , structs ... )
}
if len ( allStructs ) == 0 {
log . Printf ( "No structs with %s annotation found in %s" , lensAnnotation , absDir )
return nil
}
// Collect all unique imports from all structs
allImports := make ( map [ string ] string ) // import path -> alias
for _ , s := range allStructs {
for importPath , alias := range s . Imports {
allImports [ importPath ] = alias
}
}
// Create output file
outPath := filepath . Join ( absDir , filename )
f , err := os . Create ( filepath . Clean ( outPath ) )
if err != nil {
return err
}
defer f . Close ( )
log . Printf ( "Generating lens code in [%s] for package [%s] with [%d] structs ..." , outPath , packageName , len ( allStructs ) )
// Write header
writePackage ( f , packageName )
// Write imports
f . WriteString ( "import (\n" )
// Standard fp-go imports always needed
f . WriteString ( "\tL \"github.com/IBM/fp-go/v2/optics/lens\"\n" )
f . WriteString ( "\tLO \"github.com/IBM/fp-go/v2/optics/lens/option\"\n" )
f . WriteString ( "\tO \"github.com/IBM/fp-go/v2/option\"\n" )
2025-11-07 17:31:27 +01:00
f . WriteString ( "\tI \"github.com/IBM/fp-go/v2/optics/iso/option\"\n" )
2025-11-07 16:15:16 +01:00
// Add additional imports collected from field types
for importPath , alias := range allImports {
f . WriteString ( "\t" + alias + " \"" + importPath + "\"\n" )
}
f . WriteString ( ")\n" )
// Generate lens code for each struct using templates
for _ , s := range allStructs {
var buf bytes . Buffer
// Generate struct type
if err := structTmpl . Execute ( & buf , s ) ; err != nil {
return err
}
// Generate constructor
if err := constructorTmpl . Execute ( & buf , s ) ; err != nil {
return err
}
// Write to file
if _ , err := f . Write ( buf . Bytes ( ) ) ; err != nil {
return err
}
}
return nil
}
// LensCommand creates the CLI command for lens generation
func LensCommand ( ) * C . Command {
return & C . Command {
Name : "lens" ,
Usage : "generate lens code for annotated structs" ,
2025-11-07 17:31:27 +01:00
Description : "Scans Go files for structs annotated with 'fp-go:Lens' and generates lens types. Pointer types and non-pointer types with json omitempty tag generate LensO (optional lens)." ,
2025-11-07 16:15:16 +01:00
Flags : [ ] C . Flag {
flagLensDir ,
flagFilename ,
flagVerbose ,
} ,
Action : func ( ctx * C . Context ) error {
return generateLensHelpers (
ctx . String ( keyLensDir ) ,
ctx . String ( keyFilename ) ,
ctx . Bool ( keyVerbose ) ,
)
} ,
}
}