diff --git a/pkg/gui/context/branches_context.go b/pkg/gui/context/branches_context.go index e4806165f..ac1fae52c 100644 --- a/pkg/gui/context/branches_context.go +++ b/pkg/gui/context/branches_context.go @@ -24,7 +24,7 @@ func NewBranchesContext(c *ContextCommon) *BranchesContext { }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetBranchListDisplayStrings( viewModel.GetItems(), c.State().GetRepoState().GetScreenMode() != types.SCREEN_NORMAL, @@ -45,9 +45,11 @@ func NewBranchesContext(c *ContextCommon) *BranchesContext { Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } diff --git a/pkg/gui/context/commit_files_context.go b/pkg/gui/context/commit_files_context.go index 035230e9d..037554c91 100644 --- a/pkg/gui/context/commit_files_context.go +++ b/pkg/gui/context/commit_files_context.go @@ -28,7 +28,7 @@ func NewCommitFilesContext(c *ContextCommon) *CommitFilesContext { c.UserConfig.Gui.ShowFileTree, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { if viewModel.Len() == 0 { return [][]string{{style.FgRed.Sprint("(none)")}} } @@ -54,9 +54,11 @@ func NewCommitFilesContext(c *ContextCommon) *CommitFilesContext { Transient: true, }), ), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } diff --git a/pkg/gui/context/list_context_trait.go b/pkg/gui/context/list_context_trait.go index 900be019c..cdc20a85b 100644 --- a/pkg/gui/context/list_context_trait.go +++ b/pkg/gui/context/list_context_trait.go @@ -4,17 +4,13 @@ import ( "fmt" "github.com/jesseduffield/lazygit/pkg/gui/types" - "github.com/jesseduffield/lazygit/pkg/utils" ) type ListContextTrait struct { types.Context + ListRenderer - c *ContextCommon - list types.IList - getDisplayStrings func(startIdx int, length int) [][]string - // Alignment for each column. If nil, the default is left alignment - getColumnAlignments func() []utils.Alignment + c *ContextCommon // Some contexts, like the commit context, will highlight the path from the selected commit // to its parents, because it's ambiguous otherwise. For these, we need to refresh the viewport // so that we show the highlighted path. @@ -26,10 +22,6 @@ type ListContextTrait struct { func (self *ListContextTrait) IsListContext() {} -func (self *ListContextTrait) GetList() types.IList { - return self.list -} - func (self *ListContextTrait) FocusLine() { // Doing this at the end of the layout function because we need the view to be // resized before we focus the line, otherwise if we're in accordion mode @@ -37,7 +29,8 @@ func (self *ListContextTrait) FocusLine() { self.c.AfterLayout(func() error { oldOrigin, _ := self.GetViewTrait().ViewPortYBounds() - self.GetViewTrait().FocusPoint(self.list.GetSelectedLineIdx()) + self.GetViewTrait().FocusPoint( + self.ModelIndexToViewIndex(self.list.GetSelectedLineIdx())) // If FocusPoint() caused the view to scroll (because the selected line // was out of view before), we need to rerender the view port again. @@ -59,8 +52,7 @@ func (self *ListContextTrait) FocusLine() { func (self *ListContextTrait) refreshViewport() { startIdx, length := self.GetViewTrait().ViewPortYBounds() - displayStrings := self.getDisplayStrings(startIdx, length) - content := utils.RenderDisplayStrings(displayStrings, nil) + content := self.renderLines(startIdx, startIdx+length) self.GetViewTrait().SetViewPortContent(content) } @@ -93,14 +85,7 @@ func (self *ListContextTrait) HandleFocusLost(opts types.OnFocusLostOpts) error // OnFocus assumes that the content of the context has already been rendered to the view. OnRender is the function which actually renders the content to the view func (self *ListContextTrait) HandleRender() error { self.list.RefreshSelectedIdx() - var columnAlignments []utils.Alignment - if self.getColumnAlignments != nil { - columnAlignments = self.getColumnAlignments() - } - content := utils.RenderDisplayStrings( - self.getDisplayStrings(0, self.list.Len()), - columnAlignments, - ) + content := self.renderLines(-1, -1) self.GetViewTrait().SetContent(content) self.c.Render() self.setFooter() diff --git a/pkg/gui/context/list_renderer.go b/pkg/gui/context/list_renderer.go new file mode 100644 index 000000000..f29407055 --- /dev/null +++ b/pkg/gui/context/list_renderer.go @@ -0,0 +1,124 @@ +package context + +import ( + "strings" + + "github.com/jesseduffield/lazygit/pkg/gui/types" + "github.com/jesseduffield/lazygit/pkg/utils" + "github.com/samber/lo" + "golang.org/x/exp/slices" +) + +type NonModelItem struct { + // Where in the model this should be inserted + Index int + // Content to render + Content string + // The column from which to render the item + Column int +} + +type ListRenderer struct { + list types.IList + // Function to get the display strings for each model item in the given + // range. startIdx and endIdx are model indices. For each model item, return + // an array of strings, one for each column; the list renderer will take + // care of aligning the columns appropriately. + getDisplayStrings func(startIdx int, endIdx int) [][]string + // Alignment for each column. If nil, the default is left alignment + getColumnAlignments func() []utils.Alignment + // Function to insert non-model items (e.g. section headers). If nil, no + // such items are inserted + getNonModelItems func() []*NonModelItem + + // The remaining fields are private and shouldn't be initialized by clients + numNonModelItems int + viewIndicesByModelIndex []int + modelIndicesByViewIndex []int +} + +func (self *ListRenderer) GetList() types.IList { + return self.list +} + +func (self *ListRenderer) ModelIndexToViewIndex(modelIndex int) int { + modelIndex = lo.Clamp(modelIndex, 0, self.list.Len()) + if self.viewIndicesByModelIndex != nil { + return self.viewIndicesByModelIndex[modelIndex] + } + + return modelIndex +} + +func (self *ListRenderer) ViewIndexToModelIndex(viewIndex int) int { + viewIndex = utils.Clamp(viewIndex, 0, self.list.Len()+self.numNonModelItems) + if self.modelIndicesByViewIndex != nil { + return self.modelIndicesByViewIndex[viewIndex] + } + + return viewIndex +} + +// startIdx and endIdx are view indices, not model indices. If you want to +// render the whole list, pass -1 for both. +func (self *ListRenderer) renderLines(startIdx int, endIdx int) string { + var columnAlignments []utils.Alignment + if self.getColumnAlignments != nil { + columnAlignments = self.getColumnAlignments() + } + nonModelItems := []*NonModelItem{} + self.numNonModelItems = 0 + if self.getNonModelItems != nil { + nonModelItems = self.getNonModelItems() + self.prepareConversionArrays(nonModelItems) + } + startModelIdx := 0 + if startIdx == -1 { + startIdx = 0 + } else { + startModelIdx = self.ViewIndexToModelIndex(startIdx) + } + endModelIdx := self.list.Len() + if endIdx == -1 { + endIdx = endModelIdx + len(nonModelItems) + } else { + endModelIdx = self.ViewIndexToModelIndex(endIdx) + } + lines, columnPositions := utils.RenderDisplayStrings( + self.getDisplayStrings(startModelIdx, endModelIdx), + columnAlignments) + lines = self.insertNonModelItems(nonModelItems, endIdx, startIdx, lines, columnPositions) + return strings.Join(lines, "\n") +} + +func (self *ListRenderer) prepareConversionArrays(nonModelItems []*NonModelItem) { + self.numNonModelItems = len(nonModelItems) + self.viewIndicesByModelIndex = lo.Range(self.list.Len() + 1) + self.modelIndicesByViewIndex = lo.Range(self.list.Len() + 1) + offset := 0 + for _, item := range nonModelItems { + for i := item.Index; i <= self.list.Len(); i++ { + self.viewIndicesByModelIndex[i]++ + } + self.modelIndicesByViewIndex = slices.Insert( + self.modelIndicesByViewIndex, item.Index+offset, self.modelIndicesByViewIndex[item.Index+offset]) + offset++ + } +} + +func (self *ListRenderer) insertNonModelItems( + nonModelItems []*NonModelItem, endIdx int, startIdx int, lines []string, columnPositions []int, +) []string { + offset := 0 + for _, item := range nonModelItems { + if item.Index+offset >= endIdx { + break + } + if item.Index+offset >= startIdx { + padding := strings.Repeat(" ", columnPositions[item.Column]) + lines = slices.Insert(lines, item.Index+offset-startIdx, padding+item.Content) + } + offset++ + } + return lines +} diff --git a/pkg/gui/context/list_renderer_test.go b/pkg/gui/context/list_renderer_test.go new file mode 100644 index 000000000..98e3f60aa --- /dev/null +++ b/pkg/gui/context/list_renderer_test.go @@ -0,0 +1,256 @@ +package context + +import ( + "fmt" + "strings" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/assert" +) + +func TestListRenderer_renderLines(t *testing.T) { + scenarios := []struct { + name string + modelStrings []string + nonModelIndices []int + startIdx int + endIdx int + expectedOutput string + }{ + { + name: "Render whole list", + modelStrings: []string{"a", "b", "c"}, + startIdx: 0, + endIdx: 3, + expectedOutput: ` + a + b + c`, + }, + { + name: "Partial list, beginning", + modelStrings: []string{"a", "b", "c"}, + startIdx: 0, + endIdx: 2, + expectedOutput: ` + a + b`, + }, + { + name: "Partial list, end", + modelStrings: []string{"a", "b", "c"}, + startIdx: 1, + endIdx: 3, + expectedOutput: ` + b + c`, + }, + { + name: "Pass an endIdx greater than the model length", + modelStrings: []string{"a", "b", "c"}, + startIdx: 2, + endIdx: 5, + expectedOutput: ` + c`, + }, + { + name: "Whole list with section headers", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{1, 3}, + startIdx: 0, + endIdx: 5, + expectedOutput: ` + a + --- 1 (0) --- + b + c + --- 3 (1) ---`, + }, + { + name: "Multiple consecutive headers", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{0, 0, 2, 2, 2}, + startIdx: 0, + endIdx: 8, + expectedOutput: ` + --- 0 (0) --- + --- 0 (1) --- + a + b + --- 2 (2) --- + --- 2 (3) --- + --- 2 (4) --- + c`, + }, + { + name: "Partial list with headers, beginning", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{1, 3}, + startIdx: 0, + endIdx: 3, + expectedOutput: ` + a + --- 1 (0) --- + b`, + }, + { + name: "Partial list with headers, end (beyond end index)", + modelStrings: []string{"a", "b", "c"}, + nonModelIndices: []int{1, 3}, + startIdx: 2, + endIdx: 7, + expectedOutput: ` + b + c + --- 3 (1) ---`, + }, + } + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + viewModel := NewListViewModel[string](func() []string { return s.modelStrings }) + var getNonModelItems func() []*NonModelItem + if s.nonModelIndices != nil { + getNonModelItems = func() []*NonModelItem { + return lo.Map(s.nonModelIndices, func(modelIndex int, nonModelIndex int) *NonModelItem { + return &NonModelItem{ + Index: modelIndex, + Content: fmt.Sprintf("--- %d (%d) ---", modelIndex, nonModelIndex), + } + }) + } + } + self := &ListRenderer{ + list: viewModel, + getDisplayStrings: func(startIdx int, endIdx int) [][]string { + return lo.Map(s.modelStrings[startIdx:endIdx], + func(s string, _ int) []string { return []string{s} }) + }, + getNonModelItems: getNonModelItems, + } + + expectedOutput := strings.Join(lo.Map( + strings.Split(strings.TrimPrefix(s.expectedOutput, "\n"), "\n"), + func(line string, _ int) string { return strings.TrimSpace(line) }), "\n") + + assert.Equal(t, expectedOutput, self.renderLines(s.startIdx, s.endIdx)) + }) + } +} + +func TestListRenderer_ModelIndexToViewIndex_and_back(t *testing.T) { + scenarios := []struct { + name string + numModelItems int + nonModelIndices []int + + modelIndices []int + expectedViewIndices []int + + viewIndices []int + expectedModelIndices []int + }{ + { + name: "no headers (no getNonModelItems provided)", + numModelItems: 3, + nonModelIndices: nil, // no get + + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{0, 0, 1, 2, 3, 3}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedModelIndices: []int{0, 0, 1, 2, 3, 3}, + }, + { + name: "no headers (getNonModelItems returns zero items)", + numModelItems: 3, + nonModelIndices: []int{}, + + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{0, 0, 1, 2, 3, 3}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedModelIndices: []int{0, 0, 1, 2, 3, 3}, + }, + { + name: "basic", + numModelItems: 3, + nonModelIndices: []int{1, 2}, + + /* + 0: model 0 + 1: --- header 0 --- + 2: model 1 + 3: --- header 1 --- + 4: model 2 + */ + + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{0, 0, 2, 4, 5, 5}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4, 5, 6}, + expectedModelIndices: []int{0, 0, 1, 1, 2, 2, 3, 3}, + }, + { + name: "consecutive section headers", + numModelItems: 3, + nonModelIndices: []int{0, 0, 2, 2, 2, 3, 3}, + + /* + 0: --- header 0 --- + 1: --- header 1 --- + 2: model 0 + 3: model 1 + 4: --- header 2 --- + 5: --- header 3 --- + 6: --- header 4 --- + 7: model 2 + 8: --- header 5 --- + 9: --- header 6 --- + */ + modelIndices: []int{-1, 0, 1, 2, 3, 4}, + expectedViewIndices: []int{2, 2, 3, 7, 10, 10}, + + viewIndices: []int{-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + expectedModelIndices: []int{0, 0, 0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3}, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + // Expect lists of equal length for each test: + assert.Equal(t, len(s.modelIndices), len(s.expectedViewIndices)) + assert.Equal(t, len(s.viewIndices), len(s.expectedModelIndices)) + + modelInts := lo.Range(s.numModelItems) + viewModel := NewListViewModel[int](func() []int { return modelInts }) + var getNonModelItems func() []*NonModelItem + if s.nonModelIndices != nil { + getNonModelItems = func() []*NonModelItem { + return lo.Map(s.nonModelIndices, func(modelIndex int, _ int) *NonModelItem { + return &NonModelItem{Index: modelIndex, Content: ""} + }) + } + } + self := &ListRenderer{ + list: viewModel, + getDisplayStrings: func(startIdx int, endIdx int) [][]string { + return lo.Map(modelInts[startIdx:endIdx], + func(i int, _ int) []string { return []string{fmt.Sprint(i)} }) + }, + getNonModelItems: getNonModelItems, + } + + // Need to render first so that it knows the non-model items + self.renderLines(-1, -1) + + for i := 0; i < len(s.modelIndices); i++ { + assert.Equal(t, s.expectedViewIndices[i], self.ModelIndexToViewIndex(s.modelIndices[i])) + } + + for i := 0; i < len(s.viewIndices); i++ { + assert.Equal(t, s.expectedModelIndices[i], self.ViewIndexToModelIndex(s.viewIndices[i])) + } + }) + } +} diff --git a/pkg/gui/context/local_commits_context.go b/pkg/gui/context/local_commits_context.go index f8f7848f2..fa91e6a79 100644 --- a/pkg/gui/context/local_commits_context.go +++ b/pkg/gui/context/local_commits_context.go @@ -27,7 +27,7 @@ func NewLocalCommitsContext(c *ContextCommon) *LocalCommitsContext { c, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(startIdx int, endIdx int) [][]string { selectedCommitSha := "" if c.CurrentContext().GetKey() == LOCAL_COMMITS_CONTEXT_KEY { @@ -56,7 +56,7 @@ func NewLocalCommitsContext(c *ContextCommon) *LocalCommitsContext { c.UserConfig.Git.ParseEmoji, selectedCommitSha, startIdx, - length, + endIdx, shouldShowGraph(c), c.Model().BisectInfo, showYouAreHereLabel, @@ -74,8 +74,10 @@ func NewLocalCommitsContext(c *ContextCommon) *LocalCommitsContext { Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, c: c, refreshViewportOnChange: true, }, diff --git a/pkg/gui/context/menu_context.go b/pkg/gui/context/menu_context.go index 353f3d386..287ed92ec 100644 --- a/pkg/gui/context/menu_context.go +++ b/pkg/gui/context/menu_context.go @@ -34,10 +34,13 @@ func NewMenuContext( Focusable: true, HasUncontrolledBounds: true, })), - getDisplayStrings: viewModel.GetDisplayStrings, - list: viewModel, - c: c, - getColumnAlignments: func() []utils.Alignment { return viewModel.columnAlignment }, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: viewModel.GetDisplayStrings, + getColumnAlignments: func() []utils.Alignment { return viewModel.columnAlignment }, + getNonModelItems: viewModel.GetNonModelItems, + }, + c: c, }, } } @@ -79,7 +82,7 @@ func (self *MenuViewModel) SetMenuItems(items []*types.MenuItem, columnAlignment } // TODO: move into presentation package -func (self *MenuViewModel) GetDisplayStrings(_startIdx int, _length int) [][]string { +func (self *MenuViewModel) GetDisplayStrings(_ int, _ int) [][]string { menuItems := self.FilteredListViewModel.GetItems() showKeys := lo.SomeBy(menuItems, func(item *types.MenuItem) bool { return item.Key != nil @@ -111,6 +114,40 @@ func (self *MenuViewModel) GetDisplayStrings(_startIdx int, _length int) [][]str }) } +func (self *MenuViewModel) GetNonModelItems() []*NonModelItem { + // Don't display section headers when we are filtering. The reason is that + // filtering changes the order of the items (they are sorted by best match), + // so all the sections would be messed up. + if self.FilteredListViewModel.IsFiltering() { + return []*NonModelItem{} + } + + result := []*NonModelItem{} + menuItems := self.FilteredListViewModel.GetItems() + var prevSection *types.MenuSection = nil + for i, menuItem := range menuItems { + menuItem := menuItem + if menuItem.Section != nil && menuItem.Section != prevSection { + if prevSection != nil { + result = append(result, &NonModelItem{ + Index: i, + Column: 1, + Content: "", + }) + } + + result = append(result, &NonModelItem{ + Index: i, + Column: 1, + Content: style.FgGreen.SetBold().Sprintf("--- %s ---", menuItem.Section.Title), + }) + prevSection = menuItem.Section + } + } + + return result +} + func (self *MenuContext) GetKeybindings(opts types.KeybindingsOpts) []*types.Binding { basicBindings := self.ListContextTrait.GetKeybindings(opts) menuItemsWithKeys := lo.Filter(self.menuItems, func(item *types.MenuItem, _ int) bool { diff --git a/pkg/gui/context/reflog_commits_context.go b/pkg/gui/context/reflog_commits_context.go index 5038b1870..a90507e86 100644 --- a/pkg/gui/context/reflog_commits_context.go +++ b/pkg/gui/context/reflog_commits_context.go @@ -26,7 +26,7 @@ func NewReflogCommitsContext(c *ContextCommon) *ReflogCommitsContext { }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetReflogCommitListDisplayStrings( viewModel.GetItems(), c.State().GetRepoState().GetScreenMode() != types.SCREEN_NORMAL, @@ -49,9 +49,11 @@ func NewReflogCommitsContext(c *ContextCommon) *ReflogCommitsContext { Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/context/remote_branches_context.go b/pkg/gui/context/remote_branches_context.go index fbc91f352..144a8c369 100644 --- a/pkg/gui/context/remote_branches_context.go +++ b/pkg/gui/context/remote_branches_context.go @@ -27,7 +27,7 @@ func NewRemoteBranchesContext( }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetRemoteBranchListDisplayStrings(viewModel.GetItems(), c.Modes().Diffing.Ref) } @@ -43,9 +43,11 @@ func NewRemoteBranchesContext( Focusable: true, Transient: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/context/remotes_context.go b/pkg/gui/context/remotes_context.go index f5e2a97ab..035fb2321 100644 --- a/pkg/gui/context/remotes_context.go +++ b/pkg/gui/context/remotes_context.go @@ -24,7 +24,7 @@ func NewRemotesContext(c *ContextCommon) *RemotesContext { }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetRemoteListDisplayStrings(viewModel.GetItems(), c.Modes().Diffing.Ref) } @@ -38,9 +38,11 @@ func NewRemotesContext(c *ContextCommon) *RemotesContext { Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/context/stash_context.go b/pkg/gui/context/stash_context.go index 7bd4740f8..2b86d945f 100644 --- a/pkg/gui/context/stash_context.go +++ b/pkg/gui/context/stash_context.go @@ -26,7 +26,7 @@ func NewStashContext( }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetStashEntryListDisplayStrings(viewModel.GetItems(), c.Modes().Diffing.Ref) } @@ -40,9 +40,11 @@ func NewStashContext( Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/context/sub_commits_context.go b/pkg/gui/context/sub_commits_context.go index 2643d294b..43cc513c0 100644 --- a/pkg/gui/context/sub_commits_context.go +++ b/pkg/gui/context/sub_commits_context.go @@ -36,7 +36,7 @@ func NewSubCommitsContext( limitCommits: true, } - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(startIdx int, endIdx int) [][]string { // This can happen if a sub-commits view is asked to be rerendered while // it is invisble; for example when switching screen modes, which // rerenders all views. @@ -72,7 +72,7 @@ func NewSubCommitsContext( c.UserConfig.Git.ParseEmoji, selectedCommitSha, startIdx, - length, + endIdx, shouldShowGraph(c), git_commands.NewNullBisectInfo(), false, @@ -93,8 +93,10 @@ func NewSubCommitsContext( Focusable: true, Transient: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, c: c, refreshViewportOnChange: true, }, diff --git a/pkg/gui/context/submodules_context.go b/pkg/gui/context/submodules_context.go index e97fa4f5c..2cffd82d6 100644 --- a/pkg/gui/context/submodules_context.go +++ b/pkg/gui/context/submodules_context.go @@ -21,7 +21,7 @@ func NewSubmodulesContext(c *ContextCommon) *SubmodulesContext { }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetSubmoduleListDisplayStrings(viewModel.GetItems()) } @@ -35,9 +35,11 @@ func NewSubmodulesContext(c *ContextCommon) *SubmodulesContext { Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/context/suggestions_context.go b/pkg/gui/context/suggestions_context.go index d8b650642..e3c1f5f26 100644 --- a/pkg/gui/context/suggestions_context.go +++ b/pkg/gui/context/suggestions_context.go @@ -36,7 +36,7 @@ func NewSuggestionsContext( return state.Suggestions } - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetSuggestionListDisplayStrings(state.Suggestions) } @@ -54,9 +54,11 @@ func NewSuggestionsContext( Focusable: true, HasUncontrolledBounds: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/context/tags_context.go b/pkg/gui/context/tags_context.go index 95b845a28..4a9f525f6 100644 --- a/pkg/gui/context/tags_context.go +++ b/pkg/gui/context/tags_context.go @@ -26,7 +26,7 @@ func NewTagsContext( }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetTagListDisplayStrings(viewModel.GetItems(), c.Modes().Diffing.Ref) } @@ -40,9 +40,11 @@ func NewTagsContext( Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/context/working_tree_context.go b/pkg/gui/context/working_tree_context.go index 390c03b33..0e0b8d72b 100644 --- a/pkg/gui/context/working_tree_context.go +++ b/pkg/gui/context/working_tree_context.go @@ -23,7 +23,7 @@ func NewWorkingTreeContext(c *ContextCommon) *WorkingTreeContext { c.UserConfig.Gui.ShowFileTree, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { lines := presentation.RenderFileTree(viewModel, c.Modes().Diffing.Ref, c.Model().Submodules) return lo.Map(lines, func(line string, _ int) []string { return []string{line} @@ -41,9 +41,11 @@ func NewWorkingTreeContext(c *ContextCommon) *WorkingTreeContext { Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } diff --git a/pkg/gui/context/worktrees_context.go b/pkg/gui/context/worktrees_context.go index 055467b74..c616dd49e 100644 --- a/pkg/gui/context/worktrees_context.go +++ b/pkg/gui/context/worktrees_context.go @@ -21,7 +21,7 @@ func NewWorktreesContext(c *ContextCommon) *WorktreesContext { }, ) - getDisplayStrings := func(startIdx int, length int) [][]string { + getDisplayStrings := func(_ int, _ int) [][]string { return presentation.GetWorktreeDisplayStrings( c.Tr, viewModel.GetFilteredList(), @@ -38,9 +38,11 @@ func NewWorktreesContext(c *ContextCommon) *WorktreesContext { Kind: types.SIDE_CONTEXT, Focusable: true, })), - list: viewModel, - getDisplayStrings: getDisplayStrings, - c: c, + ListRenderer: ListRenderer{ + list: viewModel, + getDisplayStrings: getDisplayStrings, + }, + c: c, }, } } diff --git a/pkg/gui/controllers/list_controller.go b/pkg/gui/controllers/list_controller.go index 6094561f4..025561993 100644 --- a/pkg/gui/controllers/list_controller.go +++ b/pkg/gui/controllers/list_controller.go @@ -83,9 +83,11 @@ func (self *ListController) handleLineChange(change int) error { // we're not constantly re-rendering the main view. if before != after { if change == -1 { - checkScrollUp(self.context.GetViewTrait(), self.c.UserConfig, before, after) + checkScrollUp(self.context.GetViewTrait(), self.c.UserConfig, + self.context.ModelIndexToViewIndex(before), self.context.ModelIndexToViewIndex(after)) } else if change == 1 { - checkScrollDown(self.context.GetViewTrait(), self.c.UserConfig, before, after) + checkScrollDown(self.context.GetViewTrait(), self.c.UserConfig, + self.context.ModelIndexToViewIndex(before), self.context.ModelIndexToViewIndex(after)) } return self.context.HandleFocus(types.OnFocusOpts{}) @@ -112,7 +114,7 @@ func (self *ListController) HandleGotoBottom() error { func (self *ListController) HandleClick(opts gocui.ViewMouseBindingOpts) error { prevSelectedLineIdx := self.context.GetList().GetSelectedLineIdx() - newSelectedLineIdx := opts.Y + newSelectedLineIdx := self.context.ViewIndexToModelIndex(opts.Y) alreadyFocused := self.isFocused() if err := self.pushContextIfNotFocused(); err != nil { diff --git a/pkg/gui/controllers/options_menu_action.go b/pkg/gui/controllers/options_menu_action.go index f757ba0a1..341fe8fb3 100644 --- a/pkg/gui/controllers/options_menu_action.go +++ b/pkg/gui/controllers/options_menu_action.go @@ -18,23 +18,33 @@ func (self *OptionsMenuAction) Call() error { return nil } - bindings := self.getBindings(ctx) + local, global, navigation := self.getBindings(ctx) - menuItems := lo.Map(bindings, func(binding *types.Binding, _ int) *types.MenuItem { - return &types.MenuItem{ - OpensMenu: binding.OpensMenu, - Label: binding.Description, - OnPress: func() error { - if binding.Handler == nil { - return nil + menuItems := []*types.MenuItem{} + + appendBindings := func(bindings []*types.Binding, section *types.MenuSection) { + menuItems = append(menuItems, + lo.Map(bindings, func(binding *types.Binding, _ int) *types.MenuItem { + return &types.MenuItem{ + OpensMenu: binding.OpensMenu, + Label: binding.Description, + OnPress: func() error { + if binding.Handler == nil { + return nil + } + + return binding.Handler() + }, + Key: binding.Key, + Tooltip: binding.Tooltip, + Section: section, } + })...) + } - return binding.Handler() - }, - Key: binding.Key, - Tooltip: binding.Tooltip, - } - }) + appendBindings(local, &types.MenuSection{Title: self.c.Tr.KeybindingsMenuSectionLocal, Column: 1}) + appendBindings(global, &types.MenuSection{Title: self.c.Tr.KeybindingsMenuSectionGlobal, Column: 1}) + appendBindings(navigation, &types.MenuSection{Title: self.c.Tr.KeybindingsMenuSectionNavigation, Column: 1}) return self.c.Menu(types.CreateMenuOptions{ Title: self.c.Tr.Keybindings, @@ -44,7 +54,8 @@ func (self *OptionsMenuAction) Call() error { }) } -func (self *OptionsMenuAction) getBindings(context types.Context) []*types.Binding { +// Returns three slices of bindings: local, global, and navigation +func (self *OptionsMenuAction) getBindings(context types.Context) ([]*types.Binding, []*types.Binding, []*types.Binding) { var bindingsGlobal, bindingsPanel, bindingsNavigation []*types.Binding bindings, _ := self.c.GetInitialKeybindingsWithCustomCommands() @@ -61,14 +72,7 @@ func (self *OptionsMenuAction) getBindings(context types.Context) []*types.Bindi } } - resultBindings := []*types.Binding{} - resultBindings = append(resultBindings, uniqueBindings(bindingsPanel)...) - // adding a separator between the panel-specific bindings and the other bindings - resultBindings = append(resultBindings, &types.Binding{}) - resultBindings = append(resultBindings, uniqueBindings(bindingsGlobal)...) - resultBindings = append(resultBindings, uniqueBindings(bindingsNavigation)...) - - return resultBindings + return uniqueBindings(bindingsPanel), uniqueBindings(bindingsGlobal), uniqueBindings(bindingsNavigation) } // We shouldn't really need to do this. We should define alternative keys for the same diff --git a/pkg/gui/menu_panel.go b/pkg/gui/menu_panel.go index 9f3b3a55d..88095584d 100644 --- a/pkg/gui/menu_panel.go +++ b/pkg/gui/menu_panel.go @@ -1,7 +1,8 @@ package gui import ( - "github.com/jesseduffield/lazygit/pkg/gui/presentation" + "fmt" + "github.com/jesseduffield/lazygit/pkg/gui/types" "github.com/jesseduffield/lazygit/pkg/theme" "github.com/jesseduffield/lazygit/pkg/utils" @@ -27,7 +28,7 @@ func (gui *Gui) createMenu(opts types.CreateMenuOptions) error { } if item.OpensMenu { - item.LabelColumns[0] = presentation.OpensMenuStyle(item.LabelColumns[0]) + item.LabelColumns[0] = fmt.Sprintf("%s...", item.LabelColumns[0]) } maxColumnSize = utils.Max(maxColumnSize, len(item.LabelColumns)) diff --git a/pkg/gui/presentation/commits.go b/pkg/gui/presentation/commits.go index 6661e9d30..6ae04f95a 100644 --- a/pkg/gui/presentation/commits.go +++ b/pkg/gui/presentation/commits.go @@ -52,7 +52,7 @@ func GetCommitListDisplayStrings( parseEmoji bool, selectedCommitSha string, startIdx int, - length int, + endIdx int, showGraph bool, bisectInfo *git_commands.BisectInfo, showYouAreHereLabel bool, @@ -68,11 +68,10 @@ func GetCommitListDisplayStrings( return nil } - end := utils.Min(startIdx+length, len(commits)) // this is where my non-TODO commits begin - rebaseOffset := utils.Min(indexOfFirstNonTODOCommit(commits), end) + rebaseOffset := utils.Min(indexOfFirstNonTODOCommit(commits), endIdx) - filteredCommits := commits[startIdx:end] + filteredCommits := commits[startIdx:endIdx] bisectBounds := getbisectBounds(commits, bisectInfo) @@ -85,8 +84,8 @@ func GetCommitListDisplayStrings( pipeSets := loadPipesets(commits[rebaseOffset:]) pipeSetOffset := utils.Max(startIdx-rebaseOffset, 0) - graphPipeSets := pipeSets[pipeSetOffset:utils.Max(end-rebaseOffset, 0)] - graphCommits := commits[graphOffset:end] + graphPipeSets := pipeSets[pipeSetOffset:utils.Max(endIdx-rebaseOffset, 0)] + graphCommits := commits[graphOffset:endIdx] graphLines := graph.RenderAux( graphPipeSets, graphCommits, diff --git a/pkg/gui/presentation/commits_test.go b/pkg/gui/presentation/commits_test.go index 65122961e..16f1de660 100644 --- a/pkg/gui/presentation/commits_test.go +++ b/pkg/gui/presentation/commits_test.go @@ -41,7 +41,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { parseEmoji bool selectedCommitSha string startIdx int - length int + endIdx int showGraph bool bisectInfo *git_commands.BisectInfo showYouAreHereLabel bool @@ -52,7 +52,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { testName: "no commits", commits: []*models.Commit{}, startIdx: 0, - length: 1, + endIdx: 1, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -66,7 +66,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit2", Sha: "sha2"}, }, startIdx: 0, - length: 2, + endIdx: 2, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -83,7 +83,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit2", Sha: "sha2"}, }, startIdx: 0, - length: 2, + endIdx: 2, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -110,7 +110,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { currentBranchName: "current-branch", hasUpdateRefConfig: true, startIdx: 0, - length: 4, + endIdx: 4, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -135,7 +135,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { currentBranchName: "current-branch", hasUpdateRefConfig: true, startIdx: 0, - length: 2, + endIdx: 2, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -158,7 +158,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { currentBranchName: "current-branch", hasUpdateRefConfig: false, startIdx: 0, - length: 2, + endIdx: 2, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -179,7 +179,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "some-branch", CommitHash: "sha2"}, }, startIdx: 0, - length: 3, + endIdx: 3, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -200,7 +200,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit5", Sha: "sha5", Parents: []string{"sha7"}}, }, startIdx: 0, - length: 5, + endIdx: 5, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -223,7 +223,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit5", Sha: "sha5", Parents: []string{"sha7"}}, }, startIdx: 0, - length: 5, + endIdx: 5, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -247,7 +247,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit5", Sha: "sha5", Parents: []string{"sha7"}}, }, startIdx: 1, - length: 10, + endIdx: 5, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -270,7 +270,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit5", Sha: "sha5", Parents: []string{"sha7"}}, }, startIdx: 3, - length: 2, + endIdx: 5, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -291,7 +291,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit5", Sha: "sha5", Parents: []string{"sha7"}}, }, startIdx: 0, - length: 2, + endIdx: 2, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -312,7 +312,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit5", Sha: "sha5", Parents: []string{"sha7"}}, }, startIdx: 4, - length: 2, + endIdx: 5, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -332,7 +332,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit5", Sha: "sha5", Parents: []string{"sha7"}}, }, startIdx: 0, - length: 2, + endIdx: 2, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -351,7 +351,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { {Name: "commit3", Sha: "sha3", Parents: []string{"sha4"}}, }, startIdx: 0, - length: 5, + endIdx: 3, showGraph: true, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -373,7 +373,7 @@ func TestGetCommitListDisplayStrings(t *testing.T) { timeFormat: "2006-01-02", shortTimeFormat: "3:04PM", startIdx: 0, - length: 2, + endIdx: 2, showGraph: false, bisectInfo: git_commands.NewNullBisectInfo(), cherryPickedCommitShaSet: set.New[string](), @@ -416,13 +416,14 @@ func TestGetCommitListDisplayStrings(t *testing.T) { s.parseEmoji, s.selectedCommitSha, s.startIdx, - s.length, + s.endIdx, s.showGraph, s.bisectInfo, s.showYouAreHereLabel, ) - renderedResult := utils.RenderDisplayStrings(result, nil) + renderedLines, _ := utils.RenderDisplayStrings(result, nil) + renderedResult := strings.Join(renderedLines, "\n") t.Logf("\n%s", renderedResult) assert.EqualValues(t, s.expected, renderedResult) diff --git a/pkg/gui/presentation/menu.go b/pkg/gui/presentation/menu.go deleted file mode 100644 index c43896c22..000000000 --- a/pkg/gui/presentation/menu.go +++ /dev/null @@ -1,7 +0,0 @@ -package presentation - -import "github.com/jesseduffield/lazygit/pkg/gui/style" - -func OpensMenuStyle(str string) string { - return style.FgMagenta.Sprintf("%s...", str) -} diff --git a/pkg/gui/types/common.go b/pkg/gui/types/common.go index 919c15b0e..0b6a8e430 100644 --- a/pkg/gui/types/common.go +++ b/pkg/gui/types/common.go @@ -177,6 +177,11 @@ type PromptOpts struct { Mask bool } +type MenuSection struct { + Title string + Column int // The column that this section title should be aligned with +} + type MenuItem struct { Label string @@ -194,6 +199,14 @@ type MenuItem struct { // The tooltip will be displayed upon highlighting the menu item Tooltip string + + // Can be used to group menu items into sections with headers. MenuItems + // with the same Section should be contiguous, and will automatically get a + // section header. If nil, the item is not part of a section. + // Note that pointer comparison is used to determine whether two menu items + // belong to the same section, so make sure all your items in a given + // section point to the same MenuSection instance. + Section *MenuSection } type Model struct { diff --git a/pkg/gui/types/context.go b/pkg/gui/types/context.go index dca5b042c..7aa07056e 100644 --- a/pkg/gui/types/context.go +++ b/pkg/gui/types/context.go @@ -124,6 +124,8 @@ type IListContext interface { GetSelectedItemId() string GetList() IList + ViewIndexToModelIndex(int) int + ModelIndexToViewIndex(int) int FocusLine() IsListContext() // used for type switch diff --git a/pkg/i18n/english.go b/pkg/i18n/english.go index 50b62549f..93909cbdc 100644 --- a/pkg/i18n/english.go +++ b/pkg/i18n/english.go @@ -388,6 +388,9 @@ type TranslationSet struct { Panel string Keybindings string KeybindingsLegend string + KeybindingsMenuSectionLocal string + KeybindingsMenuSectionGlobal string + KeybindingsMenuSectionNavigation string RenameBranch string SetUnsetUpstream string NewGitFlowBranchPrompt string @@ -986,6 +989,9 @@ func EnglishTranslationSet() TranslationSet { ConflictsResolved: "All merge conflicts resolved. Continue?", Continue: "Continue", Keybindings: "Keybindings", + KeybindingsMenuSectionLocal: "Local", + KeybindingsMenuSectionGlobal: "Global", + KeybindingsMenuSectionNavigation: "Navigation", RebasingTitle: "Rebase '{{.checkedOutBranch}}' onto '{{.ref}}'", RebasingFromBaseCommitTitle: "Rebase '{{.checkedOutBranch}}' from marked base onto '{{.ref}}'", SimpleRebase: "Simple rebase", diff --git a/pkg/integration/components/view_driver.go b/pkg/integration/components/view_driver.go index 6778dd8dd..b5e985155 100644 --- a/pkg/integration/components/view_driver.go +++ b/pkg/integration/components/view_driver.go @@ -456,13 +456,13 @@ func (self *ViewDriver) NavigateToLine(matcher *TextMatcher) *ViewDriver { self.IsFocused() view := self.getView() + lines := view.BufferLines() var matchIndex int self.t.assertWithRetries(func() (bool, string) { matchIndex = -1 var matches []string - lines := view.BufferLines() // first we look for a duplicate on the current screen. We won't bother looking beyond that though. for i, line := range lines { ok, _ := matcher.test(line) @@ -486,19 +486,38 @@ func (self *ViewDriver) NavigateToLine(matcher *TextMatcher) *ViewDriver { return self } if selectedLineIdx == matchIndex { - self.SelectedLine(matcher) - } else if selectedLineIdx < matchIndex { - for i := selectedLineIdx; i < matchIndex; i++ { - self.SelectNextItem() - } - self.SelectedLine(matcher) - } else { - for i := selectedLineIdx; i > matchIndex; i-- { - self.SelectPreviousItem() - } - self.SelectedLine(matcher) + return self.SelectedLine(matcher) } + // At this point we can't just take the difference of selected and matched + // index and press up or down arrow this many times. The reason is that + // there might be section headers between those lines, and these will be + // skipped when pressing up or down arrow. So we must keep pressing the + // arrow key in a loop, and check after each one whether we now reached the + // target line. + var maxNumKeyPresses int + var keyPress func() + if selectedLineIdx < matchIndex { + maxNumKeyPresses = matchIndex - selectedLineIdx + keyPress = func() { self.SelectNextItem() } + } else { + maxNumKeyPresses = selectedLineIdx - matchIndex + keyPress = func() { self.SelectPreviousItem() } + } + + for i := 0; i < maxNumKeyPresses; i++ { + keyPress() + idx, err := self.getSelectedLineIdx() + if err != nil { + self.t.fail(err.Error()) + return self + } + if ok, _ := matcher.test(lines[idx]); ok { + return self + } + } + + self.t.fail(fmt.Sprintf("Could not navigate to item matching: %s. Lines:\n%s", matcher.name(), strings.Join(lines, "\n"))) return self } diff --git a/pkg/utils/formatting.go b/pkg/utils/formatting.go index bf4f0debd..47e17b612 100644 --- a/pkg/utils/formatting.go +++ b/pkg/utils/formatting.go @@ -5,6 +5,7 @@ import ( "github.com/mattn/go-runewidth" "github.com/samber/lo" + "golang.org/x/exp/slices" ) type Alignment int @@ -36,10 +37,14 @@ func WithPadding(str string, padding int, alignment Alignment) string { // defaults to left-aligning each column. If you want to set the alignment of // each column, pass in a slice of Alignment values. -func RenderDisplayStrings(displayStringsArr [][]string, columnAlignments []Alignment) string { - displayStringsArr = excludeBlankColumns(displayStringsArr) +// returns a list of strings that should be joined with "\n", and an array of +// the column positions +func RenderDisplayStrings(displayStringsArr [][]string, columnAlignments []Alignment) ([]string, []int) { + displayStringsArr, columnAlignments, removedColumns := excludeBlankColumns(displayStringsArr, columnAlignments) padWidths := getPadWidths(displayStringsArr) columnConfigs := make([]ColumnConfig, len(padWidths)) + columnPositions := make([]int, len(padWidths)+1) + columnPositions[0] = 0 for i, padWidth := range padWidths { // gracefully handle when columnAlignments is shorter than padWidths alignment := AlignLeft @@ -51,16 +56,23 @@ func RenderDisplayStrings(displayStringsArr [][]string, columnAlignments []Align Width: padWidth, Alignment: alignment, } + columnPositions[i+1] = columnPositions[i] + padWidth + 1 } - output := getPaddedDisplayStrings(displayStringsArr, columnConfigs) - - return output + // Add the removed columns back into columnPositions (a removed column gets + // the same position as the following column); clients should be able to rely + // on them all to be there + for _, removedColumn := range removedColumns { + if removedColumn < len(columnPositions) { + columnPositions = slices.Insert(columnPositions, removedColumn, columnPositions[removedColumn]) + } + } + return getPaddedDisplayStrings(displayStringsArr, columnConfigs), columnPositions } // NOTE: this mutates the input slice for the sake of performance -func excludeBlankColumns(displayStringsArr [][]string) [][]string { +func excludeBlankColumns(displayStringsArr [][]string, columnAlignments []Alignment) ([][]string, []Alignment, []int) { if len(displayStringsArr) == 0 { - return displayStringsArr + return displayStringsArr, columnAlignments, []int{} } // if all rows share a blank column, we want to remove that column @@ -76,26 +88,33 @@ outer: } if len(toRemove) == 0 { - return displayStringsArr + return displayStringsArr, columnAlignments, []int{} } // remove the columns for i, strings := range displayStringsArr { for j := len(toRemove) - 1; j >= 0; j-- { - strings = append(strings[:toRemove[j]], strings[toRemove[j]+1:]...) + strings = slices.Delete(strings, toRemove[j], toRemove[j]+1) } displayStringsArr[i] = strings } - return displayStringsArr + for j := len(toRemove) - 1; j >= 0; j-- { + if columnAlignments != nil && toRemove[j] < len(columnAlignments) { + columnAlignments = slices.Delete(columnAlignments, toRemove[j], toRemove[j]+1) + } + } + + return displayStringsArr, columnAlignments, toRemove } -func getPaddedDisplayStrings(stringArrays [][]string, columnConfigs []ColumnConfig) string { - builder := strings.Builder{} - for i, stringArray := range stringArrays { +func getPaddedDisplayStrings(stringArrays [][]string, columnConfigs []ColumnConfig) []string { + result := make([]string, 0, len(stringArrays)) + for _, stringArray := range stringArrays { if len(stringArray) == 0 { continue } + builder := strings.Builder{} for j, columnConfig := range columnConfigs { if len(stringArray)-1 < j { continue @@ -107,12 +126,9 @@ func getPaddedDisplayStrings(stringArrays [][]string, columnConfigs []ColumnConf continue } builder.WriteString(stringArray[len(columnConfigs)]) - - if i < len(stringArrays)-1 { - builder.WriteString("\n") - } + result = append(result, builder.String()) } - return builder.String() + return result } func getPadWidths(stringArrays [][]string) []int { diff --git a/pkg/utils/formatting_test.go b/pkg/utils/formatting_test.go index b1777911b..3858fd2ec 100644 --- a/pkg/utils/formatting_test.go +++ b/pkg/utils/formatting_test.go @@ -1,6 +1,7 @@ package utils import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -157,66 +158,90 @@ func TestTruncateWithEllipsis(t *testing.T) { func TestRenderDisplayStrings(t *testing.T) { type scenario struct { - input [][]string - columnAlignments []Alignment - expected string + input [][]string + columnAlignments []Alignment + expectedOutput string + expectedColumnPositions []int } tests := []scenario{ { - input: [][]string{{""}, {""}}, - columnAlignments: nil, - expected: "", + input: [][]string{{""}, {""}}, + columnAlignments: nil, + expectedOutput: "", + expectedColumnPositions: []int{0, 0}, }, { - input: [][]string{{"a"}, {""}}, - columnAlignments: nil, - expected: "a\n", + input: [][]string{{"a"}, {""}}, + columnAlignments: nil, + expectedOutput: "a\n", + expectedColumnPositions: []int{0}, }, { - input: [][]string{{"a"}, {"b"}}, - columnAlignments: nil, - expected: "a\nb", + input: [][]string{{"a"}, {"b"}}, + columnAlignments: nil, + expectedOutput: "a\nb", + expectedColumnPositions: []int{0}, }, { - input: [][]string{{"a", "b"}, {"c", "d"}}, - columnAlignments: nil, - expected: "a b\nc d", + input: [][]string{{"a", "b"}, {"c", "d"}}, + columnAlignments: nil, + expectedOutput: "a b\nc d", + expectedColumnPositions: []int{0, 2}, }, { - input: [][]string{{"a", "", "c"}, {"d", "", "f"}}, - columnAlignments: nil, - expected: "a c\nd f", + input: [][]string{{"a", "", "c"}, {"d", "", "f"}}, + columnAlignments: nil, + expectedOutput: "a c\nd f", + expectedColumnPositions: []int{0, 2, 2}, }, { - input: [][]string{{"a", "", "c", ""}, {"d", "", "f", ""}}, - columnAlignments: nil, - expected: "a c\nd f", + input: [][]string{{"a", "", "c", ""}, {"d", "", "f", ""}}, + columnAlignments: nil, + expectedOutput: "a c\nd f", + expectedColumnPositions: []int{0, 2, 2}, }, { - input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, - columnAlignments: nil, - expected: "abc d\ne f", + input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, + columnAlignments: nil, + expectedOutput: "abc d\ne f", + expectedColumnPositions: []int{0, 4, 4}, }, { - input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, - columnAlignments: []Alignment{AlignLeft, AlignLeft}, // same as nil (default) - expected: "abc d\ne f", + input: [][]string{{"", "abc", "", "", "d", "e"}, {"", "f", "", "", "g", "h"}}, + columnAlignments: nil, + expectedOutput: "abc d e\nf g h", + expectedColumnPositions: []int{0, 0, 4, 4, 4, 6}, }, { - input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, - columnAlignments: []Alignment{AlignRight, AlignLeft}, - expected: "abc d\n e f", + input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, + columnAlignments: []Alignment{AlignLeft, AlignLeft}, // same as nil (default) + expectedOutput: "abc d\ne f", + expectedColumnPositions: []int{0, 4, 4}, }, { - input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, - columnAlignments: []Alignment{AlignRight}, // gracefully defaults unspecified columns to left-align - expected: "abc d\n e f", + input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, + columnAlignments: []Alignment{AlignRight, AlignLeft}, + expectedOutput: "abc d\n e f", + expectedColumnPositions: []int{0, 4, 4}, + }, + { + input: [][]string{{"a", "", "bcd", "efg", "h"}, {"i", "", "j", "k", "l"}}, + columnAlignments: []Alignment{AlignLeft, AlignLeft, AlignRight, AlignLeft}, + expectedOutput: "a bcd efg h\ni j k l", + expectedColumnPositions: []int{0, 2, 2, 6, 10}, + }, + { + input: [][]string{{"abc", "", "d", ""}, {"e", "", "f", ""}}, + columnAlignments: []Alignment{AlignRight}, // gracefully defaults unspecified columns to left-align + expectedOutput: "abc d\n e f", + expectedColumnPositions: []int{0, 4, 4}, }, } for _, test := range tests { - output := RenderDisplayStrings(test.input, test.columnAlignments) - assert.EqualValues(t, test.expected, output) + output, columnPositions := RenderDisplayStrings(test.input, test.columnAlignments) + assert.EqualValues(t, test.expectedOutput, strings.Join(output, "\n")) + assert.EqualValues(t, test.expectedColumnPositions, columnPositions) } }