mirror of
https://github.com/mgechev/revive.git
synced 2024-12-04 10:24:49 +02:00
143 lines
3.0 KiB
Go
143 lines
3.0 KiB
Go
package rule
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
|
|
"github.com/mgechev/revive/lint"
|
|
)
|
|
|
|
// DataRaceRule lints assignments to value method-receivers.
|
|
type DataRaceRule struct{}
|
|
|
|
// Apply applies the rule to given file.
|
|
func (*DataRaceRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure {
|
|
var failures []lint.Failure
|
|
onFailure := func(failure lint.Failure) {
|
|
failures = append(failures, failure)
|
|
}
|
|
w := lintDataRaces{onFailure: onFailure}
|
|
|
|
ast.Walk(w, file.AST)
|
|
|
|
return failures
|
|
}
|
|
|
|
// Name returns the rule name.
|
|
func (*DataRaceRule) Name() string {
|
|
return "datarace"
|
|
}
|
|
|
|
type lintDataRaces struct {
|
|
onFailure func(failure lint.Failure)
|
|
}
|
|
|
|
func (w lintDataRaces) Visit(n ast.Node) ast.Visitor {
|
|
node, ok := n.(*ast.FuncDecl)
|
|
if !ok {
|
|
return w // not function declaration
|
|
}
|
|
if node.Body == nil {
|
|
return nil // empty body
|
|
}
|
|
|
|
results := node.Type.Results
|
|
|
|
returnIDs := map[*ast.Object]struct{}{}
|
|
if results != nil {
|
|
returnIDs = w.ExtractReturnIDs(results.List)
|
|
}
|
|
fl := &lintFunctionForDataRaces{onFailure: w.onFailure, returnIDs: returnIDs, rangeIDs: map[*ast.Object]struct{}{}}
|
|
ast.Walk(fl, node.Body)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (lintDataRaces) ExtractReturnIDs(fields []*ast.Field) map[*ast.Object]struct{} {
|
|
r := map[*ast.Object]struct{}{}
|
|
for _, f := range fields {
|
|
for _, id := range f.Names {
|
|
r[id.Obj] = struct{}{}
|
|
}
|
|
}
|
|
|
|
return r
|
|
}
|
|
|
|
type lintFunctionForDataRaces struct {
|
|
_ struct{}
|
|
onFailure func(failure lint.Failure)
|
|
returnIDs map[*ast.Object]struct{}
|
|
rangeIDs map[*ast.Object]struct{}
|
|
}
|
|
|
|
func (w lintFunctionForDataRaces) Visit(node ast.Node) ast.Visitor {
|
|
switch n := node.(type) {
|
|
case *ast.RangeStmt:
|
|
if n.Body == nil {
|
|
return nil
|
|
}
|
|
|
|
getIds := func(exprs ...ast.Expr) []*ast.Ident {
|
|
r := []*ast.Ident{}
|
|
for _, expr := range exprs {
|
|
if id, ok := expr.(*ast.Ident); ok {
|
|
r = append(r, id)
|
|
}
|
|
}
|
|
return r
|
|
}
|
|
|
|
ids := getIds(n.Key, n.Value)
|
|
for _, id := range ids {
|
|
w.rangeIDs[id.Obj] = struct{}{}
|
|
}
|
|
|
|
ast.Walk(w, n.Body)
|
|
|
|
for _, id := range ids {
|
|
delete(w.rangeIDs, id.Obj)
|
|
}
|
|
|
|
return nil // do not visit the body of the range, it has been already visited
|
|
case *ast.GoStmt:
|
|
f := n.Call.Fun
|
|
funcLit, ok := f.(*ast.FuncLit)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
selectIDs := func(n ast.Node) bool {
|
|
_, ok := n.(*ast.Ident)
|
|
return ok
|
|
}
|
|
|
|
ids := pick(funcLit.Body, selectIDs, nil)
|
|
for _, id := range ids {
|
|
id := id.(*ast.Ident)
|
|
_, isRangeID := w.rangeIDs[id.Obj]
|
|
_, isReturnID := w.returnIDs[id.Obj]
|
|
|
|
switch {
|
|
case isRangeID:
|
|
w.onFailure(lint.Failure{
|
|
Confidence: 1,
|
|
Node: id,
|
|
Category: "logic",
|
|
Failure: fmt.Sprintf("datarace: range value %s is captured (by-reference) in goroutine", id.Name),
|
|
})
|
|
case isReturnID:
|
|
w.onFailure(lint.Failure{
|
|
Confidence: 0.8,
|
|
Node: id,
|
|
Category: "logic",
|
|
Failure: fmt.Sprintf("potential datarace: return value %s is captured (by-reference) in goroutine", id.Name),
|
|
})
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return w
|
|
}
|