mirror of
				https://github.com/securego/gosec.git
				synced 2025-10-30 23:47:56 +02:00 
			
		
		
		
	Extend helpers and call list
- Update call list to work directly with call expression - Add call list test cases - Extend helpers to add GetCallInfo to resolve call name and package or type if it's a var. - Add test cases to ensure correct behaviour
This commit is contained in:
		| @@ -13,16 +13,13 @@ | ||||
|  | ||||
| package core | ||||
|  | ||||
| type set map[string]bool | ||||
| import "go/ast" | ||||
|  | ||||
| type calls struct { | ||||
| 	matchAny  bool | ||||
| 	functions set | ||||
| } | ||||
| type set map[string]bool | ||||
|  | ||||
| /// CallList is used to check for usage of specific packages | ||||
| /// and functions. | ||||
| type CallList map[string]*calls | ||||
| type CallList map[string]set | ||||
|  | ||||
| /// NewCallList creates a new empty CallList | ||||
| func NewCallList() CallList { | ||||
| @@ -30,36 +27,39 @@ func NewCallList() CallList { | ||||
| } | ||||
|  | ||||
| /// NewCallListFor createse a call list using the package path | ||||
| func NewCallListFor(pkg string, funcs ...string) CallList { | ||||
| func NewCallListFor(selector string, idents ...string) CallList { | ||||
| 	c := NewCallList() | ||||
| 	if len(funcs) == 0 { | ||||
| 		c[pkg] = &calls{true, make(set)} | ||||
| 	} else { | ||||
| 		for _, fn := range funcs { | ||||
| 			c.Add(pkg, fn) | ||||
| 		} | ||||
| 	c[selector] = make(set) | ||||
| 	for _, ident := range idents { | ||||
| 		c.Add(selector, ident) | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| /// Add a new package and function to the call list | ||||
| func (c CallList) Add(pkg, fn string) { | ||||
| 	if cl, ok := c[pkg]; ok { | ||||
| 		if cl.matchAny { | ||||
| 			cl.matchAny = false | ||||
| 		} | ||||
| 	} else { | ||||
| 		c[pkg] = &calls{false, make(set)} | ||||
| /// Add a selector and call to the call list | ||||
| func (c CallList) Add(selector, ident string) { | ||||
| 	if _, ok := c[selector]; !ok { | ||||
| 		c[selector] = make(set) | ||||
| 	} | ||||
| 	c[pkg].functions[fn] = true | ||||
| 	c[selector][ident] = true | ||||
| } | ||||
|  | ||||
| /// Contains returns true if the package and function are | ||||
| /// members of this call list. | ||||
| func (c CallList) Contains(pkg, fn string) bool { | ||||
| 	if funcs, ok := c[pkg]; ok { | ||||
| 		_, ok = funcs.functions[fn] | ||||
| 		return ok || funcs.matchAny | ||||
| func (c CallList) Contains(selector, ident string) bool { | ||||
| 	if idents, ok := c[selector]; ok { | ||||
| 		_, found := idents[ident] | ||||
| 		return found | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| /// ContainsCallExpr resolves the call expression name and type | ||||
| /// or package and determines if it exists within the CallList | ||||
| func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool { | ||||
| 	selector, ident, err := GetCallInfo(n, ctx) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	return c.Contains(selector, ident) | ||||
| } | ||||
|   | ||||
							
								
								
									
										58
									
								
								core/call_list_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								core/call_list_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| package core | ||||
|  | ||||
| import ( | ||||
| 	"go/ast" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| type callListRule struct { | ||||
| 	MetaData | ||||
| 	callList CallList | ||||
| 	matched  int | ||||
| } | ||||
|  | ||||
| func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) { | ||||
| 	if r.callList.ContainsCallExpr(n, c) { | ||||
| 		r.matched += 1 | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func TestCallListContainsCallExpr(t *testing.T) { | ||||
| 	config := map[string]interface{}{"ignoreNosec": false} | ||||
| 	analyzer := NewAnalyzer(config, nil) | ||||
| 	rule := &callListRule{ | ||||
| 		MetaData: MetaData{ | ||||
| 			Severity:   Low, | ||||
| 			Confidence: Low, | ||||
| 			What:       "A dummy rule", | ||||
| 		}, | ||||
| 		callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"), | ||||
| 		matched:  0, | ||||
| 	} | ||||
| 	analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)}) | ||||
| 	source := ` | ||||
| 	package main | ||||
| 	import ( | ||||
| 		"bytes" | ||||
| 		"fmt" | ||||
| 	) | ||||
| 	func main() { | ||||
| 		var b bytes.Buffer | ||||
| 		b.Write([]byte("Hello ")) | ||||
| 		fmt.Fprintf(&b, "world!") | ||||
| 	}` | ||||
|  | ||||
| 	analyzer.ProcessSource("dummy.go", source) | ||||
| 	if rule.matched != 1 { | ||||
| 		t.Errorf("Expected to match a bytes.Buffer.Write call") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCallListContains(t *testing.T) { | ||||
| 	callList := NewCallList() | ||||
| 	callList.Add("fmt", "Printf") | ||||
| 	if !callList.Contains("fmt", "Printf") { | ||||
| 		t.Errorf("Expected call list to contain fmt.Printf") | ||||
| 	} | ||||
| } | ||||
| @@ -69,18 +69,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a | ||||
| 		importName = alias | ||||
| 	} | ||||
|  | ||||
| 	switch node := n.(type) { | ||||
| 	case *ast.CallExpr: | ||||
| 		switch fn := node.Fun.(type) { | ||||
| 		case *ast.SelectorExpr: | ||||
| 			switch expr := fn.X.(type) { | ||||
| 			case *ast.Ident: | ||||
| 				if expr.Name == importName { | ||||
| 					for _, name := range names { | ||||
| 						if fn.Sel.Name == name { | ||||
| 							return node, true | ||||
| 						} | ||||
| 					} | ||||
| 	if callExpr, ok := n.(*ast.CallExpr); ok { | ||||
| 		packageName, callName, err := GetCallInfo(callExpr, c) | ||||
| 		if err != nil { | ||||
| 			return nil, false | ||||
| 		} | ||||
| 		if packageName == importName { | ||||
| 			for _, name := range names { | ||||
| 				if callName == name { | ||||
| 					return callExpr, true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| @@ -95,19 +92,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a | ||||
| // 	node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write") | ||||
| // | ||||
| func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) { | ||||
| 	switch callExpr := n.(type) { | ||||
| 	case *ast.CallExpr: | ||||
| 		switch fn := callExpr.Fun.(type) { | ||||
| 		case *ast.SelectorExpr: | ||||
| 			switch expr := fn.X.(type) { | ||||
| 			case *ast.Ident: | ||||
| 				t := ctx.Info.TypeOf(expr) | ||||
| 				if t != nil && t.String() == requiredType { | ||||
| 					for _, call := range calls { | ||||
| 						if fn.Sel.Name == call { | ||||
| 							return callExpr, true | ||||
| 						} | ||||
| 					} | ||||
| 	if callExpr, ok := n.(*ast.CallExpr); ok { | ||||
| 		typeName, callName, err := GetCallInfo(callExpr, ctx) | ||||
| 		if err != nil { | ||||
| 			return nil, false | ||||
| 		} | ||||
| 		if typeName == requiredType { | ||||
| 			for _, call := range calls { | ||||
| 				if call == callName { | ||||
| 					return callExpr, true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| @@ -171,3 +164,28 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) { | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| // GetCallInfo returns the package or type and name  associated with a | ||||
| // call expression. | ||||
| func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) { | ||||
| 	switch node := n.(type) { | ||||
| 	case *ast.CallExpr: | ||||
| 		switch fn := node.Fun.(type) { | ||||
| 		case *ast.SelectorExpr: | ||||
| 			switch expr := fn.X.(type) { | ||||
| 			case *ast.Ident: | ||||
| 				if expr.Obj != nil && expr.Obj.Kind == ast.Var { | ||||
| 					t := ctx.Info.TypeOf(expr) | ||||
| 					if t != nil { | ||||
| 						return t.String(), fn.Sel.Name, nil | ||||
| 					} else { | ||||
| 						return "undefined", fn.Sel.Name, fmt.Errorf("missing type info") | ||||
| 					} | ||||
| 				} else { | ||||
| 					return expr.Name, fn.Sel.Name, nil | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return "", "", fmt.Errorf("unable to determine call info") | ||||
| } | ||||
|   | ||||
| @@ -56,4 +56,16 @@ func TestMatchCallByType(t *testing.T) { | ||||
| 	if rule.matched != 1 || len(rule.callExpr) != 1 { | ||||
| 		t.Errorf("Expected to match a bytes.Buffer.Write call") | ||||
| 	} | ||||
|  | ||||
| 	typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Unable to resolve call info: %v\n", err) | ||||
| 	} | ||||
| 	if typeName != "bytes.Buffer" { | ||||
| 		t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName) | ||||
| 	} | ||||
| 	if callName != "Write" { | ||||
| 		t.Errorf("Expected: %s, Got: %s\n", "Write", callName) | ||||
| 	} | ||||
|  | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user