1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-21 21:47:11 +02:00

Create seperate page getter

This commit is contained in:
Joel Speed 2021-03-20 14:58:47 +00:00
parent f3bd61b371
commit 6c6fd4f862
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
2 changed files with 47 additions and 27 deletions

View File

@ -21,7 +21,7 @@ var defaultRobotsTxt []byte
// staticPageWriter is used to write static pages. // staticPageWriter is used to write static pages.
type staticPageWriter struct { type staticPageWriter struct {
pages map[string][]byte pageGetter *pageGetter
errorPageWriter *errorPageWriter errorPageWriter *errorPageWriter
} }
@ -32,13 +32,7 @@ func (s *staticPageWriter) WriteRobotsTxt(rw http.ResponseWriter, req *http.Requ
// writePage writes the content of the page to the response writer. // writePage writes the content of the page to the response writer.
func (s *staticPageWriter) writePage(rw http.ResponseWriter, req *http.Request, pageName string) { func (s *staticPageWriter) writePage(rw http.ResponseWriter, req *http.Request, pageName string) {
content, ok := s.pages[pageName] _, err := rw.Write(s.pageGetter.getPage(pageName))
if !ok {
// If the page isn't regiested, something went wrong and there is a bug.
// Tests should make sure this code path is never hit.
panic(fmt.Sprintf("Static page %q not found", pageName))
}
_, err := rw.Write(content)
if err != nil { if err != nil {
logger.Printf("Error writing %q: %v", pageName, err) logger.Printf("Error writing %q: %v", pageName, err)
scope := middlewareapi.GetRequestScope(req) scope := middlewareapi.GetRequestScope(req)
@ -52,13 +46,13 @@ func (s *staticPageWriter) writePage(rw http.ResponseWriter, req *http.Request,
} }
func newStaticPageWriter(customDir string, errorWriter *errorPageWriter) (*staticPageWriter, error) { func newStaticPageWriter(customDir string, errorWriter *errorPageWriter) (*staticPageWriter, error) {
pages, err := loadStaticPages(customDir) pageGetter, err := loadStaticPages(customDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not load static pages: %v", err) return nil, fmt.Errorf("could not load static pages: %v", err)
} }
return &staticPageWriter{ return &staticPageWriter{
pages: pages, pageGetter: pageGetter,
errorPageWriter: errorWriter, errorPageWriter: errorWriter,
}, nil }, nil
} }
@ -68,31 +62,57 @@ func newStaticPageWriter(customDir string, errorWriter *errorPageWriter) (*stati
// instead. // instead.
// Statis files include: // Statis files include:
// - robots.txt // - robots.txt
func loadStaticPages(customDir string) (map[string][]byte, error) { func loadStaticPages(customDir string) (*pageGetter, error) {
pages := make(map[string][]byte) pages := newPageGetter(customDir)
if err := addStaticPage(pages, customDir, robotsTxtName, defaultRobotsTxt); err != nil { if err := pages.addPage(robotsTxtName, defaultRobotsTxt); err != nil {
return nil, fmt.Errorf("could not add robots.txt: %v", err) return nil, fmt.Errorf("could not add robots.txt: %v", err)
} }
return pages, nil return pages, nil
} }
// addStaticPage tries to load the named file from the custom directory. // pageGetter is used to load and read page content for static pages.
// If no custom directory is provided, the default content is used instead. type pageGetter struct {
func addStaticPage(pages map[string][]byte, customDir, fileName string, defaultContent []byte) error { pages map[string][]byte
filePath := filepath.Join(customDir, fileName) dir string
if customDir != "" && isFile(filePath) { }
// newPageGetter creates a new page getter for the custom directory.
func newPageGetter(customDir string) *pageGetter {
return &pageGetter{
pages: make(map[string][]byte),
dir: customDir,
}
}
// addPage loads a new page into the pageGetter.
// If the given file name does not exist in the custom directory, the default
// content will be used instead.
func (p *pageGetter) addPage(fileName string, defaultContent []byte) error {
filePath := filepath.Join(p.dir, fileName)
if p.dir != "" && isFile(filePath) {
content, err := os.ReadFile(filePath) content, err := os.ReadFile(filePath)
if err != nil { if err != nil {
return fmt.Errorf("could not read file: %v", err) return fmt.Errorf("could not read file: %v", err)
} }
pages[fileName] = content p.pages[fileName] = content
return nil return nil
} }
// No custom content defined, use the default. // No custom content defined, use the default.
pages[fileName] = defaultContent p.pages[fileName] = defaultContent
return nil return nil
} }
// getPage returns the page content for a given page.
func (p *pageGetter) getPage(name string) []byte {
content, ok := p.pages[name]
if !ok {
// If the page isn't registered, something went wrong and there is a bug.
// Tests should make sure this code path is never hit.
panic(fmt.Sprintf("Static page %q not found", name))
}
return content
}

View File

@ -16,7 +16,7 @@ import (
var _ = Describe("Static Pages", func() { var _ = Describe("Static Pages", func() {
var customDir string var customDir string
const customRobots = "I AM A ROBOT!!!" const customRobots = "User-agent: *\nAllow: /\n"
var errorPage *errorPageWriter var errorPage *errorPageWriter
var request *http.Request var request *http.Request
@ -110,8 +110,8 @@ var _ = Describe("Static Pages", func() {
It("Loads the custom content", func() { It("Loads the custom content", func() {
pages, err := loadStaticPages(customDir) pages, err := loadStaticPages(customDir)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(pages).To(HaveLen(1)) Expect(pages.pages).To(HaveLen(1))
Expect(pages).To(HaveKeyWithValue(robotsTxtName, []byte(customRobots))) Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(customRobots))
}) })
}) })
@ -122,8 +122,8 @@ var _ = Describe("Static Pages", func() {
pages, err := loadStaticPages(customDir) pages, err := loadStaticPages(customDir)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(pages).To(HaveLen(1)) Expect(pages.pages).To(HaveLen(1))
Expect(pages).To(HaveKeyWithValue(robotsTxtName, defaultRobotsTxt)) Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(defaultRobotsTxt))
}) })
}) })
}) })
@ -132,8 +132,8 @@ var _ = Describe("Static Pages", func() {
It("Loads the default content", func() { It("Loads the default content", func() {
pages, err := loadStaticPages("") pages, err := loadStaticPages("")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(pages).To(HaveLen(1)) Expect(pages.pages).To(HaveLen(1))
Expect(pages).To(HaveKeyWithValue(robotsTxtName, defaultRobotsTxt)) Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(defaultRobotsTxt))
}) })
}) })
}) })