diff --git a/oauthproxy_test.go b/oauthproxy_test.go index d0fd9481..9f60439b 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -404,6 +404,78 @@ func (tp *TestProvider) ValidateSessionState(ctx context.Context, session *sessi return tp.ValidToken } +func Test_redeemCode(t *testing.T) { + opts := baseTestOptions() + err := validation.Validate(opts) + assert.NoError(t, err) + + proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + if err != nil { + t.Fatal(err) + } + + _, err = proxy.redeemCode(context.Background(), "www.example.com", "") + assert.Error(t, err) +} + +func Test_enrichSession(t *testing.T) { + const ( + sessionUser = "Mr Session" + sessionEmail = "session@example.com" + providerEmail = "provider@example.com" + ) + + testCases := map[string]struct { + session *sessions.SessionState + expectedUser string + expectedEmail string + }{ + "Session already has enrichable fields": { + session: &sessions.SessionState{ + User: sessionUser, + Email: sessionEmail, + }, + expectedUser: sessionUser, + expectedEmail: sessionEmail, + }, + "Session is missing Email and GetEmailAddress is implemented": { + session: &sessions.SessionState{ + User: sessionUser, + }, + expectedUser: sessionUser, + expectedEmail: providerEmail, + }, + "Session is missing User and GetUserName is not implemented": { + session: &sessions.SessionState{ + Email: sessionEmail, + }, + expectedUser: "", + expectedEmail: sessionEmail, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + opts := baseTestOptions() + err := validation.Validate(opts) + assert.NoError(t, err) + + // intentionally set after validation.Validate(opts) since it will clobber + // our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider` + opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)) + proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + if err != nil { + t.Fatal(err) + } + + err = proxy.enrichSession(context.Background(), tc.session) + assert.NoError(t, err) + assert.Equal(t, tc.expectedUser, tc.session.User) + assert.Equal(t, tc.expectedEmail, tc.session.Email) + }) + } +} + func TestBasicAuthPassword(t *testing.T) { providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Printf("%#v", r)