diff --git a/README.md b/README.md index 214915e..ecb88ab 100644 --- a/README.md +++ b/README.md @@ -271,6 +271,7 @@ List of all available rules. The rules ported from `golint` are left unchanged a | `redefines-builtin-id`| n/a | Warns on redefinitions of builtin identifiers | no | no | | `function-result-limit` | int | Specifies the maximum number of results a function can return | no | no | | `imports-blacklist` | []string | Disallows importing the specified packages | no | no | +| `range-loop-var` | n/a | Disallows incorrect uses of range loop variables in closures | no | no | ## Configurable rules diff --git a/config.go b/config.go index 05ea1c7..c4dffb4 100644 --- a/config.go +++ b/config.go @@ -67,6 +67,7 @@ var allRules = append([]lint.Rule{ &rule.ImportsBlacklistRule{}, &rule.FunctionResultsLimitRule{}, &rule.MaxPublicStructsRule{}, + &rule.RangeLoopVarRule{}, }, defaultRules...) var allFormatters = []lint.Formatter{ diff --git a/fixtures/range-loop-var.go b/fixtures/range-loop-var.go new file mode 100644 index 0000000..c4d668f --- /dev/null +++ b/fixtures/range-loop-var.go @@ -0,0 +1,24 @@ +package fixtures + +import "fmt" + +func foo() { + mySlice := []string{"A", "B", "C"} + for index, value := range mySlice { + go func() { + fmt.Printf("Index: %d\n", index) // MATCH /loop variable index captured by func literal/ + fmt.Printf("Value: %s\n", value) // MATCH /loop variable value captured by func literal/ + }() + } + + myDict := make(map[string]int) + myDict["A"] = 1 + myDict["B"] = 2 + myDict["C"] = 3 + for key, value := range myDict { + defer func() { + fmt.Printf("Index: %d\n", key) // MATCH /loop variable key captured by func literal/ + fmt.Printf("Value: %s\n", value) // MATCH /loop variable value captured by func literal/ + }() + } +} diff --git a/rule/range-loop-var.go b/rule/range-loop-var.go new file mode 100644 index 0000000..b371704 --- /dev/null +++ b/rule/range-loop-var.go @@ -0,0 +1,111 @@ +package rule + +import ( + "fmt" + "go/ast" + + "github.com/mgechev/revive/lint" +) + +// RangeLoopVarRule lints given else constructs. +type RangeLoopVarRule struct{} + +// Apply applies the rule to given file. +func (r *RangeLoopVarRule) Apply(file *lint.File, arguments lint.Arguments) []lint.Failure { + var failures []lint.Failure + + walker := RangeLoopVar{ + onFailure: func(failure lint.Failure) { + failures = append(failures, failure) + }, + } + + ast.Walk(walker, file.AST) + + return failures +} + +// Name returns the rule name. +func (r *RangeLoopVarRule) Name() string { + return "range-loop-var" +} + +type RangeLoopVar struct { + onFailure func(lint.Failure) +} + +func (w RangeLoopVar) Visit(node ast.Node) ast.Visitor { + + // Find the variables updated by the loop statement. + var vars []*ast.Ident + addVar := func(expr ast.Expr) { + if id, ok := expr.(*ast.Ident); ok { + vars = append(vars, id) + } + } + var body *ast.BlockStmt + switch n := node.(type) { + case *ast.RangeStmt: + body = n.Body + addVar(n.Key) + addVar(n.Value) + case *ast.ForStmt: + body = n.Body + switch post := n.Post.(type) { + case *ast.AssignStmt: + // e.g. for p = head; p != nil; p = p.next + for _, lhs := range post.Lhs { + addVar(lhs) + } + case *ast.IncDecStmt: + // e.g. for i := 0; i < n; i++ + addVar(post.X) + } + } + if vars == nil { + return w + } + + // Inspect a go or defer statement + // if it's the last one in the loop body. + // (We give up if there are following statements, + // because it's hard to prove go isn't followed by wait, + // or defer by return.) + if len(body.List) == 0 { + return w + } + var last *ast.CallExpr + switch s := body.List[len(body.List)-1].(type) { + case *ast.GoStmt: + last = s.Call + case *ast.DeferStmt: + last = s.Call + default: + return w + } + lit, ok := last.Fun.(*ast.FuncLit) + if !ok { + return w + } + ast.Inspect(lit.Body, func(n ast.Node) bool { + id, ok := n.(*ast.Ident) + if !ok || id.Obj == nil { + return true + } + if lit.Type == nil { + // Not referring to a variable (e.g. struct field name) + return true + } + for _, v := range vars { + if v.Obj == id.Obj { + w.onFailure(lint.Failure{ + Confidence: 1, + Failure: fmt.Sprintf("loop variable %v captured by func literal", id.Name), + Node: n, + }) + } + } + return true + }) + return w +} diff --git a/test/range-loop-var_test.go b/test/range-loop-var_test.go new file mode 100644 index 0000000..2536f10 --- /dev/null +++ b/test/range-loop-var_test.go @@ -0,0 +1,12 @@ +package test + +import ( + "testing" + + "github.com/mgechev/revive/lint" + "github.com/mgechev/revive/rule" +) + +func TestRangeLoopVar(t *testing.T) { + testRule(t, "range-loop-var", &rule.RangeLoopVarRule{}, &lint.RuleConfig{}) +} diff --git a/untyped.toml b/untyped.toml index 5c7ff44..6f95047 100644 --- a/untyped.toml +++ b/untyped.toml @@ -13,3 +13,4 @@ [rule.receiver-naming] [rule.indent-error-flow] [rule.empty-block] +[rule.range-loop-var]