diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d522260..5225d798 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ ## Changes since v5.1.1 +- [#514](https://github.com/oauth2-proxy/oauth2-proxy/pull/514) Add basic string functions to templates - [#524](https://github.com/oauth2-proxy/oauth2-proxy/pull/524) Sign cookies with SHA256 (@NickMeves) - [#515](https://github.com/oauth2-proxy/oauth2-proxy/pull/515) Drop configure script in favour of native Makefile env and checks (@JoelSpeed) - [#487](https://github.com/oauth2-proxy/oauth2-proxy/pull/487) Switch flags to PFlag to remove StringArray (@JoelSpeed) diff --git a/templates.go b/templates.go index b0e9014d..39e9e14e 100644 --- a/templates.go +++ b/templates.go @@ -3,6 +3,7 @@ package main import ( "html/template" "path" + "strings" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" ) @@ -12,7 +13,11 @@ func loadTemplates(dir string) *template.Template { return getTemplates() } logger.Printf("using custom template directory %q", dir) - t, err := template.New("").ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) + funcMap := template.FuncMap{ + "ToUpper": strings.ToUpper, + "ToLower": strings.ToLower, + } + t, err := template.New("").Funcs(funcMap).ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) if err != nil { logger.Fatalf("failed parsing template %s", err) } diff --git a/templates_test.go b/templates_test.go index 49e1a9dd..63757a0c 100644 --- a/templates_test.go +++ b/templates_test.go @@ -1,11 +1,61 @@ package main import ( + "bytes" + "io/ioutil" + "log" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" ) +func TestLoadTemplates(t *testing.T) { + data := struct { + TestString string + }{ + TestString: "Testing", + } + + templates := loadTemplates("") + assert.NotEqual(t, templates, nil) + + var defaultSignin bytes.Buffer + templates.ExecuteTemplate(&defaultSignin, "sign_in.html", data) + assert.Equal(t, "\n", defaultSignin.String()[0:16]) + + var defaultError bytes.Buffer + templates.ExecuteTemplate(&defaultError, "error.html", data) + assert.Equal(t, "\n", defaultError.String()[0:16]) + + dir, err := ioutil.TempDir("", "templatetest") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(dir) + + templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` + signInFile := filepath.Join(dir, "sign_in.html") + if err := ioutil.WriteFile(signInFile, []byte(templateHTML), 0666); err != nil { + log.Fatal(err) + } + errorFile := filepath.Join(dir, "error.html") + if err := ioutil.WriteFile(errorFile, []byte(templateHTML), 0666); err != nil { + log.Fatal(err) + } + templates = loadTemplates(dir) + assert.NotEqual(t, templates, nil) + + var sitpl bytes.Buffer + templates.ExecuteTemplate(&sitpl, "sign_in.html", data) + assert.Equal(t, "Testing testing TESTING", sitpl.String()) + + var errtpl bytes.Buffer + templates.ExecuteTemplate(&errtpl, "error.html", data) + assert.Equal(t, "Testing testing TESTING", errtpl.String()) +} + func TestTemplatesCompile(t *testing.T) { templates := getTemplates() assert.NotEqual(t, templates, nil)