diff --git a/oauthproxy.go b/oauthproxy.go index 0cd7106a..0ef5060c 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -18,6 +18,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" @@ -116,7 +117,10 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("error initialising session store: %v", err) } - templates := loadTemplates(opts.Templates.Path) + templates, err := app.LoadTemplates(opts.Templates.Path) + if err != nil { + return nil, fmt.Errorf("error loading templates: %v", err) + } proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) if err != nil { diff --git a/pkg/app/app_suite_test.go b/pkg/app/app_suite_test.go new file mode 100644 index 00000000..d2df0233 --- /dev/null +++ b/pkg/app/app_suite_test.go @@ -0,0 +1,17 @@ +package app + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestOptionsSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + logger.SetErrOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "App Suite") +} diff --git a/templates.go b/pkg/app/templates.go similarity index 75% rename from templates.go rename to pkg/app/templates.go index 15dcbc75..ef38c902 100644 --- a/templates.go +++ b/pkg/app/templates.go @@ -1,124 +1,20 @@ -package main +package app import ( + "fmt" "html/template" - "path" + "os" + "path/filepath" "strings" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) -func loadTemplates(dir string) *template.Template { - if dir == "" { - return getTemplates() - } - logger.Printf("using custom template directory %q", dir) - funcMap := template.FuncMap{ - "ToUpper": strings.ToUpper, - "ToLower": strings.ToLower, - } - t, err := template.New("").Funcs(funcMap).ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) - if err != nil { - logger.Fatalf("failed parsing template %s", err) - } - return t -} +const ( + errorTemplateName = "error.html" + signInTemplateName = "sign_in.html" -func getTemplates() *template.Template { - t, err := template.New("foo").Parse(`{{define "sign_in.html"}} - - - - - - Sign In - - - - - - - -
-
-
- - {{ if .SignInMessage }} -

{{.SignInMessage}}

- {{ end}} - -
- - {{ if .CustomLogin }} -
- -
- - -
- -
- -
-
- -
- -
- -
-
- - {{ end }} -
-
-
- - - - - -{{end}}`) - if err != nil { - logger.Fatalf("failed parsing template %s", err) - } - - t, err = t.Parse(`{{define "error.html"}} + defaultErrorTemplate = `{{define "error.html"}} @@ -215,9 +111,147 @@ func getTemplates() *template.Template { -{{end}}`) +{{end}}` + + defaultSignInTemplate = `{{define "sign_in.html"}} + + + + + + Sign In + + + + + + + +
+
+
+ + {{ if .SignInMessage }} +

{{.SignInMessage}}

+ {{ end}} + +
+ + {{ if .CustomLogin }} +
+ +
+ + +
+ +
+ +
+
+ +
+ +
+ +
+
+ + {{ end }} +
+
+
+ + + + + +{{end}}` +) + +// LoadTemplates adds the Sign In and Error templates from the custom template +// directory, or uses the defaults if they do not exist or the custom directory +// is not provided. +func LoadTemplates(customDir string) (*template.Template, error) { + t := template.New("").Funcs(template.FuncMap{ + "ToUpper": strings.ToUpper, + "ToLower": strings.ToLower, + }) + var err error + t, err = addTemplate(t, customDir, signInTemplateName, defaultSignInTemplate) if err != nil { - logger.Fatalf("failed parsing template %s", err) + return nil, fmt.Errorf("could not add Sign In template: %v", err) } - return t + t, err = addTemplate(t, customDir, errorTemplateName, defaultErrorTemplate) + if err != nil { + return nil, fmt.Errorf("could not add Error template: %v", err) + } + + return t, nil +} + +// addTemplate will add the template from the custom directory if provided, +// else it will add the default template. +func addTemplate(t *template.Template, customDir, fileName, defaultTemplate string) (*template.Template, error) { + filePath := filepath.Join(customDir, fileName) + if customDir != "" && isFile(filePath) { + t, err := t.ParseFiles(filePath) + if err != nil { + return nil, fmt.Errorf("failed to parse template %s: %v", filePath, err) + } + return t, nil + } + t, err := t.Parse(defaultTemplate) + if err != nil { + // This should not happen. + // Default templates should be tested and so should never fail to parse. + logger.Panic("Could not parse defaultTemplate: ", err) + } + return t, nil +} + +// isFile checks if the file exists and checks whether it is a regular file. +// If either of these fail then it cannot be used as a template file. +func isFile(fileName string) bool { + info, err := os.Stat(fileName) + if err != nil { + logger.Errorf("Could not load file %s: %v, will use default template", fileName, err) + return false + } + return info.Mode().IsRegular() } diff --git a/pkg/app/templates_test.go b/pkg/app/templates_test.go new file mode 100644 index 00000000..66f38b7f --- /dev/null +++ b/pkg/app/templates_test.go @@ -0,0 +1,199 @@ +package app + +import ( + "bytes" + "html/template" + "io/ioutil" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Templates", func() { + var customDir string + + BeforeEach(func() { + var err error + customDir, err = ioutil.TempDir("", "oauth2-proxy-templates-test") + Expect(err).ToNot(HaveOccurred()) + + templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` + signInFile := filepath.Join(customDir, signInTemplateName) + Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0666)).To(Succeed()) + errorFile := filepath.Join(customDir, errorTemplateName) + Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0666)).To(Succeed()) + }) + + AfterEach(func() { + Expect(os.RemoveAll(customDir)).To(Succeed()) + }) + + Context("LoadTemplates", func() { + var data interface{} + var t *template.Template + + BeforeEach(func() { + data = struct { + // For default templates + ProxyPrefix string + Redirect string + Footer string + + // For default sign_in template + SignInMessage string + ProviderName string + CustomLogin bool + + // For default error template + StatusCode int + Title string + Message string + + // For custom templates + TestString string + }{ + ProxyPrefix: "", + Redirect: "", + Footer: "