1
0
mirror of https://github.com/mgechev/revive.git synced 2025-07-15 01:04:40 +02:00

refactor: moves code related to AST from rule.utils into astutils package (#1380)

Modifications summary:

* Moves AST-related functions from rule/utils.go to astutils/ast_utils.go (+ modifies function calls)
* Renames some of these AST-related functions
* Avoids instantiating a printer config at each call to astutils.GoFmt
* Uses astutils.IsIdent and astutils.IsPkgDotName when possible
This commit is contained in:
chavacava
2025-05-26 13:18:38 +02:00
committed by GitHub
parent 87b146c60e
commit 92f28cb5e1
35 changed files with 164 additions and 148 deletions

View File

@ -2,8 +2,12 @@
package astutils
import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"go/token"
"regexp"
"slices"
)
@ -78,9 +82,80 @@ func getFieldTypeName(typ ast.Expr) string {
}
}
// IsStringLiteral returns true if the given expression is a string literal, false otherwise
// IsStringLiteral returns true if the given expression is a string literal, false otherwise.
func IsStringLiteral(e ast.Expr) bool {
sl, ok := e.(*ast.BasicLit)
return ok && sl.Kind == token.STRING
}
// IsCgoExported returns true if the given function declaration is exported as Cgo function, false otherwise.
func IsCgoExported(f *ast.FuncDecl) bool {
if f.Recv != nil || f.Doc == nil {
return false
}
cgoExport := regexp.MustCompile(fmt.Sprintf("(?m)^//export %s$", regexp.QuoteMeta(f.Name.Name)))
for _, c := range f.Doc.List {
if cgoExport.MatchString(c.Text) {
return true
}
}
return false
}
// IsIdent returns true if the given expression is the identifier with name ident, false otherwise.
func IsIdent(expr ast.Expr, ident string) bool {
id, ok := expr.(*ast.Ident)
return ok && id.Name == ident
}
// IsPkgDotName returns true if the given expression is a selector expression of the form <pkg>.<name>, false otherwise.
func IsPkgDotName(expr ast.Expr, pkg, name string) bool {
sel, ok := expr.(*ast.SelectorExpr)
return ok && IsIdent(sel.X, pkg) && IsIdent(sel.Sel, name)
}
// PickNodes yields a list of nodes by picking them from a sub-ast with root node n.
// Nodes are selected by applying the selector function
func PickNodes(n ast.Node, selector func(n ast.Node) bool) []ast.Node {
var result []ast.Node
if n == nil {
return result
}
onSelect := func(n ast.Node) {
result = append(result, n)
}
p := picker{selector: selector, onSelect: onSelect}
ast.Walk(p, n)
return result
}
type picker struct {
selector func(n ast.Node) bool
onSelect func(n ast.Node)
}
func (p picker) Visit(node ast.Node) ast.Visitor {
if p.selector == nil {
return nil
}
if p.selector(node) {
p.onSelect(node)
}
return p
}
var gofmtConfig = &printer.Config{Tabwidth: 8}
// GoFmt returns a string representation of an AST subtree.
func GoFmt(x any) string {
buf := bytes.Buffer{}
fs := token.NewFileSet()
gofmtConfig.Fprint(&buf, fs, x)
return buf.String()
}

View File

