diff --git a/core/call_list.go b/core/call_list.go index 1e45513..9ace433 100644 --- a/core/call_list.go +++ b/core/call_list.go @@ -13,16 +13,13 @@ package core -type set map[string]bool +import "go/ast" -type calls struct { - matchAny bool - functions set -} +type set map[string]bool /// CallList is used to check for usage of specific packages /// and functions. -type CallList map[string]*calls +type CallList map[string]set /// NewCallList creates a new empty CallList func NewCallList() CallList { @@ -30,36 +27,39 @@ func NewCallList() CallList { } /// NewCallListFor createse a call list using the package path -func NewCallListFor(pkg string, funcs ...string) CallList { +func NewCallListFor(selector string, idents ...string) CallList { c := NewCallList() - if len(funcs) == 0 { - c[pkg] = &calls{true, make(set)} - } else { - for _, fn := range funcs { - c.Add(pkg, fn) - } + c[selector] = make(set) + for _, ident := range idents { + c.Add(selector, ident) } return c } -/// Add a new package and function to the call list -func (c CallList) Add(pkg, fn string) { - if cl, ok := c[pkg]; ok { - if cl.matchAny { - cl.matchAny = false - } - } else { - c[pkg] = &calls{false, make(set)} +/// Add a selector and call to the call list +func (c CallList) Add(selector, ident string) { + if _, ok := c[selector]; !ok { + c[selector] = make(set) } - c[pkg].functions[fn] = true + c[selector][ident] = true } /// Contains returns true if the package and function are /// members of this call list. -func (c CallList) Contains(pkg, fn string) bool { - if funcs, ok := c[pkg]; ok { - _, ok = funcs.functions[fn] - return ok || funcs.matchAny +func (c CallList) Contains(selector, ident string) bool { + if idents, ok := c[selector]; ok { + _, found := idents[ident] + return found } return false } + +/// ContainsCallExpr resolves the call expression name and type +/// or package and determines if it exists within the CallList +func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool { + selector, ident, err := GetCallInfo(n, ctx) + if err != nil { + return false + } + return c.Contains(selector, ident) +} diff --git a/core/call_list_test.go b/core/call_list_test.go new file mode 100644 index 0000000..aa4f67c --- /dev/null +++ b/core/call_list_test.go @@ -0,0 +1,58 @@ +package core + +import ( + "go/ast" + "testing" +) + +type callListRule struct { + MetaData + callList CallList + matched int +} + +func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) { + if r.callList.ContainsCallExpr(n, c) { + r.matched += 1 + } + return nil, nil +} + +func TestCallListContainsCallExpr(t *testing.T) { + config := map[string]interface{}{"ignoreNosec": false} + analyzer := NewAnalyzer(config, nil) + rule := &callListRule{ + MetaData: MetaData{ + Severity: Low, + Confidence: Low, + What: "A dummy rule", + }, + callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"), + matched: 0, + } + analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)}) + source := ` + package main + import ( + "bytes" + "fmt" + ) + func main() { + var b bytes.Buffer + b.Write([]byte("Hello ")) + fmt.Fprintf(&b, "world!") + }` + + analyzer.ProcessSource("dummy.go", source) + if rule.matched != 1 { + t.Errorf("Expected to match a bytes.Buffer.Write call") + } +} + +func TestCallListContains(t *testing.T) { + callList := NewCallList() + callList.Add("fmt", "Printf") + if !callList.Contains("fmt", "Printf") { + t.Errorf("Expected call list to contain fmt.Printf") + } +} diff --git a/core/helpers.go b/core/helpers.go index ee1bddf..79a1617 100644 --- a/core/helpers.go +++ b/core/helpers.go @@ -69,18 +69,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a importName = alias } - switch node := n.(type) { - case *ast.CallExpr: - switch fn := node.Fun.(type) { - case *ast.SelectorExpr: - switch expr := fn.X.(type) { - case *ast.Ident: - if expr.Name == importName { - for _, name := range names { - if fn.Sel.Name == name { - return node, true - } - } + if callExpr, ok := n.(*ast.CallExpr); ok { + packageName, callName, err := GetCallInfo(callExpr, c) + if err != nil { + return nil, false + } + if packageName == importName { + for _, name := range names { + if callName == name { + return callExpr, true } } } @@ -95,19 +92,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a // node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write") // func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) { - switch callExpr := n.(type) { - case *ast.CallExpr: - switch fn := callExpr.Fun.(type) { - case *ast.SelectorExpr: - switch expr := fn.X.(type) { - case *ast.Ident: - t := ctx.Info.TypeOf(expr) - if t != nil && t.String() == requiredType { - for _, call := range calls { - if fn.Sel.Name == call { - return callExpr, true - } - } + if callExpr, ok := n.(*ast.CallExpr); ok { + typeName, callName, err := GetCallInfo(callExpr, ctx) + if err != nil { + return nil, false + } + if typeName == requiredType { + for _, call := range calls { + if call == callName { + return callExpr, true } } } @@ -171,3 +164,28 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) { } return nil, nil } + +// GetCallInfo returns the package or type and name associated with a +// call expression. +func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) { + switch node := n.(type) { + case *ast.CallExpr: + switch fn := node.Fun.(type) { + case *ast.SelectorExpr: + switch expr := fn.X.(type) { + case *ast.Ident: + if expr.Obj != nil && expr.Obj.Kind == ast.Var { + t := ctx.Info.TypeOf(expr) + if t != nil { + return t.String(), fn.Sel.Name, nil + } else { + return "undefined", fn.Sel.Name, fmt.Errorf("missing type info") + } + } else { + return expr.Name, fn.Sel.Name, nil + } + } + } + } + return "", "", fmt.Errorf("unable to determine call info") +} diff --git a/core/helpers_test.go b/core/helpers_test.go index beb7afb..89648e7 100644 --- a/core/helpers_test.go +++ b/core/helpers_test.go @@ -56,4 +56,16 @@ func TestMatchCallByType(t *testing.T) { if rule.matched != 1 || len(rule.callExpr) != 1 { t.Errorf("Expected to match a bytes.Buffer.Write call") } + + typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context) + if err != nil { + t.Errorf("Unable to resolve call info: %v\n", err) + } + if typeName != "bytes.Buffer" { + t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName) + } + if callName != "Write" { + t.Errorf("Expected: %s, Got: %s\n", "Write", callName) + } + }