From d3926cc0e14847e048a20ae51d0556b81d36a2ba Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?=
 <bjorn.erik.pedersen@gmail.com>
Date: Mon, 18 Nov 2019 12:31:49 +0100
Subject: [PATCH] Add WithPreWrapper option

---
 formatters/html/html.go      | 89 +++++++++++++++++++++++++++---------
 formatters/html/html_test.go | 51 +++++++++++++++++++++
 2 files changed, 118 insertions(+), 22 deletions(-)

diff --git a/formatters/html/html.go b/formatters/html/html.go
index 3f68a7d..0faa9c3 100644
--- a/formatters/html/html.go
+++ b/formatters/html/html.go
@@ -26,7 +26,18 @@ func WithClasses() Option { return func(f *Formatter) { f.Classes = true } }
 func TabWidth(width int) Option { return func(f *Formatter) { f.tabWidth = width } }
 
 // PreventSurroundingPre prevents the surrounding pre tags around the generated code
-func PreventSurroundingPre() Option { return func(f *Formatter) { f.preventSurroundingPre = true } }
+func PreventSurroundingPre() Option {
+	return func(f *Formatter) {
+		f.preWrapper = nopPreWrapper
+	}
+}
+
+// WithPreWrapper allows control of the surrounding pre tags.
+func WithPreWrapper(wrapper PreWrapper) Option {
+	return func(f *Formatter) {
+		f.preWrapper = wrapper
+	}
+}
 
 // WithLineNumbers formats output with line numbers.
 func WithLineNumbers() Option {
@@ -64,6 +75,7 @@ func BaseLineNumber(n int) Option {
 func New(options ...Option) *Formatter {
 	f := &Formatter{
 		baseLineNumber: 1,
+		preWrapper:     defaultPreWrapper,
 	}
 	for _, option := range options {
 		option(f)
@@ -71,17 +83,57 @@ func New(options ...Option) *Formatter {
 	return f
 }
 
+// PreWrapper defines the operations supported in WithPreWrapper.
+type PreWrapper interface {
+	// Start is called to write a start <pre> element.
+	// The code flag tells whether this block surrounds
+	// highlighted code. This will be false when surrounding
+	// line numbers.
+	Start(code bool, styleAttr string) string
+
+	// End is called to write the end </pre> element.
+	End(code bool) string
+}
+
+type preWrapper struct {
+	start func(code bool, styleAttr string) string
+	end   func(code bool) string
+}
+
+func (p preWrapper) Start(code bool, styleAttr string) string {
+	return p.start(code, styleAttr)
+}
+
+func (p preWrapper) End(code bool) string {
+	return p.end(code)
+}
+
+var (
+	nopPreWrapper = preWrapper{
+		start: func(code bool, styleAttr string) string { return "" },
+		end:   func(code bool) string { return "" },
+	}
+	defaultPreWrapper = preWrapper{
+		start: func(code bool, styleAttr string) string {
+			return fmt.Sprintf("<pre%s>", styleAttr)
+		},
+		end: func(code bool) string {
+			return "</pre>"
+		},
+	}
+)
+
 // Formatter that generates HTML.
 type Formatter struct {
-	standalone            bool
-	prefix                string
-	Classes               bool // Exported field to detect when classes are being used
-	preventSurroundingPre bool
-	tabWidth              int
-	lineNumbers           bool
-	lineNumbersInTable    bool
-	highlightRanges       highlightRanges
-	baseLineNumber        int
+	standalone         bool
+	prefix             string
+	Classes            bool // Exported field to detect when classes are being used
+	preWrapper         PreWrapper
+	tabWidth           int
+	lineNumbers        bool
+	lineNumbersInTable bool
+	highlightRanges    highlightRanges
+	baseLineNumber     int
 }
 
 type highlightRanges [][2]int
@@ -129,9 +181,7 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma.
 		fmt.Fprintf(w, "<div%s>\n", f.styleAttr(css, chroma.Background))
 		fmt.Fprintf(w, "<table%s><tr>", f.styleAttr(css, chroma.LineTable))
 		fmt.Fprintf(w, "<td%s>\n", f.styleAttr(css, chroma.LineTableTD))
-		if !f.preventSurroundingPre {
-			fmt.Fprintf(w, "<pre%s>", f.styleAttr(css, chroma.Background))
-		}
+		fmt.Fprintf(w, f.preWrapper.Start(false, f.styleAttr(css, chroma.Background)))
 		for index := range lines {
 			line := f.baseLineNumber + index
 			highlight, next := f.shouldHighlight(highlightIndex, line)
@@ -148,16 +198,13 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma.
 				fmt.Fprintf(w, "</span>")
 			}
 		}
-		if !f.preventSurroundingPre {
-			fmt.Fprint(w, "</pre>")
-		}
+		fmt.Fprint(w, f.preWrapper.End(false))
 		fmt.Fprint(w, "</td>\n")
 		fmt.Fprintf(w, "<td%s>\n", f.styleAttr(css, chroma.LineTableTD, "width:100%"))
 	}
 
-	if !f.preventSurroundingPre {
-		fmt.Fprintf(w, "<pre%s>", f.styleAttr(css, chroma.Background))
-	}
+	fmt.Fprintf(w, f.preWrapper.Start(true, f.styleAttr(css, chroma.Background)))
+
 	highlightIndex = 0
 	for index, tokens := range lines {
 		// 1-based line number.
@@ -187,9 +234,7 @@ func (f *Formatter) writeHTML(w io.Writer, style *chroma.Style, tokens []chroma.
 		}
 	}
 
-	if !f.preventSurroundingPre {
-		fmt.Fprint(w, "</pre>")
-	}
+	fmt.Fprintf(w, f.preWrapper.End(true))
 
 	if wrapInTable {
 		fmt.Fprint(w, "</td></tr></table>\n")
diff --git a/formatters/html/html_test.go b/formatters/html/html_test.go
index 6f10f29..2d94c83 100644
--- a/formatters/html/html_test.go
+++ b/formatters/html/html_test.go
@@ -2,6 +2,7 @@ package html
 
 import (
 	"bytes"
+	"fmt"
 	"io/ioutil"
 	"strings"
 	"testing"
@@ -106,3 +107,53 @@ func TestTableLineNumberNewlines(t *testing.T) {
 </span><span class="lnt">4
 </span>`)
 }
+
+func TestWithPreWrapper(t *testing.T) {
+	wrapper := preWrapper{
+		start: func(code bool, styleAttr string) string {
+			return fmt.Sprintf("<foo%s id=\"code-%t\">", styleAttr, code)
+		},
+		end: func(code bool) string {
+			return fmt.Sprintf("</foo>")
+		},
+	}
+
+	format := func(f *Formatter) string {
+		it, err := lexers.Get("bash").Tokenise(nil, "echo FOO")
+		assert.NoError(t, err)
+
+		var buf bytes.Buffer
+		err = f.Format(&buf, styles.Fallback, it)
+		assert.NoError(t, err)
+
+		return buf.String()
+	}
+
+	t.Run("Regular", func(t *testing.T) {
+		s := format(New(WithClasses()))
+		assert.Equal(t, s, `<pre class="chroma"><span class="nb">echo</span> FOO</pre>`)
+	})
+
+	t.Run("PreventSurroundingPre", func(t *testing.T) {
+		s := format(New(PreventSurroundingPre(), WithClasses()))
+		assert.Equal(t, s, `<span class="nb">echo</span> FOO`)
+	})
+
+	t.Run("Wrapper", func(t *testing.T) {
+		s := format(New(WithPreWrapper(wrapper), WithClasses()))
+		assert.Equal(t, s, `<foo class="chroma" id="code-true"><span class="nb">echo</span> FOO</foo>`)
+	})
+
+	t.Run("Wrapper, LineNumbersInTable", func(t *testing.T) {
+		s := format(New(WithPreWrapper(wrapper), WithClasses(), WithLineNumbers(), LineNumbersInTable()))
+
+		assert.Equal(t, s, `<div class="chroma">
+<table class="lntable"><tr><td class="lntd">
+<foo class="chroma" id="code-false"><span class="lnt">1
+</span></foo></td>
+<td class="lntd">
+<foo class="chroma" id="code-true"><span class="nb">echo</span> FOO</foo></td></tr></table>
+</div>
+`)
+	})
+}