@ -5,6 +5,7 @@ import (
"go/token"
"go/types"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -76,9 +77,9 @@ func (w atomic) Visit(node ast.Node) ast.Visitor {
broken := false
if uarg, ok := arg.(*ast.UnaryExpr); ok && uarg.Op == token.AND {
broken = gofmt(left) == gofmt(uarg.X)
broken = astutils.GoFmt(left) == astutils.GoFmt(uarg.X)
} else if star, ok := left.(*ast.StarExpr); ok {
broken = gofmt(star.X) == gofmt(arg)
broken = astutils.GoFmt(star.X) == astutils.GoFmt(arg)
}
if broken {

View File

@ -6,6 +6,7 @@ import (
"strings"
"sync"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -190,7 +191,7 @@ func (w *lintConfusingNames) Visit(n ast.Node) ast.Visitor {
// Exclude naming warnings for functions that are exported to C but
// not exported in the Go API.
// See https://github.com/golang/lint/issues/144.
if ast.IsExported(v.Name.Name) || !isCgoExported(v) {
if ast.IsExported(v.Name.Name) || !astutils.IsCgoExported(v) {
checkMethodName(getStructName(v.Recv), v.Name, w)
}
case *ast.TypeSpec:

View File

@ -3,6 +3,7 @@ package rule
import (
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -28,7 +29,7 @@ func (*ConfusingResultsRule) Apply(file *lint.File, _ lint.Arguments) []lint.Fai
lastType := ""
for _, result := range funcDecl.Type.Results.List {
resultTypeName := gofmt(result.Type)
resultTypeName := astutils.GoFmt(result.Type)
if resultTypeName == lastType {
failures = append(failures, lint.Failure{

View File

@ -4,6 +4,7 @@ import (
"go/ast"
"go/token"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -40,7 +41,7 @@ func (w *lintConstantLogicalExpr) Visit(node ast.Node) ast.Visitor {
return w
}
subExpressionsAreNotEqual := gofmt(n.X) != gofmt(n.Y)
subExpressionsAreNotEqual := astutils.GoFmt(n.X) != astutils.GoFmt(n.Y)
if subExpressionsAreNotEqual {
return w // nothing to say
}

View File

@ -5,6 +5,7 @@ import (
"go/ast"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -28,7 +29,7 @@ func (r *ContextAsArgumentRule) Apply(file *lint.File, _ lint.Arguments) []lint.
// Flag any that show up after the first.
isCtxStillAllowed := true
for _, arg := range fnArgs {
argIsCtx := isPkgDot(arg.Type, "context", "Context")
argIsCtx := astutils.IsPkgDotName(arg.Type, "context", "Context")
if argIsCtx && !isCtxStillAllowed {
failures = append(failures, lint.Failure{
Node: arg,
@ -40,7 +41,7 @@ func (r *ContextAsArgumentRule) Apply(file *lint.File, _ lint.Arguments) []lint.
break // only flag one
}
typeName := gofmt(arg.Type)
typeName := astutils.GoFmt(arg.Type)
// a parameter of type context.Context is still allowed if the current arg type is in the allow types LookUpTable
_, isCtxStillAllowed = r.allowTypes[typeName]
}

View File

@ -5,6 +5,7 @@ import (
"go/ast"
"go/types"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -51,15 +52,7 @@ func (w lintContextKeyTypes) Visit(n ast.Node) ast.Visitor {
func checkContextKeyType(w lintContextKeyTypes, x *ast.CallExpr) {
f := w.file
sel, ok := x.Fun.(*ast.SelectorExpr)
if !ok {
return
}
pkg, ok := sel.X.(*ast.Ident)
if !ok || pkg.Name != "context" {
return
}
if sel.Sel.Name != "WithValue" {
if !astutils.IsPkgDotName(x.Fun, "context", "WithValue") {
return
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -111,7 +112,7 @@ func (w lintFunctionForDataRaces) Visit(node ast.Node) ast.Visitor {
return ok
}
ids := pick(funcLit.Body, selectIDs)
ids := astutils.PickNodes(funcLit.Body, selectIDs)
for _, id := range ids {
id := id.(*ast.Ident)
_, isRangeID := w.rangeIDs[id.Obj]

View File

@ -4,6 +4,7 @@ import (
"fmt"
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -106,7 +107,7 @@ func (w lintDeferRule) Visit(node ast.Node) ast.Visitor {
w.newFailure("return in a defer function has no effect", n, 1.0, lint.FailureCategoryLogic, deferOptionReturn)
}
case *ast.CallExpr:
isCallToRecover := isIdent(n.Fun, "recover")
isCallToRecover := astutils.IsIdent(n.Fun, "recover")
switch {
case !w.inADefer && isCallToRecover:
// func fn() { recover() }
@ -122,7 +123,7 @@ func (w lintDeferRule) Visit(node ast.Node) ast.Visitor {
}
return nil // no need to analyze the arguments of the function call
case *ast.DeferStmt:
if isIdent(n.Call.Fun, "recover") {
if astutils.IsIdent(n.Call.Fun, "recover") {
// defer recover()
//
// confidence is not truly 1 because this could be in a correctly-deferred func,

View File

@ -4,6 +4,7 @@ import (
"fmt"
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -101,8 +102,7 @@ func (r *EnforceMapStyleRule) Apply(file *lint.File, _ lint.Arguments) []lint.Fa
return true
}
ident, ok := v.Fun.(*ast.Ident)
if !ok || ident.Name != "make" {
if !astutils.IsIdent(v.Fun, "make") {
return true
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -130,8 +131,8 @@ func (r *EnforceRepeatedArgTypeStyleRule) Apply(file *lint.File, _ lint.Argument
if fn.Type.Params != nil {
var prevType ast.Expr
for _, field := range fn.Type.Params.List {
prevTypeStr := gofmt(prevType)
currentTypeStr := gofmt(field.Type)
prevTypeStr := astutils.GoFmt(prevType)
currentTypeStr := astutils.GoFmt(field.Type)
if currentTypeStr == prevTypeStr {
failures = append(failures, lint.Failure{
Confidence: 1,
@ -163,8 +164,8 @@ func (r *EnforceRepeatedArgTypeStyleRule) Apply(file *lint.File, _ lint.Argument
if fn.Type.Results != nil {
var prevType ast.Expr
for _, field := range fn.Type.Results.List {
prevTypeStr := gofmt(prevType)
currentTypeStr := gofmt(field.Type)
prevTypeStr := astutils.GoFmt(prevType)
currentTypeStr := astutils.GoFmt(field.Type)
if field.Names != nil && currentTypeStr == prevTypeStr {
failures = append(failures, lint.Failure{
Confidence: 1,

View File

@ -4,6 +4,7 @@ import (
"fmt"
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -117,8 +118,7 @@ func (r *EnforceSliceStyleRule) Apply(file *lint.File, _ lint.Arguments) []lint.
return true
}
ident, ok := v.Fun.(*ast.Ident)
if !ok || ident.Name != "make" {
if !astutils.IsIdent(v.Fun, "make") {
return true
}

View File

@ -6,6 +6,7 @@ import (
"go/token"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -56,7 +57,7 @@ func (w lintErrors) Visit(_ ast.Node) ast.Visitor {
if !ok {
continue
}
if !isPkgDot(ce.Fun, "errors", "New") && !isPkgDot(ce.Fun, "fmt", "Errorf") {
if !astutils.IsPkgDotName(ce.Fun, "errors", "New") && !astutils.IsPkgDotName(ce.Fun, "fmt", "Errorf") {
continue
}

View File

@ -3,6 +3,7 @@ package rule
import (
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -21,7 +22,7 @@ func (*ErrorReturnRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure
}
funcResults := funcDecl.Type.Results.List
isLastResultError := isIdent(funcResults[len(funcResults)-1].Type, "error")
isLastResultError := astutils.IsIdent(funcResults[len(funcResults)-1].Type, "error")
if isLastResultError {
continue
}
@ -29,7 +30,7 @@ func (*ErrorReturnRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure
// An error return parameter should be the last parameter.
// Flag any error parameters found before the last.
for _, r := range funcResults[:len(funcResults)-1] {
if isIdent(r.Type, "error") {
if astutils.IsIdent(r.Type, "error") {
failures = append(failures, lint.Failure{
Category: lint.FailureCategoryStyle,
Confidence: 0.9,

View File

@ -6,6 +6,7 @@ import (
"regexp"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -47,7 +48,7 @@ func (w lintErrorf) Visit(n ast.Node) ast.Visitor {
if !ok || len(ce.Args) != 1 {
return w
}
isErrorsNew := isPkgDot(ce.Fun, "errors", "New")
isErrorsNew := astutils.IsPkgDotName(ce.Fun, "errors", "New")
var isTestingError bool
se, ok := ce.Fun.(*ast.SelectorExpr)
if ok && se.Sel.Name == "Error" {
@ -60,7 +61,7 @@ func (w lintErrorf) Visit(n ast.Node) ast.Visitor {
}
arg := ce.Args[0]
ce, ok = arg.(*ast.CallExpr)
if !ok || !isPkgDot(ce.Fun, "fmt", "Sprintf") {
if !ok || !astutils.IsPkgDotName(ce.Fun, "fmt", "Sprintf") {
return w
}
errorfPrefix := "fmt"

View File

@ -4,6 +4,7 @@ import (
"fmt"
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -26,7 +27,7 @@ func (*FlagParamRule) Apply(file *lint.File, _ lint.Arguments) []lint.Failure {
boolParams := map[string]struct{}{}
for _, param := range fd.Type.Params.List {
if !isIdent(param.Type, "bool") {
if !astutils.IsIdent(param.Type, "bool") {
continue
}
@ -77,7 +78,7 @@ func (w conditionVisitor) Visit(node ast.Node) ast.Visitor {
return w.idents[ident.Name] == struct{}{}
}
uses := pick(ifStmt.Cond, findUsesOfIdents)
uses := astutils.PickNodes(ifStmt.Cond, findUsesOfIdents)
if len(uses) < 1 {
return w

View File

@ -3,6 +3,7 @@ package rule
import (
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -64,12 +65,12 @@ func (*lintIdenticalBranches) identicalBranches(branches []*ast.BlockStmt) bool
return false // only one branch to compare thus we return
}
referenceBranch := gofmt(branches[0])
referenceBranch := astutils.GoFmt(branches[0])
referenceBranchSize := len(branches[0].List)
for i := 1; i < len(branches); i++ {
currentBranch := branches[i]
currentBranchSize := len(currentBranch.List)
if currentBranchSize != referenceBranchSize || gofmt(currentBranch) != referenceBranch {
if currentBranchSize != referenceBranchSize || astutils.GoFmt(currentBranch) != referenceBranch {
return false
}
}

View File

@ -5,6 +5,7 @@ import (
"go/token"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -116,7 +117,7 @@ func (r *ModifiesValRecRule) findReturnReceiverStatements(receiverName string, t
return false
}
return pick(target, finder)
return astutils.PickNodes(target, finder)
}
func (r *ModifiesValRecRule) mustSkip(receiver *ast.Field, pkg *lint.Package) bool {
@ -179,5 +180,5 @@ func (r *ModifiesValRecRule) getReceiverModifications(receiverName string, funcB
return false
}
return pick(funcBody, receiverModificationFinder)
return astutils.PickNodes(funcBody, receiverModificationFinder)
}

View File

@ -5,6 +5,7 @@ import (
"go/ast"
"go/token"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -63,20 +64,20 @@ func (w lintOptimizeOperandsOrderExpr) Visit(node ast.Node) ast.Visitor {
}
// check if the left sub-expression contains a function call
nodes := pick(binExpr.X, isCaller)
nodes := astutils.PickNodes(binExpr.X, isCaller)
if len(nodes) < 1 {
return w
}
// check if the right sub-expression does not contain a function call
nodes = pick(binExpr.Y, isCaller)
nodes = astutils.PickNodes(binExpr.Y, isCaller)
if len(nodes) > 0 {
return w
}
newExpr := ast.BinaryExpr{X: binExpr.Y, Y: binExpr.X, Op: binExpr.Op}
w.onFailure(lint.Failure{
Failure: fmt.Sprintf("for better performance '%v' might be rewritten as '%v'", gofmt(binExpr), gofmt(&newExpr)),
Failure: fmt.Sprintf("for better performance '%v' might be rewritten as '%v'", astutils.GoFmt(binExpr), astutils.GoFmt(&newExpr)),
Node: node,
Category: lint.FailureCategoryOptimization,
Confidence: 0.3,

View File

@ -5,6 +5,7 @@ import (
"go/ast"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -43,7 +44,7 @@ func (w *lintRanges) Visit(node ast.Node) ast.Visitor {
// for x = range m { ... }
return w // single var form
}
if !isIdent(rs.Value, "_") {
if !astutils.IsIdent(rs.Value, "_") {
// for ?, y = range m { ... }
return w
}

View File

@ -8,6 +8,7 @@ import (
"strings"
"github.com/fatih/structtag"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -390,7 +391,7 @@ func checkCompoundPropertiesOption(key, value string, fieldType ast.Expr, seenOp
return msgTypeMismatch, false
}
case "layout":
if gofmt(fieldType) != "time.Time" {
if astutils.GoFmt(fieldType) != "time.Time" {
return "layout option is only applicable to fields of type time.Time", false
}
}

View File

@ -8,6 +8,7 @@ import (
"strconv"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
"github.com/mgechev/revive/logging"
)
@ -61,7 +62,7 @@ func (w lintTimeDate) Visit(n ast.Node) ast.Visitor {
if !ok || len(ce.Args) != timeDateArity {
return w
}
if !isPkgDot(ce.Fun, "time", "Date") {
if !astutils.IsPkgDotName(ce.Fun, "time", "Date") {
return w
}

View File

@ -5,6 +5,7 @@ import (
"go/ast"
"go/token"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -66,7 +67,7 @@ func (l *lintTimeEqual) Visit(node ast.Node) ast.Visitor {
Category: lint.FailureCategoryTime,
Confidence: 1,
Node: node,
Failure: fmt.Sprintf("use %s%s.Equal(%s) instead of %q operator", negateStr, gofmt(expr.X), gofmt(expr.Y), expr.Op),
Failure: fmt.Sprintf("use %s%s.Equal(%s) instead of %q operator", negateStr, astutils.GoFmt(expr.X), astutils.GoFmt(expr.Y), expr.Op),
})
return l

View File

@ -5,6 +5,7 @@ import (
"fmt"
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -70,12 +71,7 @@ type lintUncheckedTypeAssertion struct {
}
func isIgnored(e ast.Expr) bool {
ident, ok := e.(*ast.Ident)
if !ok {
return false
}
return ident.Name == "_"
return astutils.IsIdent(e, "_")
}
func isTypeSwitch(e *ast.TypeAssertExpr) bool {
@ -177,7 +173,7 @@ func (w *lintUncheckedTypeAssertion) Visit(node ast.Node) ast.Visitor {
}
func (w *lintUncheckedTypeAssertion) addFailure(n *ast.TypeAssertExpr, why string) {
s := fmt.Sprintf("type cast result is unchecked in %v - %s", gofmt(n), why)
s := fmt.Sprintf("type cast result is unchecked in %v - %s", astutils.GoFmt(n), why)
w.onFailure(lint.Failure{
Category: lint.FailureCategoryBadPractice,
Confidence: 1,

View File

@ -3,6 +3,7 @@ package rule
import (
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -174,7 +175,7 @@ func (*lintUnconditionalRecursionRule) hasControlExit(node ast.Node) bool {
case *ast.ReturnStmt:
return true
case *ast.CallExpr:
if isIdent(n.Fun, "panic") {
if astutils.IsIdent(n.Fun, "panic") {
return true
}
se, ok := n.Fun.(*ast.SelectorExpr)
@ -197,5 +198,5 @@ func (*lintUnconditionalRecursionRule) hasControlExit(node ast.Node) bool {
return false
}
return len(pick(node, isExit)) != 0
return len(astutils.PickNodes(node, isExit)) != 0
}

View File

@ -8,6 +8,7 @@ import (
"regexp"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -122,7 +123,7 @@ func (w *lintUnhandledErrors) addFailure(n *ast.CallExpr) {
func (w *lintUnhandledErrors) funcName(call *ast.CallExpr) string {
fn, ok := w.getFunc(call)
if !ok {
return gofmt(call.Fun)
return astutils.GoFmt(call.Fun)
}
name := fn.FullName()

View File

@ -94,7 +94,7 @@ func (w lintUnnecessaryFormat) Visit(n ast.Node) ast.Visitor {
return w
}
funcName := gofmt(ce.Fun)
funcName := astutils.GoFmt(ce.Fun)
spec, ok := formattingFuncs[funcName]
if !ok {
return w
@ -110,7 +110,7 @@ func (w lintUnnecessaryFormat) Visit(n ast.Node) ast.Visitor {
return w
}
format := gofmt(arg)
format := astutils.GoFmt(arg)
if strings.Contains(format, `%`) {
return w

View File

@ -5,6 +5,7 @@ import (
"go/ast"
"regexp"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -122,7 +123,7 @@ func (w lintUnusedParamRule) Visit(node ast.Node) ast.Visitor {
return false
}
_ = pick(funcBody, fselect)
_ = astutils.PickNodes(funcBody, fselect)
for _, p := range funcType.Params.List {
for _, n := range p.Names {

View File

@ -5,6 +5,7 @@ import (
"go/ast"
"regexp"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -82,7 +83,7 @@ func (r *UnusedReceiverRule) Apply(file *lint.File, _ lint.Arguments) []lint.Fai
return isAnID && ident.Obj == recID.Obj
}
receiverUses := pick(funcDecl.Body, selectReceiverUses)
receiverUses := astutils.PickNodes(funcDecl.Body, selectReceiverUses)
if len(receiverUses) > 0 {
continue // the receiver is referenced in the func body

View File

@ -3,6 +3,7 @@ package rule
import (
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -39,7 +40,7 @@ func (w lintFmtErrorf) Visit(n ast.Node) ast.Visitor {
return w // not a function call
}
isFmtErrorf := isPkgDot(funcCall.Fun, "fmt", "Errorf")
isFmtErrorf := astutils.IsPkgDotName(funcCall.Fun, "fmt", "Errorf")
if !isFmtErrorf {
return w // not a call to fmt.Errorf
}

View File

@ -1,10 +1,7 @@
package rule
import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"go/token"
"regexp"
"strings"
@ -26,31 +23,6 @@ var exitFunctions = map[string]map[string]bool{
},
}
func isCgoExported(f *ast.FuncDecl) bool {
if f.Recv != nil || f.Doc == nil {
return false
}
cgoExport := regexp.MustCompile(fmt.Sprintf("(?m)^//export %s$", regexp.QuoteMeta(f.Name.Name)))
for _, c := range f.Doc.List {
if cgoExport.MatchString(c.Text) {
return true
}
}
return false
}
func isIdent(expr ast.Expr, ident string) bool {
id, ok := expr.(*ast.Ident)
return ok && id.Name == ident
}
// isPkgDot checks if the expression is <pkg>.<name>
func isPkgDot(expr ast.Expr, pkg, name string) bool {
sel, ok := expr.(*ast.SelectorExpr)
return ok && isIdent(sel.X, pkg) && isIdent(sel.Sel, name)
}
func srcLine(src []byte, p token.Position) string {
// Run to end of line in both directions if not at line start/end.
lo, hi := p.Offset, p.Offset+1
@ -63,48 +35,6 @@ func srcLine(src []byte, p token.Position) string {
return string(src[lo:hi])
}
// pick yields a list of nodes by picking them from a sub-ast with root node n.
// Nodes are selected by applying the fselect function
func pick(n ast.Node, fselect func(n ast.Node) bool) []ast.Node {
var result []ast.Node
if n == nil {
return result
}
onSelect := func(n ast.Node) {
result = append(result, n)
}
p := picker{fselect: fselect, onSelect: onSelect}
ast.Walk(p, n)
return result
}
type picker struct {
fselect func(n ast.Node) bool
onSelect func(n ast.Node)
}
func (p picker) Visit(node ast.Node) ast.Visitor {
if p.fselect == nil {
return nil
}
if p.fselect(node) {
p.onSelect(node)
}
return p
}
// gofmt returns a string representation of an AST subtree.
func gofmt(x any) string {
buf := bytes.Buffer{}
fs := token.NewFileSet()
printer.Fprint(&buf, fs, x)
return buf.String()
}
// checkNumberOfArguments fails if the given number of arguments is not, at least, the expected one
func checkNumberOfArguments(expected int, args lint.Arguments, ruleName string) error {
if len(args) < expected {

View File

@ -7,6 +7,7 @@ import (
"go/types"
"strings"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -79,14 +80,14 @@ func (w *lintVarDeclarations) Visit(node ast.Node) ast.Visitor {
rhs := v.Values[0]
// An underscore var appears in a common idiom for compile-time interface satisfaction,
// as in "var _ Interface = (*Concrete)(nil)".
if isIdent(v.Names[0], "_") {
if astutils.IsIdent(v.Names[0], "_") {
return nil
}
// If the RHS is a isZero value, suggest dropping it.
isZero := false
if lit, ok := rhs.(*ast.BasicLit); ok {
isZero = isZeroValue(lit.Value, v.Type)
} else if isIdent(rhs, "nil") {
} else if astutils.IsIdent(rhs, "nil") {
isZero = true
}
if isZero {
@ -122,7 +123,7 @@ func (w *lintVarDeclarations) Visit(node ast.Node) ast.Visitor {
return nil
}
// If the RHS is an untyped const, only warn if the LHS type is its default type.
if defType, ok := w.file.IsUntypedConst(rhs); ok && !isIdent(v.Type, defType) {
if defType, ok := w.file.IsUntypedConst(rhs); ok && !astutils.IsIdent(v.Type, defType) {
return nil
}

View File

@ -8,6 +8,7 @@ import (
"strings"
"sync"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -267,7 +268,7 @@ func (w *lintNames) Visit(n ast.Node) ast.Visitor {
// Exclude naming warnings for functions that are exported to C but
// not exported in the Go API.
// See https://github.com/golang/lint/issues/144.
if ast.IsExported(v.Name.Name) || !isCgoExported(v) {
if ast.IsExported(v.Name.Name) || !astutils.IsCgoExported(v) {
w.check(v.Name, thing)
}

View File

@ -3,6 +3,7 @@ package rule
import (
"go/ast"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
)
@ -40,7 +41,7 @@ func (w lintWaitGroupByValueRule) Visit(node ast.Node) ast.Visitor {
// Check all function parameters
for _, field := range fd.Type.Params.List {
if !w.isWaitGroup(field.Type) {
if !astutils.IsPkgDotName(field.Type, "sync", "WaitGroup") {
continue
}
@ -53,14 +54,3 @@ func (w lintWaitGroupByValueRule) Visit(node ast.Node) ast.Visitor {
return nil // skip visiting function body
}
func (lintWaitGroupByValueRule) isWaitGroup(ft ast.Expr) bool {
se, ok := ft.(*ast.SelectorExpr)
if !ok {
return false
}
x, _ := se.X.(*ast.Ident)
sel := se.Sel.Name
return x.Name == "sync" && sel == "WaitGroup"
}

View File

@ -1,5 +1,11 @@
package fixtures
import (
ast "go/ast"
"github.com/mgechev/revive/lint"
)
func foo(a, b, c, d int) {
switch n := node.(type) { // MATCH /switch with only one case can be replaced by an if-then/
case *ast.SwitchStmt:
@ -7,7 +13,7 @@ func foo(a, b, c, d int) {
_, ok := n.(*ast.CaseClause)
return ok
}
cases := pick(n.Body, caseSelector, nil)
cases := astutils.PickNodes(n.Body, caseSelector, nil)
if len(cases) == 1 {
cs, ok := cases[0].(*ast.CaseClause)
if ok && len(cs.List) == 1 {