diff --git a/pkg/integration/README.md b/pkg/integration/README.md index ba2365403..b2aa2ddf9 100644 --- a/pkg/integration/README.md +++ b/pkg/integration/README.md @@ -37,7 +37,7 @@ If you find yourself doing something frequently in a test, consider making it a There are three ways to invoke a test: -1. go run cmd/integration_test/main.go cli [...] +1. go run cmd/integration_test/main.go cli [...] 2. go run cmd/integration_test/main.go tui 3. go test pkg/integration/clients/go_test.go diff --git a/pkg/integration/clients/cli.go b/pkg/integration/clients/cli.go index 79bb96c4f..bb8a8d0b9 100644 --- a/pkg/integration/clients/cli.go +++ b/pkg/integration/clients/cli.go @@ -44,10 +44,11 @@ func runAndPrintError(test *components.IntegrationTest, f func() error) { } func getTestsToRun(testNames []string) []*components.IntegrationTest { + allIntegrationTests := tests.GetTests() var testsToRun []*components.IntegrationTest if len(testNames) == 0 { - return tests.Tests + return allIntegrationTests } testNames = slices.Map(testNames, func(name string) string { @@ -61,7 +62,7 @@ func getTestsToRun(testNames []string) []*components.IntegrationTest { outer: for _, testName := range testNames { // check if our given test name actually exists - for _, test := range tests.Tests { + for _, test := range allIntegrationTests { if test.Name() == testName { testsToRun = append(testsToRun, test) continue outer diff --git a/pkg/integration/clients/go_test.go b/pkg/integration/clients/go_test.go index d52cd409a..9fceecd40 100644 --- a/pkg/integration/clients/go_test.go +++ b/pkg/integration/clients/go_test.go @@ -29,7 +29,7 @@ func TestIntegration(t *testing.T) { testNumber := 0 err := components.RunTests( - tests.Tests, + tests.GetTests(), t.Logf, runCmdHeadless, func(test *components.IntegrationTest, f func() error) { diff --git a/pkg/integration/clients/injector/main.go b/pkg/integration/clients/injector/main.go index 263dba5da..37c76fe3e 100644 --- a/pkg/integration/clients/injector/main.go +++ b/pkg/integration/clients/injector/main.go @@ -52,7 +52,8 @@ func getIntegrationTest() integrationTypes.IntegrationTest { )) } - for _, candidateTest := range tests.Tests { + allTests := tests.GetTests() + for _, candidateTest := range allTests { if candidateTest.Name() == integrationTestName { return candidateTest } diff --git a/pkg/integration/clients/tui.go b/pkg/integration/clients/tui.go index 707e482ca..716d1abe8 100644 --- a/pkg/integration/clients/tui.go +++ b/pkg/integration/clients/tui.go @@ -168,7 +168,7 @@ func RunTUI() { return err } - app.filteredTests = tests.Tests + app.filteredTests = app.allTests app.renderTests() app.editorView.TextArea.Clear() app.editorView.Clear() @@ -204,6 +204,7 @@ func RunTUI() { } type app struct { + allTests []*components.IntegrationTest filteredTests []*components.IntegrationTest itemIdx int testDir string @@ -214,7 +215,7 @@ type app struct { } func newApp(testDir string) *app { - return &app{testDir: testDir} + return &app{testDir: testDir, allTests: tests.GetTests()} } func (self *app) getCurrentTest() *components.IntegrationTest { @@ -226,7 +227,7 @@ func (self *app) getCurrentTest() *components.IntegrationTest { } func (self *app) loadTests() { - self.filteredTests = tests.Tests + self.filteredTests = self.allTests self.adjustCursor() } @@ -237,9 +238,9 @@ func (self *app) adjustCursor() { func (self *app) filterWithString(needle string) { if needle == "" { - self.filteredTests = tests.Tests + self.filteredTests = self.allTests } else { - self.filteredTests = slices.Filter(tests.Tests, func(test *components.IntegrationTest) bool { + self.filteredTests = slices.Filter(self.allTests, func(test *components.IntegrationTest) bool { return strings.Contains(test.Name(), needle) }) } diff --git a/pkg/integration/components/assert.go b/pkg/integration/components/assert.go index dcfd00615..ea67273a9 100644 --- a/pkg/integration/components/assert.go +++ b/pkg/integration/components/assert.go @@ -21,12 +21,12 @@ func NewAssert(gui integrationTypes.GuiDriver) *Assert { } // for making assertions on string values -type matcher[T any] struct { - testFn func(T) (bool, string) +type matcher struct { + testFn func(string) (bool, string) prefix string } -func (self *matcher[T]) test(value T) (bool, string) { +func (self *matcher) test(value string) (bool, string) { ok, message := self.testFn(value) if ok { return true, "" @@ -39,20 +39,20 @@ func (self *matcher[T]) test(value T) (bool, string) { return false, message } -func (self *matcher[T]) context(prefix string) *matcher[T] { +func (self *matcher) context(prefix string) *matcher { self.prefix = prefix return self } -func Contains(target string) *matcher[string] { - return &matcher[string]{testFn: func(value string) (bool, string) { +func Contains(target string) *matcher { + return &matcher{testFn: func(value string) (bool, string) { return strings.Contains(value, target), fmt.Sprintf("Expected '%s' to contain '%s'", value, target) }} } -func Equals[T constraints.Ordered](target T) *matcher[T] { - return &matcher[T]{testFn: func(value T) (bool, string) { +func Equals[T constraints.Ordered](target string) *matcher { + return &matcher{testFn: func(value string) (bool, string) { return target == value, fmt.Sprintf("Expected '%T' to equal '%T'", value, target) }} } @@ -79,7 +79,7 @@ func (self *Assert) CommitCount(expectedCount int) { }) } -func (self *Assert) MatchHeadCommitMessage(matcher *matcher[string]) { +func (self *Assert) MatchHeadCommitMessage(matcher *matcher) { self.assertWithRetries(func() (bool, string) { return len(self.gui.Model().Commits) == 0, "Expected at least one commit to be present" }) @@ -113,7 +113,7 @@ func (self *Assert) InListContext() { }) } -func (self *Assert) MatchSelectedLine(matcher *matcher[string]) { +func (self *Assert) MatchSelectedLine(matcher *matcher) { self.matchString(matcher, "Unexpected selected line.", func() string { return self.gui.CurrentContext().GetView().SelectedLine() @@ -149,7 +149,7 @@ func (self *Assert) InMenu() { }) } -func (self *Assert) MatchCurrentViewTitle(matcher *matcher[string]) { +func (self *Assert) MatchCurrentViewTitle(matcher *matcher) { self.matchString(matcher, "Unexpected current view title.", func() string { return self.gui.CurrentContext().GetView().Title @@ -157,7 +157,7 @@ func (self *Assert) MatchCurrentViewTitle(matcher *matcher[string]) { ) } -func (self *Assert) MatchMainViewContent(matcher *matcher[string]) { +func (self *Assert) MatchMainViewContent(matcher *matcher) { self.matchString(matcher, "Unexpected main view content.", func() string { return self.gui.MainView().Buffer() @@ -165,7 +165,7 @@ func (self *Assert) MatchMainViewContent(matcher *matcher[string]) { ) } -func (self *Assert) MatchSecondaryViewContent(matcher *matcher[string]) { +func (self *Assert) MatchSecondaryViewContent(matcher *matcher) { self.matchString(matcher, "Unexpected secondary view title.", func() string { return self.gui.SecondaryView().Buffer() @@ -173,7 +173,7 @@ func (self *Assert) MatchSecondaryViewContent(matcher *matcher[string]) { ) } -func (self *Assert) matchString(matcher *matcher[string], context string, getValue func() string) { +func (self *Assert) matchString(matcher *matcher, context string, getValue func() string) { self.assertWithRetries(func() (bool, string) { value := getValue() return matcher.context(context).test(value) diff --git a/pkg/integration/components/test.go b/pkg/integration/components/test.go index a5973c07a..3cc6a7641 100644 --- a/pkg/integration/components/test.go +++ b/pkg/integration/components/test.go @@ -53,7 +53,7 @@ func NewIntegrationTest(args NewIntegrationTestArgs) *IntegrationTest { if args.Description != unitTestDescription { // this panics if we're in a unit test for our integration tests, // so we're using "test test" as a sentinel value - name = testNameFromFilePath() + name = testNameFromCurrentFilePath() } return &IntegrationTest{ @@ -106,8 +106,12 @@ func (self *IntegrationTest) Run(gui integrationTypes.GuiDriver) { } } -func testNameFromFilePath() string { +func testNameFromCurrentFilePath() string { path := utils.FilePath(3) + return TestNameFromFilePath(path) +} + +func TestNameFromFilePath(path string) string { name := strings.Split(path, "integration/tests/")[1] return name[:len(name)-len(".go")] diff --git a/pkg/integration/tests/tests.go b/pkg/integration/tests/tests.go index c681dd3b7..bbcb5c1d1 100644 --- a/pkg/integration/tests/tests.go +++ b/pkg/integration/tests/tests.go @@ -1,17 +1,25 @@ package tests import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/jesseduffield/generics/set" + "github.com/jesseduffield/generics/slices" "github.com/jesseduffield/lazygit/pkg/integration/components" "github.com/jesseduffield/lazygit/pkg/integration/tests/branch" "github.com/jesseduffield/lazygit/pkg/integration/tests/commit" "github.com/jesseduffield/lazygit/pkg/integration/tests/custom_commands" "github.com/jesseduffield/lazygit/pkg/integration/tests/interactive_rebase" + "github.com/jesseduffield/lazygit/pkg/utils" ) // Here is where we lists the actual tests that will run. When you create a new test, // be sure to add it to this list. -var Tests = []*components.IntegrationTest{ +var tests = []*components.IntegrationTest{ commit.Commit, commit.NewBranch, branch.Suggestions, @@ -19,3 +27,47 @@ var Tests = []*components.IntegrationTest{ custom_commands.Basic, custom_commands.MultiplePrompts, } + +func GetTests() []*components.IntegrationTest { + // first we ensure that each test in this directory has actually been added to the above list. + testCount := 0 + + testNamesSet := set.NewFromSlice(slices.Map( + tests, + func(test *components.IntegrationTest) string { + return test.Name() + }, + )) + + missingTestNames := []string{} + + if err := filepath.Walk(filepath.Join(utils.GetLazygitRootDirectory(), "pkg/integration/tests"), func(path string, info os.FileInfo, err error) error { + if !info.IsDir() && strings.HasSuffix(path, ".go") { + // ignoring this current file + if filepath.Base(path) == "tests.go" { + return nil + } + + nameFromPath := components.TestNameFromFilePath(path) + if !testNamesSet.Includes(nameFromPath) { + missingTestNames = append(missingTestNames, nameFromPath) + } + testCount++ + } + return nil + }); err != nil { + panic(fmt.Sprintf("failed to walk tests: %v", err)) + } + + if len(missingTestNames) > 0 { + panic(fmt.Sprintf("The following tests are missing from the list of tests: %s. You need to add them to `pkg/integration/tests/tests.go`.", strings.Join(missingTestNames, ", "))) + } + + if testCount > len(tests) { + panic("you have not added all of the tests to the tests list in `pkg/integration/tests/tests.go`") + } else if testCount < len(tests) { + panic("There are more tests in `pkg/integration/tests/tests.go` than there are test files in the tests directory. Ensure that you only have one test per file and you haven't included the same test twice in the tests list.") + } + + return tests +}