From 7e1d35d8d232a52f109896f277a5bd7e66df4784 Mon Sep 17 00:00:00 2001 From: Miles Delahunty <4904544+mdelah@users.noreply.github.com> Date: Thu, 28 Nov 2024 18:51:33 +1100 Subject: [PATCH] feat: optional extension to early-return rule (#1133) (#1138) --- RULES_DESCRIPTIONS.md | 3 +- internal/ifelse/args.go | 9 ++- internal/ifelse/branch.go | 42 +++++++++--- internal/ifelse/branch_kind.go | 23 +++---- internal/ifelse/chain.go | 12 ++-- internal/ifelse/func.go | 6 +- internal/ifelse/rule.go | 118 ++++++++++++++++++++++----------- internal/ifelse/target.go | 3 +- rule/early_return.go | 35 ++++++---- rule/indent_error_flow.go | 21 +++--- rule/superfluous_else.go | 21 +++--- test/early_return_test.go | 1 + testdata/early_return.go | 9 +++ testdata/early_return_jump.go | 115 ++++++++++++++++++++++++++++++++ 14 files changed, 311 insertions(+), 107 deletions(-) create mode 100644 testdata/early_return_jump.go diff --git a/RULES_DESCRIPTIONS.md b/RULES_DESCRIPTIONS.md index e084756..7627b83 100644 --- a/RULES_DESCRIPTIONS.md +++ b/RULES_DESCRIPTIONS.md @@ -348,12 +348,13 @@ if !cond { _Configuration_: ([]string) rule flags. Available flags are: * _preserveScope_: do not suggest refactorings that would increase variable scope +* _allowJump_: suggest a new jump (`return`, `continue` or `break` statement) if it could unnest multiple statements. By default, only relocation of _existing_ jumps (i.e. from the `else` clause) are suggested. Example: ```toml [rule.early-return] - arguments = ["preserveScope"] + arguments = ["preserveScope", "allowJump"] ``` ## empty-block diff --git a/internal/ifelse/args.go b/internal/ifelse/args.go index c6e647e..fc65b70 100644 --- a/internal/ifelse/args.go +++ b/internal/ifelse/args.go @@ -4,8 +4,15 @@ package ifelse // that would enlarge variable scope const PreserveScope = "preserveScope" +// AllowJump is a configuration argument that permits early-return to +// suggest introducing a new jump (return, continue, etc) statement +// to reduce nesting. By default, suggestions only bring existing jumps +// earlier. +const AllowJump = "allowJump" + // Args contains arguments common to the early-return, indent-error-flow -// and superfluous-else rules (currently just preserveScope) +// and superfluous-else rules type Args struct { PreserveScope bool + AllowJump bool } diff --git a/internal/ifelse/branch.go b/internal/ifelse/branch.go index 6e6036b..dfa744e 100644 --- a/internal/ifelse/branch.go +++ b/internal/ifelse/branch.go @@ -9,8 +9,8 @@ import ( // Branch contains information about a branch within an if-else chain. type Branch struct { BranchKind - Call // The function called at the end for kind Panic or Exit. - HasDecls bool // The branch has one or more declarations (at the top level block) + Call // The function called at the end for kind Panic or Exit. + block []ast.Stmt } // BlockBranch gets the Branch of an ast.BlockStmt. @@ -21,7 +21,7 @@ func BlockBranch(block *ast.BlockStmt) Branch { } branch := StmtBranch(block.List[blockLen-1]) - branch.HasDecls = hasDecls(block) + branch.block = block.List return branch } @@ -61,11 +61,14 @@ func StmtBranch(stmt ast.Stmt) Branch { // String returns a brief string representation func (b Branch) String() string { switch b.BranchKind { + case Empty: + return "{ }" + case Regular: + return "{ ... }" case Panic, Exit: - return fmt.Sprintf("... %v()", b.Call) - default: - return b.BranchKind.String() + return fmt.Sprintf("{ ... %v() }", b.Call) } + return fmt.Sprintf("{ ... %v }", b.BranchKind) } // LongString returns a longer form string representation @@ -73,13 +76,13 @@ func (b Branch) LongString() string { switch b.BranchKind { case Panic, Exit: return fmt.Sprintf("call to %v function", b.Call) - default: - return b.BranchKind.LongString() } + return b.BranchKind.LongString() } -func hasDecls(block *ast.BlockStmt) bool { - for _, stmt := range block.List { +// HasDecls returns whether the branch has any top-level declarations +func (b Branch) HasDecls() bool { + for _, stmt := range b.block { switch stmt := stmt.(type) { case *ast.DeclStmt: return true @@ -91,3 +94,22 @@ func hasDecls(block *ast.BlockStmt) bool { } return false } + +// IsShort returns whether the branch is empty or consists of a single statement +func (b Branch) IsShort() bool { + switch len(b.block) { + case 0: + return true + case 1: + return isShortStmt(b.block[0]) + } + return false +} + +func isShortStmt(stmt ast.Stmt) bool { + switch stmt.(type) { + case *ast.BlockStmt, *ast.IfStmt, *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt, *ast.ForStmt, *ast.RangeStmt: + return false + } + return true +} diff --git a/internal/ifelse/branch_kind.go b/internal/ifelse/branch_kind.go index 41601d1..75d3b0c 100644 --- a/internal/ifelse/branch_kind.go +++ b/internal/ifelse/branch_kind.go @@ -44,9 +44,8 @@ func (k BranchKind) Deviates() bool { return false case Return, Continue, Break, Goto, Panic, Exit: return true - default: - panic("invalid kind") } + panic("invalid kind") } // Branch returns a Branch with the given kind @@ -58,22 +57,21 @@ func (k BranchKind) String() string { case Empty: return "" case Regular: - return "..." + return "" case Return: - return "... return" + return "return" case Continue: - return "... continue" + return "continue" case Break: - return "... break" + return "break" case Goto: - return "... goto" + return "goto" case Panic: - return "... panic()" + return "panic()" case Exit: - return "... os.Exit()" - default: - panic("invalid kind") + return "os.Exit()" } + panic("invalid kind") } // LongString returns a longer form string representation @@ -95,7 +93,6 @@ func (k BranchKind) LongString() string { return "a function call that panics" case Exit: return "a function call that exits the program" - default: - panic("invalid kind") } + panic("invalid kind") } diff --git a/internal/ifelse/chain.go b/internal/ifelse/chain.go index 9891635..e3c8898 100644 --- a/internal/ifelse/chain.go +++ b/internal/ifelse/chain.go @@ -2,9 +2,11 @@ package ifelse // Chain contains information about an if-else chain. type Chain struct { - If Branch // what happens at the end of the "if" block - Else Branch // what happens at the end of the "else" block - HasInitializer bool // is there an "if"-initializer somewhere in the chain? - HasPriorNonDeviating bool // is there a prior "if" block that does NOT deviate control flow? - AtBlockEnd bool // whether the chain is placed at the end of the surrounding block + If Branch // what happens at the end of the "if" block + HasElse bool // is there an "else" block? + Else Branch // what happens at the end of the "else" block + HasInitializer bool // is there an "if"-initializer somewhere in the chain? + HasPriorNonDeviating bool // is there a prior "if" block that does NOT deviate control flow? + AtBlockEnd bool // whether the chain is placed at the end of the surrounding block + BlockEndKind BranchKind // control flow at end of surrounding block (e.g. "return" for function body) } diff --git a/internal/ifelse/func.go b/internal/ifelse/func.go index 7ba3519..45c78f0 100644 --- a/internal/ifelse/func.go +++ b/internal/ifelse/func.go @@ -42,10 +42,8 @@ func ExprCall(expr *ast.ExprStmt) (Call, bool) { // String returns the function name with package qualifier (if any) func (f Call) String() string { - switch { - case f.Pkg != "": + if f.Pkg != "" { return fmt.Sprintf("%s.%s", f.Pkg, f.Name) - default: - return f.Name } + return f.Name } diff --git a/internal/ifelse/rule.go b/internal/ifelse/rule.go index 07ad456..94f0221 100644 --- a/internal/ifelse/rule.go +++ b/internal/ifelse/rule.go @@ -7,10 +7,10 @@ import ( "github.com/mgechev/revive/lint" ) -// Rule is an interface for linters operating on if-else chains -type Rule interface { - CheckIfElse(chain Chain, args Args) (failMsg string) -} +// CheckFunc evaluates a rule against the given if-else chain and returns a message +// describing the proposed refactor, along with a indicator of whether such a refactor +// could be found. +type CheckFunc func(Chain, Args) (string, bool) // Apply evaluates the given Rule on if-else chains found within the given AST, // and returns the failures. @@ -28,11 +28,14 @@ type Rule interface { // // Only the block following "bar" is linted. This is because the rules that use this function // do not presently have anything to say about earlier blocks in the chain. -func Apply(rule Rule, node ast.Node, target Target, args lint.Arguments) []lint.Failure { - v := &visitor{rule: rule, target: target} +func Apply(check CheckFunc, node ast.Node, target Target, args lint.Arguments) []lint.Failure { + v := &visitor{check: check, target: target} for _, arg := range args { - if arg == PreserveScope { + switch arg { + case PreserveScope: v.args.PreserveScope = true + case AllowJump: + v.args.AllowJump = true } } ast.Walk(v, node) @@ -42,64 +45,99 @@ func Apply(rule Rule, node ast.Node, target Target, args lint.Arguments) []lint. type visitor struct { failures []lint.Failure target Target - rule Rule + check CheckFunc args Args } func (v *visitor) Visit(node ast.Node) ast.Visitor { - block, ok := node.(*ast.BlockStmt) - if !ok { + switch stmt := node.(type) { + case *ast.FuncDecl: + v.visitBody(stmt.Body, Return) + case *ast.FuncLit: + v.visitBody(stmt.Body, Return) + case *ast.ForStmt: + v.visitBody(stmt.Body, Continue) + case *ast.RangeStmt: + v.visitBody(stmt.Body, Continue) + case *ast.CaseClause: + v.visitBlock(stmt.Body, Break) + case *ast.BlockStmt: + v.visitBlock(stmt.List, Regular) + default: return v } - - for i, stmt := range block.List { - if ifStmt, ok := stmt.(*ast.IfStmt); ok { - v.visitChain(ifStmt, Chain{AtBlockEnd: i == len(block.List)-1}) - continue - } - ast.Walk(v, stmt) - } return nil } -func (v *visitor) visitChain(ifStmt *ast.IfStmt, chain Chain) { - // look for other if-else chains nested inside this if { } block - ast.Walk(v, ifStmt.Body) - - if ifStmt.Else == nil { - // no else branch - return +func (v *visitor) visitBody(body *ast.BlockStmt, endKind BranchKind) { + if body != nil { + v.visitBlock(body.List, endKind) } +} + +func (v *visitor) visitBlock(stmts []ast.Stmt, endKind BranchKind) { + for i, stmt := range stmts { + ifStmt, ok := stmt.(*ast.IfStmt) + if !ok { + ast.Walk(v, stmt) + continue + } + var chain Chain + if i == len(stmts)-1 { + chain.AtBlockEnd = true + chain.BlockEndKind = endKind + } + v.visitIf(ifStmt, chain) + } +} + +func (v *visitor) visitIf(ifStmt *ast.IfStmt, chain Chain) { + // look for other if-else chains nested inside this if { } block + v.visitBlock(ifStmt.Body.List, chain.BlockEndKind) if as, ok := ifStmt.Init.(*ast.AssignStmt); ok && as.Tok == token.DEFINE { chain.HasInitializer = true } chain.If = BlockBranch(ifStmt.Body) + if ifStmt.Else == nil { + if v.args.AllowJump { + v.checkRule(ifStmt, chain) + } + return + } + switch elseBlock := ifStmt.Else.(type) { case *ast.IfStmt: if !chain.If.Deviates() { chain.HasPriorNonDeviating = true } - v.visitChain(elseBlock, chain) + v.visitIf(elseBlock, chain) case *ast.BlockStmt: // look for other if-else chains nested inside this else { } block - ast.Walk(v, elseBlock) + v.visitBlock(elseBlock.List, chain.BlockEndKind) + chain.HasElse = true chain.Else = BlockBranch(elseBlock) - if failMsg := v.rule.CheckIfElse(chain, v.args); failMsg != "" { - if chain.HasInitializer { - // if statement has a := initializer, so we might need to move the assignment - // onto its own line in case the body references it - failMsg += " (move short variable declaration to its own line if necessary)" - } - v.failures = append(v.failures, lint.Failure{ - Confidence: 1, - Node: v.target.node(ifStmt), - Failure: failMsg, - }) - } + v.checkRule(ifStmt, chain) default: - panic("invalid node type for else") + panic("unexpected node type for else") } } + +func (v *visitor) checkRule(ifStmt *ast.IfStmt, chain Chain) { + msg, found := v.check(chain, v.args) + if !found { + return // passed the check + } + if chain.HasInitializer { + // if statement has a := initializer, so we might need to move the assignment + // onto its own line in case the body references it + msg += " (move short variable declaration to its own line if necessary)" + } + v.failures = append(v.failures, lint.Failure{ + Confidence: 1, + Node: v.target.node(ifStmt), + Failure: msg, + }) +} diff --git a/internal/ifelse/target.go b/internal/ifelse/target.go index 81ff1c3..63755ac 100644 --- a/internal/ifelse/target.go +++ b/internal/ifelse/target.go @@ -19,7 +19,6 @@ func (t Target) node(ifStmt *ast.IfStmt) ast.Node { return ifStmt case TargetElse: return ifStmt.Else - default: - panic("bad target") } + panic("bad target") } diff --git a/rule/early_return.go b/rule/early_return.go index 62d491f..c6c2321 100644 --- a/rule/early_return.go +++ b/rule/early_return.go @@ -13,7 +13,7 @@ type EarlyReturnRule struct{} // Apply applies the rule to given file. func (e *EarlyReturnRule) Apply(file *lint.File, args lint.Arguments) []lint.Failure { - return ifelse.Apply(e, file.AST, ifelse.TargetIf, args) + return ifelse.Apply(e.checkIfElse, file.AST, ifelse.TargetIf, args) } // Name returns the rule name. @@ -21,31 +21,40 @@ func (*EarlyReturnRule) Name() string { return "early-return" } -// CheckIfElse evaluates the rule against an ifelse.Chain and returns a failure message if applicable. -func (*EarlyReturnRule) CheckIfElse(chain ifelse.Chain, args ifelse.Args) string { - if !chain.Else.Deviates() { - // this rule only applies if the else-block deviates control flow - return "" +func (*EarlyReturnRule) checkIfElse(chain ifelse.Chain, args ifelse.Args) (string, bool) { + if chain.HasElse { + if !chain.Else.BranchKind.Deviates() { + // this rule only applies if the else-block deviates control flow + return "", false + } + } else if !args.AllowJump || !chain.AtBlockEnd || !chain.BlockEndKind.Deviates() || chain.If.IsShort() { + // this kind of refactor requires introducing a new indented "return", "continue" or "break" statement, + // so ignore unless we are able to outdent multiple statements in exchange. + return "", false } if chain.HasPriorNonDeviating && !chain.If.IsEmpty() { // if we de-indent this block then a previous branch // might flow into it, affecting program behaviour - return "" + return "", false } - if chain.If.Deviates() { + if chain.HasElse && chain.If.Deviates() { // avoid overlapping with superfluous-else - return "" + return "", false } - if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.If.HasDecls) { + if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.If.HasDecls()) { // avoid increasing variable scope - return "" + return "", false + } + + if !chain.HasElse { + return fmt.Sprintf("if c { ... } can be rewritten if !c { %v } ... to reduce nesting", chain.BlockEndKind), true } if chain.If.IsEmpty() { - return fmt.Sprintf("if c { } else { %[1]v } can be simplified to if !c { %[1]v }", chain.Else) + return fmt.Sprintf("if c { } else %[1]v can be simplified to if !c %[1]v", chain.Else), true } - return fmt.Sprintf("if c { ... } else { %[1]v } can be simplified to if !c { %[1]v } ...", chain.Else) + return fmt.Sprintf("if c { ... } else %[1]v can be simplified to if !c %[1]v ...", chain.Else), true } diff --git a/rule/indent_error_flow.go b/rule/indent_error_flow.go index ebc1e79..2abbfbf 100644 --- a/rule/indent_error_flow.go +++ b/rule/indent_error_flow.go @@ -10,7 +10,7 @@ type IndentErrorFlowRule struct{} // Apply applies the rule to given file. func (e *IndentErrorFlowRule) Apply(file *lint.File, args lint.Arguments) []lint.Failure { - return ifelse.Apply(e, file.AST, ifelse.TargetElse, args) + return ifelse.Apply(e.checkIfElse, file.AST, ifelse.TargetElse, args) } // Name returns the rule name. @@ -18,28 +18,31 @@ func (*IndentErrorFlowRule) Name() string { return "indent-error-flow" } -// CheckIfElse evaluates the rule against an ifelse.Chain and returns a failure message if applicable. -func (*IndentErrorFlowRule) CheckIfElse(chain ifelse.Chain, args ifelse.Args) string { +func (*IndentErrorFlowRule) checkIfElse(chain ifelse.Chain, args ifelse.Args) (string, bool) { + if !chain.HasElse { + return "", false + } + if !chain.If.Deviates() { // this rule only applies if the if-block deviates control flow - return "" + return "", false } if chain.HasPriorNonDeviating { // if we de-indent the "else" block then a previous branch // might flow into it, affecting program behaviour - return "" + return "", false } if !chain.If.Returns() { // avoid overlapping with superfluous-else - return "" + return "", false } - if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls) { + if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls()) { // avoid increasing variable scope - return "" + return "", false } - return "if block ends with a return statement, so drop this else and outdent its block" + return "if block ends with a return statement, so drop this else and outdent its block", true } diff --git a/rule/superfluous_else.go b/rule/superfluous_else.go index 18e8f3b..2e8cfeb 100644 --- a/rule/superfluous_else.go +++ b/rule/superfluous_else.go @@ -12,7 +12,7 @@ type SuperfluousElseRule struct{} // Apply applies the rule to given file. func (e *SuperfluousElseRule) Apply(file *lint.File, args lint.Arguments) []lint.Failure { - return ifelse.Apply(e, file.AST, ifelse.TargetElse, args) + return ifelse.Apply(e.checkIfElse, file.AST, ifelse.TargetElse, args) } // Name returns the rule name. @@ -20,28 +20,31 @@ func (*SuperfluousElseRule) Name() string { return "superfluous-else" } -// CheckIfElse evaluates the rule against an ifelse.Chain and returns a failure message if applicable. -func (*SuperfluousElseRule) CheckIfElse(chain ifelse.Chain, args ifelse.Args) string { +func (*SuperfluousElseRule) checkIfElse(chain ifelse.Chain, args ifelse.Args) (string, bool) { + if !chain.HasElse { + return "", false + } + if !chain.If.Deviates() { // this rule only applies if the if-block deviates control flow - return "" + return "", false } if chain.HasPriorNonDeviating { // if we de-indent the "else" block then a previous branch // might flow into it, affecting program behaviour - return "" + return "", false } if chain.If.Returns() { // avoid overlapping with indent-error-flow - return "" + return "", false } - if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls) { + if args.PreserveScope && !chain.AtBlockEnd && (chain.HasInitializer || chain.Else.HasDecls()) { // avoid increasing variable scope - return "" + return "", false } - return fmt.Sprintf("if block ends with %v, so drop this else and outdent its block", chain.If.LongString()) + return fmt.Sprintf("if block ends with %v, so drop this else and outdent its block", chain.If.LongString()), true } diff --git a/test/early_return_test.go b/test/early_return_test.go index b477c6d..493ee6f 100644 --- a/test/early_return_test.go +++ b/test/early_return_test.go @@ -12,4 +12,5 @@ import ( func TestEarlyReturn(t *testing.T) { testRule(t, "early_return", &rule.EarlyReturnRule{}) testRule(t, "early_return_scope", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{ifelse.PreserveScope}}) + testRule(t, "early_return_jump", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{ifelse.AllowJump}}) } diff --git a/testdata/early_return.go b/testdata/early_return.go index 15475b0..f25879c 100644 --- a/testdata/early_return.go +++ b/testdata/early_return.go @@ -132,4 +132,13 @@ func earlyRet() bool { } else { os.Exit(0) } + + for { + // inversion is not suggested here without allowJump option enabled + if cond { + println() + println() + println() + } + } } diff --git a/testdata/early_return_jump.go b/testdata/early_return_jump.go new file mode 100644 index 0000000..54d2715 --- /dev/null +++ b/testdata/early_return_jump.go @@ -0,0 +1,115 @@ +// Test data for the early-return rule with allowJump option enabled + +package fixtures + +func fn1() { + if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/ + println() + println() + println() + } +} + +func fn2() { + for { + if cond { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/ + println() + println() + println() + } + } +} + +func fn3() { + for { + // can't flip cond2 here because the cond1 branch would flow into it + if cond1 { + println() + } else if cond2 { + println() + println() + println() + } + } +} + +func fn4() { + for { + // cond1 branch continues here so this is ok + if cond1 { + println() + continue + } else if cond2 { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/ + println() + println() + println() + } + } +} + +func fn5() { + for { + // no point flipping cond here we only unnest one statement and need to introduce one new nested statement (continue) to do it + if cond { + println() + } + } +} + +func fn6() { + for { + if x, ok := foo(); ok { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting (move short variable declaration to its own line if necessary)/ + println(x) + println(x) + println(x) + } + } +} + +func fn7() { + for i := 0; i < 10; i++ { + if cond { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/ + println() + println() + println() + } + } +} + +func fn8() { + for range c { + if cond { //MATCH /if c { ... } can be rewritten if !c { continue } ... to reduce nesting/ + println() + println() + println() + } + } +} + +func fn9() { + fn := func() { + if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/ + println() + println() + println() + } + } + fn() +} + +func fn10() { + switch { + case cond: + if foo() { //MATCH /if c { ... } can be rewritten if !c { break } ... to reduce nesting/ + println() + println() + println() + } + default: + if bar() { //MATCH /if c { ... } can be rewritten if !c { break } ... to reduce nesting/ + println() + println() + println() + } + } +}