From 2b9e1bbba0998095cea6cbb8dd4625d93fc3d6fc Mon Sep 17 00:00:00 2001
From: Nick Meves <nick.meves@greenhouse.io>
Date: Sun, 27 Sep 2020 11:46:29 -0700
Subject: [PATCH] Add EnrichSessionState as main post-Redeem session updater

---
 CHANGELOG.md                       |  2 ++
 oauthproxy.go                      |  7 ++++---
 oauthproxy_test.go                 |  8 ++++----
 providers/provider_default.go      | 18 +++++++++++++-----
 providers/provider_default_test.go |  6 ++++++
 providers/providers.go             |  3 +++
 6 files changed, 32 insertions(+), 12 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index eaca5c86..5dfc03e4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -26,9 +26,11 @@
 ## Changes since v6.1.1
 
 - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
+- [#767](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves)
 - [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed)
 - [#753](https://github.com/oauth2-proxy/oauth2-proxy/pull/753) Pass resource parameter in login url (@codablock)
 - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) Add `--skip-auth-route` configuration option for `METHOD=pathRegex` based allowlists (@NickMeves)
+- [#767](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves)
 - [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Stop accepting legacy SHA1 signed cookies (@NickMeves)
 - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) Validate Redis configuration options at startup (@NickMeves)
 - [#791](https://github.com/oauth2-proxy/oauth2-proxy/pull/791) Remove GetPreferredUsername method from provider interface (@NickMeves)
diff --git a/oauthproxy.go b/oauthproxy.go
index 9f80f643..db18e660 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -360,7 +360,7 @@ func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessio
 	return s, nil
 }
 
-func (p *OAuthProxy) enrichSession(ctx context.Context, s *sessionsapi.SessionState) error {
+func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
 	var err error
 	if s.Email == "" {
 		s.Email, err = p.provider.GetEmailAddress(ctx, s)
@@ -374,7 +374,8 @@ func (p *OAuthProxy) enrichSession(ctx context.Context, s *sessionsapi.SessionSt
 			return err
 		}
 	}
-	return nil
+
+	return p.provider.EnrichSessionState(ctx, s)
 }
 
 // MakeCSRFCookie creates a cookie for CSRF
@@ -831,7 +832,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	err = p.enrichSession(req.Context(), session)
+	err = p.enrichSessionState(req.Context(), session)
 	if err != nil {
 		logger.Errorf("Error creating session during OAuth2 callback: %v", err)
 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
diff --git a/oauthproxy_test.go b/oauthproxy_test.go
index 9f60439b..46ee24b8 100644
--- a/oauthproxy_test.go
+++ b/oauthproxy_test.go
@@ -396,11 +396,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
 	}
 }
 
-func (tp *TestProvider) GetEmailAddress(ctx context.Context, session *sessions.SessionState) (string, error) {
+func (tp *TestProvider) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {
 	return tp.EmailAddress, nil
 }
 
-func (tp *TestProvider) ValidateSessionState(ctx context.Context, session *sessions.SessionState) bool {
+func (tp *TestProvider) ValidateSessionState(_ context.Context, _ *sessions.SessionState) bool {
 	return tp.ValidToken
 }
 
@@ -468,7 +468,7 @@ func Test_enrichSession(t *testing.T) {
 				t.Fatal(err)
 			}
 
-			err = proxy.enrichSession(context.Background(), tc.session)
+			err = proxy.enrichSessionState(context.Background(), tc.session)
 			assert.NoError(t, err)
 			assert.Equal(t, tc.expectedUser, tc.session.User)
 			assert.Equal(t, tc.expectedEmail, tc.session.Email)
@@ -1955,7 +1955,7 @@ func TestClearSingleCookie(t *testing.T) {
 type NoOpKeySet struct {
 }
 
-func (NoOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
+func (NoOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
 	splitStrings := strings.Split(jwt, ".")
 	payloadString := splitStrings[1]
 	return base64.RawURLEncoding.DecodeString(payloadString)
diff --git a/providers/provider_default.go b/providers/provider_default.go
index 4e96f0e0..87bef08e 100644
--- a/providers/provider_default.go
+++ b/providers/provider_default.go
@@ -87,21 +87,29 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
 }
 
 // GetEmailAddress returns the Account email address
-func (p *ProviderData) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
+// DEPRECATED: Migrate to EnrichSessionState
+func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {
 	return "", ErrNotImplemented
 }
 
 // GetUserName returns the Account username
-func (p *ProviderData) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
+// DEPRECATED: Migrate to EnrichSessionState
+func (p *ProviderData) GetUserName(_ context.Context, _ *sessions.SessionState) (string, error) {
 	return "", ErrNotImplemented
 }
 
 // ValidateGroup validates that the provided email exists in the configured provider
 // email group(s).
-func (p *ProviderData) ValidateGroup(email string) bool {
+func (p *ProviderData) ValidateGroup(_ string) bool {
 	return true
 }
 
+// EnrichSessionState is called after Redeem to allow providers to enrich session fields
+// such as User, Email, Groups with provider specific API calls.
+func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
+	return nil
+}
+
 // ValidateSessionState validates the AccessToken
 func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
 	return validateToken(ctx, p, s.AccessToken, nil)
@@ -109,12 +117,12 @@ func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.Ses
 
 // RefreshSessionIfNeeded should refresh the user's session if required and
 // do nothing if a refresh is not required
-func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
+func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
 	return false, nil
 }
 
 // CreateSessionStateFromBearerToken should be implemented to allow providers
 // to convert ID tokens into sessions
-func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
+func (p *ProviderData) CreateSessionStateFromBearerToken(_ context.Context, _ string, _ *oidc.IDToken) (*sessions.SessionState, error) {
 	return nil, ErrNotImplemented
 }
diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go
index 8597ac66..f04fe607 100644
--- a/providers/provider_default_test.go
+++ b/providers/provider_default_test.go
@@ -47,3 +47,9 @@ func TestAcrValuesConfigured(t *testing.T) {
 	result := p.GetLoginURL("https://my.test.app/oauth", "")
 	assert.Contains(t, result, "acr_values=testValue")
 }
+
+func TestEnrichSessionState(t *testing.T) {
+	p := &ProviderData{}
+	s := &sessions.SessionState{}
+	assert.NoError(t, p.EnrichSessionState(context.Background(), s))
+}
diff --git a/providers/providers.go b/providers/providers.go
index e92b3293..5e43613e 100644
--- a/providers/providers.go
+++ b/providers/providers.go
@@ -10,10 +10,13 @@ import (
 // Provider represents an upstream identity provider implementation
 type Provider interface {
 	Data() *ProviderData
+	// DEPRECATED: Migrate to EnrichSessionState
 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
+	// DEPRECATED: Migrate to EnrichSessionState
 	GetUserName(ctx context.Context, s *sessions.SessionState) (string, error)
 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
 	ValidateGroup(string) bool
+	EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
 	ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
 	GetLoginURL(redirectURI, finalRedirect string) string
 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)