From 7fa6d2d024b4a46d477bc1a4da58e377ee60fbce Mon Sep 17 00:00:00 2001
From: Nick Meves <nick.meves@greenhouse.io>
Date: Sat, 6 Mar 2021 15:33:40 -0800
Subject: [PATCH] Manage session time fields centrally

---
 oauthproxy.go                        | 15 +++++++++++---
 pkg/apis/sessions/session_state.go   | 29 ++++++++++++++++++++++++---
 pkg/middleware/stored_session.go     |  3 +--
 pkg/sessions/cookie/session_store.go |  3 +--
 pkg/sessions/persistence/manager.go  |  3 +--
 providers/azure.go                   | 30 +++++++++++-----------------
 providers/gitlab.go                  | 12 ++++++-----
 providers/google.go                  | 25 +++++++++++------------
 providers/logingov.go                | 17 ++++++++--------
 providers/oidc.go                    |  9 +++++----
 providers/provider_default.go        |  9 ++++++---
 11 files changed, 91 insertions(+), 64 deletions(-)

diff --git a/oauthproxy.go b/oauthproxy.go
index b0c94eb0..c3a5693d 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e
 	if err != nil {
 		return nil, err
 	}
+
+	// Force setting these in case the Provider didn't
+	if s.CreatedAt == nil {
+		s.CreatedAtNow()
+	}
+	if s.ExpiresOn == nil {
+		s.ExpiresIn(p.CookieOptions.Expire)
+	}
+
 	return s, nil
 }
 
@@ -861,9 +870,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
 
 // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 var noCacheHeaders = map[string]string{
-	"Expires":         time.Unix(0, 0).Format(time.RFC1123),
-	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0",
-	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
+	"Expires":        time.Unix(0, 0).Format(time.RFC1123),
+	"Cache-Control":  "no-cache, no-store, must-revalidate, max-age=0",
+	"X-Accel-Expire": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
 }
 
 // prepareNoCache prepares headers for preventing browser caching.
diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go
index e1ee4a6c..9e77609e 100644
--- a/pkg/apis/sessions/session_state.go
+++ b/pkg/apis/sessions/session_state.go
@@ -11,6 +11,7 @@ import (
 	"time"
 	"unicode/utf8"
 
+	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
 	"github.com/pierrec/lz4"
 	"github.com/vmihailenco/msgpack/v4"
@@ -32,7 +33,8 @@ type SessionState struct {
 	Groups            []string `msgpack:"g,omitempty"`
 	PreferredUsername string   `msgpack:"pu,omitempty"`
 
-	Lock Lock `msgpack:"-"`
+	Clock clock.Clock `msgpack:"-"`
+	Lock  Lock        `msgpack:"-"`
 }
 
 func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
@@ -63,9 +65,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
 	return s.Lock.Peek(ctx)
 }
 
+// CreatedAtNow sets a SessionState's CreatedAt to now
+func (s *SessionState) CreatedAtNow() {
+	now := s.Clock.Now()
+	s.CreatedAt = &now
+}
+
+// SetExpiresOn sets an expiration
+func (s *SessionState) SetExpiresOn(exp time.Time) {
+	s.ExpiresOn = &exp
+}
+
+// ExpiresIn sets an expiration a certain duration from CreatedAt.
+// CreatedAt will be set to time.Now if it is unset.
+func (s *SessionState) ExpiresIn(d time.Duration) {
+	if s.CreatedAt == nil {
+		s.CreatedAtNow()
+	}
+	exp := s.CreatedAt.Add(d)
+	s.ExpiresOn = &exp
+}
+
 // IsExpired checks whether the session has expired
 func (s *SessionState) IsExpired() bool {
-	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
+	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
 		return true
 	}
 	return false
@@ -74,7 +97,7 @@ func (s *SessionState) IsExpired() bool {
 // Age returns the age of a session
 func (s *SessionState) Age() time.Duration {
 	if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
-		return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
+		return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
 	}
 	return 0
 }
diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go
index b3737581..9f69ba64 100644
--- a/pkg/middleware/stored_session.go
+++ b/pkg/middleware/stored_session.go
@@ -142,8 +142,7 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
 	}
 
 	// If we refreshed, update the `CreatedAt` time to reset the refresh timer
