diff --git a/formatters/html/html.go b/formatters/html/html.go
index 3f68a7d..0faa9c3 100644
--- a/formatters/html/html.go
+++ b/formatters/html/html.go
@@ -26,7 +26,18 @@ func WithClasses() Option { return func(f *Formatter) { f.Classes = true } }
func TabWidth(width int) Option { return func(f *Formatter) { f.tabWidth = width } }
// PreventSurroundingPre prevents the surrounding pre tags around the generated code
-func PreventSurroundingPre() Option { return func(f *Formatter) { f.preventSurroundingPre = true } }
+func PreventSurroundingPre() Option {
+ return func(f *Formatter) {
+ f.preWrapper = nopPreWrapper
+ }
+}
+
+// WithPreWrapper allows control of the surrounding pre tags.
+func WithPreWrapper(wrapper PreWrapper) Option {
+ return func(f *Formatter) {
+ f.preWrapper = wrapper
+ }
+}
// WithLineNumbers formats output with line numbers.
func WithLineNumbers() Option {
@@ -64,6 +75,7 @@ func BaseLineNumber(n int) Option {
func New(options ...Option) *Formatter {
f := &Formatter{
baseLineNumber: 1,
+ preWrapper: defaultPreWrapper,
}
for _, option := range options {
option(f)
@@ -71,17 +83,57 @@ func New(options ...Option) *Formatter {
return f
}
+// PreWrapper defines the operations supported in WithPreWrapper.
+type PreWrapper interface {
+ // Start is called to write a start
\n", f.styleAttr(css, chroma.Background))
fmt.Fprintf(w, "
", f.styleAttr(css, chroma.LineTable))
fmt.Fprintf(w, "\n", f.styleAttr(css, chroma.LineTableTD))
- if !f.preventSurroundingPre {
- fmt.Fprintf(w, "", f.styleAttr(css, chroma.Background))
- }
+ fmt.Fprintf(w, f.preWrapper.Start(false, f.styleAttr(css, chroma.Background)))
for index := range lines {
line := f.baseLineNumber + index
highlight, next := f.shouldHighlight(highlightIndex, line)
@@ -148,16 +198,13 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma.
fmt.Fprintf(w, "")
}
}
- if !f.preventSurroundingPre {
- fmt.Fprint(w, " ")
- }
+ fmt.Fprint(w, f.preWrapper.End(false))
fmt.Fprint(w, " | \n")
fmt.Fprintf(w, "\n", f.styleAttr(css, chroma.LineTableTD, "width:100%"))
}
- if !f.preventSurroundingPre {
- fmt.Fprintf(w, "", f.styleAttr(css, chroma.Background))
- }
+ fmt.Fprintf(w, f.preWrapper.Start(true, f.styleAttr(css, chroma.Background)))
+
highlightIndex = 0
for index, tokens := range lines {
// 1-based line number.
@@ -187,9 +234,7 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma.
}
}
- if !f.preventSurroundingPre {
- fmt.Fprint(w, " ")
- }
+ fmt.Fprintf(w, f.preWrapper.End(true))
if wrapInTable {
fmt.Fprint(w, " |
\n")
diff --git a/formatters/html/html_test.go b/formatters/html/html_test.go
index 6f10f29..2d94c83 100644
--- a/formatters/html/html_test.go
+++ b/formatters/html/html_test.go
@@ -2,6 +2,7 @@ package html
import (
"bytes"
+ "fmt"
"io/ioutil"
"strings"
"testing"
@@ -106,3 +107,53 @@ func TestTableLineNumberNewlines(t *testing.T) {
4
`)
}
+
+func TestWithPreWrapper(t *testing.T) {
+ wrapper := preWrapper{
+ start: func(code bool, styleAttr string) string {
+ return fmt.Sprintf("
", styleAttr, code)
+ },
+ end: func(code bool) string {
+ return fmt.Sprintf("")
+ },
+ }
+
+ format := func(f *Formatter) string {
+ it, err := lexers.Get("bash").Tokenise(nil, "echo FOO")
+ assert.NoError(t, err)
+
+ var buf bytes.Buffer
+ err = f.Format(&buf, styles.Fallback, it)
+ assert.NoError(t, err)
+
+ return buf.String()
+ }
+
+ t.Run("Regular", func(t *testing.T) {
+ s := format(New(WithClasses()))
+ assert.Equal(t, s, `
echo FOO
`)
+ })
+
+ t.Run("PreventSurroundingPre", func(t *testing.T) {
+ s := format(New(PreventSurroundingPre(), WithClasses()))
+ assert.Equal(t, s, `
echo FOO`)
+ })
+
+ t.Run("Wrapper", func(t *testing.T) {
+ s := format(New(WithPreWrapper(wrapper), WithClasses()))
+ assert.Equal(t, s, `
echo FOO`)
+ })
+
+ t.Run("Wrapper, LineNumbersInTable", func(t *testing.T) {
+ s := format(New(WithPreWrapper(wrapper), WithClasses(), WithLineNumbers(), LineNumbersInTable()))
+
+ assert.Equal(t, s, `
+`)
+ })
+}