diff --git a/fixtures/context-arguments.go b/fixtures/context-arguments.go new file mode 100644 index 0000000..5580e60 --- /dev/null +++ b/fixtures/context-arguments.go @@ -0,0 +1,24 @@ +// Test that context.Context is the first arg to a function. + +// Package foo ... +package foo + +import ( + "context" +) + +// A proper context.Context location +func x(ctx context.Context) { // ok +} + +// A proper context.Context location +func x(ctx context.Context, s string) { // ok +} + +// An invalid context.Context location +func y(s string, ctx context.Context) { // MATCH /context.Context should be the first parameter of a function/ +} + +// An invalid context.Context location with more than 2 args +func y(s string, r int, ctx context.Context, x int) { // MATCH /context.Context should be the first parameter of a function/ +} diff --git a/fixtures/context-key-type.go b/fixtures/context-key-type.go new file mode 100644 index 0000000..768de69 --- /dev/null +++ b/fixtures/context-key-type.go @@ -0,0 +1,38 @@ +// Package contextkeytypes verifies that correct types are used as keys in +// calls to context.WithValue. +package contextkeytypes + +import ( + "context" + "fmt" +) + +type ctxKey struct{} + +func contextKeyTypeTests() { + fmt.Println() // not in package context + context.TODO() // wrong function + c := context.Background() // wrong function + context.WithValue(c, "foo", "bar") // MATCH /should not use basic type string as key in context.WithValue/ + context.WithValue(c, true, "bar") // MATCH /should not use basic type bool as key in context.WithValue/ + context.WithValue(c, 1, "bar") // MATCH /should not use basic type int as key in context.WithValue/ + context.WithValue(c, int8(1), "bar") // MATCH /should not use basic type int8 as key in context.WithValue/ + context.WithValue(c, int16(1), "bar") // MATCH /should not use basic type int16 as key in context.WithValue/ + context.WithValue(c, int32(1), "bar") // MATCH /should not use basic type int32 as key in context.WithValue/ + context.WithValue(c, rune(1), "bar") // MATCH /should not use basic type rune as key in context.WithValue/ + context.WithValue(c, int64(1), "bar") // MATCH /should not use basic type int64 as key in context.WithValue/ + context.WithValue(c, uint(1), "bar") // MATCH /should not use basic type uint as key in context.WithValue/ + context.WithValue(c, uint8(1), "bar") // MATCH /should not use basic type uint8 as key in context.WithValue/ + context.WithValue(c, byte(1), "bar") // MATCH /should not use basic type byte as key in context.WithValue/ + context.WithValue(c, uint16(1), "bar") // MATCH /should not use basic type uint16 as key in context.WithValue/ + context.WithValue(c, uint32(1), "bar") // MATCH /should not use basic type uint32 as key in context.WithValue/ + context.WithValue(c, uint64(1), "bar") // MATCH /should not use basic type uint64 as key in context.WithValue/ + context.WithValue(c, uintptr(1), "bar") // MATCH /should not use basic type uintptr as key in context.WithValue/ + context.WithValue(c, float32(1.0), "bar") // MATCH /should not use basic type float32 as key in context.WithValue/ + context.WithValue(c, float64(1.0), "bar") // MATCH /should not use basic type float64 as key in context.WithValue/ + context.WithValue(c, complex64(1i), "bar") // MATCH /should not use basic type complex64 as key in context.WithValue/ + context.WithValue(c, complex128(1i), "bar") // MATCH /should not use basic type complex128 as key in context.WithValue/ + context.WithValue(c, ctxKey{}, "bar") // ok + context.WithValue(c, &ctxKey{}, "bar") // ok + context.WithValue(c, invalid{}, "bar") // ok +} diff --git a/fixtures/error-return.go b/fixtures/error-return.go new file mode 100644 index 0000000..b6a0eda --- /dev/null +++ b/fixtures/error-return.go @@ -0,0 +1,43 @@ +// Test for returning errors. + +// Package foo ... +package foo + +// Returns nothing +func f() { // ok +} + +// Check for a single error return +func g() error { // ok + return nil +} + +// Check for a single other return type +func h() int { // ok + return 0 +} + +// Check for multiple return but error at end. +func i() (int, error) { // ok + return 0, nil +} + +// Check for multiple return but error at end with named variables. +func j() (x int, err error) { // ok + return 0, nil +} + +// Check for error in the wrong location on 2 types +func k() (error, int) { // MATCH /error should be the last type when returning multiple items/ + return nil, 0 +} + +// Check for error in the wrong location for > 2 types +func l() (int, error, int) { // MATCH /error should be the last type when returning multiple items/ + return 0, nil, 0 +} + +// Check for error in the wrong location with named variables. +func m() (x int, err error, y int) { // MATCH /error should be the last type when returning multiple items/ + return 0, nil, 0 +} diff --git a/fixtures/time-names.go b/fixtures/time-names.go new file mode 100644 index 0000000..fca8f97 --- /dev/null +++ b/fixtures/time-names.go @@ -0,0 +1,13 @@ +// Test of time suffixes. + +// Package foo ... +package foo + +import ( + "flag" + "time" +) + +var rpcTimeoutMsec = flag.Duration("rpc_timeout", 100*time.Millisecond, "some flag") // MATCH /var rpcTimeoutMsec is of type *time.Duration; don't use unit-specific suffix "Msec"/ + +var timeoutSecs = 5 * time.Second // MATCH /var timeoutSecs is of type time.Duration; don't use unit-specific suffix "Secs"/ diff --git a/fixtures/unexported-return.go b/fixtures/unexported-return.go new file mode 100644 index 0000000..9aa4efa --- /dev/null +++ b/fixtures/unexported-return.go @@ -0,0 +1,44 @@ +// Test for unexported return types. + +// Package foo ... +package foo + +type hidden struct{} + +// Exported returns a hidden type, which is annoying. +func Exported() hidden { // MATCH /exported func Exported returns unexported type foo.hidden, which can be annoying to use/ + return hidden{} +} + +// ExpErr returns a builtin type. +func ExpErr() error { // ok +} + +func (hidden) ExpOnHidden() hidden { // ok +} + +// T is another test type. +type T struct{} + +// MethodOnT returns a hidden type, which is annoying. +func (T) MethodOnT() hidden { // MATCH /exported method MethodOnT returns unexported type foo.hidden, which can be annoying to use/ + return hidden{} +} + +// ExpT returns a T. +func ExpT() T { // ok + return T{} +} + +func unexp() hidden { // ok + return hidden{} +} + +// This is slightly sneaky: we shadow the builtin "int" type. + +type int struct{} + +// ExportedIntReturner returns an unexported type from this package. +func ExportedIntReturner() int { // MATCH /exported func ExportedIntReturner returns unexported type foo.int, which can be annoying to use/ + return int{} +} diff --git a/rule/context-arguments.go b/rule/context-arguments.go new file mode 100644 index 0000000..2486c64 --- /dev/null +++ b/rule/context-arguments.go @@ -0,0 +1,61 @@ +package rule + +import ( + "go/ast" + + "github.com/mgechev/revive/lint" +) + +// ContextArgumentsRule lints given else constructs. +type ContextArgumentsRule struct{} + +// Apply applies the rule to given file. +func (r *ContextArgumentsRule) Apply(file *lint.File, arguments lint.Arguments) []lint.Failure { + var failures []lint.Failure + + fileAst := file.AST + walker := lintContextArguments{ + file: file, + fileAst: fileAst, + onFailure: func(failure lint.Failure) { + failures = append(failures, failure) + }, + } + + ast.Walk(walker, fileAst) + + return failures +} + +// Name returns the rule name. +func (r *ContextArgumentsRule) Name() string { + return "context-arguments" +} + +type lintContextArguments struct { + file *lint.File + fileAst *ast.File + onFailure func(lint.Failure) +} + +func (w lintContextArguments) Visit(n ast.Node) ast.Visitor { + fn, ok := n.(*ast.FuncDecl) + if !ok || len(fn.Type.Params.List) <= 1 { + return w + } + // A context.Context should be the first parameter of a function. + // Flag any that show up after the first. + for _, arg := range fn.Type.Params.List[1:] { + if isPkgDot(arg.Type, "context", "Context") { + w.onFailure(lint.Failure{ + Node: fn, + Category: "arg-order", + URL: "https://golang.org/pkg/context/", + Failure: "context.Context should be the first parameter of a function", + Confidence: 0.9, + }) + break // only flag one + } + } + return w +} diff --git a/rule/context-key-type.go b/rule/context-key-type.go new file mode 100644 index 0000000..629f0c5 --- /dev/null +++ b/rule/context-key-type.go @@ -0,0 +1,80 @@ +package rule + +import ( + "fmt" + "go/ast" + "go/types" + + "github.com/mgechev/revive/lint" +) + +// ContextKeyTypeRule lints given else constructs. +type ContextKeyTypeRule struct{} + +// Apply applies the rule to given file. +func (r *ContextKeyTypeRule) Apply(file *lint.File, arguments lint.Arguments) []lint.Failure { + var failures []lint.Failure + + fileAst := file.AST + walker := lintContextKeyTypes{ + file: file, + fileAst: fileAst, + onFailure: func(failure lint.Failure) { + failures = append(failures, failure) + }, + } + + ast.Walk(walker, fileAst) + + return failures +} + +// Name returns the rule name. +func (r *ContextKeyTypeRule) Name() string { + return "context-key-types" +} + +type lintContextKeyTypes struct { + file *lint.File + fileAst *ast.File + onFailure func(lint.Failure) +} + +func (w lintContextKeyTypes) Visit(n ast.Node) ast.Visitor { + switch n := n.(type) { + case *ast.CallExpr: + checkContextKeyType(w, n) + } + + return w +} + +func checkContextKeyType(w lintContextKeyTypes, x *ast.CallExpr) { + f := w.file + sel, ok := x.Fun.(*ast.SelectorExpr) + if !ok { + return + } + pkg, ok := sel.X.(*ast.Ident) + if !ok || pkg.Name != "context" { + return + } + if sel.Sel.Name != "WithValue" { + return + } + + // key is second argument to context.WithValue + if len(x.Args) != 3 { + return + } + key := f.Pkg.TypesInfo.Types[x.Args[1]] + + if ktyp, ok := key.Type.(*types.Basic); ok && ktyp.Kind() != types.Invalid { + w.onFailure(lint.Failure{ + Confidence: 1, + Node: x, + Category: "content", + Failure: fmt.Sprintf("should not use basic type %s as key in context.WithValue", key.Type), + }) + } +} diff --git a/rule/error-return.go b/rule/error-return.go new file mode 100644 index 0000000..47b9dd3 --- /dev/null +++ b/rule/error-return.go @@ -0,0 +1,64 @@ +package rule + +import ( + "go/ast" + + "github.com/mgechev/revive/lint" +) + +// ErrorReturnRule lints given else constructs. +type ErrorReturnRule struct{} + +// Apply applies the rule to given file. +func (r *ErrorReturnRule) Apply(file *lint.File, arguments lint.Arguments) []lint.Failure { + var failures []lint.Failure + + fileAst := file.AST + walker := lintErrorReturn{ + file: file, + fileAst: fileAst, + onFailure: func(failure lint.Failure) { + failures = append(failures, failure) + }, + } + + ast.Walk(walker, fileAst) + + return failures +} + +// Name returns the rule name. +func (r *ErrorReturnRule) Name() string { + return "error-return" +} + +type lintErrorReturn struct { + file *lint.File + fileAst *ast.File + onFailure func(lint.Failure) +} + +func (w lintErrorReturn) Visit(n ast.Node) ast.Visitor { + fn, ok := n.(*ast.FuncDecl) + if !ok || fn.Type.Results == nil { + return w + } + ret := fn.Type.Results.List + if len(ret) <= 1 { + return w + } + // An error return parameter should be the last parameter. + // Flag any error parameters found before the last. + for _, r := range ret[:len(ret)-1] { + if isIdent(r.Type, "error") { + w.onFailure(lint.Failure{ + Category: "arg-order", + Confidence: 0.9, + Node: fn, + Failure: "error should be the last type when returning multiple items", + }) + break // only flag one + } + } + return w +} diff --git a/rule/time-names.go b/rule/time-names.go new file mode 100644 index 0000000..14e73ba --- /dev/null +++ b/rule/time-names.go @@ -0,0 +1,91 @@ +package rule + +import ( + "fmt" + "go/ast" + "go/types" + "strings" + + "github.com/mgechev/revive/lint" +) + +// TimeNamesRule lints given else constructs. +type TimeNamesRule struct{} + +// Apply applies the rule to given file. +func (r *TimeNamesRule) Apply(file *lint.File, arguments lint.Arguments) []lint.Failure { + var failures []lint.Failure + + onFailure := func(failure lint.Failure) { + failures = append(failures, failure) + } + + w := &lintTimeNames{file, onFailure} + ast.Walk(w, file.AST) + return failures +} + +// Name returns the rule name. +func (r *TimeNamesRule) Name() string { + return "time-names" +} + +type lintTimeNames struct { + file *lint.File + onFailure func(lint.Failure) +} + +func (w *lintTimeNames) Visit(node ast.Node) ast.Visitor { + v, ok := node.(*ast.ValueSpec) + if !ok { + return w + } + for _, name := range v.Names { + origTyp := w.file.Pkg.TypeOf(name) + // Look for time.Duration or *time.Duration; + // the latter is common when using flag.Duration. + typ := origTyp + if pt, ok := typ.(*types.Pointer); ok { + typ = pt.Elem() + } + if !isNamedType(w.file.Pkg, typ, "time", "Duration") { + continue + } + suffix := "" + for _, suf := range timeSuffixes { + if strings.HasSuffix(name.Name, suf) { + suffix = suf + break + } + } + if suffix == "" { + continue + } + w.onFailure(lint.Failure{ + Category: "time", + Confidence: 0.9, + Node: v, + Failure: fmt.Sprintf("var %s is of type %v; don't use unit-specific suffix %q", name.Name, origTyp, suffix), + }) + } + return w +} + +// timeSuffixes is a list of name suffixes that imply a time unit. +// This is not an exhaustive list. +var timeSuffixes = []string{ + "Sec", "Secs", "Seconds", + "Msec", "Msecs", + "Milli", "Millis", "Milliseconds", + "Usec", "Usecs", "Microseconds", + "MS", "Ms", +} + +func isNamedType(p *lint.Package, typ types.Type, importPath, name string) bool { + n, ok := typ.(*types.Named) + if !ok { + return false + } + tn := n.Obj() + return tn != nil && tn.Pkg() != nil && tn.Pkg().Path() == importPath && tn.Name() == name +} diff --git a/rule/unexported-return.go b/rule/unexported-return.go new file mode 100644 index 0000000..623dcc3 --- /dev/null +++ b/rule/unexported-return.go @@ -0,0 +1,97 @@ +package rule + +import ( + "fmt" + "go/ast" + "go/types" + + "github.com/mgechev/revive/lint" +) + +// UnexportedReturnRule lints given else constructs. +type UnexportedReturnRule struct{} + +// Apply applies the rule to given file. +func (r *UnexportedReturnRule) Apply(file *lint.File, arguments lint.Arguments) []lint.Failure { + var failures []lint.Failure + + fileAst := file.AST + walker := lintUnexportedReturn{ + file: file, + fileAst: fileAst, + onFailure: func(failure lint.Failure) { + failures = append(failures, failure) + }, + } + + ast.Walk(walker, fileAst) + + return failures +} + +// Name returns the rule name. +func (r *UnexportedReturnRule) Name() string { + return "unexported-return" +} + +type lintUnexportedReturn struct { + file *lint.File + fileAst *ast.File + onFailure func(lint.Failure) +} + +func (w lintUnexportedReturn) Visit(n ast.Node) ast.Visitor { + fn, ok := n.(*ast.FuncDecl) + if !ok { + return w + } + if fn.Type.Results == nil { + return nil + } + if !fn.Name.IsExported() { + return nil + } + thing := "func" + if fn.Recv != nil && len(fn.Recv.List) > 0 { + thing = "method" + if !ast.IsExported(receiverType(fn)) { + // Don't report exported methods of unexported types, + // such as private implementations of sort.Interface. + return nil + } + } + for _, ret := range fn.Type.Results.List { + typ := w.file.Pkg.TypeOf(ret.Type) + if exportedType(typ) { + continue + } + w.onFailure(lint.Failure{ + Category: "unexported-type-in-api", + Node: ret.Type, + Confidence: 0.8, + Failure: fmt.Sprintf("exported %s %s returns unexported type %s, which can be annoying to use", + thing, fn.Name.Name, typ), + }) + break // only flag one + } + return nil +} + +// exportedType reports whether typ is an exported type. +// It is imprecise, and will err on the side of returning true, +// such as for composite types. +func exportedType(typ types.Type) bool { + switch T := typ.(type) { + case *types.Named: + // Builtin types have no package. + return T.Obj().Pkg() == nil || T.Obj().Exported() + case *types.Map: + return exportedType(T.Key()) && exportedType(T.Elem()) + case interface { + Elem() types.Type + }: // array, slice, pointer, chan + return exportedType(T.Elem()) + } + // Be conservative about other types, such as struct, interface, etc. + return true +} diff --git a/testutil/lint_test.go b/testutil/lint_test.go index 9770961..5898820 100644 --- a/testutil/lint_test.go +++ b/testutil/lint_test.go @@ -47,6 +47,11 @@ var rules = []lint.Rule{ &rule.ErrorStringsRule{}, &rule.ReceiverNameRule{}, &rule.IncrementDecrementRule{}, + &rule.ErrorReturnRule{}, + &rule.UnexportedReturnRule{}, + &rule.TimeNamesRule{}, + &rule.ContextKeyTypeRule{}, + &rule.ContextArgumentsRule{}, } func TestAll(t *testing.T) {