mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-08 04:03:58 +02:00
Merge pull request #82 from 18F/sign-in-redirect
Redirect to / when /oauth2/sign_in accessed
This commit is contained in:
commit
b0f0409f2b
@ -307,6 +307,11 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
|
||||
p.ClearCookie(rw, req)
|
||||
rw.WriteHeader(code)
|
||||
|
||||
redirect_url := req.URL.RequestURI()
|
||||
if redirect_url == signInPath {
|
||||
redirect_url = "/"
|
||||
}
|
||||
|
||||
t := struct {
|
||||
ProviderName string
|
||||
SignInMessage string
|
||||
@ -317,7 +322,7 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
|
||||
ProviderName: p.provider.Data().ProviderName,
|
||||
SignInMessage: p.SignInMessage,
|
||||
CustomLogin: p.displayCustomLoginForm(),
|
||||
Redirect: req.URL.RequestURI(),
|
||||
Redirect: redirect_url,
|
||||
Version: VERSION,
|
||||
}
|
||||
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -237,3 +238,69 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
||||
assert.Equal(t, 200, code)
|
||||
assert.Equal(t, "No access token found.", payload)
|
||||
}
|
||||
|
||||
type SignInPageTest struct {
|
||||
opts *Options
|
||||
proxy *OauthProxy
|
||||
sign_in_regexp *regexp.Regexp
|
||||
}
|
||||
|
||||
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
||||
|
||||
func NewSignInPageTest() *SignInPageTest {
|
||||
var sip_test SignInPageTest
|
||||
|
||||
sip_test.opts = NewOptions()
|
||||
sip_test.opts.Upstreams = append(sip_test.opts.Upstreams, "unused")
|
||||
sip_test.opts.CookieSecret = "foobar"
|
||||
sip_test.opts.ClientID = "bazquux"
|
||||
sip_test.opts.ClientSecret = "xyzzyplugh"
|
||||
sip_test.opts.Validate()
|
||||
|
||||
sip_test.proxy = NewOauthProxy(sip_test.opts, func(email string) bool {
|
||||
return true
|
||||
})
|
||||
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
|
||||
|
||||
return &sip_test
|
||||
}
|
||||
|
||||
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
||||
rw := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
||||
sip_test.proxy.ServeHTTP(rw, req)
|
||||
return rw.Code, rw.Body.String()
|
||||
}
|
||||
|
||||
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
||||
sip_test := NewSignInPageTest()
|
||||
const endpoint = "/some/random/endpoint"
|
||||
|
||||
code, body := sip_test.GetEndpoint(endpoint)
|
||||
assert.Equal(t, 403, code)
|
||||
|
||||
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
|
||||
if match == nil {
|
||||
t.Fatal("Did not find pattern in body: " +
|
||||
signInRedirectPattern + "\nBody:\n" + body)
|
||||
}
|
||||
if match[1] != endpoint {
|
||||
t.Fatal(`expected redirect to "` + endpoint +
|
||||
`", but was "` + match[1] + `"`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
||||
sip_test := NewSignInPageTest()
|
||||
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
|
||||
assert.Equal(t, 200, code)
|
||||
|
||||
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
|
||||
if match == nil {
|
||||
t.Fatal("Did not find pattern in body: " +
|
||||
signInRedirectPattern + "\nBody:\n" + body)
|
||||
}
|
||||
if match[1] != "/" {
|
||||
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user