mirror of
https://github.com/mgechev/revive.git
synced 2025-03-03 14:52:54 +02:00
chore: Improve sortables detection (#1151)
Co-authored-by: chavacava <salvador.cavadini@gmail.com>
This commit is contained in:
parent
72b91f0188
commit
cb74ccbf44
60
internal/astutils/ast_utils.go
Normal file
60
internal/astutils/ast_utils.go
Normal file
@ -0,0 +1,60 @@
|
||||
package astutils
|
||||
|
||||
import "go/ast"
|
||||
|
||||
// FuncSignatureIs returns true if the given func decl satisfies a signature characterized
|
||||
// by the given name, parameters types and return types; false otherwise.
|
||||
//
|
||||
// Example: to check if a function declaration has the signature Foo(int, string) (bool,error)
|
||||
// call to FuncSignatureIs(funcDecl,"Foo",[]string{"int","string"},[]string{"bool","error"})
|
||||
func FuncSignatureIs(funcDecl *ast.FuncDecl, wantName string, wantParametersTypes, wantResultsTypes []string) bool {
|
||||
if wantName != funcDecl.Name.String() {
|
||||
return false // func name doesn't match expected one
|
||||
}
|
||||
|
||||
funcParametersTypes := getTypeNames(funcDecl.Type.Params)
|
||||
if len(wantParametersTypes) != len(funcParametersTypes) {
|
||||
return false // func has not the expected number of parameters
|
||||
}
|
||||
|
||||
funcResultsTypes := getTypeNames(funcDecl.Type.Results)
|
||||
if len(wantResultsTypes) != len(funcResultsTypes) {
|
||||
return false // func has not the expected number of return values
|
||||
}
|
||||
|
||||
for i, wantType := range wantParametersTypes {
|
||||
if wantType != funcParametersTypes[i] {
|
||||
return false // type of a func's parameter does not match the type of the corresponding expected parameter
|
||||
}
|
||||
}
|
||||
|
||||
for i, wantType := range wantResultsTypes {
|
||||
if wantType != funcResultsTypes[i] {
|
||||
return false // type of a func's return value does not match the type of the corresponding expected return value
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getTypeNames(fields *ast.FieldList) []string {
|
||||
result := []string{}
|
||||
|
||||
if fields == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, field := range fields.List {
|
||||
typeName := field.Type.(*ast.Ident).Name
|
||||
if field.Names == nil { // unnamed field
|
||||
result = append(result, typeName)
|
||||
continue
|
||||
}
|
||||
|
||||
for range field.Names { // add one type name for each field name
|
||||
result = append(result, typeName)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package lint
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"go/ast"
|
||||
"go/importer"
|
||||
"go/token"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/mgechev/revive/internal/astutils"
|
||||
"github.com/mgechev/revive/internal/typeparams"
|
||||
)
|
||||
|
||||
@ -31,7 +33,6 @@ type Package struct {
|
||||
var (
|
||||
trueValue = 1
|
||||
falseValue = 2
|
||||
notSet = 3
|
||||
|
||||
go121 = goversion.Must(goversion.NewVersion("1.21"))
|
||||
go122 = goversion.Must(goversion.NewVersion("1.22"))
|
||||
@ -111,6 +112,11 @@ func (p *Package) TypeCheck() error {
|
||||
astFiles = append(astFiles, f.AST)
|
||||
}
|
||||
|
||||
if anyFile == nil {
|
||||
// this is unlikely to happen, but technically guarantees anyFile to not be nil
|
||||
return errors.New("no ast.File found")
|
||||
}
|
||||
|
||||
typesPkg, err := check(config, anyFile.AST.Name.Name, p.fset, astFiles, info)
|
||||
|
||||
// Remember the typechecking info, even if config.Check failed,
|
||||
@ -135,7 +141,7 @@ func check(config *types.Config, n string, fset *token.FileSet, astFiles []*ast.
|
||||
return config.Check(n, fset, astFiles, info)
|
||||
}
|
||||
|
||||
// TypeOf returns the type of an expression.
|
||||
// TypeOf returns the type of expression.
|
||||
func (p *Package) TypeOf(expr ast.Expr) types.Type {
|
||||
if p.typesInfo == nil {
|
||||
return nil
|
||||
@ -143,39 +149,32 @@ func (p *Package) TypeOf(expr ast.Expr) types.Type {
|
||||
return p.typesInfo.TypeOf(expr)
|
||||
}
|
||||
|
||||
type walker struct {
|
||||
nmap map[string]int
|
||||
has map[string]int
|
||||
}
|
||||
type sortableMethodsFlags int
|
||||
|
||||
func (w *walker) Visit(n ast.Node) ast.Visitor {
|
||||
fn, ok := n.(*ast.FuncDecl)
|
||||
if !ok || fn.Recv == nil || len(fn.Recv.List) == 0 {
|
||||
return w
|
||||
}
|
||||
// TODO(dsymonds): We could check the signature to be more precise.
|
||||
recv := typeparams.ReceiverType(fn)
|
||||
if i, ok := w.nmap[fn.Name.Name]; ok {
|
||||
w.has[recv] |= i
|
||||
}
|
||||
return w
|
||||
}
|
||||
// flags for sortable interface methods.
|
||||
const (
|
||||
bfLen sortableMethodsFlags = 1 << iota
|
||||
bfLess
|
||||
bfSwap
|
||||
)
|
||||
|
||||
func (p *Package) scanSortable() {
|
||||
p.sortable = map[string]bool{}
|
||||
|
||||
// bitfield for which methods exist on each type.
|
||||
const (
|
||||
bfLen = 1 << iota
|
||||
bfLess
|
||||
bfSwap
|
||||
)
|
||||
nmap := map[string]int{"Len": bfLen, "Less": bfLess, "Swap": bfSwap}
|
||||
has := map[string]int{}
|
||||
sortableFlags := map[string]sortableMethodsFlags{}
|
||||
for _, f := range p.files {
|
||||
ast.Walk(&walker{nmap, has}, f.AST)
|
||||
for _, decl := range f.AST.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
isAMethodDeclaration := ok && fn.Recv != nil && len(fn.Recv.List) != 0
|
||||
if !isAMethodDeclaration {
|
||||
continue
|
||||
}
|
||||
|
||||
recvType := typeparams.ReceiverType(fn)
|
||||
sortableFlags[recvType] |= getSortableMethodFlagForFunction(fn)
|
||||
}
|
||||
}
|
||||
for typ, ms := range has {
|
||||
|
||||
p.sortable = make(map[string]bool, len(sortableFlags))
|
||||
for typ, ms := range sortableFlags {
|
||||
if ms == bfLen|bfLess|bfSwap {
|
||||
p.sortable[typ] = true
|
||||
}
|
||||
@ -204,3 +203,16 @@ func (p *Package) IsAtLeastGo121() bool {
|
||||
func (p *Package) IsAtLeastGo122() bool {
|
||||
return p.goVersion.GreaterThanOrEqual(go122)
|
||||
}
|
||||
|
||||
func getSortableMethodFlagForFunction(fn *ast.FuncDecl) sortableMethodsFlags {
|
||||
switch {
|
||||
case astutils.FuncSignatureIs(fn, "Len", []string{}, []string{"int"}):
|
||||
return bfLen
|
||||
case astutils.FuncSignatureIs(fn, "Less", []string{"int", "int"}, []string{"bool"}):
|
||||
return bfLess
|
||||
case astutils.FuncSignatureIs(fn, "Swap", []string{"int", "int"}, []string{}):
|
||||
return bfSwap
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
18
testdata/golint/sort.go
vendored
18
testdata/golint/sort.go
vendored
@ -18,3 +18,21 @@ func (u U) Less(i, j int) bool { return u[i] < u[j] }
|
||||
func (u U) Swap(i, j int) { u[i], u[j] = u[j], u[i] }
|
||||
|
||||
func (u U) Other() {} // MATCH /exported method U.Other should have comment or be unexported/
|
||||
|
||||
// V is ...
|
||||
type V []int
|
||||
|
||||
func (v V) Len() (result int) { return len(w) }
|
||||
func (v V) Less(i int, j int) (result bool) { return w[i] < w[j] }
|
||||
func (v V) Swap(i int, j int) { v[i], v[j] = v[j], v[i] }
|
||||
|
||||
// W is ...
|
||||
type W []int
|
||||
|
||||
func (w W) Swap(i int, j int) {} // MATCH /exported method W.Swap should have comment or be unexported/
|
||||
|
||||
// Vv is ...
|
||||
type Vv []int
|
||||
|
||||
func (vv Vv) Len() (result int) { return len(w) } // MATCH /exported method Vv.Len should have comment or be unexported/
|
||||
func (vv Vv) Less(i int, j int) (result bool) { return w[i] < w[j] } // MATCH /exported method Vv.Less should have comment or be unexported/
|
||||
|
Loading…
x
Reference in New Issue
Block a user