1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-26 05:27:28 +02:00
oauth2-proxy/pkg/app/pagewriter/sign_in_page_test.go
Miks Kalnins 54d44ccb8f
Allow specifying URL as input for custom sign in logo (#1330)
* Allow specifying URL as input for custom logos

* Fix typo

Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>

* Update changelog

* Only allow HTTPS URLs

Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
Co-authored-by: Nick Meves <nicholas.meves@gmail.com>
2021-09-05 09:23:22 -07:00

168 lines
5.4 KiB
Go

package pagewriter
import (
"errors"
"fmt"
"html/template"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
var _ = Describe("SignIn Page", func() {
Context("SignIn Page Writer", func() {
var request *http.Request
var signInPage *signInPageWriter
BeforeEach(func() {
errorTmpl, err := template.New("").Parse("{{.Title}} | {{.RequestID}}")
Expect(err).ToNot(HaveOccurred())
errorPage := &errorPageWriter{
template: errorTmpl,
}
tmpl, err := template.New("").Parse("{{.ProxyPrefix}} {{.ProviderName}} {{.SignInMessage}} {{.Footer}} {{.Version}} {{.Redirect}} {{.CustomLogin}} {{.LogoData}}")
Expect(err).ToNot(HaveOccurred())
signInPage = &signInPageWriter{
template: tmpl,
errorPageWriter: errorPage,
proxyPrefix: "/prefix/",
providerName: "My Provider",
signInMessage: "Sign In Here",
footer: "Custom Footer Text",
version: "v0.0.0-test",
displayLoginForm: true,
logoData: "Logo Data",
}
request = httptest.NewRequest("", "http://127.0.0.1/", nil)
request = middlewareapi.AddRequestScope(request, &middlewareapi.RequestScope{
RequestID: testRequestID,
})
})
Context("WriteSignInPage", func() {
It("Writes the template to the response writer", func() {
recorder := httptest.NewRecorder()
signInPage.WriteSignInPage(recorder, request, "/redirect")
body, err := ioutil.ReadAll(recorder.Result().Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("/prefix/ My Provider Sign In Here Custom Footer Text v0.0.0-test /redirect true Logo Data"))
})
It("Writes an error if the template can't be rendered", func() {
// Overwrite the template with something bad
tmpl, err := template.New("").Parse("{{.Unknown}}")
Expect(err).ToNot(HaveOccurred())
signInPage.template = tmpl
recorder := httptest.NewRecorder()
signInPage.WriteSignInPage(recorder, request, "/redirect")
body, err := ioutil.ReadAll(recorder.Result().Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(fmt.Sprintf("Internal Server Error | %s", testRequestID)))
})
})
})
Context("loadCustomLogo", func() {
var customDir string
const fakeImageData = "Fake Image Data"
BeforeEach(func() {
var err error
customDir, err = ioutil.TempDir("", "oauth2-proxy-sign-in-page-test")
Expect(err).ToNot(HaveOccurred())
for _, ext := range []string{".svg", ".png", ".jpg", ".jpeg", ".gif"} {
fileName := filepath.Join(customDir, fmt.Sprintf("logo%s", ext))
Expect(ioutil.WriteFile(fileName, []byte(fakeImageData), 0600)).To(Succeed())
}
})
AfterEach(func() {
Expect(os.RemoveAll(customDir)).To(Succeed())
})
type loadCustomLogoTableInput struct {
logoPath string
expectedErr error
expectedData string
}
DescribeTable("should load the logo based on configuration", func(in loadCustomLogoTableInput) {
logoPath := in.logoPath
if strings.HasPrefix(logoPath, "customDir/") {
logoPath = filepath.Join(customDir, strings.TrimLeft(logoPath, "customDir/"))
}
data, err := loadCustomLogo(logoPath)
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr.Error()))
} else {
Expect(err).ToNot(HaveOccurred())
}
Expect(data).To(Equal(in.expectedData))
},
Entry("with no custom logo path", loadCustomLogoTableInput{
logoPath: "",
expectedErr: nil,
expectedData: defaultLogoData,
}),
Entry("when disabling the logo display", loadCustomLogoTableInput{
logoPath: "-",
expectedErr: nil,
expectedData: "",
}),
Entry("with HTTPS URL", loadCustomLogoTableInput{
logoPath: "https://raw.githubusercontent.com/oauth2-proxy/oauth2-proxy/master/docs/static/img/logos/OAuth2_Proxy_icon.png",
expectedErr: nil,
expectedData: "<img src=\"https://raw.githubusercontent.com/oauth2-proxy/oauth2-proxy/master/docs/static/img/logos/OAuth2_Proxy_icon.png\" alt=\"Logo\" />",
}),
Entry("with an svg custom logo", loadCustomLogoTableInput{
logoPath: "customDir/logo.svg",
expectedErr: nil,
expectedData: fakeImageData,
}),
Entry("with a png custom logo", loadCustomLogoTableInput{
logoPath: "customDir/logo.png",
expectedErr: nil,
expectedData: "<img src=\"\" alt=\"Logo\" />",
}),
Entry("with a jpg custom logo", loadCustomLogoTableInput{
logoPath: "customDir/logo.jpg",
expectedErr: nil,
expectedData: "<img src=\"\" alt=\"Logo\" />",
}),
Entry("with a jpeg custom logo", loadCustomLogoTableInput{
logoPath: "customDir/logo.jpeg",
expectedErr: nil,
expectedData: "<img src=\"\" alt=\"Logo\" />",
}),
Entry("with a gif custom logo", loadCustomLogoTableInput{
logoPath: "customDir/logo.gif",
expectedErr: errors.New("unknown extension: \".gif\", supported extensions are .svg, .jpg, .jpeg and .png"),
expectedData: "",
}),
Entry("when the logo does not exist", loadCustomLogoTableInput{
logoPath: "unknown.svg",
expectedErr: errors.New("could not read logo file: open unknown.svg: no such file or directory"),
expectedData: "",
}),
)
})
})