You've already forked oauth2-proxy
							
							
				mirror of
				https://github.com/oauth2-proxy/oauth2-proxy.git
				synced 2025-10-30 23:47:52 +02:00 
			
		
		
		
	Move template loading to app package
This commit is contained in:
		| @@ -18,6 +18,7 @@ import ( | |||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||||
| @@ -116,7 +117,10 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||||||
| 		return nil, fmt.Errorf("error initialising session store: %v", err) | 		return nil, fmt.Errorf("error initialising session store: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	templates := loadTemplates(opts.Templates.Path) | 	templates, err := app.LoadTemplates(opts.Templates.Path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error loading templates: %v", err) | ||||||
|  | 	} | ||||||
| 	proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) | 	proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) | ||||||
| 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) | 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								pkg/app/app_suite_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								pkg/app/app_suite_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | package app | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestOptionsSuite(t *testing.T) { | ||||||
|  | 	logger.SetOutput(GinkgoWriter) | ||||||
|  | 	logger.SetErrOutput(GinkgoWriter) | ||||||
|  |  | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "App Suite") | ||||||
|  | } | ||||||
| @@ -1,124 +1,20 @@ | |||||||
| package main | package app | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"html/template" | 	"html/template" | ||||||
| 	"path" | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func loadTemplates(dir string) *template.Template { | const ( | ||||||
| 	if dir == "" { | 	errorTemplateName  = "error.html" | ||||||
| 		return getTemplates() | 	signInTemplateName = "sign_in.html" | ||||||
| 	} |  | ||||||
| 	logger.Printf("using custom template directory %q", dir) |  | ||||||
| 	funcMap := template.FuncMap{ |  | ||||||
| 		"ToUpper": strings.ToUpper, |  | ||||||
| 		"ToLower": strings.ToLower, |  | ||||||
| 	} |  | ||||||
| 	t, err := template.New("").Funcs(funcMap).ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("failed parsing template %s", err) |  | ||||||
| 	} |  | ||||||
| 	return t |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| func getTemplates() *template.Template { | 	defaultErrorTemplate = `{{define "error.html"}} | ||||||
| 	t, err := template.New("foo").Parse(`{{define "sign_in.html"}} |  | ||||||
| <!DOCTYPE html> |  | ||||||
| <html lang="en" charset="utf-8"> |  | ||||||
|   <head> |  | ||||||
|     <meta charset="utf-8"> |  | ||||||
|     <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> |  | ||||||
|     <title>Sign In</title> |  | ||||||
|     <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css"> |  | ||||||
| 
 |  | ||||||
|     <style> |  | ||||||
|       body { |  | ||||||
|         height: 100vh; |  | ||||||
|       } |  | ||||||
|       .sign-in-box { |  | ||||||
|         max-width: 400px; |  | ||||||
|         margin: 1.25rem auto; |  | ||||||
|       } |  | ||||||
|       footer a { |  | ||||||
|         text-decoration: underline; |  | ||||||
|       } |  | ||||||
|     </style> |  | ||||||
| 
 |  | ||||||
|     <script> |  | ||||||
|       if (window.location.hash) { |  | ||||||
|         (function() { |  | ||||||
|           var inputs = document.getElementsByName('rd'); |  | ||||||
|           for (var i = 0; i < inputs.length; i++) { |  | ||||||
|             // Add hash, but make sure it is only added once |  | ||||||
|             var idx = inputs[i].value.indexOf('#'); |  | ||||||
|             if (idx >= 0) { |  | ||||||
|               // Remove existing hash from URL |  | ||||||
|               inputs[i].value = inputs[i].value.substr(0, idx); |  | ||||||
|             } |  | ||||||
|             inputs[i].value += window.location.hash; |  | ||||||
|           } |  | ||||||
|         })(); |  | ||||||
|       } |  | ||||||
|     </script> |  | ||||||
|   </head> |  | ||||||
|   <body class="has-background-light"> |  | ||||||
|   <section class="section"> |  | ||||||
|     <div class="box block sign-in-box has-text-centered"> |  | ||||||
|       <form method="GET" action="{{.ProxyPrefix}}/start"> |  | ||||||
|         <input type="hidden" name="rd" value="{{.Redirect}}"> |  | ||||||
|           {{ if .SignInMessage }} |  | ||||||
|           <p class="block">{{.SignInMessage}}</p> |  | ||||||
|           {{ end}} |  | ||||||
|           <button type="submit" class="button block is-primary">Sign in with {{.ProviderName}}</button> |  | ||||||
|       </form> |  | ||||||
| 
 |  | ||||||
|       {{ if .CustomLogin }} |  | ||||||
|       <hr> |  | ||||||
| 
 |  | ||||||
|       <form method="POST" action="{{.ProxyPrefix}}/sign_in" class="block"> |  | ||||||
|         <input type="hidden" name="rd" value="{{.Redirect}}"> |  | ||||||
| 
 |  | ||||||
|         <div class="field"> |  | ||||||
|           <label class="label" for="username">Username</label> |  | ||||||
|           <div class="control"> |  | ||||||
|             <input class="input" type="email" placeholder="e.g. userx@example.com"  name="username" id="username"> |  | ||||||
|           </div> |  | ||||||
|         </div> |  | ||||||
| 
 |  | ||||||
|         <div class="field"> |  | ||||||
|           <label class="label" for="password">Password</label> |  | ||||||
|           <div class="control"> |  | ||||||
|             <input class="input" type="password" placeholder="********" name="password" id="password"> |  | ||||||
|           </div> |  | ||||||
|         </div> |  | ||||||
|         <button class="button is-primary">Sign in</button> |  | ||||||
|         {{ end }} |  | ||||||
|     </form> |  | ||||||
|     </div> |  | ||||||
|   </section> |  | ||||||
| 
 |  | ||||||
|   <footer class="footer has-text-grey has-background-light is-size-7"> |  | ||||||
|     <div class="content has-text-centered"> |  | ||||||
|     	{{ if eq .Footer "-" }} |  | ||||||
|     	{{ else if eq .Footer ""}} |  | ||||||
|     	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p> |  | ||||||
|     	{{ else }} |  | ||||||
|     	<p>{{.Footer}}</p> |  | ||||||
|     	{{ end }} |  | ||||||
|     </div> |  | ||||||
| 	</footer> |  | ||||||
| 
 |  | ||||||
|   </body> |  | ||||||
| </html> |  | ||||||
| {{end}}`) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("failed parsing template %s", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	t, err = t.Parse(`{{define "error.html"}} |  | ||||||
| <!DOCTYPE html> | <!DOCTYPE html> | ||||||
| <html lang="en" charset="utf-8"> | <html lang="en" charset="utf-8"> | ||||||
| <head> | <head> | ||||||
| @@ -215,9 +111,147 @@ func getTemplates() *template.Template { | |||||||
| 
 | 
 | ||||||
|   </body> |   </body> | ||||||
| </html> | </html> | ||||||
| {{end}}`) | {{end}}` | ||||||
| 	if err != nil { | 
 | ||||||
| 		logger.Fatalf("failed parsing template %s", err) | 	defaultSignInTemplate = `{{define "sign_in.html"}} | ||||||
|  | <!DOCTYPE html> | ||||||
|  | <html lang="en" charset="utf-8"> | ||||||
|  |   <head> | ||||||
|  |     <meta charset="utf-8"> | ||||||
|  |     <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> | ||||||
|  |     <title>Sign In</title> | ||||||
|  |     <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css"> | ||||||
|  | 
 | ||||||
|  |     <style> | ||||||
|  |       body { | ||||||
|  |         height: 100vh; | ||||||
|       } |       } | ||||||
| 	return t |       .sign-in-box { | ||||||
|  |         max-width: 400px; | ||||||
|  |         margin: 1.25rem auto; | ||||||
|  |       } | ||||||
|  |       footer a { | ||||||
|  |         text-decoration: underline; | ||||||
|  |       } | ||||||
|  |     </style> | ||||||
|  | 
 | ||||||
|  |     <script> | ||||||
|  |       if (window.location.hash) { | ||||||
|  |         (function() { | ||||||
|  |           var inputs = document.getElementsByName('rd'); | ||||||
|  |           for (var i = 0; i < inputs.length; i++) { | ||||||
|  |             // Add hash, but make sure it is only added once | ||||||
|  |             var idx = inputs[i].value.indexOf('#'); | ||||||
|  |             if (idx >= 0) { | ||||||
|  |               // Remove existing hash from URL | ||||||
|  |               inputs[i].value = inputs[i].value.substr(0, idx); | ||||||
|  |             } | ||||||
|  |             inputs[i].value += window.location.hash; | ||||||
|  |           } | ||||||
|  |         })(); | ||||||
|  |       } | ||||||
|  |     </script> | ||||||
|  |   </head> | ||||||
|  |   <body class="has-background-light"> | ||||||
|  |   <section class="section"> | ||||||
|  |     <div class="box block sign-in-box has-text-centered"> | ||||||
|  |       <form method="GET" action="{{.ProxyPrefix}}/start"> | ||||||
|  |         <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  |           {{ if .SignInMessage }} | ||||||
|  |           <p class="block">{{.SignInMessage}}</p> | ||||||
|  |           {{ end}} | ||||||
|  |           <button type="submit" class="button block is-primary">Sign in with {{.ProviderName}}</button> | ||||||
|  |       </form> | ||||||
|  | 
 | ||||||
|  |       {{ if .CustomLogin }} | ||||||
|  |       <hr> | ||||||
|  | 
 | ||||||
|  |       <form method="POST" action="{{.ProxyPrefix}}/sign_in" class="block"> | ||||||
|  |         <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  | 
 | ||||||
|  |         <div class="field"> | ||||||
|  |           <label class="label" for="username">Username</label> | ||||||
|  |           <div class="control"> | ||||||
|  |             <input class="input" type="email" placeholder="e.g. userx@example.com"  name="username" id="username"> | ||||||
|  |           </div> | ||||||
|  |         </div> | ||||||
|  | 
 | ||||||
|  |         <div class="field"> | ||||||
|  |           <label class="label" for="password">Password</label> | ||||||
|  |           <div class="control"> | ||||||
|  |             <input class="input" type="password" placeholder="********" name="password" id="password"> | ||||||
|  |           </div> | ||||||
|  |         </div> | ||||||
|  |         <button class="button is-primary">Sign in</button> | ||||||
|  |         {{ end }} | ||||||
|  |     </form> | ||||||
|  |     </div> | ||||||
|  |   </section> | ||||||
|  | 
 | ||||||
|  |   <footer class="footer has-text-grey has-background-light is-size-7"> | ||||||
|  |     <div class="content has-text-centered"> | ||||||
|  |     	{{ if eq .Footer "-" }} | ||||||
|  |     	{{ else if eq .Footer ""}} | ||||||
|  |     	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p> | ||||||
|  |     	{{ else }} | ||||||
|  |     	<p>{{.Footer}}</p> | ||||||
|  |     	{{ end }} | ||||||
|  |     </div> | ||||||
|  | 	</footer> | ||||||
|  | 
 | ||||||
|  |   </body> | ||||||
|  | </html> | ||||||
|  | {{end}}` | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // LoadTemplates adds the Sign In and Error templates from the custom template | ||||||
|  | // directory, or uses the defaults if they do not exist or the custom directory | ||||||
|  | // is not provided. | ||||||
|  | func LoadTemplates(customDir string) (*template.Template, error) { | ||||||
|  | 	t := template.New("").Funcs(template.FuncMap{ | ||||||
|  | 		"ToUpper": strings.ToUpper, | ||||||
|  | 		"ToLower": strings.ToLower, | ||||||
|  | 	}) | ||||||
|  | 	var err error | ||||||
|  | 	t, err = addTemplate(t, customDir, signInTemplateName, defaultSignInTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not add Sign In template: %v", err) | ||||||
|  | 	} | ||||||
|  | 	t, err = addTemplate(t, customDir, errorTemplateName, defaultErrorTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not add Error template: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // addTemplate will add the template from the custom directory if provided, | ||||||
|  | // else it will add the default template. | ||||||
|  | func addTemplate(t *template.Template, customDir, fileName, defaultTemplate string) (*template.Template, error) { | ||||||
|  | 	filePath := filepath.Join(customDir, fileName) | ||||||
|  | 	if customDir != "" && isFile(filePath) { | ||||||
|  | 		t, err := t.ParseFiles(filePath) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("failed to parse template %s: %v", filePath, err) | ||||||
|  | 		} | ||||||
|  | 		return t, nil | ||||||
|  | 	} | ||||||
|  | 	t, err := t.Parse(defaultTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// This should not happen. | ||||||
|  | 		// Default templates should be tested and so should never fail to parse. | ||||||
|  | 		logger.Panic("Could not parse defaultTemplate: ", err) | ||||||
|  | 	} | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // isFile checks if the file exists and checks whether it is a regular file. | ||||||
|  | // If either of these fail then it cannot be used as a template file. | ||||||
|  | func isFile(fileName string) bool { | ||||||
|  | 	info, err := os.Stat(fileName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Could not load file %s: %v, will use default template", fileName, err) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return info.Mode().IsRegular() | ||||||
| } | } | ||||||
							
								
								
									
										199
									
								
								pkg/app/templates_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								pkg/app/templates_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,199 @@ | |||||||
|  | package app | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"html/template" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  |  | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var _ = Describe("Templates", func() { | ||||||
|  | 	var customDir string | ||||||
|  |  | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		var err error | ||||||
|  | 		customDir, err = ioutil.TempDir("", "oauth2-proxy-templates-test") | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  |  | ||||||
|  | 		templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` | ||||||
|  | 		signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 		Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0666)).To(Succeed()) | ||||||
|  | 		errorFile := filepath.Join(customDir, errorTemplateName) | ||||||
|  | 		Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0666)).To(Succeed()) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	AfterEach(func() { | ||||||
|  | 		Expect(os.RemoveAll(customDir)).To(Succeed()) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	Context("LoadTemplates", func() { | ||||||
|  | 		var data interface{} | ||||||
|  | 		var t *template.Template | ||||||
|  |  | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			data = struct { | ||||||
|  | 				// For default templates | ||||||
|  | 				ProxyPrefix string | ||||||
|  | 				Redirect    string | ||||||
|  | 				Footer      string | ||||||
|  |  | ||||||
|  | 				// For default sign_in template | ||||||
|  | 				SignInMessage string | ||||||
|  | 				ProviderName  string | ||||||
|  | 				CustomLogin   bool | ||||||
|  |  | ||||||
|  | 				// For default error template | ||||||
|  | 				StatusCode int | ||||||
|  | 				Title      string | ||||||
|  | 				Message    string | ||||||
|  |  | ||||||
|  | 				// For custom templates | ||||||
|  | 				TestString string | ||||||
|  | 			}{ | ||||||
|  | 				ProxyPrefix: "<proxy-prefix>", | ||||||
|  | 				Redirect:    "<redirect>", | ||||||
|  | 				Footer:      "<footer>", | ||||||
|  |  | ||||||
|  | 				SignInMessage: "<sign-in-message>", | ||||||
|  | 				ProviderName:  "<provider-name>", | ||||||
|  | 				CustomLogin:   false, | ||||||
|  |  | ||||||
|  | 				StatusCode: 404, | ||||||
|  | 				Title:      "<title>", | ||||||
|  | 				Message:    "<message>", | ||||||
|  |  | ||||||
|  | 				TestString: "Testing", | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		Context("With no custom directory", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				t, err = LoadTemplates("") | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			It("Use the default sign_in page", func() { | ||||||
|  | 				buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 				Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			It("Use the default error page", func() { | ||||||
|  | 				buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 				Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		Context("With a custom directory", func() { | ||||||
|  | 			Context("With both templates", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					var err error | ||||||
|  | 					t, err = LoadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Use the custom sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Use the custom error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			Context("With no error template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					Expect(os.Remove(filepath.Join(customDir, errorTemplateName))).To(Succeed()) | ||||||
|  |  | ||||||
|  | 					var err error | ||||||
|  | 					t, err = LoadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Use the custom sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Use the default error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			Context("With no sign_in template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					Expect(os.Remove(filepath.Join(customDir, signInTemplateName))).To(Succeed()) | ||||||
|  |  | ||||||
|  | 					var err error | ||||||
|  | 					t, err = LoadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Use the default sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Use the custom error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			Context("With an invalid sign_in template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 					Expect(ioutil.WriteFile(signInFile, []byte("{{"), 0666)) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Should return an error when loading templates", func() { | ||||||
|  | 					t, err := LoadTemplates(customDir) | ||||||
|  | 					Expect(err).To(MatchError(HavePrefix("could not add Sign In template:"))) | ||||||
|  | 					Expect(t).To(BeNil()) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			Context("With an invalid error template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					errorFile := filepath.Join(customDir, errorTemplateName) | ||||||
|  | 					Expect(ioutil.WriteFile(errorFile, []byte("{{"), 0666)) | ||||||
|  | 				}) | ||||||
|  |  | ||||||
|  | 				It("Should return an error when loading templates", func() { | ||||||
|  | 					t, err := LoadTemplates(customDir) | ||||||
|  | 					Expect(err).To(MatchError(HavePrefix("could not add Error template:"))) | ||||||
|  | 					Expect(t).To(BeNil()) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	Context("isFile", func() { | ||||||
|  | 		It("with a valid file", func() { | ||||||
|  | 			Expect(isFile(filepath.Join(customDir, signInTemplateName))).To(BeTrue()) | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		It("with a directory", func() { | ||||||
|  | 			Expect(isFile(customDir)).To(BeFalse()) | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		It("with an invalid file", func() { | ||||||
|  | 			Expect(isFile(filepath.Join(customDir, "does_not_exist.html"))).To(BeFalse()) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
| @@ -1,62 +0,0 @@ | |||||||
| package main |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"io/ioutil" |  | ||||||
| 	"log" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/stretchr/testify/assert" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestLoadTemplates(t *testing.T) { |  | ||||||
| 	data := struct { |  | ||||||
| 		TestString string |  | ||||||
| 	}{ |  | ||||||
| 		TestString: "Testing", |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	templates := loadTemplates("") |  | ||||||
| 	assert.NotEqual(t, templates, nil) |  | ||||||
|  |  | ||||||
| 	var defaultSignin bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&defaultSignin, "sign_in.html", data) |  | ||||||
| 	assert.Equal(t, "\n<!DOCTYPE html>", defaultSignin.String()[0:16]) |  | ||||||
|  |  | ||||||
| 	var defaultError bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&defaultError, "error.html", data) |  | ||||||
| 	assert.Equal(t, "\n<!DOCTYPE html>", defaultError.String()[0:16]) |  | ||||||
|  |  | ||||||
| 	dir, err := ioutil.TempDir("", "templatetest") |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer os.RemoveAll(dir) |  | ||||||
|  |  | ||||||
| 	templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` |  | ||||||
| 	signInFile := filepath.Join(dir, "sign_in.html") |  | ||||||
| 	if err := ioutil.WriteFile(signInFile, []byte(templateHTML), 0666); err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	errorFile := filepath.Join(dir, "error.html") |  | ||||||
| 	if err := ioutil.WriteFile(errorFile, []byte(templateHTML), 0666); err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	templates = loadTemplates(dir) |  | ||||||
| 	assert.NotEqual(t, templates, nil) |  | ||||||
|  |  | ||||||
| 	var sitpl bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&sitpl, "sign_in.html", data) |  | ||||||
| 	assert.Equal(t, "Testing testing TESTING", sitpl.String()) |  | ||||||
|  |  | ||||||
| 	var errtpl bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&errtpl, "error.html", data) |  | ||||||
| 	assert.Equal(t, "Testing testing TESTING", errtpl.String()) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestTemplatesCompile(t *testing.T) { |  | ||||||
| 	templates := getTemplates() |  | ||||||
| 	assert.NotEqual(t, templates, nil) |  | ||||||
| } |  | ||||||
		Reference in New Issue
	
	Block a user