From 0151ca11f6e4d2a97a384a7f77292a0aa5ddd6e6 Mon Sep 17 00:00:00 2001
From: Joel  Speed <joel.speed@hotmail.co.uk>
Date: Sat, 6 Feb 2021 18:56:31 +0000
Subject: [PATCH] Move template loading to app package

---
 oauthproxy.go                        |   6 +-
 pkg/app/app_suite_test.go            |  17 ++
 templates.go => pkg/app/templates.go | 264 +++++++++++++++------------
 pkg/app/templates_test.go            | 199 ++++++++++++++++++++
 templates_test.go                    |  62 -------
 5 files changed, 370 insertions(+), 178 deletions(-)
 create mode 100644 pkg/app/app_suite_test.go
 rename templates.go => pkg/app/templates.go (75%)
 create mode 100644 pkg/app/templates_test.go
 delete mode 100644 templates_test.go

diff --git a/oauthproxy.go b/oauthproxy.go
index 0cd7106a..0ef5060c 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -18,6 +18,7 @@ import (
 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
+	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app"
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
@@ -116,7 +117,10 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
 		return nil, fmt.Errorf("error initialising session store: %v", err)
 	}
 
-	templates := loadTemplates(opts.Templates.Path)
+	templates, err := app.LoadTemplates(opts.Templates.Path)
+	if err != nil {
+		return nil, fmt.Errorf("error loading templates: %v", err)
+	}
 	proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix)
 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler)
 	if err != nil {
diff --git a/pkg/app/app_suite_test.go b/pkg/app/app_suite_test.go
new file mode 100644
index 00000000..d2df0233
--- /dev/null
+++ b/pkg/app/app_suite_test.go
@@ -0,0 +1,17 @@
+package app
+
+import (
+	"testing"
+
+	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
+	. "github.com/onsi/ginkgo"
+	. "github.com/onsi/gomega"
+)
+
+func TestOptionsSuite(t *testing.T) {
+	logger.SetOutput(GinkgoWriter)
+	logger.SetErrOutput(GinkgoWriter)
+
+	RegisterFailHandler(Fail)
+	RunSpecs(t, "App Suite")
+}
diff --git a/templates.go b/pkg/app/templates.go
similarity index 75%
rename from templates.go
rename to pkg/app/templates.go
index 15dcbc75..ef38c902 100644
--- a/templates.go
+++ b/pkg/app/templates.go
@@ -1,124 +1,20 @@
-package main
+package app
 
 import (
+	"fmt"
 	"html/template"
-	"path"
+	"os"
+	"path/filepath"
 	"strings"
 
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
 )
 
-func loadTemplates(dir string) *template.Template {
-	if dir == "" {
-		return getTemplates()
-	}
-	logger.Printf("using custom template directory %q", dir)
-	funcMap := template.FuncMap{
-		"ToUpper": strings.ToUpper,
-		"ToLower": strings.ToLower,
-	}
-	t, err := template.New("").Funcs(funcMap).ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html"))
-	if err != nil {
-		logger.Fatalf("failed parsing template %s", err)
-	}
-	return t
-}
+const (
+	errorTemplateName  = "error.html"
+	signInTemplateName = "sign_in.html"
 
-func getTemplates() *template.Template {
-	t, err := template.New("foo").Parse(`{{define "sign_in.html"}}
-<!DOCTYPE html>
-<html lang="en" charset="utf-8">
-  <head>
-    <meta charset="utf-8">
-    <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
-    <title>Sign In</title>
-    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css">
-
-    <style>
-      body {
-        height: 100vh;
-      }
-      .sign-in-box {
-        max-width: 400px;
-        margin: 1.25rem auto;
-      }
-      footer a {
-        text-decoration: underline;
-      }
-    </style>
-
-    <script>
-      if (window.location.hash) {
-        (function() {
-          var inputs = document.getElementsByName('rd');
-          for (var i = 0; i < inputs.length; i++) {
-            // Add hash, but make sure it is only added once
-            var idx = inputs[i].value.indexOf('#');
-            if (idx >= 0) {
-              // Remove existing hash from URL
-              inputs[i].value = inputs[i].value.substr(0, idx);
-            }
-            inputs[i].value += window.location.hash;
-          }
-        })();
-      }
-    </script>
-  </head>
-  <body class="has-background-light">
-  <section class="section">
-    <div class="box block sign-in-box has-text-centered">
-      <form method="GET" action="{{.ProxyPrefix}}/start">
-        <input type="hidden" name="rd" value="{{.Redirect}}">
-          {{ if .SignInMessage }}
-          <p class="block">{{.SignInMessage}}</p>
-          {{ end}}
-          <button type="submit" class="button block is-primary">Sign in with {{.ProviderName}}</button>
-      </form>
-
-      {{ if .CustomLogin }}
-      <hr>
-
-      <form method="POST" action="{{.ProxyPrefix}}/sign_in" class="block">
-        <input type="hidden" name="rd" value="{{.Redirect}}">
-
-        <div class="field">
-          <label class="label" for="username">Username</label>
-          <div class="control">
-            <input class="input" type="email" placeholder="e.g. userx@example.com"  name="username" id="username">
-          </div>
-        </div>
-
-        <div class="field">
-          <label class="label" for="password">Password</label>
-          <div class="control">
-            <input class="input" type="password" placeholder="********" name="password" id="password">
-          </div>
-        </div>
-        <button class="button is-primary">Sign in</button>
-        {{ end }}
-    </form>
-    </div>
-  </section>
-
-  <footer class="footer has-text-grey has-background-light is-size-7">
-    <div class="content has-text-centered">
-    	{{ if eq .Footer "-" }}
-    	{{ else if eq .Footer ""}}
-    	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p>
-    	{{ else }}
-    	<p>{{.Footer}}</p>
-    	{{ end }}
-    </div>
-	</footer>
-
-  </body>
-</html>
-{{end}}`)
-	if err != nil {
-		logger.Fatalf("failed parsing template %s", err)
-	}
-
-	t, err = t.Parse(`{{define "error.html"}}
+	defaultErrorTemplate = `{{define "error.html"}}
 <!DOCTYPE html>
 <html lang="en" charset="utf-8">
 <head>
@@ -215,9 +111,147 @@ func getTemplates() *template.Template {
 
   </body>
 </html>
-{{end}}`)
+{{end}}`
+
+	defaultSignInTemplate = `{{define "sign_in.html"}}
+<!DOCTYPE html>
+<html lang="en" charset="utf-8">
+  <head>
+    <meta charset="utf-8">
+    <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
+    <title>Sign In</title>
+    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css">
+
+    <style>
+      body {
+        height: 100vh;
+      }
+      .sign-in-box {
+        max-width: 400px;
+        margin: 1.25rem auto;
+      }
+      footer a {
+        text-decoration: underline;
+      }
+    </style>
+
+    <script>
+      if (window.location.hash) {
+        (function() {
+          var inputs = document.getElementsByName('rd');
+          for (var i = 0; i < inputs.length; i++) {
+            // Add hash, but make sure it is only added once
+            var idx = inputs[i].value.indexOf('#');
+            if (idx >= 0) {
+              // Remove existing hash from URL
+              inputs[i].value = inputs[i].value.substr(0, idx);
+            }
+            inputs[i].value += window.location.hash;
+          }
+        })();
+      }
+    </script>
+  </head>
+  <body class="has-background-light">
+  <section class="section">
+    <div class="box block sign-in-box has-text-centered">
+      <form method="GET" action="{{.ProxyPrefix}}/start">
+        <input type="hidden" name="rd" value="{{.Redirect}}">
+          {{ if .SignInMessage }}
+          <p class="block">{{.SignInMessage}}</p>
+          {{ end}}
+          <button type="submit" class="button block is-primary">Sign in with {{.ProviderName}}</button>
+      </form>
+
+      {{ if .CustomLogin }}
+      <hr>
+
+      <form method="POST" action="{{.ProxyPrefix}}/sign_in" class="block">
+        <input type="hidden" name="rd" value="{{.Redirect}}">
+
+        <div class="field">
+          <label class="label" for="username">Username</label>
+          <div class="control">
+            <input class="input" type="email" placeholder="e.g. userx@example.com"  name="username" id="username">
+          </div>
+        </div>
+
+        <div class="field">
+          <label class="label" for="password">Password</label>
+          <div class="control">
+            <input class="input" type="password" placeholder="********" name="password" id="password">
+          </div>
+        </div>
+        <button class="button is-primary">Sign in</button>
+        {{ end }}
+    </form>
+    </div>
+  </section>
+
+  <footer class="footer has-text-grey has-background-light is-size-7">
+    <div class="content has-text-centered">
+    	{{ if eq .Footer "-" }}
+    	{{ else if eq .Footer ""}}
+    	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p>
+    	{{ else }}
+    	<p>{{.Footer}}</p>
+    	{{ end }}
+    </div>
+	</footer>
+
+  </body>
+</html>
+{{end}}`
+)
+
+// LoadTemplates adds the Sign In and Error templates from the custom template
+// directory, or uses the defaults if they do not exist or the custom directory
+// is not provided.
+func LoadTemplates(customDir string) (*template.Template, error) {
+	t := template.New("").Funcs(template.FuncMap{
+		"ToUpper": strings.ToUpper,
+		"ToLower": strings.ToLower,
+	})
+	var err error
+	t, err = addTemplate(t, customDir, signInTemplateName, defaultSignInTemplate)
 	if err != nil {
-		logger.Fatalf("failed parsing template %s", err)
+		return nil, fmt.Errorf("could not add Sign In template: %v", err)
 	}
-	return t
+	t, err = addTemplate(t, customDir, errorTemplateName, defaultErrorTemplate)
+	if err != nil {
+		return nil, fmt.Errorf("could not add Error template: %v", err)
+	}
+
+	return t, nil
+}
+
+// addTemplate will add the template from the custom directory if provided,
+// else it will add the default template.
+func addTemplate(t *template.Template, customDir, fileName, defaultTemplate string) (*template.Template, error) {
+	filePath := filepath.Join(customDir, fileName)
+	if customDir != "" && isFile(filePath) {
+		t, err := t.ParseFiles(filePath)
+		if err != nil {
+			return nil, fmt.Errorf("failed to parse template %s: %v", filePath, err)
+		}
+		return t, nil
+	}
+	t, err := t.Parse(defaultTemplate)
+	if err != nil {
+		// This should not happen.
+		// Default templates should be tested and so should never fail to parse.
+		logger.Panic("Could not parse defaultTemplate: ", err)
+	}
+	return t, nil
+}
+
+// isFile checks if the file exists and checks whether it is a regular file.
+// If either of these fail then it cannot be used as a template file.
+func isFile(fileName string) bool {
+	info, err := os.Stat(fileName)
+	if err != nil {
+		logger.Errorf("Could not load file %s: %v, will use default template", fileName, err)
+		return false
+	}
+	return info.Mode().IsRegular()
 }
diff --git a/pkg/app/templates_test.go b/pkg/app/templates_test.go
new file mode 100644
index 00000000..66f38b7f
--- /dev/null
+++ b/pkg/app/templates_test.go
@@ -0,0 +1,199 @@
+package app
+
+import (
+	"bytes"
+	"html/template"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+
+	. "github.com/onsi/ginkgo"
+	. "github.com/onsi/gomega"
+)
+
+var _ = Describe("Templates", func() {
+	var customDir string
+
+	BeforeEach(func() {
+		var err error
+		customDir, err = ioutil.TempDir("", "oauth2-proxy-templates-test")
+		Expect(err).ToNot(HaveOccurred())
+
+		templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}`
+		signInFile := filepath.Join(customDir, signInTemplateName)
+		Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0666)).To(Succeed())
+		errorFile := filepath.Join(customDir, errorTemplateName)
+		Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0666)).To(Succeed())
+	})
+
+	AfterEach(func() {
+		Expect(os.RemoveAll(customDir)).To(Succeed())
+	})
+
+	Context("LoadTemplates", func() {
+		var data interface{}
+		var t *template.Template
+
+		BeforeEach(func() {
+			data = struct {
+				// For default templates
+				ProxyPrefix string
+				Redirect    string
+				Footer      string
+
+				// For default sign_in template
+				SignInMessage string
+				ProviderName  string
+				CustomLogin   bool
+
+				// For default error template
+				StatusCode int
+				Title      string
+				Message    string
+
+				// For custom templates
+				TestString string
+			}{
+				ProxyPrefix: "<proxy-prefix>",
+				Redirect:    "<redirect>",
+				Footer:      "<footer>",
+
+				SignInMessage: "<sign-in-message>",
+				ProviderName:  "<provider-name>",
+				CustomLogin:   false,
+
+				StatusCode: 404,
+				Title:      "<title>",
+				Message:    "<message>",
+
+				TestString: "Testing",
+			}
+		})
+
+		Context("With no custom directory", func() {
+			BeforeEach(func() {
+				var err error
+				t, err = LoadTemplates("")
+				Expect(err).ToNot(HaveOccurred())
+			})
+
+			It("Use the default sign_in page", func() {
+				buf := bytes.NewBuffer([]byte{})
+				Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed())
+				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>"))
+			})
+
+			It("Use the default error page", func() {
+				buf := bytes.NewBuffer([]byte{})
+				Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed())
+				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>"))
+			})
+		})
+
+		Context("With a custom directory", func() {
+			Context("With both templates", func() {
+				BeforeEach(func() {
+					var err error
+					t, err = LoadTemplates(customDir)
+					Expect(err).ToNot(HaveOccurred())
+				})
+
+				It("Use the custom sign_in page", func() {
+					buf := bytes.NewBuffer([]byte{})
+					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed())
+					Expect(buf.String()).To(Equal("Testing testing TESTING"))
+				})
+
+				It("Use the custom error page", func() {
+					buf := bytes.NewBuffer([]byte{})
+					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed())
+					Expect(buf.String()).To(Equal("Testing testing TESTING"))
+				})
+			})
+
+			Context("With no error template", func() {
+				BeforeEach(func() {
+					Expect(os.Remove(filepath.Join(customDir, errorTemplateName))).To(Succeed())
+
+					var err error
+					t, err = LoadTemplates(customDir)
+					Expect(err).ToNot(HaveOccurred())
+				})
+
+				It("Use the custom sign_in page", func() {
+					buf := bytes.NewBuffer([]byte{})
+					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed())
+					Expect(buf.String()).To(Equal("Testing testing TESTING"))
+				})
+
+				It("Use the default error page", func() {
+					buf := bytes.NewBuffer([]byte{})
+					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed())
+					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>"))
+				})
+			})
+
+			Context("With no sign_in template", func() {
+				BeforeEach(func() {
+					Expect(os.Remove(filepath.Join(customDir, signInTemplateName))).To(Succeed())
+
+					var err error
+					t, err = LoadTemplates(customDir)
+					Expect(err).ToNot(HaveOccurred())
+				})
+
+				It("Use the default sign_in page", func() {
+					buf := bytes.NewBuffer([]byte{})
+					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed())
+					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>"))
+				})
+
+				It("Use the custom error page", func() {
+					buf := bytes.NewBuffer([]byte{})
+					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed())
+					Expect(buf.String()).To(Equal("Testing testing TESTING"))
+				})
+			})
+
+			Context("With an invalid sign_in template", func() {
+				BeforeEach(func() {
+					signInFile := filepath.Join(customDir, signInTemplateName)
+					Expect(ioutil.WriteFile(signInFile, []byte("{{"), 0666))
+				})
+
+				It("Should return an error when loading templates", func() {
+					t, err := LoadTemplates(customDir)
+					Expect(err).To(MatchError(HavePrefix("could not add Sign In template:")))
+					Expect(t).To(BeNil())
+				})
+			})
+
+			Context("With an invalid error template", func() {
+				BeforeEach(func() {
+					errorFile := filepath.Join(customDir, errorTemplateName)
+					Expect(ioutil.WriteFile(errorFile, []byte("{{"), 0666))
+				})
+
+				It("Should return an error when loading templates", func() {
+					t, err := LoadTemplates(customDir)
+					Expect(err).To(MatchError(HavePrefix("could not add Error template:")))
+					Expect(t).To(BeNil())
+				})
+			})
+		})
+	})
+
+	Context("isFile", func() {
+		It("with a valid file", func() {
+			Expect(isFile(filepath.Join(customDir, signInTemplateName))).To(BeTrue())
+		})
+
+		It("with a directory", func() {
+			Expect(isFile(customDir)).To(BeFalse())
+		})
+
+		It("with an invalid file", func() {
+			Expect(isFile(filepath.Join(customDir, "does_not_exist.html"))).To(BeFalse())
+		})
+	})
+})
diff --git a/templates_test.go b/templates_test.go
deleted file mode 100644
index 63757a0c..00000000
--- a/templates_test.go
+++ /dev/null
@@ -1,62 +0,0 @@
-package main
-
-import (
-	"bytes"
-	"io/ioutil"
-	"log"
-	"os"
-	"path/filepath"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestLoadTemplates(t *testing.T) {
-	data := struct {
-		TestString string
-	}{
-		TestString: "Testing",
-	}
-
-	templates := loadTemplates("")
-	assert.NotEqual(t, templates, nil)
-
-	var defaultSignin bytes.Buffer
-	templates.ExecuteTemplate(&defaultSignin, "sign_in.html", data)
-	assert.Equal(t, "\n<!DOCTYPE html>", defaultSignin.String()[0:16])
-
-	var defaultError bytes.Buffer
-	templates.ExecuteTemplate(&defaultError, "error.html", data)
-	assert.Equal(t, "\n<!DOCTYPE html>", defaultError.String()[0:16])
-
-	dir, err := ioutil.TempDir("", "templatetest")
-	if err != nil {
-		log.Fatal(err)
-	}
-	defer os.RemoveAll(dir)
-
-	templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}`
-	signInFile := filepath.Join(dir, "sign_in.html")
-	if err := ioutil.WriteFile(signInFile, []byte(templateHTML), 0666); err != nil {
-		log.Fatal(err)
-	}
-	errorFile := filepath.Join(dir, "error.html")
-	if err := ioutil.WriteFile(errorFile, []byte(templateHTML), 0666); err != nil {
-		log.Fatal(err)
-	}
-	templates = loadTemplates(dir)
-	assert.NotEqual(t, templates, nil)
-
-	var sitpl bytes.Buffer
-	templates.ExecuteTemplate(&sitpl, "sign_in.html", data)
-	assert.Equal(t, "Testing testing TESTING", sitpl.String())
-
-	var errtpl bytes.Buffer
-	templates.ExecuteTemplate(&errtpl, "error.html", data)
-	assert.Equal(t, "Testing testing TESTING", errtpl.String())
-}
-
-func TestTemplatesCompile(t *testing.T) {
-	templates := getTemplates()
-	assert.NotEqual(t, templates, nil)
-}