mirror of
				https://github.com/mgechev/revive.git
				synced 2025-10-30 23:37:49 +02:00 
			
		
		
		
	early-return: detect short deviated statements (#1396)
This commit is contained in:
		| @@ -102,6 +102,8 @@ func (b Branch) IsShort() bool { | ||||
| 		return true | ||||
| 	case 1: | ||||
| 		return isShortStmt(b.block[0]) | ||||
| 	case 2: | ||||
| 		return isShortStmt(b.block[1]) | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
							
								
								
									
										252
									
								
								internal/ifelse/branch_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										252
									
								
								internal/ifelse/branch_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,252 @@ | ||||
| package ifelse | ||||
|  | ||||
| import ( | ||||
| 	"go/ast" | ||||
| 	"go/token" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestBlockBranch(t *testing.T) { | ||||
| 	t.Run("empty", func(t *testing.T) { | ||||
| 		block := &ast.BlockStmt{List: []ast.Stmt{}} | ||||
| 		b := BlockBranch(block) | ||||
| 		if b.BranchKind != Empty { | ||||
| 			t.Errorf("want Empty branch, got %v", b.BranchKind) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("non empty", func(t *testing.T) { | ||||
| 		stmt := &ast.ReturnStmt{} | ||||
| 		block := &ast.BlockStmt{List: []ast.Stmt{stmt}} | ||||
| 		b := BlockBranch(block) | ||||
| 		if b.BranchKind != Return { | ||||
| 			t.Errorf("want Return branch, got %v", b.BranchKind) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestStmtBranch(t *testing.T) { | ||||
| 	cases := []struct { | ||||
| 		name string | ||||
| 		stmt ast.Stmt | ||||
| 		kind BranchKind | ||||
| 		call *Call | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "ReturnStmt", | ||||
| 			stmt: &ast.ReturnStmt{}, | ||||
| 			kind: Return, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "BreakStmt", | ||||
| 			stmt: &ast.BranchStmt{Tok: token.BREAK}, | ||||
| 			kind: Break, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ContinueStmt", | ||||
| 			stmt: &ast.BranchStmt{Tok: token.CONTINUE}, | ||||
| 			kind: Continue, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "GotoStmt", | ||||
| 			stmt: &ast.BranchStmt{Tok: token.GOTO}, | ||||
| 			kind: Goto, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "EmptyStmt", | ||||
| 			stmt: &ast.EmptyStmt{}, | ||||
| 			kind: Empty, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ExprStmt with DeviatingFunc (panic)", | ||||
| 			stmt: &ast.ExprStmt{ | ||||
| 				X: &ast.CallExpr{ | ||||
| 					Fun: &ast.Ident{Name: "panic"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			kind: Panic, | ||||
| 			call: &Call{Name: "panic"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ExprStmt with DeviatingFunc (os.Exit)", | ||||
| 			stmt: &ast.ExprStmt{ | ||||
| 				X: &ast.CallExpr{ | ||||
| 					Fun: &ast.SelectorExpr{ | ||||
| 						X:   &ast.Ident{Name: "os"}, | ||||
| 						Sel: &ast.Ident{Name: "Exit"}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			kind: Exit, | ||||
| 			call: &Call{Pkg: "os", Name: "Exit"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ExprStmt with non-deviating func", | ||||
| 			stmt: &ast.ExprStmt{ | ||||
| 				X: &ast.CallExpr{ | ||||
| 					Fun: &ast.Ident{Name: "foo"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			kind: Regular, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "LabeledStmt wrapping ReturnStmt", | ||||
| 			stmt: &ast.LabeledStmt{ | ||||
| 				Label: &ast.Ident{Name: "lbl"}, | ||||
| 				Stmt:  &ast.ReturnStmt{}, | ||||
| 			}, | ||||
| 			kind: Return, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "LabeledStmt wrapping ExprStmt", | ||||
| 			stmt: &ast.LabeledStmt{ | ||||
| 				Label: &ast.Ident{Name: "lbl"}, | ||||
| 				Stmt:  &ast.ExprStmt{X: &ast.CallExpr{Fun: &ast.Ident{Name: "foo"}}}, | ||||
| 			}, | ||||
| 			kind: Regular, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "BlockStmt with ReturnStmt", | ||||
| 			stmt: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{}}}, | ||||
| 			kind: Return, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "BlockStmt with ExprStmt", | ||||
| 			stmt: &ast.BlockStmt{List: []ast.Stmt{&ast.ExprStmt{X: &ast.CallExpr{Fun: &ast.Ident{Name: "foo"}}}}}, | ||||
| 			kind: Regular, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, c := range cases { | ||||
| 		b := StmtBranch(c.stmt) | ||||
| 		if b.BranchKind != c.kind { | ||||
| 			t.Errorf("%s: want %v, got %v", c.name, c.kind, b.BranchKind) | ||||
| 		} | ||||
| 		if c.call != nil { | ||||
| 			if b.Call != *c.call { | ||||
| 				t.Errorf("%s: want Call %+v, got %+v", c.name, *c.call, b.Call) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestBranch_String_LongString(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name     string | ||||
| 		branch   Branch | ||||
| 		wantStr  string | ||||
| 		wantLong string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:     "Return branch", | ||||
| 			branch:   Branch{BranchKind: Return}, | ||||
| 			wantStr:  "{ ... return }", | ||||
| 			wantLong: "a return statement", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "Panic branch with Call", | ||||
| 			branch:   Branch{BranchKind: Panic, Call: Call{Name: "panic"}}, | ||||
| 			wantStr:  "{ ... panic() }", | ||||
| 			wantLong: "call to panic function", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "Exit branch with Call", | ||||
| 			branch:   Branch{BranchKind: Exit, Call: Call{Pkg: "os", Name: "Exit"}}, | ||||
| 			wantStr:  "{ ... os.Exit() }", | ||||
| 			wantLong: "call to os.Exit function", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "Empty branch", | ||||
| 			branch:   Branch{BranchKind: Empty}, | ||||
| 			wantStr:  "{ }", | ||||
| 			wantLong: "an empty block", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "Regular branch", | ||||
| 			branch:   Branch{BranchKind: Regular}, | ||||
| 			wantStr:  "{ ... }", | ||||
| 			wantLong: "a regular statement", | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		if got := tt.branch.String(); got != tt.wantStr { | ||||
| 			t.Errorf("%s: String() = %q, want %q", tt.name, got, tt.wantStr) | ||||
| 		} | ||||
| 		if got := tt.branch.LongString(); got != tt.wantLong { | ||||
| 			t.Errorf("%s: LongString() = %q, want %q", tt.name, got, tt.wantLong) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestBranch_HasDecls(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name  string | ||||
| 		block []ast.Stmt | ||||
| 		want  bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:  "DeclStmt", | ||||
| 			block: []ast.Stmt{&ast.DeclStmt{}}, | ||||
| 			want:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "AssignStmt with :=", | ||||
| 			block: []ast.Stmt{&ast.AssignStmt{Tok: token.DEFINE}}, | ||||
| 			want:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "ExprStmt", | ||||
| 			block: []ast.Stmt{&ast.ExprStmt{}}, | ||||
| 			want:  false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		b := Branch{block: tt.block} | ||||
| 		if got := b.HasDecls(); got != tt.want { | ||||
| 			t.Errorf("%s: want HasDecls to be %v, got %v", tt.name, tt.want, got) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestBranch_IsShort(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name  string | ||||
| 		block []ast.Stmt | ||||
| 		want  bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:  "nil block", | ||||
| 			block: nil, | ||||
| 			want:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "single ExprStmt", | ||||
| 			block: []ast.Stmt{&ast.ExprStmt{}}, | ||||
| 			want:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "single BlockStmt", | ||||
| 			block: []ast.Stmt{&ast.BlockStmt{}}, | ||||
| 			want:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "two short statements", | ||||
| 			block: []ast.Stmt{&ast.ExprStmt{}, &ast.ExprStmt{}}, | ||||
| 			want:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "second non-short statement", | ||||
| 			block: []ast.Stmt{&ast.ExprStmt{}, &ast.BlockStmt{}}, | ||||
| 			want:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "three statements (should return false)", | ||||
| 			block: []ast.Stmt{&ast.ExprStmt{}, &ast.ExprStmt{}, &ast.ExprStmt{}}, | ||||
| 			want:  false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		b := Branch{block: tt.block} | ||||
| 		if got := b.IsShort(); got != tt.want { | ||||
| 			t.Errorf("%s: want IsShort to be %v, got %v", tt.name, tt.want, got) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -13,4 +13,5 @@ func TestEarlyReturn(t *testing.T) { | ||||
| 	testRule(t, "early_return_scope", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"preserve-scope"}}) | ||||
| 	testRule(t, "early_return_jump", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"allowJump"}}) | ||||
| 	testRule(t, "early_return_jump", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"allow-jump"}}) | ||||
| 	testRule(t, "early_return_jump_scope", &rule.EarlyReturnRule{}, &lint.RuleConfig{Arguments: []any{"allow-jump", "preserve-scope"}}) | ||||
| } | ||||
|   | ||||
							
								
								
									
										84
									
								
								testdata/early_return_jump.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										84
									
								
								testdata/early_return_jump.go
									
									
									
									
										vendored
									
									
								
							| @@ -2,6 +2,14 @@ | ||||
|  | ||||
| package fixtures | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"log/slog" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| func fn1() { | ||||
| 	if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/ | ||||
| 		println() | ||||
| @@ -113,3 +121,79 @@ func fn10() { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn11() { | ||||
| 	if a() { | ||||
| 		println() | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn12() { | ||||
| 	if a() { | ||||
| 		println() | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn13() { | ||||
| 	if err := a(); err != nil { | ||||
| 		println() | ||||
| 		panic(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn14() { | ||||
| 	if err := a(); err != nil { | ||||
| 		println() | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn15() { | ||||
| 	if err := a(); err != nil { | ||||
| 		println() | ||||
| 		log.Panic(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn16() { | ||||
| 	if err := a(); err != nil { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting (move short variable declaration to its own line if necessary)/ | ||||
| 		println() | ||||
| 		println() | ||||
| 		log.Panic(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn17() { | ||||
| 	if err := a(); err != nil { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting (move short variable declaration to its own line if necessary)/ | ||||
| 		println() | ||||
| 		println() | ||||
| 		println() | ||||
| 		panic(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func MustEncode[T any](w http.ResponseWriter, status int, v T) { | ||||
| 	if err := Encode(w, status, v); err != nil { | ||||
| 		slog.Error("Error encoding response", "err", err) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *client) renewAuthInfo() { | ||||
| 	err := RenewLease(c.ctx, c, "auth", c.authCreds, func() (*hashiVault.Secret, error) { | ||||
| 		authInfo, err := c.auth(c.v) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("unable to renew auth info: %w", err) | ||||
| 		} | ||||
|  | ||||
| 		c.authCreds = authInfo | ||||
|  | ||||
| 		return authInfo, nil | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		slog.Error("unable to renew auth info", slog.String(loggingKeyError, err.Error())) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										31
									
								
								testdata/early_return_jump_scope.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								testdata/early_return_jump_scope.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| // Test data for the early-return rule with allowJump option enabled | ||||
|  | ||||
| package fixtures | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| func fn1() { | ||||
| 	if cond { //MATCH /if c { ... } can be rewritten if !c { return } ... to reduce nesting/ | ||||
| 		println() | ||||
| 		println() | ||||
| 		println() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn3() { | ||||
| 	if a() { | ||||
| 		println() | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func fn4() { | ||||
| 	// No initializer, match as normal | ||||
| 	if cond { //   MATCH /if c { ... } else { ... return } can be simplified to if !c { ... return } .../ | ||||
| 		fn2() | ||||
| 	} else { | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user