-	// TODO: Implement
-	// session.CreatedAtNow()
+	session.CreatedAtNow()
 
 	// Because the session was refreshed, make sure to save it
 	err = s.store.Save(rw, req, session)
diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go
index ce51ed07..1b3c12de 100644
--- a/pkg/sessions/cookie/session_store.go
+++ b/pkg/sessions/cookie/session_store.go
@@ -36,8 +36,7 @@ type SessionStore struct {
 // within Cookies set on the HTTP response writer
 func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
 	if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
-		now := time.Now()
-		ss.CreatedAt = &now
+		ss.CreatedAtNow()
 	}
 	value, err := s.cookieForSession(ss)
 	if err != nil {
diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go
index 49225171..3215b257 100644
--- a/pkg/sessions/persistence/manager.go
+++ b/pkg/sessions/persistence/manager.go
@@ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager {
 // from the persistent data store.
 func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
 	if s.CreatedAt == nil || s.CreatedAt.IsZero() {
-		now := time.Now()
-		s.CreatedAt = &now
+		s.CreatedAtNow()
 	}
 
 	tckt, err := decodeTicketFromRequest(req, m.Options)
diff --git a/providers/azure.go b/providers/azure.go
index f66d3764..46d7e302 100644
--- a/providers/azure.go
+++ b/providers/azure.go
@@ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
 		return nil, err
 	}
 
-	created := time.Now()
-	expires := time.Unix(jsonResponse.ExpiresOn, 0)
-
 	session := &sessions.SessionState{
 		AccessToken:  jsonResponse.AccessToken,
 		IDToken:      jsonResponse.IDToken,
-		CreatedAt:    &created,
-		ExpiresOn:    &expires,
 		RefreshToken: jsonResponse.RefreshToken,
 	}
+	session.CreatedAtNow()
+	session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
 
 	email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
 
@@ -239,10 +236,9 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st
 	return email, nil
 }
 
-// RefreshSessionIfNeeded checks if the session has expired and uses the
-// RefreshToken to fetch a new ID token if required
-func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
-	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
+// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
+func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
+	if s == nil || s.RefreshToken == "" {
 		return false, nil
 	}
 
@@ -257,7 +253,7 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.
 	return true, nil
 }
 
-func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
+func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
 	params := url.Values{}
 	params.Add("client_id", p.ClientID)
 	params.Add("client_secret", p.ClientSecret)
@@ -271,25 +267,23 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
 		IDToken      string `json:"id_token"`
 	}
 
-	err = requests.New(p.RedeemURL.String()).
+	err := requests.New(p.RedeemURL.String()).
 		WithContext(ctx).
 		WithMethod("POST").
 		WithBody(bytes.NewBufferString(params.Encode())).
 		SetHeader("Content-Type", "application/x-www-form-urlencoded").
 		Do().
 		UnmarshalInto(&jsonResponse)
-
 	if err != nil {
-		return
+		return err
 	}
 
-	now := time.Now()
-	expires := time.Unix(jsonResponse.ExpiresOn, 0)
 	s.AccessToken = jsonResponse.AccessToken
 	s.IDToken = jsonResponse.IDToken
 	s.RefreshToken = jsonResponse.RefreshToken
-	s.CreatedAt = &now
-	s.ExpiresOn = &expires
+
+	s.CreatedAtNow()
+	s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
 
 	email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
 
@@ -312,7 +306,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
 		}
 	}
 
-	return
+	return nil
 }
 
 func makeAzureHeader(accessToken string) http.Header {
diff --git a/providers/gitlab.go b/providers/gitlab.go
index ca9a8bf2..a2b11df7 100644
--- a/providers/gitlab.go
+++ b/providers/gitlab.go
@@ -259,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token)
 		}
 	}
 
-	created := time.Now()
-	return &sessions.SessionState{
+	ss := &sessions.SessionState{
 		AccessToken:  token.AccessToken,
 		IDToken:      getIDToken(token),
 		RefreshToken: token.RefreshToken,
-		CreatedAt:    &created,
-		ExpiresOn:    &idToken.Expiry,
-	}, nil
+	}
+
+	ss.CreatedAtNow()
+	ss.SetExpiresOn(idToken.Expiry)
+
+	return ss, nil
 }
 
 // ValidateSession checks that the session's IDToken is still valid
