diff --git a/oauthproxy.go b/oauthproxy.go
index 0cd7106a..0ef5060c 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -18,6 +18,7 @@ import (
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
@@ -116,7 +117,10 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, fmt.Errorf("error initialising session store: %v", err)
}
- templates := loadTemplates(opts.Templates.Path)
+ templates, err := app.LoadTemplates(opts.Templates.Path)
+ if err != nil {
+ return nil, fmt.Errorf("error loading templates: %v", err)
+ }
proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix)
upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler)
if err != nil {
diff --git a/pkg/app/app_suite_test.go b/pkg/app/app_suite_test.go
new file mode 100644
index 00000000..d2df0233
--- /dev/null
+++ b/pkg/app/app_suite_test.go
@@ -0,0 +1,17 @@
+package app
+
+import (
+ "testing"
+
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+func TestOptionsSuite(t *testing.T) {
+ logger.SetOutput(GinkgoWriter)
+ logger.SetErrOutput(GinkgoWriter)
+
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "App Suite")
+}
diff --git a/templates.go b/pkg/app/templates.go
similarity index 75%
rename from templates.go
rename to pkg/app/templates.go
index 15dcbc75..ef38c902 100644
--- a/templates.go
+++ b/pkg/app/templates.go
@@ -1,124 +1,20 @@
-package main
+package app
import (
+ "fmt"
"html/template"
- "path"
+ "os"
+ "path/filepath"
"strings"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
)
-func loadTemplates(dir string) *template.Template {
- if dir == "" {
- return getTemplates()
- }
- logger.Printf("using custom template directory %q", dir)
- funcMap := template.FuncMap{
- "ToUpper": strings.ToUpper,
- "ToLower": strings.ToLower,
- }
- t, err := template.New("").Funcs(funcMap).ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html"))
- if err != nil {
- logger.Fatalf("failed parsing template %s", err)
- }
- return t
-}
+const (
+ errorTemplateName = "error.html"
+ signInTemplateName = "sign_in.html"
-func getTemplates() *template.Template {
- t, err := template.New("foo").Parse(`{{define "sign_in.html"}}
-
-
-
-
-
- Sign In
-
-
-
-
-
-
-
-
-
-
-
- {{ if .CustomLogin }}
-
-
-
-
-
-
-
-
-
-
-{{end}}`)
- if err != nil {
- logger.Fatalf("failed parsing template %s", err)
- }
-
- t, err = t.Parse(`{{define "error.html"}}
+ defaultErrorTemplate = `{{define "error.html"}}
@@ -215,9 +111,147 @@ func getTemplates() *template.Template {
-{{end}}`)
+{{end}}`
+
+ defaultSignInTemplate = `{{define "sign_in.html"}}
+
+
+
+
+
+ Sign In
+
+
+
+
+
+
+
+
+
+
+
+ {{ if .CustomLogin }}
+
+
+
+
+
+
+
+
+
+
+{{end}}`
+)
+
+// LoadTemplates adds the Sign In and Error templates from the custom template
+// directory, or uses the defaults if they do not exist or the custom directory
+// is not provided.
+func LoadTemplates(customDir string) (*template.Template, error) {
+ t := template.New("").Funcs(template.FuncMap{
+ "ToUpper": strings.ToUpper,
+ "ToLower": strings.ToLower,
+ })
+ var err error
+ t, err = addTemplate(t, customDir, signInTemplateName, defaultSignInTemplate)
if err != nil {
- logger.Fatalf("failed parsing template %s", err)
+ return nil, fmt.Errorf("could not add Sign In template: %v", err)
}
- return t
+ t, err = addTemplate(t, customDir, errorTemplateName, defaultErrorTemplate)
+ if err != nil {
+ return nil, fmt.Errorf("could not add Error template: %v", err)
+ }
+
+ return t, nil
+}
+
+// addTemplate will add the template from the custom directory if provided,
+// else it will add the default template.
+func addTemplate(t *template.Template, customDir, fileName, defaultTemplate string) (*template.Template, error) {
+ filePath := filepath.Join(customDir, fileName)
+ if customDir != "" && isFile(filePath) {
+ t, err := t.ParseFiles(filePath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse template %s: %v", filePath, err)
+ }
+ return t, nil
+ }
+ t, err := t.Parse(defaultTemplate)
+ if err != nil {
+ // This should not happen.
+ // Default templates should be tested and so should never fail to parse.
+ logger.Panic("Could not parse defaultTemplate: ", err)
+ }
+ return t, nil
+}
+
+// isFile checks if the file exists and checks whether it is a regular file.
+// If either of these fail then it cannot be used as a template file.
+func isFile(fileName string) bool {
+ info, err := os.Stat(fileName)
+ if err != nil {
+ logger.Errorf("Could not load file %s: %v, will use default template", fileName, err)
+ return false
+ }
+ return info.Mode().IsRegular()
}
diff --git a/pkg/app/templates_test.go b/pkg/app/templates_test.go
new file mode 100644
index 00000000..66f38b7f
--- /dev/null
+++ b/pkg/app/templates_test.go
@@ -0,0 +1,199 @@
+package app
+
+import (
+ "bytes"
+ "html/template"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Templates", func() {
+ var customDir string
+
+ BeforeEach(func() {
+ var err error
+ customDir, err = ioutil.TempDir("", "oauth2-proxy-templates-test")
+ Expect(err).ToNot(HaveOccurred())
+
+ templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}`
+ signInFile := filepath.Join(customDir, signInTemplateName)
+ Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0666)).To(Succeed())
+ errorFile := filepath.Join(customDir, errorTemplateName)
+ Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0666)).To(Succeed())
+ })
+
+ AfterEach(func() {
+ Expect(os.RemoveAll(customDir)).To(Succeed())
+ })
+
+ Context("LoadTemplates", func() {
+ var data interface{}
+ var t *template.Template
+
+ BeforeEach(func() {
+ data = struct {
+ // For default templates
+ ProxyPrefix string
+ Redirect string
+ Footer string
+
+ // For default sign_in template
+ SignInMessage string
+ ProviderName string
+ CustomLogin bool
+
+ // For default error template
+ StatusCode int
+ Title string
+ Message string
+
+ // For custom templates
+ TestString string
+ }{
+ ProxyPrefix: "",
+ Redirect: "",
+ Footer: "