diff --git a/linter/file.go b/linter/file.go index 0fddaa9..27a9131 100644 --- a/linter/file.go +++ b/linter/file.go @@ -1,9 +1,12 @@ package linter import ( + "bytes" "go/ast" "go/parser" + "go/printer" "go/token" + "go/types" "math" "regexp" "strings" @@ -12,7 +15,7 @@ import ( // File abstraction used for representing files. type File struct { Name string - pkg *Package + Pkg *Package content []byte ast *ast.File } @@ -26,14 +29,14 @@ func NewFile(name string, content []byte, pkg *Package) (*File, error) { return &File{ Name: name, content: content, - pkg: pkg, + Pkg: pkg, ast: f, }, nil } // ToPosition returns line and column for given position. func (f *File) ToPosition(pos token.Pos) token.Position { - return f.pkg.fset.Position(pos) + return f.Pkg.fset.Position(pos) } // GetAST returns the AST of the file @@ -41,6 +44,44 @@ func (f *File) GetAST() *ast.File { return f.ast } +// Render renters a node. +func (f *File) Render(x interface{}) string { + var buf bytes.Buffer + if err := printer.Fprint(&buf, f.Pkg.fset, x); err != nil { + panic(err) + } + return buf.String() +} + +var basicTypeKinds = map[types.BasicKind]string{ + types.UntypedBool: "bool", + types.UntypedInt: "int", + types.UntypedRune: "rune", + types.UntypedFloat: "float64", + types.UntypedComplex: "complex128", + types.UntypedString: "string", +} + +// IsUntypedConst reports whether expr is an untyped constant, +// and indicates what its default type is. +// scope may be nil. +func (f *File) IsUntypedConst(expr ast.Expr) (defType string, ok bool) { + // Re-evaluate expr outside of its context to see if it's untyped. + // (An expr evaluated within, for example, an assignment context will get the type of the LHS.) + exprStr := f.Render(expr) + tv, err := types.Eval(f.Pkg.fset, f.Pkg.TypesPkg, expr.Pos(), exprStr) + if err != nil { + return "", false + } + if b, ok := tv.Type.(*types.Basic); ok { + if dt, ok := basicTypeKinds[b.Kind()]; ok { + return dt, true + } + } + + return "", false +} + func (f *File) isMain() bool { if f.GetAST().Name.Name == "main" { return true diff --git a/linter/package.go b/linter/package.go index 946ea57..1f0eee5 100644 --- a/linter/package.go +++ b/linter/package.go @@ -76,6 +76,14 @@ func (p *Package) TypeCheck() error { return err } +// TypeOf returns the type of an expression. +func (p *Package) TypeOf(expr ast.Expr) types.Type { + if p.TypesInfo == nil { + return nil + } + return p.TypesInfo.TypeOf(expr) +} + func (p *Package) lint(rules []Rule, config RulesConfig) []Failure { var failures []Failure p.TypeCheck() diff --git a/linter/rule.go b/linter/rule.go index 38141f9..31023c2 100644 --- a/linter/rule.go +++ b/linter/rule.go @@ -25,6 +25,7 @@ type FailurePosition struct { type Failure struct { Failure string RuleName string + Category string Type FailureType Position FailurePosition Node ast.Node diff --git a/rule/ranges.go b/rule/ranges.go index b3960f9..a10646c 100644 --- a/rule/ranges.go +++ b/rule/ranges.go @@ -18,9 +18,8 @@ func (r *LintRangesRule) Apply(file *linter.File, arguments linter.Arguments) [] failures = append(failures, failure) } - astFile := file.GetAST() - w := &lintRanges{astFile, onFailure} - ast.Walk(w, astFile) + w := &lintRanges{file, onFailure} + ast.Walk(w, file.GetAST()) return failures } @@ -30,7 +29,7 @@ func (r *LintRangesRule) Name() string { } type lintRanges struct { - file *ast.File + file *linter.File onFailure func(linter.Failure) } @@ -49,7 +48,7 @@ func (w *lintRanges) Visit(node ast.Node) ast.Visitor { } w.onFailure(linter.Failure{ - Failure: fmt.Sprintf("should omit 2nd value from range; this loop is equivalent to `for %s %s range ...`", render(rs.Key), rs.Tok), + Failure: fmt.Sprintf("should omit 2nd value from range; this loop is equivalent to `for %s %s range ...`", w.file.Render(rs.Key), rs.Tok), Confidence: 1, Node: rs.Value, }) diff --git a/rule/var-declarations.go b/rule/var-declarations.go index 1570c19..f0e7f39 100644 --- a/rule/var-declarations.go +++ b/rule/var-declarations.go @@ -1,11 +1,10 @@ package rule import ( - "bytes" "fmt" "go/ast" - "go/printer" "go/token" + "go/types" "github.com/mgechev/revive/linter" ) @@ -75,18 +74,45 @@ func (w *lintVarDeclarations) Visit(node ast.Node) ast.Visitor { w.onFailure(linter.Failure{ Confidence: 0.9, Node: rhs, - Failure: fmt.Sprintf("should drop = %s from declaration of var %s; it is the zero value", render(rhs), v.Names[0]), + Failure: fmt.Sprintf("should drop = %s from declaration of var %s; it is the zero value", w.file.Render(rhs), v.Names[0]), }) return nil } + lhsTyp := w.file.Pkg.TypeOf(v.Type) + rhsTyp := w.file.Pkg.TypeOf(rhs) + + if !validType(lhsTyp) || !validType(rhsTyp) { + // Type checking failed (often due to missing imports). + return nil + } + + if !types.Identical(lhsTyp, rhsTyp) { + // Assignment to a different type is not redundant. + return nil + } + + // The next three conditions are for suppressing the warning in situations + // where we were unable to typecheck. + + // If the LHS type is an interface, don't warn, since it is probably a + // concrete type on the RHS. Note that our feeble lexical check here + // will only pick up interface{} and other literal interface types; + // that covers most of the cases we care to exclude right now. + if _, ok := v.Type.(*ast.InterfaceType); ok { + return nil + } + // If the RHS is an untyped const, only warn if the LHS type is its default type. + if defType, ok := w.file.IsUntypedConst(rhs); ok && !isIdent(v.Type, defType) { + return nil + } + + w.onFailure(linter.Failure{ + Category: "type-inference", + Confidence: 0.8, + Node: v.Type, + Failure: fmt.Sprintf("should omit type %s from declaration of var %s; it will be inferred from the right-hand side", w.file.Render(v.Type), v.Names[0]), + }) + return nil } return w } - -func render(x interface{}) string { - var buf bytes.Buffer - if err := printer.Fprint(&buf, token.NewFileSet(), x); err != nil { - panic(err) - } - return buf.String() -}