diff --git a/internal/astutils/ast_utils.go b/internal/astutils/ast_utils.go index 0a34604..cc79887 100644 --- a/internal/astutils/ast_utils.go +++ b/internal/astutils/ast_utils.go @@ -3,6 +3,7 @@ package astutils import ( "go/ast" + "slices" ) // FuncSignatureIs returns true if the given func decl satisfies a signature characterized @@ -15,38 +16,33 @@ func FuncSignatureIs(funcDecl *ast.FuncDecl, wantName string, wantParametersType return false // func name doesn't match expected one } - funcParametersTypes := getTypeNames(funcDecl.Type.Params) - if len(wantParametersTypes) != len(funcParametersTypes) { - return false // func has not the expected number of parameters + funcResultsTypes := GetTypeNames(funcDecl.Type.Results) + if !slices.Equal(wantResultsTypes, funcResultsTypes) { + return false // func has not the expected return values } - funcResultsTypes := getTypeNames(funcDecl.Type.Results) - if len(wantResultsTypes) != len(funcResultsTypes) { - return false // func has not the expected number of return values - } - - for i, wantType := range wantParametersTypes { - if wantType != funcParametersTypes[i] { - return false // type of a func's parameter does not match the type of the corresponding expected parameter - } - } - - for i, wantType := range wantResultsTypes { - if wantType != funcResultsTypes[i] { - return false // type of a func's return value does not match the type of the corresponding expected return value - } - } - - return true + // Name and return values are those we expected, + // the final result depends on parameters being what we want. + return funcParametersSignatureIs(funcDecl, wantParametersTypes) } -func getTypeNames(fields *ast.FieldList) []string { - result := []string{} +// funcParametersSignatureIs returns true if the function has parameters of the given type and order, +// false otherwise +func funcParametersSignatureIs(funcDecl *ast.FuncDecl, wantParametersTypes []string) bool { + funcParametersTypes := GetTypeNames(funcDecl.Type.Params) + return slices.Equal(wantParametersTypes, funcParametersTypes) +} + +// GetTypeNames yields an slice with the string representation of the types of given fields. +// It yields nil if the field list is nil. +func GetTypeNames(fields *ast.FieldList) []string { if fields == nil { - return result + return nil } + result := []string{} + for _, field := range fields.List { typeName := getFieldTypeName(field.Type) if field.Names == nil { // unnamed field @@ -67,7 +63,7 @@ func getFieldTypeName(typ ast.Expr) string { case *ast.Ident: return f.Name case *ast.SelectorExpr: - return f.Sel.Name + "." + getFieldTypeName(f.X) + return getFieldTypeName(f.X) + "." + getFieldTypeName(f.Sel) case *ast.StarExpr: return "*" + getFieldTypeName(f.X) case *ast.IndexExpr: diff --git a/rule/get_return.go b/rule/get_return.go index cf58a68..a6230e0 100644 --- a/rule/get_return.go +++ b/rule/get_return.go @@ -5,6 +5,7 @@ import ( "go/ast" "strings" + "github.com/mgechev/revive/internal/astutils" "github.com/mgechev/revive/lint" ) @@ -29,6 +30,10 @@ func (*GetReturnRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure { continue } + if isHTTPHandler(fd.Type.Params) { + continue // the Get prefix in the function name refers to HTTP GET + } + failures = append(failures, lint.Failure{ Confidence: 0.8, Node: fd, @@ -69,3 +74,15 @@ func isGetter(name string) bool { func hasResults(rs *ast.FieldList) bool { return rs != nil && len(rs.List) > 0 } + +// isHTTPHandler returns true if the given params match with the signature of an HTTP handler, false otherwise +// A params list is considered to be an HTTP handler if the first two parameters are +// http.ResponseWriter, *http.Request in that order. +func isHTTPHandler(params *ast.FieldList) bool { + typeNames := astutils.GetTypeNames(params) + if len(typeNames) < 2 { + return false + } + + return typeNames[0] == "http.ResponseWriter" && typeNames[1] == "*http.Request" +} diff --git a/testdata/get_return.go b/testdata/get_return.go index a279ed7..e316daf 100644 --- a/testdata/get_return.go +++ b/testdata/get_return.go @@ -1,5 +1,7 @@ package fixtures +import "net/http" + func getfoo() { } @@ -27,3 +29,10 @@ func (t *t) GetSaz(a string, b int) { // MATCH /function 'GetSaz' seems to be a func GetQux(a string, b int, c int, d string, e int64) { // MATCH /function 'GetQux' seems to be a getter but it does not return any result/ } + +// non-regression test issue #1323 +func (b *t) GetInfo(w http.ResponseWriter, r *http.Request) {} + +func GetSomething(w http.ResponseWriter, r *http.Request, p int) {} + +func GetSomethingElse(p int, w http.ResponseWriter, r *http.Request) {} // MATCH /function 'GetSomethingElse' seems to be a getter but it does not return any result/