diff --git a/pkg/app/pagewriter/static_pages.go b/pkg/app/pagewriter/static_pages.go index a149793f..3e1bdcdc 100644 --- a/pkg/app/pagewriter/static_pages.go +++ b/pkg/app/pagewriter/static_pages.go @@ -21,7 +21,7 @@ var defaultRobotsTxt []byte // staticPageWriter is used to write static pages. type staticPageWriter struct { - pages map[string][]byte + pageGetter *pageGetter errorPageWriter *errorPageWriter } @@ -32,13 +32,7 @@ func (s *staticPageWriter) WriteRobotsTxt(rw http.ResponseWriter, req *http.Requ // writePage writes the content of the page to the response writer. func (s *staticPageWriter) writePage(rw http.ResponseWriter, req *http.Request, pageName string) { - content, ok := s.pages[pageName] - if !ok { - // If the page isn't regiested, something went wrong and there is a bug. - // Tests should make sure this code path is never hit. - panic(fmt.Sprintf("Static page %q not found", pageName)) - } - _, err := rw.Write(content) + _, err := rw.Write(s.pageGetter.getPage(pageName)) if err != nil { logger.Printf("Error writing %q: %v", pageName, err) scope := middlewareapi.GetRequestScope(req) @@ -52,13 +46,13 @@ func (s *staticPageWriter) writePage(rw http.ResponseWriter, req *http.Request, } func newStaticPageWriter(customDir string, errorWriter *errorPageWriter) (*staticPageWriter, error) { - pages, err := loadStaticPages(customDir) + pageGetter, err := loadStaticPages(customDir) if err != nil { return nil, fmt.Errorf("could not load static pages: %v", err) } return &staticPageWriter{ - pages: pages, + pageGetter: pageGetter, errorPageWriter: errorWriter, }, nil } @@ -68,31 +62,57 @@ func newStaticPageWriter(customDir string, errorWriter *errorPageWriter) (*stati // instead. // Statis files include: // - robots.txt -func loadStaticPages(customDir string) (map[string][]byte, error) { - pages := make(map[string][]byte) +func loadStaticPages(customDir string) (*pageGetter, error) { + pages := newPageGetter(customDir) - if err := addStaticPage(pages, customDir, robotsTxtName, defaultRobotsTxt); err != nil { + if err := pages.addPage(robotsTxtName, defaultRobotsTxt); err != nil { return nil, fmt.Errorf("could not add robots.txt: %v", err) } return pages, nil } -// addStaticPage tries to load the named file from the custom directory. -// If no custom directory is provided, the default content is used instead. -func addStaticPage(pages map[string][]byte, customDir, fileName string, defaultContent []byte) error { - filePath := filepath.Join(customDir, fileName) - if customDir != "" && isFile(filePath) { +// pageGetter is used to load and read page content for static pages. +type pageGetter struct { + pages map[string][]byte + dir string +} + +// newPageGetter creates a new page getter for the custom directory. +func newPageGetter(customDir string) *pageGetter { + return &pageGetter{ + pages: make(map[string][]byte), + dir: customDir, + } +} + +// addPage loads a new page into the pageGetter. +// If the given file name does not exist in the custom directory, the default +// content will be used instead. +func (p *pageGetter) addPage(fileName string, defaultContent []byte) error { + filePath := filepath.Join(p.dir, fileName) + if p.dir != "" && isFile(filePath) { content, err := os.ReadFile(filePath) if err != nil { return fmt.Errorf("could not read file: %v", err) } - pages[fileName] = content + p.pages[fileName] = content return nil } // No custom content defined, use the default. - pages[fileName] = defaultContent + p.pages[fileName] = defaultContent return nil } + +// getPage returns the page content for a given page. +func (p *pageGetter) getPage(name string) []byte { + content, ok := p.pages[name] + if !ok { + // If the page isn't registered, something went wrong and there is a bug. + // Tests should make sure this code path is never hit. + panic(fmt.Sprintf("Static page %q not found", name)) + } + return content +} diff --git a/pkg/app/pagewriter/static_pages_test.go b/pkg/app/pagewriter/static_pages_test.go index 1d702674..f52451ba 100644 --- a/pkg/app/pagewriter/static_pages_test.go +++ b/pkg/app/pagewriter/static_pages_test.go @@ -16,7 +16,7 @@ import ( var _ = Describe("Static Pages", func() { var customDir string - const customRobots = "I AM A ROBOT!!!" + const customRobots = "User-agent: *\nAllow: /\n" var errorPage *errorPageWriter var request *http.Request @@ -110,8 +110,8 @@ var _ = Describe("Static Pages", func() { It("Loads the custom content", func() { pages, err := loadStaticPages(customDir) Expect(err).ToNot(HaveOccurred()) - Expect(pages).To(HaveLen(1)) - Expect(pages).To(HaveKeyWithValue(robotsTxtName, []byte(customRobots))) + Expect(pages.pages).To(HaveLen(1)) + Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(customRobots)) }) }) @@ -122,8 +122,8 @@ var _ = Describe("Static Pages", func() { pages, err := loadStaticPages(customDir) Expect(err).ToNot(HaveOccurred()) - Expect(pages).To(HaveLen(1)) - Expect(pages).To(HaveKeyWithValue(robotsTxtName, defaultRobotsTxt)) + Expect(pages.pages).To(HaveLen(1)) + Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(defaultRobotsTxt)) }) }) }) @@ -132,8 +132,8 @@ var _ = Describe("Static Pages", func() { It("Loads the default content", func() { pages, err := loadStaticPages("") Expect(err).ToNot(HaveOccurred()) - Expect(pages).To(HaveLen(1)) - Expect(pages).To(HaveKeyWithValue(robotsTxtName, defaultRobotsTxt)) + Expect(pages.pages).To(HaveLen(1)) + Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(defaultRobotsTxt)) }) }) })