1
0
mirror of https://github.com/mgechev/revive.git synced 2024-12-04 10:24:49 +02:00

feat: optional extension to early-return rule (#1133) (#1138)

This commit is contained in:
Miles Delahunty 2024-11-28 18:51:33 +11:00 committed by GitHub
parent 777abc9c35
commit 7e1d35d8d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 311 additions and 107 deletions

View File

@ -348,12 +348,13 @@ if !cond {
_Configuration_: ([]string) rule flags. Available flags are:
* _preserveScope_: do not suggest refactorings that would increase variable scope
* _allowJump_: suggest a new jump (`return`, `continue` or `break` statement) if it could unnest multiple statements. By default, only relocation of _existing_ jumps (i.e. from the `else` clause) are suggested.
Example:
```toml
[rule.early-return]
arguments = ["preserveScope"]
arguments = ["preserveScope", "allowJump"]
```
## empty-block

View File

@ -4,8 +4,15 @@ package ifelse
// that would enlarge variable scope
const PreserveScope = "preserveScope"
// AllowJump is a configuration argument that permits early-return to
// suggest introducing a new jump (return, continue, etc) statement
// to reduce nesting. By default, suggestions only bring existing jumps
// earlier.
const AllowJump = "allowJump"
// Args contains arguments common to the early-return, indent-error-flow
// and superfluous-else rules (currently just preserveScope)
// and superfluous-else rules
type Args struct {
PreserveScope bool
AllowJump bool
}

View File

@ -9,8 +9,8 @@ import (
// Branch contains information about a branch within an if-else chain.
type Branch struct {
BranchKind
Call // The function called at the end for kind Panic or Exit.
HasDecls bool // The branch has one or more declarations (at the top level block)
Call // The function called at the end for kind Panic or Exit.
block []ast.Stmt
}
// BlockBranch gets the Branch of an ast.BlockStmt.
@ -21,7 +21,7 @@ func BlockBranch(block *ast.BlockStmt) Branch {
}
branch := StmtBranch(block.List[blockLen-1])
branch.HasDecls = hasDecls(block)
branch.block = block.List
return branch
}
@ -61,11 +61,14 @@ func StmtBranch(stmt ast.Stmt) Branch {
// String returns a brief string representation
func (b Branch) String() string {
switch b.BranchKind {
case Empty:
return "{ }"
case Regular:
return "{ ... }"
case Panic, Exit:
return fmt.Sprintf("... %v()", b.Call)
default:
return b.BranchKind.String()
return fmt.Sprintf("{ ... %v() }", b.Call)
}
return fmt.Sprintf("{ ... %v }", b.BranchKind)
}
// LongString returns a longer form string representation
@ -73,13 +76,13 @@ func (b Branch) LongString() string {
switch b.BranchKind {
case Panic, Exit:
return fmt.Sprintf("call to %v function", b.Call)
default:
return b.BranchKind.LongString()
}
return b.BranchKind.LongString()
}
func hasDecls(block *ast.BlockStmt) bool {
for _, stmt := range block.List {
// HasDecls returns whether the branch has any top-level declarations
func (b Branch) HasDecls() bool {
for _, stmt := range b.block {
switch stmt := stmt.(type) {
case *ast.DeclStmt:
return true
@ -91,3 +94,22 @@ func hasDecls(block *ast.BlockStmt) bool {
}
return false
}
// IsShort returns whether the branch is empty or consists of a single statement
func (b Branch) IsShort() bool {
switch len(b.block) {
case 0:
return true
case 1:
return isShortStmt(b.block[0])
}
return false
}
func isShortStmt(stmt ast.Stmt) bool {
switch stmt.(type) {
case *ast.BlockStmt, *ast.IfStmt, *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt, *ast.ForStmt, *ast.RangeStmt:
return false
}
return true
}

View File

@ -44,9 +44,8 @@ func (k BranchKind) Deviates() bool {
return false
case Return, Continue, Break, Goto, Panic, Exit:
return true
default:
panic("invalid kind")
}
panic("invalid kind")
}
// Branch returns a Branch with the given kind
@ -58,22 +57,21 @@ func (k BranchKind) String() string {
case Empty:
return ""
case Regular:
return "..."
return ""
case Return:
return "... return"
return "return"
case Continue:
return "... continue"
return "continue"
case Break:
return "... break"
return "break"
case Goto:
return "... goto"
return "goto"
case Panic:
return "... panic()"
return "panic()"
case Exit:
return "... os.Exit()"
default:
panic("invalid kind")
return "os.Exit()"
}
panic("invalid kind")
}
// LongString returns a longer form string representation
@ -95,7 +93,6 @@ func (k BranchKind) LongString() string {
return "a function call that panics"
case Exit:
return "a function call that exits the program"
default:
panic("invalid kind")
}
panic("invalid kind")
}

View File

@ -2,9 +2,11 @@ package ifelse
// Chain contains information about an if-else chain.
type Chain struct {
If Branch // what happens at the end of the "if" block
Else Branch // what happens at the end of the "else" block
HasInitializer bool // is there an "if"-initializer somewhere in the chain?
HasPriorNonDeviating bool // is there a prior "if" block that does NOT deviate control flow?
AtBlockEnd bool // whether the chain is placed at the end of the surrounding block
If Branch // what happens at the end of the "if" block
HasElse bool // is there an "else" block?
Else Branch // what happens at the end of the "else" block
HasInitializer bool // is there an "if"-initializer somewhere in the chain?
HasPriorNonDeviating bool // is there a prior "if" block that does NOT deviate control flow?
AtBlockEnd bool // whether the chain is placed at the end of the surrounding block
BlockEndKind BranchKind // control flow at end of surrounding block (e.g. "return" for function body)
}

View File

@ -42,10 +42,8 @@ func ExprCall(expr *ast.ExprStmt) (Call, bool) {
// String returns the function name with package qualifier (if any)
func (f Call) String() string {
switch {
case f.Pkg != "":
if f.Pkg != "" {
return fmt.Sprintf("%s.%s", f.Pkg, f.Name)
default:
return f.Name
}
return f.Name
}

View File

@ -7,10 +7,10 @@ import (
"github.com/mgechev/revive/lint"
)
// Rule is an interface for linters operating on if-else chains
type Rule interface {
CheckIfElse(chain Chain, args Args) (failMsg string)
}
// CheckFunc evaluates a rule against the given if-else chain and returns a message
// describing the proposed refactor, along with a indicator of whether such a refactor
// could be found.
type CheckFunc func(Chain, Args) (string, bool)
// Apply evaluates the given Rule on if-else chains found within the given AST,
// and returns the failures.
@ -28,11 +28,14 @@ type Rule interface {
//
// Only the block following "bar" is linted. This is because the rules that use this function
// do not presently have anything to say about earlier blocks in the chain.
func Apply(rule Rule, node ast.Node, target Target, args lint.Arguments) []lint.Failure {
v := &visitor{rule: rule, target: target}
func Apply(check CheckFunc, node ast.Node, target Target, args lint.Arguments) []lint.Failure {
v := &visitor{check: check, target: target}
for _, arg := range args {
if arg == PreserveScope {
switch arg {
case PreserveScope:
v.args.PreserveScope = true
case AllowJump:
v.args.AllowJump = true
}
}
ast.Walk(v, node)
@ -42,64 +45,99 @@ func Apply(rule Rule, node ast.Node, target Target, args lint.Arguments) []lint.
type visitor struct {
failures []lint.Failure
target Target
rule Rule
check CheckFunc
args Args
}
func (v *visitor) Visit(node ast.Node) ast.Visitor {
block, ok := node.(*ast.BlockStmt)
if !ok {
switch stmt := node.(type) {
case *ast.FuncDecl:
v.visitBody(stmt.Body, Return)
case *ast.FuncLit:
v.visitBody(stmt.Body, Return)
case *ast.ForStmt:
v.visitBody(stmt.Body, Continue)
case *ast.RangeStmt:
v.visitBody(stmt.Body, Continue)
case *ast.CaseClause:
v.visitBlock(stmt.Body, Break)
case *ast.BlockStmt:
v.visitBlock(stmt.List, Regular)
default:
return v
}
for i, stmt := range block.List {
if ifStmt, ok := stmt.(*ast.IfStmt); ok {
v.visitChain(ifStmt, Chain{AtBlockEnd: i == len(block.List)-1})
continue
}
ast.Walk(v, stmt)
}
return nil
}
func (v *visitor) visitChain(ifStmt *ast.IfStmt, chain Chain) {
// look for other if-else chains nested inside this if { } block
ast.Walk(v, ifStmt.Body)
if ifStmt.Else == nil {
// no else branch
return
func (v *visitor) visitBody(body *ast.BlockStmt, endKind BranchKind) {
if body != nil {
v.visitBlock(body.List, endKind)
}
}
func (v *visitor) visitBlock(stmts []ast.Stmt, endKind BranchKind) {
for i, stmt := range stmts {
ifStmt, ok := stmt.(*ast.IfStmt)
if !ok {
ast.Walk(v, stmt)
continue
}
var chain Chain
if i == len(stmts)-1 {
chain.AtBlockEnd = true
chain.BlockEndKind = endKind
}
v.visitIf(ifStmt, chain)
}
}
func (v *visitor) visitIf(ifStmt *ast.IfStmt, chain Chain) {
// look for other if-else chains nested inside this if { } block
v.visitBlock(ifStmt.Body.List, chain.BlockEndKind)
if as, ok := ifStmt.Init.(*ast.AssignStmt); ok && as.Tok == token.DEFINE {
chain.HasInitializer = true
}
chain.If = BlockBranch(ifStmt.Body)
if ifStmt.Else == nil {
if v.args.AllowJump {
v.checkRule(ifStmt, chain)
}
return
}
switch elseBlock := ifStmt.Else.(type) {
case *ast.IfStmt:
if !chain.If.Deviates() {
chain.HasPriorNonDeviating = true
}
v.visitChain(elseBlock, chain)
v.visitIf(elseBlock, chain)
case *ast.BlockStmt:
// look for other if-else chains nested inside this else { } block
ast.Walk(v, elseBlock)
v.visitBlock(elseBlock.List, chain.BlockEndKind)
chain.HasElse = true
chain.Else = BlockBranch(elseBlock)
if failMsg := v.rule.CheckIfElse(chain, v.args); failMsg != "" {
if chain.HasInitializer {
// if statement has a := initializer, so we might need to move the assignment
// onto its own line in case the body references it
failMsg += " (move short variable declaration to its own line if necessary)"
}
v.failures = append(v.failures, lint.Failure{
Confidence: 1,
Node: v.target.node(ifStmt),
Failure: failMsg,
})
}
v.checkRule(ifStmt, chain)
default:
panic("invalid node type for else")
panic("unexpected node type for else")
}
}
func (v *visitor) checkRule(ifStmt *ast.IfStmt, chain Chain) {
msg, found := v.check(chain, v.args)
if !found {
return // passed the check
}
if chain.HasInitializer {
// if statement has a := initializer, so we might need to move the assignment
// onto its own line in case the body references it
msg += " (move short variable declaration to its own line if necessary)"
}
v.failures = append(v.failures, lint.Failure{
Confidence: 1,
Node: v.target.node(ifStmt),
Failure: msg,
})
}

View File

@ -19,7 +19,6 @@ func (t Target) node(ifStmt *ast.IfStmt) ast.Node {
return ifStmt
case TargetElse:
return ifStmt.Else
default:
panic("bad target")
}
panic("bad target")
}

View File

@ -13,7 +13,7 @@ type EarlyReturnRule struct{}
// Apply applies the rule to given file.
func (e *EarlyReturnRule) Apply(file *lint.File, args lint.Arguments) []lint.Failure {
return ifelse.Apply(e, file.AST, ifelse.TargetIf, args)
return ifelse.Apply(e.checkIfElse, file.AST, ifelse.TargetIf, args)
}
// Name returns the rule name.
@ -21,31 +21,40 @@ func (*EarlyReturnRule) Name() string {
return "early-return"
}
// CheckIfElse evaluates the rule against an ifelse.Chain and returns a failure message if applicable.
func (*EarlyReturnRule) CheckIfElse(chain ifelse.Chain, args ifelse.Args) string {
if !chain.Else.Deviates() {
// this rule only applies if the else-block deviates control flow
return ""
func (*EarlyReturnRule) checkIfElse(chain ifelse.Chain, args ifelse.Args) (string, bool) {
if chain.HasElse {
if !chain.Else.BranchKind.Deviates() {
// this rule only applies if the else-block deviates control flow
return "", false
}
} else if !args.AllowJump || !chain.AtBlockEnd || !chain.BlockEndKind.Deviates() || chain.If.IsShort() {
// this kind of refactor requires introducing a new indented "return", "continue" or "break" statement,
// so ignore unless we are able to outdent multiple statements in exchange.
return "", false
}
if chain.HasPriorNonDeviating && !chain.If.IsEmpty() {
// if we de-indent this block then a previous branch
// might flow into it, affecting program behaviour
return ""
return "", false
}
if chain.If.Deviates() {
if chain.HasElse && chain.If.Deviates() {
// avoid overlapping with superfluous-else
return ""
return "", false
}
if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.If.HasDecls) {
if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.If.HasDecls()) {
// avoid increasing variable scope
return ""
return "", false
}
if !chain.HasElse {
return fmt.Sprintf("if c { ... } can be rewritten if !c { %v } ... to reduce nesting", chain.BlockEndKind), true
}
if chain.If.IsEmpty() {
return fmt.Sprintf("if c { } else { %[1]v } can be simplified to if !c { %[1]v }", chain.Else)
return fmt.Sprintf("if c { } else %[1]v can be simplified to if !c %[1]v", chain.Else), true
}
return fmt.Sprintf("if c { ... } else { %[1]v } can be simplified to if !c { %[1]v } ...", chain.Else)
return fmt.Sprintf("if c { ... } else %[1]v can be simplified to if !c %[1]v ...", chain.Else), true
}

View File

@ -10,7 +10,7 @@ type IndentErrorFlowRule struct{}
// Apply applies the rule to given file.
func (e *IndentErrorFlowRule) Apply(file *lint.File, args lint.Arguments) []lint.Failure {
return ifelse.Apply(e, file.AST, ifelse.TargetElse, args)
return ifelse.Apply(e.checkIfElse, file.AST, ifelse.TargetElse, args)
}
// Name returns the rule name.
@ -18,28 +18,31 @@ func (*IndentErrorFlowRule) Name() string {
return "indent-error-flow"
}
// CheckIfElse evaluates the rule against an ifelse.Chain and returns a failure message if applicable.
func (*IndentErrorFlowRule) CheckIfElse(chain ifelse.Chain, args ifelse.Args) string {
func (*IndentErrorFlowRule) checkIfElse(chain ifelse.Chain, args ifelse.Args) (string, bool) {
if !chain.HasElse {
return "", false
}
if !chain.If.Deviates() {
// this rule only applies if the if-block deviates control flow
return ""
return "", false
}
if chain.HasPriorNonDeviating {
// if we de-indent the "else" block then a previous branch
// might flow into it, affecting program behaviour
return ""
return "", false
}
if !chain.If.Returns() {
// avoid overlapping with superfluous-else
return ""
return "", false
}
if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls) {
if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls()) {
// avoid increasing variable scope
return ""
return "", false
}
return "if block ends with a return statement, so drop this else and outdent its block"
return "if block ends with a return statement, so drop this else and outdent its block", true
}

View File

@ -12,7 +12,7 @@ type SuperfluousElseRule struct{}
// Apply applies the rule to given file.
func (e *SuperfluousElseRule) Apply(file *lint.File, args lint.Arguments) []lint.Failure {
return ifelse.Apply(e, file.AST, ifelse.TargetElse, args)
return ifelse.Apply(e.checkIfElse, file.AST, ifelse.TargetElse, args)
}
// Name returns the rule name.
@ -20,28 +20,31 @@ func (*SuperfluousElseRule) Name() string {
return "superfluous-else"
}
// CheckIfElse evaluates the rule against an ifelse.Chain and returns a failure message if applicable.
func (*SuperfluousElseRule) CheckIfElse(chain ifelse.Chain, args ifelse.Args) string {
func (*SuperfluousElseRule) checkIfElse(chain ifelse.Chain, args ifelse.Args) (string, bool) {
if !chain.HasElse {
return "", false
}
if !chain.If.Deviates() {
// this rule only applies if the if-block deviates control flow
return ""
return "", false
}
if chain.HasPriorNonDeviating {
// if we de-indent the "else" block then a previous branch
// might flow into it, affecting program behaviour
return ""
return "", false
}
if chain.If.Returns() {
// avoid overlapping with indent-error-flow
return ""
return "", false
}
if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls) {
if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls()) {
// avoid increasing variable scope
return ""
return "", false
}
return fmt.Sprintf("if block ends with %v, so drop this else and outdent its block", chain.If.LongString())
return fmt.Sprintf("if block ends with %v, so drop this else and outdent its block", chain.If.LongString()), true
}

View File

@ -12,4 +12,5 @@ import (
func TestEarlyReturn(t *testing.T) {
testRule(t, "early_return", &rule.EarlyReturnRule{})
testRule(t, "early_return_scope", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{ifelse.PreserveScope}})
testRule(t, "early_return_jump", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{ifelse.AllowJump}})
}

View File

@ -132,4 +132,13 @@ func earlyRet() bool {
} else {
os.Exit(0)
}
for {
// inversion is not suggested here without allowJump option enabled
if cond {
println()
println()
println()
}
}
}

115
testdata/early_return_jump.go vendored Normal file
View File

@ -0,0 +1,115 @@
// Test data for the early-return rule with allowJump option enabled
package fixtures
func fn1() {
if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/
println()
println()
println()
}
}
func fn2() {
for {
if cond { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/
println()
println()
println()
}
}
}
func fn3() {
for {
// can't flip cond2 here because the cond1 branch would flow into it
if cond1 {
println()
} else if cond2 {
println()
println()
println()
}
}
}
func fn4() {
for {
// cond1 branch continues here so this is ok
if cond1 {
println()
continue
} else if cond2 { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/
println()
println()
println()
}
}
}
func fn5() {
for {
// no point flipping cond here we only unnest one statement and need to introduce one new nested statement (continue) to do it
if cond {
println()
}
}
}
func fn6() {
for {
if x, ok := foo(); ok { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting (move short variable declaration to its own line if necessary)/
println(x)
println(x)
println(x)
}
}
}
func fn7() {
for i := 0; i < 10; i++ {
if cond { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/
println()
println()
println()
}
}
}
func fn8() {
for range c {
if cond { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/
println()
println()
println()
}
}
}
func fn9() {
fn := func() {
if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/
println()
println()
println()
}
}
fn()
}
func fn10() {
switch {
case cond:
if foo() { //MATCH /if c { ... } can be rewritten if !c { break } ... to reduce nesting/
println()
println()
println()
}
default:
if bar() { //MATCH /if c { ... } can be rewritten if !c { break } ... to reduce nesting/
println()
println()
println()
}
}
}