diff --git a/pkg/gui/context/list_context_trait.go b/pkg/gui/context/list_context_trait.go index df178e567..cdc20a85b 100644 --- a/pkg/gui/context/list_context_trait.go +++ b/pkg/gui/context/list_context_trait.go @@ -29,7 +29,8 @@ func (self *ListContextTrait) FocusLine() { self.c.AfterLayout(func() error { oldOrigin, _ := self.GetViewTrait().ViewPortYBounds() - self.GetViewTrait().FocusPoint(self.list.GetSelectedLineIdx()) + self.GetViewTrait().FocusPoint( + self.ModelIndexToViewIndex(self.list.GetSelectedLineIdx())) // If FocusPoint() caused the view to scroll (because the selected line // was out of view before), we need to rerender the view port again. @@ -84,7 +85,7 @@ func (self *ListContextTrait) HandleFocusLost(opts types.OnFocusLostOpts) error // OnFocus assumes that the content of the context has already been rendered to the view. OnRender is the function which actually renders the content to the view func (self *ListContextTrait) HandleRender() error { self.list.RefreshSelectedIdx() - content := self.renderLines(0, self.list.Len()) + content := self.renderLines(-1, -1) self.GetViewTrait().SetContent(content) self.c.Render() self.setFooter() diff --git a/pkg/gui/context/list_renderer.go b/pkg/gui/context/list_renderer.go index cca5a6c77..f29407055 100644 --- a/pkg/gui/context/list_renderer.go +++ b/pkg/gui/context/list_renderer.go @@ -5,26 +5,120 @@ import ( "github.com/jesseduffield/lazygit/pkg/gui/types" "github.com/jesseduffield/lazygit/pkg/utils" + "github.com/samber/lo" + "golang.org/x/exp/slices" ) +type NonModelItem struct { + // Where in the model this should be inserted + Index int + // Content to render + Content string + // The column from which to render the item + Column int +} + type ListRenderer struct { - list types.IList + list types.IList + // Function to get the display strings for each model item in the given + // range. startIdx and endIdx are model indices. For each model item, return + // an array of strings, one for each column; the list renderer will take + // care of aligning the columns appropriately. getDisplayStrings func(startIdx int, endIdx int) [][]string // Alignment for each column. If nil, the default is left alignment getColumnAlignments func() []utils.Alignment + // Function to insert non-model items (e.g. section headers). If nil, no + // such items are inserted + getNonModelItems func() []*NonModelItem + + // The remaining fields are private and shouldn't be initialized by clients + numNonModelItems int + viewIndicesByModelIndex []int + modelIndicesByViewIndex []int } func (self *ListRenderer) GetList() types.IList { return self.list } +func (self *ListRenderer) ModelIndexToViewIndex(modelIndex int) int { + modelIndex = lo.Clamp(modelIndex, 0, self.list.Len()) + if self.viewIndicesByModelIndex != nil { + return self.viewIndicesByModelIndex[modelIndex] + } + + return modelIndex +} + +func (self *ListRenderer) ViewIndexToModelIndex(viewIndex int) int { + viewIndex = utils.Clamp(viewIndex, 0, self.list.Len()+self.numNonModelItems) + if self.modelIndicesByViewIndex != nil { + return self.modelIndicesByViewIndex[viewIndex] + } + + return viewIndex +} + +// startIdx and endIdx are view indices, not model indices. If you want to +// render the whole list, pass -1 for both. func (self *ListRenderer) renderLines(startIdx int, endIdx int) string { var columnAlignments []utils.Alignment if self.getColumnAlignments != nil { columnAlignments = self.getColumnAlignments() } - lines, _ := utils.RenderDisplayStrings( - self.getDisplayStrings(startIdx, utils.Min(endIdx, self.list.Len())), + nonModelItems := []*NonModelItem{} + self.numNonModelItems = 0 + if self.getNonModelItems != nil { + nonModelItems = self.getNonModelItems() + self.prepareConversionArrays(nonModelItems) + } + startModelIdx := 0 + if startIdx == -1 { + startIdx = 0 + } else { + startModelIdx = self.ViewIndexToModelIndex(startIdx) + } + endModelIdx := self.list.Len() + if endIdx == -1 { + endIdx = endModelIdx + len(nonModelItems) + } else { + endModelIdx = self.ViewIndexToModelIndex(endIdx) + } + lines, columnPositions := utils.RenderDisplayStrings( + self.getDisplayStrings(startModelIdx, endModelIdx), columnAlignments) + lines = self.insertNonModelItems(nonModelItems, endIdx, startIdx, lines, columnPositions) return strings.Join(lines, "\n") } + +func (self *ListRenderer) prepareConversionArrays(nonModelItems []*NonModelItem) { + self.numNonModelItems = len(nonModelItems) + self.viewIndicesByModelIndex = lo.Range(self.list.Len() + 1) + self.modelIndicesByViewIndex = lo.Range(self.list.Len() + 1) + offset := 0 + for _, item := range nonModelItems { + for i := item.Index; i <= self.list.Len(); i++ { + self.viewIndicesByModelIndex[i]++ + } + self.modelIndicesByViewIndex = slices.Insert( + self.modelIndicesByViewIndex, item.Index+offset, self.modelIndicesByViewIndex[item.Index+offset]) + offset++ + } +} + +func (self *ListRenderer) insertNonModelItems( + nonModelItems []*NonModelItem, endIdx int, startIdx int, lines []string, columnPositions []int, +) []string { + offset := 0 + for _, item := range nonModelItems { + if item.Index+offset >= endIdx { + break + } + if item.Index+offset >= startIdx { + padding := strings.Repeat(" ", columnPositions[item.Column]) + lines = slices.Insert(lines, item.Index+offset-startIdx, padding+item.Content) + } + offset++ + } + return lines +} diff --git a/pkg/gui/context/list_renderer_test.go b/pkg/gui/context/list_renderer_test.go index 97b64f3dc..98e3f60aa 100644 --- a/pkg/gui/context/list_renderer_test.go +++ b/pkg/gui/context/list_renderer_test.go @@ -1,6 +1,7 @@ package context import ( + "fmt" "strings" "testing" @@ -10,11 +11,12 @@ import ( func TestListRenderer_renderLines(t *testing.T) { scenarios := []struct { - name string - modelStrings []string - startIdx int - endIdx int - expectedOutput string + name string + modelStrings []string + nonModelIndices []int + startIdx int + endIdx int + expectedOutput string }{ { name: "Render whole list", @@ -52,16 +54,79 @@ func TestListRenderer_renderLines(t *testing.T) { expectedOutput: ` c`, }, + { + name: "Whole list with section headers", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{1, 3}, + startIdx: 0, + endIdx: 5, + expectedOutput: ` + a + --- 1 (0) --- + b + c + --- 3 (1) ---`, + }, + { + name: "Multiple consecutive headers", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{0, 0, 2, 2, 2}, + startIdx: 0, + endIdx: 8, + expectedOutput: ` + --- 0 (0) --- + --- 0 (1) --- + a + b + --- 2 (2) --- + --- 2 (3) --- + --- 2 (4) --- + c`, + }, + { + name: "Partial list with headers, beginning", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{1, 3}, + startIdx: 0, + endIdx: 3, + expectedOutput: ` + a + --- 1 (0) --- + b`, + }, + { + name: "Partial list with headers, end (beyond end index)", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{1, 3}, + startIdx: 2, + endIdx: 7, + expectedOutput: ` + b + c + --- 3 (1) ---`, + }, } for _, s := range scenarios { t.Run(s.name, func(t *testing.T) { viewModel := NewListViewModel[string](func() []string { return s.modelStrings }) + var getNonModelItems func() []*NonModelItem + if s.nonModelIndices != nil { + getNonModelItems = func() []*NonModelItem { + return lo.Map(s.nonModelIndices, func(modelIndex int, nonModelIndex int) *NonModelItem { + return &NonModelItem{ + Index: modelIndex, + Content: fmt.Sprintf("--- %d (%d) ---", modelIndex, nonModelIndex), + } + }) + } + } self := &ListRenderer{ list: viewModel, getDisplayStrings: func(startIdx int, endIdx int) [][]string { return lo.Map(s.modelStrings[startIdx:endIdx], func(s string, _ int) []string { return []string{s} }) }, + getNonModelItems: getNonModelItems, } expectedOutput := strings.Join(lo.Map( @@ -72,3 +137,120 @@ func TestListRenderer_renderLines(t *testing.T) { }) } } + +func TestListRenderer_ModelIndexToViewIndex_and_back(t *testing.T) { + scenarios := []struct { + name string + numModelItems int + nonModelIndices []int + + modelIndices []int + expectedViewIndices []int + + viewIndices []int + expectedModelIndices []int + }{ + { + name: "no headers (no getNonModelItems provided)", + numModelItems: 3, + nonModelIndices: nil, // no get + + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{0, 0, 1, 2, 3, 3}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedModelIndices: []int{0, 0, 1, 2, 3, 3}, + }, + { + name: "no headers (getNonModelItems returns zero items)", + numModelItems: 3, + nonModelIndices: []int{}, + + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{0, 0, 1, 2, 3, 3}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedModelIndices: []int{0, 0, 1, 2, 3, 3}, + }, + { + name: "basic", + numModelItems: 3, + nonModelIndices: []int{1, 2}, + + /* + 0: model 0 + 1: --- header 0 --- + 2: model 1 + 3: --- header 1 --- + 4: model 2 + */ + + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{0, 0, 2, 4, 5, 5}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4, 5, 6}, + expectedModelIndices: []int{0, 0, 1, 1, 2, 2, 3, 3}, + }, + { + name: "consecutive section headers", + numModelItems: 3, + nonModelIndices: []int{0, 0, 2, 2, 2, 3, 3}, + + /* + 0: --- header 0 --- + 1: --- header 1 --- + 2: model 0 + 3: model 1 + 4: --- header 2 --- + 5: --- header 3 --- + 6: --- header 4 --- + 7: model 2 + 8: --- header 5 --- + 9: --- header 6 --- + */ + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{2, 2, 3, 7, 10, 10}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + expectedModelIndices: []int{0, 0, 0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3}, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + // Expect lists of equal length for each test: + assert.Equal(t, len(s.modelIndices), len(s.expectedViewIndices)) + assert.Equal(t, len(s.viewIndices), len(s.expectedModelIndices)) + + modelInts := lo.Range(s.numModelItems) + viewModel := NewListViewModel[int](func() []int { return modelInts }) + var getNonModelItems func() []*NonModelItem + if s.nonModelIndices != nil { + getNonModelItems = func() []*NonModelItem { + return lo.Map(s.nonModelIndices, func(modelIndex int, _ int) *NonModelItem { + return &NonModelItem{Index: modelIndex, Content: ""} + }) + } + } + self := &ListRenderer{ + list: viewModel, + getDisplayStrings: func(startIdx int, endIdx int) [][]string { + return lo.Map(modelInts[startIdx:endIdx], + func(i int, _ int) []string { return []string{fmt.Sprint(i)} }) + }, + getNonModelItems: getNonModelItems, + } + + // Need to render first so that it knows the non-model items + self.renderLines(-1, -1) + + for i := 0; i < len(s.modelIndices); i++ { + assert.Equal(t, s.expectedViewIndices[i], self.ModelIndexToViewIndex(s.modelIndices[i])) + } + + for i := 0; i < len(s.viewIndices); i++ { + assert.Equal(t, s.expectedModelIndices[i], self.ViewIndexToModelIndex(s.viewIndices[i])) + } + }) + } +} diff --git a/pkg/gui/controllers/list_controller.go b/pkg/gui/controllers/list_controller.go index 6094561f4..025561993 100644 --- a/pkg/gui/controllers/list_controller.go +++ b/pkg/gui/controllers/list_controller.go @@ -83,9 +83,11 @@ func (self *ListController) handleLineChange(change int) error { // we're not constantly re-rendering the main view. if before != after { if change == -1 { - checkScrollUp(self.context.GetViewTrait(), self.c.UserConfig, before, after) + checkScrollUp(self.context.GetViewTrait(), self.c.UserConfig, + self.context.ModelIndexToViewIndex(before), self.context.ModelIndexToViewIndex(after)) } else if change == 1 { - checkScrollDown(self.context.GetViewTrait(), self.c.UserConfig, before, after) + checkScrollDown(self.context.GetViewTrait(), self.c.UserConfig, + self.context.ModelIndexToViewIndex(before), self.context.ModelIndexToViewIndex(after)) } return self.context.HandleFocus(types.OnFocusOpts{}) @@ -112,7 +114,7 @@ func (self *ListController) HandleGotoBottom() error { func (self *ListController) HandleClick(opts gocui.ViewMouseBindingOpts) error { prevSelectedLineIdx := self.context.GetList().GetSelectedLineIdx() - newSelectedLineIdx := opts.Y + newSelectedLineIdx := self.context.ViewIndexToModelIndex(opts.Y) alreadyFocused := self.isFocused() if err := self.pushContextIfNotFocused(); err != nil { diff --git a/pkg/gui/types/context.go b/pkg/gui/types/context.go index dca5b042c..7aa07056e 100644 --- a/pkg/gui/types/context.go +++ b/pkg/gui/types/context.go @@ -124,6 +124,8 @@ type IListContext interface { GetSelectedItemId() string GetList() IList + ViewIndexToModelIndex(int) int + ModelIndexToViewIndex(int) int FocusLine() IsListContext() // used for type switch diff --git a/pkg/integration/components/view_driver.go b/pkg/integration/components/view_driver.go index 6778dd8dd..b5e985155 100644 --- a/pkg/integration/components/view_driver.go +++ b/pkg/integration/components/view_driver.go @@ -456,13 +456,13 @@ func (self *ViewDriver) NavigateToLine(matcher *TextMatcher) *ViewDriver { self.IsFocused() view := self.getView() + lines := view.BufferLines() var matchIndex int self.t.assertWithRetries(func() (bool, string) { matchIndex = -1 var matches []string - lines := view.BufferLines() // first we look for a duplicate on the current screen. We won't bother looking beyond that though. for i, line := range lines { ok, _ := matcher.test(line) @@ -486,19 +486,38 @@ func (self *ViewDriver) NavigateToLine(matcher *TextMatcher) *ViewDriver { return self } if selectedLineIdx == matchIndex { - self.SelectedLine(matcher) - } else if selectedLineIdx < matchIndex { - for i := selectedLineIdx; i < matchIndex; i++ { - self.SelectNextItem() - } - self.SelectedLine(matcher) - } else { - for i := selectedLineIdx; i > matchIndex; i-- { - self.SelectPreviousItem() - } - self.SelectedLine(matcher) + return self.SelectedLine(matcher) } + // At this point we can't just take the difference of selected and matched + // index and press up or down arrow this many times. The reason is that + // there might be section headers between those lines, and these will be + // skipped when pressing up or down arrow. So we must keep pressing the + // arrow key in a loop, and check after each one whether we now reached the + // target line. + var maxNumKeyPresses int + var keyPress func() + if selectedLineIdx < matchIndex { + maxNumKeyPresses = matchIndex - selectedLineIdx + keyPress = func() { self.SelectNextItem() } + } else { + maxNumKeyPresses = selectedLineIdx - matchIndex + keyPress = func() { self.SelectPreviousItem() } + } + + for i := 0; i < maxNumKeyPresses; i++ { + keyPress() + idx, err := self.getSelectedLineIdx() + if err != nil { + self.t.fail(err.Error()) + return self + } + if ok, _ := matcher.test(lines[idx]); ok { + return self + } + } + + self.t.fail(fmt.Sprintf("Could not navigate to item matching: %s. Lines:\n%s", matcher.name(), strings.Join(lines, "\n"))) return self }