diff --git a/pkg/integration/components/text_matcher.go b/pkg/integration/components/text_matcher.go index a5b987646..8d8077922 100644 --- a/pkg/integration/components/text_matcher.go +++ b/pkg/integration/components/text_matcher.go @@ -9,6 +9,8 @@ import ( ) type TextMatcher struct { + // If you add or change a field here, be sure to update the copy + // code in checkIsSelected() *Matcher[string] } @@ -95,8 +97,8 @@ func (self *TextMatcher) IsSelected() *TextMatcher { // if the matcher has an `IsSelected` rule, it returns true, along with the matcher after that rule has been removed func (self *TextMatcher) checkIsSelected() (bool, *TextMatcher) { // copying into a new matcher in case we want to re-use the original later - newMatcher := &TextMatcher{} - *newMatcher = *self + newMatcher := &TextMatcher{Matcher: &Matcher[string]{}} + *newMatcher.Matcher = *self.Matcher check := lo.ContainsBy(newMatcher.rules, func(rule matcherRule[string]) bool { return rule.name == IS_SELECTED_RULE_NAME }) diff --git a/pkg/integration/components/view_driver.go b/pkg/integration/components/view_driver.go index 437e647be..6160d6b2e 100644 --- a/pkg/integration/components/view_driver.go +++ b/pkg/integration/components/view_driver.go @@ -211,29 +211,63 @@ func (self *ViewDriver) validateVisibleLineCount(matchers []*TextMatcher) { func (self *ViewDriver) assertLines(offset int, matchers ...*TextMatcher) *ViewDriver { view := self.getView() + var expectedStartIdx, expectedEndIdx int + foundSelectionStart := false + foundSelectionEnd := false + expectedSelectedLines := []string{} + for matcherIndex, matcher := range matchers { lineIdx := matcherIndex + offset + checkIsSelected, matcher := matcher.checkIsSelected() + if checkIsSelected { + if foundSelectionEnd { + self.t.fail("The IsSelected matcher can only be used on a contiguous range of lines.") + } + if !foundSelectionStart { + expectedStartIdx = lineIdx + foundSelectionStart = true + } + expectedSelectedLines = append(expectedSelectedLines, matcher.name()) + expectedEndIdx = lineIdx + } else if foundSelectionStart { + foundSelectionEnd = true + } + } + + for matcherIndex, matcher := range matchers { + lineIdx := matcherIndex + offset + expectSelected, matcher := matcher.checkIsSelected() + self.t.matchString(matcher, fmt.Sprintf("Unexpected content in view '%s'.", view.Name()), func() string { return view.BufferLines()[lineIdx] }, ) - if checkIsSelected { + // If any of the matchers care about the selection, we need to + // assert on the selection for each matcher. + if foundSelectionStart { self.t.assertWithRetries(func() (bool, string) { startIdx, endIdx := self.getSelectedRange() - if lineIdx < startIdx || lineIdx > endIdx { - if startIdx == endIdx { - return false, fmt.Sprintf("Unexpected selected line index in view '%s'. Expected %d, got %d", view.Name(), lineIdx, startIdx) - } else { - lines := self.getSelectedLines() - return false, fmt.Sprintf("Unexpected selected line index in view '%s'. Expected line %d to be in range %d to %d. Selected lines:\n---\n%s\n---\n\nExpected line: '%s'", view.Name(), lineIdx, startIdx, endIdx, strings.Join(lines, "\n"), matcher.name()) - } + selected := lineIdx >= startIdx && lineIdx <= endIdx + + if (selected && expectSelected) || (!selected && !expectSelected) { + return true, "" } - return true, "" + + lines := self.getSelectedLines() + + return false, fmt.Sprintf( + "Unexpected selection in view '%s'. Expected %s to be selected but got %s.\nExpected selected lines:\n---\n%s\n---\n\nActual selected lines:\n---\n%s\n---\n", + view.Name(), + formatLineRange(startIdx, endIdx), + formatLineRange(expectedStartIdx, expectedEndIdx), + strings.Join(lines, "\n"), + strings.Join(expectedSelectedLines, "\n"), + ) }) } } @@ -241,6 +275,14 @@ func (self *ViewDriver) assertLines(offset int, matchers ...*TextMatcher) *ViewD return self } +func formatLineRange(from int, to int) string { + if from == to { + return "line " + fmt.Sprintf("%d", from) + } + + return "lines " + fmt.Sprintf("%d-%d", from, to) +} + // asserts on the content of the view i.e. the stuff within the view's frame. func (self *ViewDriver) Content(matcher *TextMatcher) *ViewDriver { self.t.matchString(matcher, fmt.Sprintf("%s: Unexpected content.", self.context), diff --git a/pkg/integration/tests/test_list.go b/pkg/integration/tests/test_list.go index 3e34549af..7f5cc23fa 100644 --- a/pkg/integration/tests/test_list.go +++ b/pkg/integration/tests/test_list.go @@ -161,6 +161,7 @@ var tests = []*components.IntegrationTest{ interactive_rebase.EditTheConflCommit, interactive_rebase.FixupFirstCommit, interactive_rebase.FixupSecondCommit, + interactive_rebase.MidRebaseRangeSelect, interactive_rebase.Move, interactive_rebase.MoveInRebase, interactive_rebase.MoveWithCustomCommentChar,