diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 9add52ed..001d7347 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -587,6 +587,53 @@ func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { return rw.Code, rw.Body.String() } +type AlwaysSuccessfulValidator struct { +} + +func (AlwaysSuccessfulValidator) Validate(user, password string) bool { + return true +} + +func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) { + userGroups := []string{"somegroup", "someothergroup"} + + opts := baseTestOptions() + opts.HtpasswdUserGroups = userGroups + err := validation.Validate(opts) + if err != nil { + t.Fatal(err) + } + + proxy, err := NewOAuthProxy(opts, func(email string) bool { + return true + }) + if err != nil { + t.Fatal(err) + } + proxy.basicAuthValidator = AlwaysSuccessfulValidator{} + + rw := httptest.NewRecorder() + formData := url.Values{} + formData.Set("username", "someuser") + formData.Set("password", "somepass") + signInReq, _ := http.NewRequest(http.MethodPost, "/oauth2/sign_in", strings.NewReader(formData.Encode())) + signInReq.Header.Add("Content-Type", "application/x-www-form-urlencoded") + proxy.ServeHTTP(rw, signInReq) + + assert.Equal(t, http.StatusFound, rw.Code) + + req, _ := http.NewRequest(http.MethodGet, "/something", strings.NewReader(formData.Encode())) + for _, c := range rw.Result().Cookies() { + req.AddCookie(c) + } + + s, err := proxy.sessionStore.Load(req) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, userGroups, s.Groups) +} + func TestSignInPageIncludesTargetRedirect(t *testing.T) { sipTest, err := NewSignInPageTest(false) if err != nil {