diff --git a/core/call_list.go b/core/call_list.go index 9ace433..2002024 100644 --- a/core/call_list.go +++ b/core/call_list.go @@ -13,7 +13,9 @@ package core -import "go/ast" +import ( + "go/ast" +) type set map[string]bool @@ -26,14 +28,11 @@ func NewCallList() CallList { return make(CallList) } -/// NewCallListFor createse a call list using the package path -func NewCallListFor(selector string, idents ...string) CallList { - c := NewCallList() - c[selector] = make(set) +/// AddAll will add several calls to the call list at once +func (c CallList) AddAll(selector string, idents ...string) { for _, ident := range idents { c.Add(selector, ident) } - return c } /// Add a selector and call to the call list @@ -61,5 +60,14 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool { if err != nil { return false } - return c.Contains(selector, ident) + // Try direct resolution + if c.Contains(selector, ident) { + return true + } + + // Also support explicit path + if path, ok := GetImportPath(selector, ctx); ok { + return c.Contains(path, ident) + } + return false } diff --git a/core/call_list_test.go b/core/call_list_test.go index aa4f67c..ef58293 100644 --- a/core/call_list_test.go +++ b/core/call_list_test.go @@ -21,13 +21,15 @@ func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) { func TestCallListContainsCallExpr(t *testing.T) { config := map[string]interface{}{"ignoreNosec": false} analyzer := NewAnalyzer(config, nil) + calls := NewCallList() + calls.AddAll("bytes.Buffer", "Write", "WriteTo") rule := &callListRule{ MetaData: MetaData{ Severity: Low, Confidence: Low, What: "A dummy rule", }, - callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"), + callList: calls, matched: 0, } analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)}) diff --git a/core/helpers.go b/core/helpers.go index 79a1617..d42ceca 100644 --- a/core/helpers.go +++ b/core/helpers.go @@ -56,25 +56,17 @@ func MatchCall(n ast.Node, r *regexp.Regexp) *ast.CallExpr { // func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*ast.CallExpr, bool) { - importName, imported := c.Imports.Imported[pkg] - if !imported { + importedName, found := GetImportedName(pkg, c) + if !found { return nil, false } - if _, initonly := c.Imports.InitOnly[pkg]; initonly { - return nil, false - } - - if alias, ok := c.Imports.Aliased[pkg]; ok { - importName = alias - } - if callExpr, ok := n.(*ast.CallExpr); ok { packageName, callName, err := GetCallInfo(callExpr, c) if err != nil { return nil, false } - if packageName == importName { + if packageName == importedName { for _, name := range names { if callName == name { return callExpr, true @@ -185,7 +177,38 @@ func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) { return expr.Name, fn.Sel.Name, nil } } + case *ast.Ident: + return ctx.Pkg.Name(), fn.Name, nil } } return "", "", fmt.Errorf("unable to determine call info") } + +// GetImportedName returns the name used for the package within the +// code. It will resolve aliases and ignores initalization only imports. +func GetImportedName(path string, ctx *Context) (string, bool) { + importName, imported := ctx.Imports.Imported[path] + if !imported { + return "", false + } + + if _, initonly := ctx.Imports.InitOnly[path]; initonly { + return "", false + } + + if alias, ok := ctx.Imports.Aliased[path]; ok { + importName = alias + } + return importName, true +} + +// GetImportPath resolves the full import path of an identifer based on +// the imports in the current context. +func GetImportPath(name string, ctx *Context) (string, bool) { + for path, _ := range ctx.Imports.Imported { + if imported, ok := GetImportedName(path, ctx); ok && imported == name { + return path, true + } + } + return "", false +} diff --git a/rules/errors.go b/rules/errors.go index 4490312..2bf61c9 100644 --- a/rules/errors.go +++ b/rules/errors.go @@ -15,47 +15,82 @@ package rules import ( + gas "github.com/GoASTScanner/gas/core" "go/ast" "go/types" - "reflect" - - gas "github.com/GoASTScanner/gas/core" ) type NoErrorCheck struct { gas.MetaData + whitelist gas.CallList } -func (r *NoErrorCheck) Match(n ast.Node, c *gas.Context) (gi *gas.Issue, err error) { - if node, ok := n.(*ast.AssignStmt); ok { - sel := reflect.TypeOf(&ast.CallExpr{}) - if call, ok := gas.SimpleSelect(node.Rhs[0], sel).(*ast.CallExpr); ok { - if t := c.Info.Types[call].Type; t != nil { - if typeVal, typeErr := t.(*types.Tuple); typeErr { - for i := 0; i < typeVal.Len(); i++ { - if typeVal.At(i).Type().String() == "error" { // TODO(tkelsey): is there a better way? - if id, ok := node.Lhs[i].(*ast.Ident); ok && id.Name == "_" { - return gas.NewIssue(c, n, r.What, r.Severity, r.Confidence), nil - } - } - } - } else if t.String() == "error" { // TODO(tkelsey): is there a better way? - if id, ok := node.Lhs[0].(*ast.Ident); ok && id.Name == "_" { - return gas.NewIssue(c, n, r.What, r.Severity, r.Confidence), nil - } +func returnsError(callExpr *ast.CallExpr, ctx *gas.Context) int { + if tv := ctx.Info.TypeOf(callExpr); tv != nil { + switch t := tv.(type) { + case *types.Tuple: + for pos := 0; pos < t.Len(); pos += 1 { + variable := t.At(pos) + if variable != nil && variable.Type().String() == "error" { + return pos } } + case *types.Named: + if t.String() == "error" { + return 0 + } + } + } + return -1 +} + +func (r *NoErrorCheck) Match(n ast.Node, ctx *gas.Context) (*gas.Issue, error) { + switch stmt := n.(type) { + case *ast.AssignStmt: + for _, expr := range stmt.Rhs { + if callExpr, ok := expr.(*ast.CallExpr); ok && !r.whitelist.ContainsCallExpr(callExpr, ctx) { + pos := returnsError(callExpr, ctx) + if pos < 0 || pos >= len(stmt.Lhs) { + return nil, nil + } + if id, ok := stmt.Lhs[pos].(*ast.Ident); ok && id.Name == "_" { + return gas.NewIssue(ctx, n, r.What, r.Severity, r.Confidence), nil + } + } + } + case *ast.ExprStmt: + if callExpr, ok := stmt.X.(*ast.CallExpr); ok && !r.whitelist.ContainsCallExpr(callExpr, ctx) { + pos := returnsError(callExpr, ctx) + if pos >= 0 { + return gas.NewIssue(ctx, n, r.What, r.Severity, r.Confidence), nil + } } } return nil, nil } func NewNoErrorCheck(conf map[string]interface{}) (gas.Rule, []ast.Node) { + + // TODO(gm) Come up with sensible defaults here. Or flip it to use a + // black list instead. + whitelist := gas.NewCallList() + whitelist.AddAll("bytes.Buffer", "Write", "WriteByte", "WriteRune", "WriteString") + whitelist.AddAll("fmt", "Print", "Printf", "Println") + whitelist.Add("io.PipeWriter", "CloseWithError") + + if configured, ok := conf["G104"]; ok { + if whitelisted, ok := configured.(map[string][]string); ok { + for key, val := range whitelisted { + whitelist.AddAll(key, val...) + } + } + } return &NoErrorCheck{ MetaData: gas.MetaData{ Severity: gas.Low, Confidence: gas.High, What: "Errors unhandled.", }, - }, []ast.Node{(*ast.AssignStmt)(nil)} + whitelist: whitelist, + }, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} } diff --git a/rules/errors_test.go b/rules/errors_test.go index 4ae502b..d4a07a0 100644 --- a/rules/errors_test.go +++ b/rules/errors_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -28,17 +28,17 @@ func TestErrorsMulti(t *testing.T) { issues := gasTestRunner( `package main - import ( - "fmt" - ) + import ( + "fmt" + ) - func test() (val int, err error) { - return 0, nil - } + func test() (val int, err error) { + return 0, nil + } - func main() { - v, _ := test() - }`, analyzer) + func main() { + v, _ := test() + }`, analyzer) checkTestResults(t, issues, 1, "Errors unhandled") } @@ -51,19 +51,30 @@ func TestErrorsSingle(t *testing.T) { issues := gasTestRunner( `package main - import ( - "fmt" - ) + import ( + "fmt" + ) - func test() (err error) { - return nil - } + func a() error { + return fmt.Errorf("This is an error") + } - func main() { - _ := test() - }`, analyzer) + func b() { + fmt.Println("b") + } - checkTestResults(t, issues, 1, "Errors unhandled") + func c() string { + return fmt.Sprintf("This isn't anything") + } + + func main() { + _ = a() + a() + b() + _ = c() + c() + }`, analyzer) + checkTestResults(t, issues, 2, "Errors unhandled") } func TestErrorsGood(t *testing.T) { @@ -74,17 +85,56 @@ func TestErrorsGood(t *testing.T) { issues := gasTestRunner( `package main - import ( - "fmt" - ) + import ( + "fmt" + ) - func test() err error { - return 0, nil - } + func test() err error { + return 0, nil + } - func main() { - e := test() - }`, analyzer) + func main() { + e := test() + }`, analyzer) checkTestResults(t, issues, 0, "") } + +func TestErrorsWhitelisted(t *testing.T) { + config := map[string]interface{}{ + "ignoreNosec": false, + "G104": map[string][]string{ + "compress/zlib": []string{"NewReader"}, + "io": []string{"Copy"}, + }, + } + analyzer := gas.NewAnalyzer(config, nil) + analyzer.AddRule(NewNoErrorCheck(config)) + source := `package main + import ( + "io" + "os" + "fmt" + "bytes" + "compress/zlib" + ) + + func a() error { + return fmt.Errorf("This is an error ok") + } + + func main() { + // Expect at least one failure + _ = a() + + var b bytes.Buffer + // Default whitelist + nbytes, _ := b.Write([]byte("Hello ")) + + // Whitelisted via configuration + r, _ := zlib.NewReader(&b) + io.Copy(os.Stdout, r) + }` + issues := gasTestRunner(source, analyzer) + checkTestResults(t, issues, 1, "Errors unhandled") +}