diff --git a/helpers.go b/helpers.go index 40dc8e9..5e90f40 100644 --- a/helpers.go +++ b/helpers.go @@ -226,6 +226,27 @@ func GetIdentStringValues(ident *ast.Ident) []string { return values } +// GetBinaryExprOperands returns all operands of a binary expression by traversing +// the expression tree +func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node { + var traverse func(be *ast.BinaryExpr) + result := []ast.Node{} + traverse = func(be *ast.BinaryExpr) { + if lhs, ok := be.X.(*ast.BinaryExpr); ok { + traverse(lhs) + } else { + result = append(result, be.X) + } + if rhs, ok := be.Y.(*ast.BinaryExpr); ok { + traverse(rhs) + } else { + result = append(result, be.Y) + } + } + traverse(be) + return result +} + // GetImportedName returns the name used for the package within the // code. It will resolve aliases and ignores initialization only imports. func GetImportedName(path string, ctx *Context) (string, bool) { diff --git a/helpers_test.go b/helpers_test.go index 8fc7019..b13325a 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -229,4 +229,68 @@ var _ = Describe("Helpers", func() { Expect(result).Should(HaveKeyWithValue("fmt", "Println")) }) }) + Context("when getting binary expression operands", func() { + It("should return all operands of a binary experssion", func() { + pkg := testutils.NewTestPackage() + defer pkg.Close() + pkg.AddFile("main.go", ` + package main + + import( + "fmt" + ) + + func main() { + be := "test1" + "test2" + fmt.Println(be) + } + `) + ctx := pkg.CreateContext("main.go") + var be *ast.BinaryExpr + visitor := testutils.NewMockVisitor() + visitor.Context = ctx + visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool { + if expr, ok := n.(*ast.BinaryExpr); ok { + be = expr + } + return true + } + ast.Walk(visitor, ctx.Root) + + operands := gosec.GetBinaryExprOperands(be) + Expect(len(operands)).Should(Equal(2)) + }) + It("should return all operands of complex binary experssion", func() { + pkg := testutils.NewTestPackage() + defer pkg.Close() + pkg.AddFile("main.go", ` + package main + + import( + "fmt" + ) + + func main() { + be := "test1" + "test2" + "test3" + "test4" + fmt.Println(be) + } + `) + ctx := pkg.CreateContext("main.go") + var be *ast.BinaryExpr + visitor := testutils.NewMockVisitor() + visitor.Context = ctx + visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool { + if expr, ok := n.(*ast.BinaryExpr); ok { + if be == nil { + be = expr + } + } + return true + } + ast.Walk(visitor, ctx.Root) + + operands := gosec.GetBinaryExprOperands(be) + Expect(len(operands)).Should(Equal(4)) + }) + }) }) diff --git a/rules/sql.go b/rules/sql.go index 48671d5..127dec5 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -16,7 +16,6 @@ package rules import ( "go/ast" - "go/token" "regexp" "strings" @@ -82,20 +81,19 @@ func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gose } if be, ok := query.(*ast.BinaryExpr); ok { - // Skip all operations which aren't concatenation - if be.Op != token.ADD { - return nil, nil - } - if start, ok := be.X.(*ast.BasicLit); ok { + operands := gosec.GetBinaryExprOperands(be) + if start, ok := operands[0].(*ast.BasicLit); ok { if str, e := gosec.GetString(start); e == nil { if !s.MatchPatterns(str) { return nil, nil } - if _, ok := be.Y.(*ast.BasicLit); ok { - return nil, nil // string cat OK + } + for _, op := range operands[1:] { + if _, ok := op.(*ast.BasicLit); ok { + continue } - if second, ok := be.Y.(*ast.Ident); ok && s.checkObject(second, ctx) { - return nil, nil + if op, ok := op.(*ast.Ident); ok && s.checkObject(op, ctx) { + continue } return gosec.NewIssue(ctx, be, s.ID(), s.What, s.Severity, s.Confidence), nil } diff --git a/testutils/source.go b/testutils/source.go index bc98258..fcf9838 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -1026,6 +1026,23 @@ func main(){ panic(err) } }`}, 1, gosec.NewConfig()}, {[]string{` +// multiple string concatenation +package main +import ( + "database/sql" + "os" +) +func main(){ + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + rows, err := db.Query("SELECT * FROM foo" + "WHERE name = " + os.Args[1]) + if err != nil { + panic(err) + } + defer rows.Close() +}`}, 1, gosec.NewConfig()}, {[]string{` // false positive package main import (