diff --git a/providers/google.go b/providers/google.go
index 49eae1c1..0cfd3e1c 100644
--- a/providers/google.go
+++ b/providers/google.go
@@ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
 		return nil, err
 	}
 
-	created := time.Now()
-	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
-
-	return &sessions.SessionState{
+	ss := &sessions.SessionState{
 		AccessToken:  jsonResponse.AccessToken,
 		IDToken:      jsonResponse.IDToken,
-		CreatedAt:    &created,
-		ExpiresOn:    &expires,
 		RefreshToken: jsonResponse.RefreshToken,
 		Email:        c.Email,
 		User:         c.Subject,
-	}, nil
+	}
+	ss.CreatedAtNow()
+	ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
+
+	return ss, nil
 }
 
 // EnrichSession checks the listed Google Groups configured and adds any
 // that the user is a member of to session.Groups.
-func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
+func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error {
 	// TODO (@NickMeves) - Move to pure EnrichSession logic and stop
 	// reusing legacy `groupValidator`.
 	//
@@ -272,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session
 		return false, nil
 	}
 
-	newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken)
+	newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken)
 	if err != nil {
 		return false, err
 	}
@@ -285,12 +284,12 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session
 		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
 	}
 
-	origExpiration := s.ExpiresOn
-	expires := time.Now().Add(duration).Truncate(time.Second)
 	s.AccessToken = newToken
 	s.IDToken = newIDToken
-	s.ExpiresOn = &expires
-	logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
+
+	s.CreatedAtNow()
+	s.ExpiresIn(ttl)
+
 	return true, nil
 }
 
diff --git a/providers/logingov.go b/providers/logingov.go
index 0f625208..43f361f3 100644
--- a/providers/logingov.go
+++ b/providers/logingov.go
@@ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
 }
 
 // Redeem exchanges the OAuth2 authentication token for an ID token
-func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
+func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) {
 	if code == "" {
 		return nil, ErrMissingCode
 	}
@@ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
 		return nil, err
 	}
 
-	created := time.Now()
-	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
-
-	// Store the data that we found in the session state
-	return &sessions.SessionState{
+	session := &sessions.SessionState{
 		AccessToken: jsonResponse.AccessToken,
 		IDToken:     jsonResponse.IDToken,
-		CreatedAt:   &created,
-		ExpiresOn:   &expires,
 		Email:       email,
-	}, nil
+	}
+
+	session.CreatedAtNow()
+	session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
+
+	return session, nil
 }
 
 // GetLoginURL overrides GetLoginURL to add login.gov parameters
diff --git a/providers/oidc.go b/providers/oidc.go
index 3e1e79a8..2cbbd009 100644
--- a/providers/oidc.go
+++ b/providers/oidc.go
@@ -226,7 +226,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
 	ss.AccessToken = token
 	ss.IDToken = token
 	ss.RefreshToken = ""
-	ss.ExpiresOn = &idToken.Expiry
+
+	ss.CreatedAtNow()
+	ss.SetExpiresOn(idToken.Expiry)
 
 	return ss, nil
 }
@@ -256,9 +258,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
 	ss.RefreshToken = token.RefreshToken
 	ss.IDToken = getIDToken(token)
 
-	created := time.Now()
-	ss.CreatedAt = &created
-	ss.ExpiresOn = &token.Expiry
+	ss.CreatedAtNow()
+	ss.SetExpiresOn(token.Expiry)
 
 	return ss, nil
 }
diff --git a/providers/provider_default.go b/providers/provider_default.go
index be57f0e5..0a62c240 100644
--- a/providers/provider_default.go
+++ b/providers/provider_default.go
@@ -6,7 +6,6 @@ import (
 	"errors"
 	"fmt"
 	"net/url"
-	"time"
 
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
@@ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s
 	if err != nil {
 		return nil, err
 	}
+	// TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration
 	if token := values.Get("access_token"); token != "" {
-		created := time.Now()
-		return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
+		ss := &sessions.SessionState{
+			AccessToken: token,
+		}
+		ss.CreatedAtNow()
+		return ss, nil
 	}
 
 	return nil, fmt.Errorf("no access token found %s", result.Body())