1
0
mirror of https://github.com/mgechev/revive.git synced 2025-10-30 23:37:49 +02:00

feature: detect identical-branches in switch statements (#1448)

This commit is contained in:
chavacava
2025-07-30 15:55:03 +02:00
committed by GitHub
parent 9fc7dc7d77
commit 68ac5514f5
5 changed files with 169 additions and 73 deletions

View File

@@ -3,6 +3,8 @@ package astutils
import (
"bytes"
"crypto/md5"
"encoding/hex"
"fmt"
"go/ast"
"go/printer"
@@ -159,3 +161,13 @@ func GoFmt(x any) string {
gofmtConfig.Fprint(&buf, fs, x)
return buf.String()
}
// NodeHash yields the MD5 hash of the given AST node.
func NodeHash(node ast.Node) string {
hasher := func(in string) string {
binHash := md5.Sum([]byte(in))
return hex.EncodeToString(binHash[:])
}
str := GoFmt(node)
return hasher(str)
}

View File

@@ -54,9 +54,7 @@ func (k BranchKind) Branch() Branch { return Branch{BranchKind: k} }
// String returns a brief string representation.
func (k BranchKind) String() string {
switch k {
case Empty:
return ""
case Regular:
case Empty, Regular:
return ""
case Return:
return "return"

View File

@@ -1,10 +1,9 @@
package rule
import (
"crypto/md5"
"encoding/hex"
"fmt"
"go/ast"
"go/token"
"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/lint"
@@ -67,15 +66,67 @@ func (w *lintIdenticalBranches) Visit(node ast.Node) ast.Visitor {
return w
case *ast.SwitchStmt:
// TODO later
return w
if n.Tag == nil {
return w // do not lint untagged switches (order of case evaluation might be important)
}
w.lintSwitch(n)
return nil // switch branches already analyzed
default:
return w
}
}
func (*lintIdenticalBranches) isIfElseIf(n *ast.IfStmt) bool {
_, ok := n.Else.(*ast.IfStmt)
func (w *lintIdenticalBranches) lintSwitch(switchStmt *ast.SwitchStmt) {
doesFallthrough := func(stmts []ast.Stmt) bool {
if len(stmts) == 0 {
return false
}
ft, ok := stmts[len(stmts)-1].(*ast.BranchStmt)
return ok && ft.Tok == token.FALLTHROUGH
}
hashes := map[string]int{} // map hash(branch code) -> branch line
for _, cc := range switchStmt.Body.List {
caseClause := cc.(*ast.CaseClause)
if doesFallthrough(caseClause.Body) {
continue // skip fallthrough branches
}
branch := &ast.BlockStmt{
List: caseClause.Body,
}
hash := astutils.NodeHash(branch)
branchLine := w.file.ToPosition(caseClause.Pos()).Line
if matchLine, ok := hashes[hash]; ok {
w.newFailure(
switchStmt,
fmt.Sprintf(`"switch" with identical branches (lines %d and %d)`, matchLine, branchLine),
1.0,
)
}
hashes[hash] = branchLine
w.walkBranch(branch)
}
}
// walkBranch analyzes the given branch.
func (w *lintIdenticalBranches) walkBranch(branch ast.Stmt) {
if branch == nil {
return
}
walker := &lintIdenticalBranches{
onFailure: w.onFailure,
file: w.file,
}
ast.Walk(walker, branch)
}
func (*lintIdenticalBranches) isIfElseIf(node *ast.IfStmt) bool {
_, ok := node.Else.(*ast.IfStmt)
return ok
}
@@ -105,6 +156,15 @@ func (*lintIdenticalBranches) identicalBranches(body, elseBranch *ast.BlockStmt)
return bodyStr == elseStr
}
func (w *lintIdenticalBranches) newFailure(node ast.Node, msg string, confidence float64) {
w.onFailure(lint.Failure{
Confidence: confidence,
Node: node,
Category: lint.FailureCategoryLogic,
Failure: msg,
})
}
type lintIfChainIdenticalBranches struct {
file *lint.File // only necessary to retrieve the line number of branches
onFailure func(lint.Failure)
@@ -139,7 +199,7 @@ func (w *lintIfChainIdenticalBranches) Visit(node ast.Node) ast.Visitor {
}
// recursively analyze the then-branch
w.walkBranch(n.Body)
w.rootWalker.walkBranch(n.Body)
if n.Init == nil { // only check if without initialization to avoid false positives
w.addBranch(n.Body)
@@ -154,14 +214,13 @@ func (w *lintIfChainIdenticalBranches) Visit(node ast.Node) ast.Visitor {
w.Visit(chainedIf)
} else {
w.addBranch(n.Else)
w.walkBranch(n.Else)
w.rootWalker.walkBranch(n.Else)
}
}
identicalBranches := w.identicalBranches(w.branches)
for _, branchPair := range identicalBranches {
branchLines := w.getStmtLines(branchPair)
msg := fmt.Sprintf(`"if...else if" chain with identical branches (lines %v)`, branchLines)
msg := fmt.Sprintf(`"if...else if" chain with identical branches (lines %d and %d)`, branchPair[0], branchPair[1])
confidence := 1.0
if w.hasComplexCondition {
confidence = 0.8
@@ -173,31 +232,6 @@ func (w *lintIfChainIdenticalBranches) Visit(node ast.Node) ast.Visitor {
return nil
}
// getStmtLines yields the start line number of the given statements.
func (w *lintIfChainIdenticalBranches) getStmtLines(stmts []ast.Stmt) []int {
result := []int{}
for _, stmt := range stmts {
pos := w.file.ToPosition(stmt.Pos())
result = append(result, pos.Line)
}
return result
}
// walkBranch analyzes the given branch.
func (w *lintIfChainIdenticalBranches) walkBranch(branch ast.Stmt) {
if branch == nil {
return
}
walker := &lintIfChainIdenticalBranches{
onFailure: w.onFailure,
file: w.file,
rootWalker: w.rootWalker,
}
ast.Walk(walker, branch)
}
// isComplexCondition returns true if the given expression is "complex", false otherwise.
// An expression is considered complex if it has a function call.
func (*lintIfChainIdenticalBranches) isComplexCondition(expr ast.Expr) bool {
@@ -209,38 +243,23 @@ func (*lintIfChainIdenticalBranches) isComplexCondition(expr ast.Expr) bool {
return len(calls) > 0
}
// identicalBranches yields pairs of identical branches from the given branches.
func (*lintIfChainIdenticalBranches) identicalBranches(branches []ast.Stmt) [][]ast.Stmt {
result := [][]ast.Stmt{}
// identicalBranches yields pairs of (line numbers) of identical branches from the given branches.
func (w *lintIfChainIdenticalBranches) identicalBranches(branches []ast.Stmt) [][]int {
result := [][]int{}
if len(branches) < 2 {
return result // only one branch to compare thus we return
}
hasher := func(in string) string {
binHash := md5.Sum([]byte(in))
return hex.EncodeToString(binHash[:])
}
hashes := map[string]ast.Stmt{}
hashes := map[string]int{} // branch code hash -> branch line
for _, branch := range branches {
str := astutils.GoFmt(branch)
hash := hasher(str)
hash := astutils.NodeHash(branch)
branchLine := w.file.ToPosition(branch.Pos()).Line
if match, ok := hashes[hash]; ok {
result = append(result, []ast.Stmt{match, branch})
result = append(result, []int{match, branchLine})
}
hashes[hash] = branch
hashes[hash] = branchLine
}
return result
}
func (w *lintIdenticalBranches) newFailure(node ast.Node, msg string, confidence float64) {
w.onFailure(lint.Failure{
Confidence: confidence,
Node: node,
Category: lint.FailureCategoryLogic,
Failure: msg,
})
}

View File

@@ -480,8 +480,8 @@ func checkURLTag(checkCtx *checkContext, tag *structtag.Tag, _ ast.Expr) (messag
var delimiter = ""
for _, opt := range tag.Options {
switch opt {
case "int", "omitempty", "numbered", "brackets":
case "unix", "unixmilli", "unixnano": // TODO : check that the field is of type time.Time
case "int", "omitempty", "numbered", "brackets",
"unix", "unixmilli", "unixnano": // TODO : check that the field is of type time.Time
case "comma", "semicolon", "space":
if delimiter == "" {
delimiter = opt
@@ -597,9 +597,7 @@ func typeValueMatch(t ast.Expr, val string) bool {
case "int":
_, err := strconv.ParseInt(val, 10, 64)
typeMatches = err == nil
case "string":
case "nil":
default:
default: // "string", "nil", ...
// unchecked type
}

View File

@@ -38,7 +38,7 @@ func identicalBranches() {
println("else")
}
if true { // MATCH /"if...else if" chain with identical branches (lines [41 49])/
if true { // MATCH /"if...else if" chain with identical branches (lines 41 and 49)/
print("something")
} else if true {
print("something else")
@@ -50,7 +50,7 @@ func identicalBranches() {
print("something")
}
if true { // MATCH /"if...else if" chain with identical branches (lines [53 59])/
if true { // MATCH /"if...else if" chain with identical branches (lines 53 and 59)/
print("something")
} else if true {
print("something else")
@@ -66,7 +66,7 @@ func identicalBranches() {
print("something")
} else if true {
print("something else")
if true { // MATCH /"if...else if" chain with identical branches (lines [69 71])/
if true { // MATCH /"if...else if" chain with identical branches (lines 69 and 71)/
print("something")
} else if false {
print("something")
@@ -92,10 +92,10 @@ func identicalBranches() {
} else if d {
bar()
}
// MATCH:86 /"if...else if" chain with identical branches (lines [86 90])/
// MATCH:86 /"if...else if" chain with identical branches (lines [88 92])/
// MATCH:86 /"if...else if" chain with identical branches (lines 86 and 90)/
// MATCH:86 /"if...else if" chain with identical branches (lines 88 and 92)/
if createFile() { // json:{"MATCH": "\"if...else if\" chain with identical branches (lines [98 102])","Confidence": 0.8}
if createFile() { // json:{"MATCH": "\"if...else if\" chain with identical branches (lines 98 and 102)","Confidence": 0.8}
doSomething()
} else if !delete() {
return new("cannot delete file")
@@ -106,11 +106,80 @@ func identicalBranches() {
}
// Test confidence is reset
if a { // json:{"MATCH": "\"if...else if\" chain with identical branches (lines [109 111])","Confidence": 1}
if a { // json:{"MATCH": "\"if...else if\" chain with identical branches (lines 109 and 111)","Confidence": 1}
foo()
} else if b {
foo()
} else {
bar()
}
switch a { // MATCH /"switch" with identical branches (lines 119 and 123)/
// expected values
case 1:
foo()
case 2:
bar()
case 3:
foo()
default:
return newError("blah")
}
// MATCH:131 /"switch" with identical branches (lines 133 and 137)/
// MATCH:131 /"switch" with identical branches (lines 135 and 139)/
switch a {
// expected values
case 1:
foo()
case 2:
bar()
case 3:
foo()
default:
bar()
}
switch a { // MATCH /"switch" with identical branches (lines 145 and 147)/
// expected values
case 1:
foo()
case 3:
foo()
default:
if true { // MATCH /"if...else if" chain with identical branches (lines 150 and 152)/
something()
} else if true {
something()
} else {
if true { // MATCH /both branches of the if are identical/
print("identical")
} else {
print("identical")
}
}
}
// Skip untagged switch
switch {
case a > b:
foo()
default:
foo()
}
// Do not warn on fallthrough
switch a {
case 1:
foo()
fallthrough
case 2:
fallthrough
case 3:
foo()
case 4:
fallthrough
default:
bar()
}
}