diff --git a/formatters/html/html.go b/formatters/html/html.go index 3f68a7d..0faa9c3 100644 --- a/formatters/html/html.go +++ b/formatters/html/html.go @@ -26,7 +26,18 @@ func WithClasses() Option { return func(f *Formatter) { f.Classes = true } } func TabWidth(width int) Option { return func(f *Formatter) { f.tabWidth = width } } // PreventSurroundingPre prevents the surrounding pre tags around the generated code -func PreventSurroundingPre() Option { return func(f *Formatter) { f.preventSurroundingPre = true } } +func PreventSurroundingPre() Option { + return func(f *Formatter) { + f.preWrapper = nopPreWrapper + } +} + +// WithPreWrapper allows control of the surrounding pre tags. +func WithPreWrapper(wrapper PreWrapper) Option { + return func(f *Formatter) { + f.preWrapper = wrapper + } +} // WithLineNumbers formats output with line numbers. func WithLineNumbers() Option { @@ -64,6 +75,7 @@ func BaseLineNumber(n int) Option { func New(options ...Option) *Formatter { f := &Formatter{ baseLineNumber: 1, + preWrapper: defaultPreWrapper, } for _, option := range options { option(f) @@ -71,17 +83,57 @@ func New(options ...Option) *Formatter { return f } +// PreWrapper defines the operations supported in WithPreWrapper. +type PreWrapper interface { + // Start is called to write a start
 element.
+	// The code flag tells whether this block surrounds
+	// highlighted code. This will be false when surrounding
+	// line numbers.
+	Start(code bool, styleAttr string) string
+
+	// End is called to write the end 
element. + End(code bool) string +} + +type preWrapper struct { + start func(code bool, styleAttr string) string + end func(code bool) string +} + +func (p preWrapper) Start(code bool, styleAttr string) string { + return p.start(code, styleAttr) +} + +func (p preWrapper) End(code bool) string { + return p.end(code) +} + +var ( + nopPreWrapper = preWrapper{ + start: func(code bool, styleAttr string) string { return "" }, + end: func(code bool) string { return "" }, + } + defaultPreWrapper = preWrapper{ + start: func(code bool, styleAttr string) string { + return fmt.Sprintf("", styleAttr) + }, + end: func(code bool) string { + return "" + }, + } +) + // Formatter that generates HTML. type Formatter struct { - standalone bool - prefix string - Classes bool // Exported field to detect when classes are being used - preventSurroundingPre bool - tabWidth int - lineNumbers bool - lineNumbersInTable bool - highlightRanges highlightRanges - baseLineNumber int + standalone bool + prefix string + Classes bool // Exported field to detect when classes are being used + preWrapper PreWrapper + tabWidth int + lineNumbers bool + lineNumbersInTable bool + highlightRanges highlightRanges + baseLineNumber int } type highlightRanges [][2]int @@ -129,9 +181,7 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma. fmt.Fprintf(w, "\n", f.styleAttr(css, chroma.Background)) fmt.Fprintf(w, "", f.styleAttr(css, chroma.LineTable)) fmt.Fprintf(w, "\n", f.styleAttr(css, chroma.LineTableTD)) - if !f.preventSurroundingPre { - fmt.Fprintf(w, "", f.styleAttr(css, chroma.Background)) - } + fmt.Fprintf(w, f.preWrapper.Start(false, f.styleAttr(css, chroma.Background))) for index := range lines { line := f.baseLineNumber + index highlight, next := f.shouldHighlight(highlightIndex, line) @@ -148,16 +198,13 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma. fmt.Fprintf(w, "") } } - if !f.preventSurroundingPre { - fmt.Fprint(w, "") - } + fmt.Fprint(w, f.preWrapper.End(false)) fmt.Fprint(w, "\n") fmt.Fprintf(w, "\n", f.styleAttr(css, chroma.LineTableTD, "width:100%")) } - if !f.preventSurroundingPre { - fmt.Fprintf(w, "", f.styleAttr(css, chroma.Background)) - } + fmt.Fprintf(w, f.preWrapper.Start(true, f.styleAttr(css, chroma.Background))) + highlightIndex = 0 for index, tokens := range lines { // 1-based line number. @@ -187,9 +234,7 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma. } } - if !f.preventSurroundingPre { - fmt.Fprint(w, "") - } + fmt.Fprintf(w, f.preWrapper.End(true)) if wrapInTable { fmt.Fprint(w, "\n") diff --git a/formatters/html/html_test.go b/formatters/html/html_test.go index 6f10f29..2d94c83 100644 --- a/formatters/html/html_test.go +++ b/formatters/html/html_test.go @@ -2,6 +2,7 @@ package html import ( "bytes" + "fmt" "io/ioutil" "strings" "testing" @@ -106,3 +107,53 @@ func TestTableLineNumberNewlines(t *testing.T) { 4 `) } + +func TestWithPreWrapper(t *testing.T) { + wrapper := preWrapper{ + start: func(code bool, styleAttr string) string { + return fmt.Sprintf("", styleAttr, code) + }, + end: func(code bool) string { + return fmt.Sprintf("") + }, + } + + format := func(f *Formatter) string { + it, err := lexers.Get("bash").Tokenise(nil, "echo FOO") + assert.NoError(t, err) + + var buf bytes.Buffer + err = f.Format(&buf, styles.Fallback, it) + assert.NoError(t, err) + + return buf.String() + } + + t.Run("Regular", func(t *testing.T) { + s := format(New(WithClasses())) + assert.Equal(t, s, `
echo FOO
`) + }) + + t.Run("PreventSurroundingPre", func(t *testing.T) { + s := format(New(PreventSurroundingPre(), WithClasses())) + assert.Equal(t, s, `echo FOO`) + }) + + t.Run("Wrapper", func(t *testing.T) { + s := format(New(WithPreWrapper(wrapper), WithClasses())) + assert.Equal(t, s, `echo FOO`) + }) + + t.Run("Wrapper, LineNumbersInTable", func(t *testing.T) { + s := format(New(WithPreWrapper(wrapper), WithClasses(), WithLineNumbers(), LineNumbersInTable())) + + assert.Equal(t, s, `
+ +
+1 + +echo FOO
+
+`) + }) +}