diff --git a/CHANGELOG.md b/CHANGELOG.md index bd635065..f1f3dde1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.0.1 +- [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) - [#1054](https://github.com/oauth2-proxy/oauth2-proxy/pull/1054) Update to Go 1.16 (@JoelSpeed) - [#1052](https://github.com/oauth2-proxy/oauth2-proxy/pull/1052) Update golangci-lint to latest version (v1.36.0) (@JoelSpeed) - [#1043](https://github.com/oauth2-proxy/oauth2-proxy/pull/1043) Refactor Sign In Page rendering and capture all page rendering code in pagewriter package (@JoelSpeed) diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index b9017977..66750004 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -40,6 +40,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--cookie-secure` | bool | set [secure (HTTPS only) cookie flag](https://owasp.org/www-community/controls/SecureFlag) | true | | `--cookie-samesite` | string | set SameSite cookie attribute (`"lax"`, `"strict"`, `"none"`, or `""`). | `""` | | `--custom-templates-dir` | string | path to custom html templates | | +| `--custom-sign-in-logo` | string | path to an custom image for the sign_in page logo. Use \"-\" to disable default logo. | | `--display-htpasswd-form` | bool | display username / password login form if an htpasswd file is provided | true | | `--email-domain` | string \| list | authenticate emails with the specified domain (may be given multiple times). Use `*` to authenticate any email | | | `--errors-to-info-log` | bool | redirects error-level logging to default log channel instead of stderr | | diff --git a/oauthproxy.go b/oauthproxy.go index 7bf524aa..82f89e6b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -123,6 +123,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{ TemplatesPath: opts.Templates.Path, + CustomLogo: opts.Templates.CustomLogo, ProxyPrefix: opts.ProxyPrefix, Footer: opts.Templates.Footer, Version: VERSION, diff --git a/pkg/apis/options/app.go b/pkg/apis/options/app.go index 76c4f84a..4d6353b8 100644 --- a/pkg/apis/options/app.go +++ b/pkg/apis/options/app.go @@ -11,6 +11,12 @@ type Templates struct { // If either file is missing, the default will be used instead. Path string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` + // CustomLogo is the path to a logo that should replace the default logo + // on the sign_in page template. + // Supported formats are .svg, .png, .jpg and .jpeg. + // To disable the default logo, set this value to "-". + CustomLogo string `flag:"custom-sign-in-logo" cfg:"custom_sign_in_logo"` + // Banner overides the default sign_in page banner text. If unspecified, // the message will give users a list of allowed email domains. Banner string `flag:"banner" cfg:"banner"` @@ -34,6 +40,7 @@ func templatesFlagSet() *pflag.FlagSet { flagSet := pflag.NewFlagSet("templates", pflag.ExitOnError) flagSet.String("custom-templates-dir", "", "path to custom html templates") + flagSet.String("custom-sign-in-logo", "", "path to an custom image for the sign_in page logo. Use \"-\" to disable default logo.") flagSet.String("banner", "", "custom banner string. Use \"-\" to disable default banner.") flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") diff --git a/pkg/app/pagewriter/default_logo.svg b/pkg/app/pagewriter/default_logo.svg new file mode 100644 index 00000000..37851c2a --- /dev/null +++ b/pkg/app/pagewriter/default_logo.svg @@ -0,0 +1 @@ +OAuth2_Proxy_logo_v3 diff --git a/pkg/app/pagewriter/pagewriter.go b/pkg/app/pagewriter/pagewriter.go index fdc8ec30..ad79aee2 100644 --- a/pkg/app/pagewriter/pagewriter.go +++ b/pkg/app/pagewriter/pagewriter.go @@ -49,6 +49,10 @@ type Opts struct { // SignInMessage is the messge displayed above the login button. SignInMessage string + + // CustomLogo is the path to a logo to be displayed on the sign in page. + // The logo can be either PNG, JPG/JPEG or SVG. + CustomLogo string } // NewWriter constructs a Writer from the options given to allow @@ -59,6 +63,11 @@ func NewWriter(opts Opts) (Writer, error) { return nil, fmt.Errorf("error loading templates: %v", err) } + logoData, err := loadCustomLogo(opts.CustomLogo) + if err != nil { + return nil, fmt.Errorf("error loading logo: %v", err) + } + errorPage := &errorPageWriter{ template: templates.Lookup("error.html"), proxyPrefix: opts.ProxyPrefix, @@ -76,6 +85,7 @@ func NewWriter(opts Opts) (Writer, error) { footer: opts.Footer, version: opts.Version, displayLoginForm: opts.DisplayLoginForm, + logoData: logoData, } return &pageWriter{ diff --git a/pkg/app/pagewriter/sign_in.html b/pkg/app/pagewriter/sign_in.html index e148c2b5..652b674e 100644 --- a/pkg/app/pagewriter/sign_in.html +++ b/pkg/app/pagewriter/sign_in.html @@ -15,6 +15,9 @@ max-width: 400px; margin: 1.25rem auto; } + .logo-box { + margin: 1.5rem 3rem; + } footer a { text-decoration: underline; } @@ -40,6 +43,12 @@
+ {{ if .LogoData }} +
+ {{.LogoData}} +
+ {{ end }} +
{{ if .SignInMessage }} diff --git a/pkg/app/pagewriter/sign_in_page.go b/pkg/app/pagewriter/sign_in_page.go index df96126c..43f84576 100644 --- a/pkg/app/pagewriter/sign_in_page.go +++ b/pkg/app/pagewriter/sign_in_page.go @@ -1,12 +1,24 @@ package pagewriter import ( + // Import embed to allow importing default logo + _ "embed" + + "encoding/base64" + "fmt" + "os" + "path/filepath" + "strings" + "html/template" "net/http" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) +//go:embed default_logo.svg +var defaultLogoData string + // signInPageWriter is used to render sign-in pages. type signInPageWriter struct { // Template is the sign-in page HTML template. @@ -33,6 +45,10 @@ type signInPageWriter struct { // DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page. displayLoginForm bool + + // LogoData is the logo to render in the template. + // This should contain valid html. + logoData string } // WriteSignInPage writes the sign-in page to the given response writer. @@ -48,6 +64,7 @@ func (s *signInPageWriter) WriteSignInPage(rw http.ResponseWriter, redirectURL s Version string ProxyPrefix string Footer template.HTML + LogoData template.HTML }{ ProviderName: s.providerName, SignInMessage: template.HTML(s.signInMessage), @@ -56,6 +73,7 @@ func (s *signInPageWriter) WriteSignInPage(rw http.ResponseWriter, redirectURL s Version: s.version, ProxyPrefix: s.proxyPrefix, Footer: template.HTML(s.footer), + LogoData: template.HTML(s.logoData), } err := s.template.Execute(rw, t) @@ -64,3 +82,42 @@ func (s *signInPageWriter) WriteSignInPage(rw http.ResponseWriter, redirectURL s s.errorPageWriter.WriteErrorPage(rw, http.StatusInternalServerError, redirectURL, err.Error()) } } + +// loadCustomLogo loads the logo file from the path and encodes it to an HTML +// entity. If no custom logo is provided, the OAuth2 Proxy Icon is used instead. +func loadCustomLogo(logoPath string) (string, error) { + if logoPath == "" { + // The default logo is an SVG so this will be valid to just return. + return defaultLogoData, nil + } + + if logoPath == "-" { + // Return no logo when the custom logo is set to `-`. + // This disables the logo rendering. + return "", nil + } + + logoData, err := os.ReadFile(logoPath) + if err != nil { + return "", fmt.Errorf("could not read logo file: %v", err) + } + + extension := strings.ToLower(filepath.Ext(logoPath)) + switch extension { + case ".svg": + return string(logoData), nil + case ".jpg", ".jpeg": + return encodeImg(logoData, "jpeg"), nil + case ".png": + return encodeImg(logoData, "png"), nil + default: + return "", fmt.Errorf("unknown extension: %q, supported extensions are .svg, .jpg, .jpeg and .png", extension) + } +} + +// encodeImg takes the raw image data and converts it to an HTML Img tag with +// a base64 data source. +func encodeImg(data []byte, format string) string { + b64Data := base64.StdEncoding.EncodeToString(data) + return fmt.Sprintf("\"Logo\"", format, b64Data) +} diff --git a/pkg/app/pagewriter/sign_in_page_test.go b/pkg/app/pagewriter/sign_in_page_test.go index eefcbe14..d32c93be 100644 --- a/pkg/app/pagewriter/sign_in_page_test.go +++ b/pkg/app/pagewriter/sign_in_page_test.go @@ -1,61 +1,154 @@ package pagewriter import ( + "errors" + "fmt" "html/template" "io/ioutil" "net/http/httptest" + "os" + "path/filepath" + "strings" . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) -var _ = Describe("SignIn Page Writer", func() { - var signInPage *signInPageWriter +var _ = Describe("SignIn Page", func() { - BeforeEach(func() { - errorTmpl, err := template.New("").Parse("{{.Title}}") - Expect(err).ToNot(HaveOccurred()) - errorPage := &errorPageWriter{ - template: errorTmpl, - } + Context("SignIn Page Writer", func() { + var signInPage *signInPageWriter - tmpl, err := template.New("").Parse("{{.ProxyPrefix}} {{.ProviderName}} {{.SignInMessage}} {{.Footer}} {{.Version}} {{.Redirect}} {{.CustomLogin}}") - Expect(err).ToNot(HaveOccurred()) + BeforeEach(func() { + errorTmpl, err := template.New("").Parse("{{.Title}}") + Expect(err).ToNot(HaveOccurred()) + errorPage := &errorPageWriter{ + template: errorTmpl, + } - signInPage = &signInPageWriter{ - template: tmpl, - errorPageWriter: errorPage, - proxyPrefix: "/prefix/", - providerName: "My Provider", - signInMessage: "Sign In Here", - footer: "Custom Footer Text", - version: "v0.0.0-test", - displayLoginForm: true, - } + tmpl, err := template.New("").Parse("{{.ProxyPrefix}} {{.ProviderName}} {{.SignInMessage}} {{.Footer}} {{.Version}} {{.Redirect}} {{.CustomLogin}} {{.LogoData}}") + Expect(err).ToNot(HaveOccurred()) + + signInPage = &signInPageWriter{ + template: tmpl, + errorPageWriter: errorPage, + proxyPrefix: "/prefix/", + providerName: "My Provider", + signInMessage: "Sign In Here", + footer: "Custom Footer Text", + version: "v0.0.0-test", + displayLoginForm: true, + logoData: "Logo Data", + } + }) + + Context("WriteSignInPage", func() { + It("Writes the template to the response writer", func() { + recorder := httptest.NewRecorder() + signInPage.WriteSignInPage(recorder, "/redirect") + + body, err := ioutil.ReadAll(recorder.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("/prefix/ My Provider Sign In Here Custom Footer Text v0.0.0-test /redirect true Logo Data")) + }) + + It("Writes an error if the template can't be rendered", func() { + // Overwrite the template with something bad + tmpl, err := template.New("").Parse("{{.Unknown}}") + Expect(err).ToNot(HaveOccurred()) + signInPage.template = tmpl + + recorder := httptest.NewRecorder() + signInPage.WriteSignInPage(recorder, "/redirect") + + body, err := ioutil.ReadAll(recorder.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Internal Server Error")) + }) + }) }) - Context("WriteSignInPage", func() { - It("Writes the template to the response writer", func() { - recorder := httptest.NewRecorder() - signInPage.WriteSignInPage(recorder, "/redirect") + Context("loadCustomLogo", func() { + var customDir string - body, err := ioutil.ReadAll(recorder.Result().Body) + const fakeImageData = "Fake Image Data" + + BeforeEach(func() { + var err error + customDir, err = ioutil.TempDir("", "oauth2-proxy-sign-in-page-test") Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(Equal("/prefix/ My Provider Sign In Here Custom Footer Text v0.0.0-test /redirect true")) + + for _, ext := range []string{".svg", ".png", ".jpg", ".jpeg", ".gif"} { + fileName := filepath.Join(customDir, fmt.Sprintf("logo%s", ext)) + Expect(ioutil.WriteFile(fileName, []byte(fakeImageData), 0600)).To(Succeed()) + } }) - It("Writes an error if the template can't be rendered", func() { - // Overwrite the template with something bad - tmpl, err := template.New("").Parse("{{.Unknown}}") - Expect(err).ToNot(HaveOccurred()) - signInPage.template = tmpl - - recorder := httptest.NewRecorder() - signInPage.WriteSignInPage(recorder, "/redirect") - - body, err := ioutil.ReadAll(recorder.Result().Body) - Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(Equal("Internal Server Error")) + AfterEach(func() { + Expect(os.RemoveAll(customDir)).To(Succeed()) }) + + type loadCustomLogoTableInput struct { + logoPath string + expectedErr error + expectedData string + } + + DescribeTable("should load the logo based on configuration", func(in loadCustomLogoTableInput) { + logoPath := in.logoPath + if strings.HasPrefix(logoPath, "customDir/") { + logoPath = filepath.Join(customDir, strings.TrimLeft(logoPath, "customDir/")) + } + + data, err := loadCustomLogo(logoPath) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr.Error())) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(data).To(Equal(in.expectedData)) + }, + Entry("with no custom logo path", loadCustomLogoTableInput{ + logoPath: "", + expectedErr: nil, + expectedData: defaultLogoData, + }), + Entry("when disabling the logo display", loadCustomLogoTableInput{ + logoPath: "-", + expectedErr: nil, + expectedData: "", + }), + Entry("with an svg custom logo", loadCustomLogoTableInput{ + logoPath: "customDir/logo.svg", + expectedErr: nil, + expectedData: fakeImageData, + }), + Entry("with a png custom logo", loadCustomLogoTableInput{ + logoPath: "customDir/logo.png", + expectedErr: nil, + expectedData: "\"Logo\"", + }), + Entry("with a jpg custom logo", loadCustomLogoTableInput{ + logoPath: "customDir/logo.jpg", + expectedErr: nil, + expectedData: "\"Logo\"", + }), + Entry("with a jpeg custom logo", loadCustomLogoTableInput{ + logoPath: "customDir/logo.jpeg", + expectedErr: nil, + expectedData: "\"Logo\"", + }), + Entry("with a gif custom logo", loadCustomLogoTableInput{ + logoPath: "customDir/logo.gif", + expectedErr: errors.New("unknown extension: \".gif\", supported extensions are .svg, .jpg, .jpeg and .png"), + expectedData: "", + }), + Entry("when the logo does not exist", loadCustomLogoTableInput{ + logoPath: "unknown.svg", + expectedErr: errors.New("could not read logo file: open unknown.svg: no such file or directory"), + expectedData: "", + }), + ) }) }) diff --git a/pkg/app/pagewriter/templates_test.go b/pkg/app/pagewriter/templates_test.go index 49b6ee55..afa6294f 100644 --- a/pkg/app/pagewriter/templates_test.go +++ b/pkg/app/pagewriter/templates_test.go @@ -45,6 +45,7 @@ var _ = Describe("Templates", func() { SignInMessage string ProviderName string CustomLogin bool + LogoData string // For default error template StatusCode int @@ -61,6 +62,7 @@ var _ = Describe("Templates", func() { SignInMessage: "", ProviderName: "", CustomLogin: false, + LogoData: "", StatusCode: 404, Title: "",