1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-05-23 22:40:31 +02:00

Merge pull request #1086 from oauth2-proxy/early-refresh

Convert RefreshSessionIfNeeded into RefreshSession
This commit is contained in:
Nick Meves 2021-06-22 17:13:14 -07:00 committed by GitHub
commit 16a9893a19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 302 additions and 248 deletions

View File

@ -4,10 +4,16 @@
## Important Notes ## Important Notes
- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session
deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade
to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret`
value and force all sessions to reauthenticate.
## Breaking Changes ## Breaking Changes
## Changes since v7.1.3 ## Changes since v7.1.3
- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves)
- [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed) - [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed)
- [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed) - [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed)
- [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi) - [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi)

View File

@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
} }
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore, SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh, RefreshPeriod: opts.Cookie.Refresh,
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, RefreshSession: opts.GetProvider().RefreshSession,
ValidateSessionState: opts.GetProvider().ValidateSession, ValidateSession: opts.GetProvider().ValidateSession,
})) }))
return chain return chain
@ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Force setting these in case the Provider didn't
if s.CreatedAt == nil {
s.CreatedAtNow()
}
if s.ExpiresOn == nil {
s.ExpiresIn(p.CookieOptions.Expire)
}
return s, nil return s, nil
} }

View File

@ -3,14 +3,12 @@ package sessions
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"reflect"
"time" "time"
"unicode/utf8"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/pierrec/lz4" "github.com/pierrec/lz4"
"github.com/vmihailenco/msgpack/v4" "github.com/vmihailenco/msgpack/v4"
@ -32,7 +30,9 @@ type SessionState struct {
Groups []string `msgpack:"g,omitempty"` Groups []string `msgpack:"g,omitempty"`
PreferredUsername string `msgpack:"pu,omitempty"` PreferredUsername string `msgpack:"pu,omitempty"`
Lock Lock `msgpack:"-"` // Internal helpers, not serialized
Clock clock.Clock `msgpack:"-"`
Lock Lock `msgpack:"-"`
} }
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
@ -63,9 +63,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
return s.Lock.Peek(ctx) return s.Lock.Peek(ctx)
} }
// CreatedAtNow sets a SessionState's CreatedAt to now
func (s *SessionState) CreatedAtNow() {
now := s.Clock.Now()
s.CreatedAt = &now
}
// SetExpiresOn sets an expiration
func (s *SessionState) SetExpiresOn(exp time.Time) {
s.ExpiresOn = &exp
}
// ExpiresIn sets an expiration a certain duration from CreatedAt.
// CreatedAt will be set to time.Now if it is unset.
func (s *SessionState) ExpiresIn(d time.Duration) {
if s.CreatedAt == nil {
s.CreatedAtNow()
}
exp := s.CreatedAt.Add(d)
s.ExpiresOn = &exp
}
// IsExpired checks whether the session has expired // IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool { func (s *SessionState) IsExpired() bool {
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
return true return true
} }
return false return false
@ -74,7 +95,7 @@ func (s *SessionState) IsExpired() bool {
// Age returns the age of a session // Age returns the age of a session
func (s *SessionState) Age() time.Duration { func (s *SessionState) Age() time.Duration {
if s.CreatedAt != nil && !s.CreatedAt.IsZero() { if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(*s.CreatedAt) return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
} }
return 0 return 0
} }
@ -177,11 +198,6 @@ func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*Ses
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
} }
err = ss.validate()
if err != nil {
return nil, err
}
return &ss, nil return &ss, nil
} }
@ -235,35 +251,3 @@ func lz4Decompress(compressed []byte) ([]byte, error) {
return payload, nil return payload, nil
} }
// validate ensures the decoded session is non-empty and contains valid data
//
// Non-empty check is needed due to ensure the non-authenticated AES-CFB
// decryption doesn't result in garbage data that collides with a valid
// MessagePack header bytes (which MessagePack will unmarshal to an empty
// default SessionState). <1% chance, but observed with random test data.
//
// UTF-8 check ensures the strings are valid and not raw bytes overloaded
// into Latin-1 encoding. The occurs when legacy unencrypted fields are
// decrypted with AES-CFB which results in random bytes.
func (s *SessionState) validate() error {
for _, field := range []string{
s.User,
s.Email,
s.PreferredUsername,
s.AccessToken,
s.IDToken,
s.RefreshToken,
} {
if !utf8.ValidString(field) {
return errors.New("invalid non-UTF8 field in session")
}
}
empty := new(SessionState)
if reflect.DeepEqual(*s, *empty) {
return errors.New("invalid empty session unmarshalled")
}
return nil
}

View File

@ -16,6 +16,30 @@ func timePtr(t time.Time) *time.Time {
return &t return &t
} }
func TestCreatedAtNow(t *testing.T) {
g := NewWithT(t)
ss := &SessionState{}
now := time.Unix(1234567890, 0)
ss.Clock.Set(now)
ss.CreatedAtNow()
g.Expect(*ss.CreatedAt).To(Equal(now))
}
func TestExpiresIn(t *testing.T) {
g := NewWithT(t)
ss := &SessionState{}
now := time.Unix(1234567890, 0)
ss.Clock.Set(now)
ttl := time.Duration(743) * time.Second
ss.ExpiresIn(ttl)
g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl)))
}
func TestString(t *testing.T) { func TestString(t *testing.T) {
g := NewWithT(t) g := NewWithT(t)
created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z") created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z")

View File

@ -63,13 +63,10 @@ func Reset() *clockapi.Mock {
// package. // package.
type Clock struct { type Clock struct {
mock *clockapi.Mock mock *clockapi.Mock
sync.Mutex
} }
// Set sets the Clock to a clock.Mock at the given time.Time // Set sets the Clock to a clock.Mock at the given time.Time
func (c *Clock) Set(t time.Time) { func (c *Clock) Set(t time.Time) {
c.Lock()
defer c.Unlock()
if c.mock == nil { if c.mock == nil {
c.mock = clockapi.NewMock() c.mock = clockapi.NewMock()
} }
@ -79,8 +76,6 @@ func (c *Clock) Set(t time.Time) {
// Add moves clock forward time.Duration if it is mocked. It will error // Add moves clock forward time.Duration if it is mocked. It will error
// if the clock is not mocked. // if the clock is not mocked.
func (c *Clock) Add(d time.Duration) error { func (c *Clock) Add(d time.Duration) error {
c.Lock()
defer c.Unlock()
if c.mock == nil { if c.mock == nil {
return errors.New("clock not mocked") return errors.New("clock not mocked")
} }
@ -91,8 +86,6 @@ func (c *Clock) Add(d time.Duration) error {
// Reset removes local clock.Mock. Returns any existing Mock if set in case // Reset removes local clock.Mock. Returns any existing Mock if set in case
// lingering time operations are attached to it. // lingering time operations are attached to it.
func (c *Clock) Reset() *clockapi.Mock { func (c *Clock) Reset() *clockapi.Mock {
c.Lock()
defer c.Unlock()
existing := c.mock existing := c.mock
c.mock = nil c.mock = nil
return existing return existing

View File

@ -11,25 +11,26 @@ import (
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
) )
// StoredSessionLoaderOptions cotnains all of the requirements to construct // StoredSessionLoaderOptions contains all of the requirements to construct
// a stored session loader. // a stored session loader.
// All options must be provided. // All options must be provided.
type StoredSessionLoaderOptions struct { type StoredSessionLoaderOptions struct {
// Session storage basckend // Session storage backend
SessionStore sessionsapi.SessionStore SessionStore sessionsapi.SessionStore
// How often should sessions be refreshed // How often should sessions be refreshed
RefreshPeriod time.Duration RefreshPeriod time.Duration
// Provider based sesssion refreshing // Provider based session refreshing
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
// Provider based session validation. // Provider based session validation.
// If the sesssion is older than `RefreshPeriod` but the provider doesn't // If the sesssion is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation. // refresh it, we must re-validate using this validation.
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool ValidateSession func(context.Context, *sessionsapi.SessionState) bool
} }
// NewStoredSessionLoader creates a new storedSessionLoader which loads // NewStoredSessionLoader creates a new storedSessionLoader which loads
@ -38,10 +39,10 @@ type StoredSessionLoaderOptions struct {
// If a session was loader by a previous handler, it will not be replaced. // If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
ss := &storedSessionLoader{ ss := &storedSessionLoader{
store: opts.SessionStore, store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod, refreshPeriod: opts.RefreshPeriod,
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, sessionRefresher: opts.RefreshSession,
validateSessionState: opts.ValidateSessionState, sessionValidator: opts.ValidateSession,
} }
return ss.loadSession return ss.loadSession
} }
@ -49,10 +50,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
// storedSessionLoader is responsible for loading sessions from cookie // storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store. // identified sessions in the session store.
type storedSessionLoader struct { type storedSessionLoader struct {
store sessionsapi.SessionStore store sessionsapi.SessionStore
refreshPeriod time.Duration refreshPeriod time.Duration
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
validateSessionState func(context.Context, *sessionsapi.SessionState) bool sessionValidator func(context.Context, *sessionsapi.SessionState) bool
} }
// loadSession attempts to load a session as identified by the request cookies. // loadSession attempts to load a session as identified by the request cookies.
@ -108,49 +109,59 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
// refreshSessionIfNeeded will attempt to refresh a session if the session // refreshSessionIfNeeded will attempt to refresh a session if the session
// is older than the refresh period. // is older than the refresh period.
// It is assumed that if the provider refreshes the session, the session is now // Success or fail, we will then validate the session.
// valid.
// If the session requires refreshing but the provider does not refresh it,
// we must validate the session to ensure that the returned session is still
// valid.
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
// Refresh is disabled or the session is not old enough, do nothing // Refresh is disabled or the session is not old enough, do nothing
return nil return nil
} }
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
refreshed, err := s.refreshSessionWithProvider(rw, req, session) err := s.refreshSession(rw, req, session)
if err != nil { if err != nil {
return err // If a preemptive refresh fails, we still keep the session
// if validateSession succeeds.
logger.Errorf("Unable to refresh session: %v", err)
} }
if !refreshed { // Validate all sessions after any Redeem/Refresh operation (fail or success)
// Session wasn't refreshed, so make sure it's still valid return s.validateSession(req.Context(), session)
return s.validateSession(req.Context(), session)
}
return nil
} }
// refreshSessionWithProvider attempts to refresh the sessinon with the provider // refreshSession attempts to refresh the session with the provider
// and will save the session if it was updated. // and will save the session if it was updated.
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) refreshed, err := s.sessionRefresher(req.Context(), session)
if err != nil { if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
return false, fmt.Errorf("error refreshing access token: %v", err) return fmt.Errorf("error refreshing tokens: %v", err)
} }
if !refreshed { // HACK:
return false, nil // Providers that don't implement `RefreshSession` use the default
// implementation which returns `ErrNotImplemented`.
// Pretend it refreshed to reset the refresh timer so that `ValidateSession`
// isn't triggered every subsequent request and is only called once during
// this request.
if errors.Is(err, providers.ErrNotImplemented) {
refreshed = true
} }
// Session not refreshed, nothing to persist.
if !refreshed {
return nil
}
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
// (In case underlying provider implementations forget)
session.CreatedAtNow()
// Because the session was refreshed, make sure to save it // Because the session was refreshed, make sure to save it
err = s.store.Save(rw, req, session) err = s.store.Save(rw, req, session)
if err != nil { if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
return false, fmt.Errorf("error saving session: %v", err) return fmt.Errorf("error saving session: %v", err)
} }
return true, nil return nil
} }
// validateSession checks whether the session has expired and performs // validateSession checks whether the session has expired and performs
@ -161,7 +172,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
return errors.New("session is expired") return errors.New("session is expired")
} }
if !s.validateSessionState(ctx, session) { if !s.sessionValidator(ctx, session) {
return errors.New("session is invalid") return errors.New("session is invalid")
} }

View File

@ -10,6 +10,8 @@ import (
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -17,15 +19,17 @@ import (
var _ = Describe("Stored Session Suite", func() { var _ = Describe("Stored Session Suite", func() {
const ( const (
refresh = "Refresh" refresh = "Refresh"
noRefresh = "NoRefresh" noRefresh = "NoRefresh"
notImplemented = "NotImplemented"
) )
var ctx = context.Background() var ctx = context.Background()
Context("StoredSessionLoader", func() { Context("StoredSessionLoader", func() {
createdPast := time.Now().Add(-5 * time.Minute) now := time.Now()
createdFuture := time.Now().Add(5 * time.Minute) createdPast := now.Add(-5 * time.Minute)
createdFuture := now.Add(5 * time.Minute)
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken { switch ss.RefreshToken {
@ -85,6 +89,14 @@ var _ = Describe("Stored Session Suite", func() {
}, },
} }
BeforeEach(func() {
clock.Set(now)
})
AfterEach(func() {
clock.Reset()
})
type storedSessionLoaderTableInput struct { type storedSessionLoaderTableInput struct {
requestHeaders http.Header requestHeaders http.Header
existingSession *sessionsapi.SessionState existingSession *sessionsapi.SessionState
@ -109,10 +121,10 @@ var _ = Describe("Stored Session Suite", func() {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
opts := &StoredSessionLoaderOptions{ opts := &StoredSessionLoaderOptions{
SessionStore: in.store, SessionStore: in.store,
RefreshPeriod: in.refreshPeriod, RefreshPeriod: in.refreshPeriod,
RefreshSessionIfNeeded: in.refreshSession, RefreshSession: in.refreshSession,
ValidateSessionState: in.validateSession, ValidateSession: in.validateSession,
} }
// Create the handler with a next handler that will capture the session // Create the handler with a next handler that will capture the session
@ -208,6 +220,21 @@ var _ = Describe("Stored Session Suite", func() {
existingSession: nil, existingSession: nil,
expectedSession: &sessionsapi.SessionState{ expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed", RefreshToken: "Refreshed",
CreatedAt: &now,
ExpiresOn: &createdFuture,
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "RefreshError",
CreatedAt: &createdPast, CreatedAt: &createdPast,
ExpiresOn: &createdFuture, ExpiresOn: &createdFuture,
}, },
@ -216,7 +243,7 @@ var _ = Describe("Stored Session Suite", func() {
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, validateSession: defaultValidateFunc,
}), }),
Entry("when the provider refresh fails", storedSessionLoaderTableInput{ Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"}, "Cookie": []string{"_oauth2_proxy=RefreshError"},
}, },
@ -225,7 +252,7 @@ var _ = Describe("Stored Session Suite", func() {
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false },
}), }),
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
@ -261,18 +288,20 @@ var _ = Describe("Stored Session Suite", func() {
s := &storedSessionLoader{ s := &storedSessionLoader{
refreshPeriod: in.refreshPeriod, refreshPeriod: in.refreshPeriod,
store: &fakeSessionStore{}, store: &fakeSessionStore{},
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
refreshed = true refreshed = true
switch ss.RefreshToken { switch ss.RefreshToken {
case refresh: case refresh:
return true, nil return true, nil
case noRefresh: case noRefresh:
return false, nil return false, nil
case notImplemented:
return false, providers.ErrNotImplemented
default: default:
return false, errors.New("error refreshing session") return false, errors.New("error refreshing session")
} }
}, },
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
validated = true validated = true
return ss.AccessToken != "Invalid" return ss.AccessToken != "Invalid"
}, },
@ -326,7 +355,7 @@ var _ = Describe("Stored Session Suite", func() {
}, },
expectedErr: nil, expectedErr: nil,
expectRefreshed: true, expectRefreshed: true,
expectValidated: false, expectValidated: true,
}), }),
Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
@ -339,15 +368,25 @@ var _ = Describe("Stored Session Suite", func() {
expectRefreshed: true, expectRefreshed: true,
expectValidated: true, expectValidated: true,
}), }),
Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{
RefreshToken: notImplemented,
CreatedAt: &createdPast,
},
expectedErr: nil,
expectRefreshed: true,
expectValidated: true,
}),
Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{ session: &sessionsapi.SessionState{
RefreshToken: "RefreshError", RefreshToken: "RefreshError",
CreatedAt: &createdPast, CreatedAt: &createdPast,
}, },
expectedErr: errors.New("error refreshing access token: error refreshing session"), expectedErr: nil,
expectRefreshed: true, expectRefreshed: true,
expectValidated: false, expectValidated: true,
}), }),
Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
@ -364,12 +403,11 @@ var _ = Describe("Stored Session Suite", func() {
) )
}) })
Context("refreshSessionWithProvider", func() { Context("refreshSession", func() {
type refreshSessionWithProviderTableInput struct { type refreshSessionWithProviderTableInput struct {
session *sessionsapi.SessionState session *sessionsapi.SessionState
expectedErr error expectedErr error
expectRefreshed bool expectSaved bool
expectSaved bool
} }
now := time.Now() now := time.Now()
@ -388,12 +426,14 @@ var _ = Describe("Stored Session Suite", func() {
return nil return nil
}, },
}, },
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken { switch ss.RefreshToken {
case refresh: case refresh:
return true, nil return true, nil
case noRefresh: case noRefresh:
return false, nil return false, nil
case notImplemented:
return false, providers.ErrNotImplemented
default: default:
return false, errors.New("error refreshing session") return false, errors.New("error refreshing session")
} }
@ -402,30 +442,34 @@ var _ = Describe("Stored Session Suite", func() {
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) err := s.refreshSession(nil, req, in.session)
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
} else { } else {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
Expect(refreshed).To(Equal(in.expectRefreshed))
Expect(saved).To(Equal(in.expectSaved)) Expect(saved).To(Equal(in.expectSaved))
}, },
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{ session: &sessionsapi.SessionState{
RefreshToken: noRefresh, RefreshToken: noRefresh,
}, },
expectedErr: nil, expectedErr: nil,
expectRefreshed: false, expectSaved: false,
expectSaved: false,
}), }),
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{ session: &sessionsapi.SessionState{
RefreshToken: refresh, RefreshToken: refresh,
}, },
expectedErr: nil, expectedErr: nil,
expectRefreshed: true, expectSaved: true,
expectSaved: true, }),
Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: notImplemented,
},
expectedErr: nil,
expectSaved: true,
}), }),
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{ session: &sessionsapi.SessionState{
@ -433,18 +477,16 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &now, CreatedAt: &now,
ExpiresOn: &now, ExpiresOn: &now,
}, },
expectedErr: errors.New("error refreshing access token: error refreshing session"), expectedErr: errors.New("error refreshing tokens: error refreshing session"),
expectRefreshed: false, expectSaved: false,
expectSaved: false,
}), }),
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{ session: &sessionsapi.SessionState{
RefreshToken: refresh, RefreshToken: refresh,
AccessToken: "NoSave", AccessToken: "NoSave",
}, },
expectedErr: errors.New("error saving session: unable to save session"), expectedErr: errors.New("error saving session: unable to save session"),
expectRefreshed: false, expectSaved: true,
expectSaved: true,
}), }),
) )
}) })
@ -454,7 +496,7 @@ var _ = Describe("Stored Session Suite", func() {
BeforeEach(func() { BeforeEach(func() {
s = &storedSessionLoader{ s = &storedSessionLoader{
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
return ss.AccessToken == "Valid" return ss.AccessToken == "Valid"
}, },
} }

View File

@ -36,8 +36,7 @@ type SessionStore struct {
// within Cookies set on the HTTP response writer // within Cookies set on the HTTP response writer
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
now := time.Now() ss.CreatedAtNow()
ss.CreatedAt = &now
} }
value, err := s.cookieForSession(ss) value, err := s.cookieForSession(ss)
if err != nil { if err != nil {

View File

@ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager {
// from the persistent data store. // from the persistent data store.
func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
if s.CreatedAt == nil || s.CreatedAt.IsZero() { if s.CreatedAt == nil || s.CreatedAt.IsZero() {
now := time.Now() s.CreatedAtNow()
s.CreatedAt = &now
} }
tckt, err := decodeTicketFromRequest(req, m.Options) tckt, err := decodeTicketFromRequest(req, m.Options)

View File

@ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
return nil, err return nil, err
} }
created := time.Now()
expires := time.Unix(jsonResponse.ExpiresOn, 0)
session := &sessions.SessionState{ session := &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken, RefreshToken: jsonResponse.RefreshToken,
} }
session.CreatedAtNow()
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
@ -239,28 +236,29 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st
return email, nil return email, nil
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
// RefreshToken to fetch a new ID token if required func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.RefreshToken == "" {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }
origExpiration := s.ExpiresOn
err := p.redeemRefreshToken(ctx, s) err := p.redeemRefreshToken(ctx, s)
if err != nil { if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err) return false, fmt.Errorf("unable to redeem refresh token: %v", err)
} }
logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
return true, nil return true, nil
} }
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
clientSecret, err := p.GetClientSecret()
if err != nil {
return err
}
params := url.Values{} params := url.Values{}
params.Add("client_id", p.ClientID) params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret) params.Add("client_secret", clientSecret)
params.Add("refresh_token", s.RefreshToken) params.Add("refresh_token", s.RefreshToken)
params.Add("grant_type", "refresh_token") params.Add("grant_type", "refresh_token")
@ -278,18 +276,16 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
SetHeader("Content-Type", "application/x-www-form-urlencoded"). SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do(). Do().
UnmarshalInto(&jsonResponse) UnmarshalInto(&jsonResponse)
if err != nil { if err != nil {
return return err
} }
now := time.Now()
expires := time.Unix(jsonResponse.ExpiresOn, 0)
s.AccessToken = jsonResponse.AccessToken s.AccessToken = jsonResponse.AccessToken
s.IDToken = jsonResponse.IDToken s.IDToken = jsonResponse.IDToken
s.RefreshToken = jsonResponse.RefreshToken s.RefreshToken = jsonResponse.RefreshToken
s.CreatedAt = &now
s.ExpiresOn = &expires s.CreatedAtNow()
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
@ -312,7 +308,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
} }
} }
return return nil
} }
func makeAzureHeader(accessToken string) http.Header { func makeAzureHeader(accessToken string) http.Header {

View File

@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
} }
func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { func TestAzureProviderRefresh(t *testing.T) {
p := testAzureProvider("")
expires := time.Now().Add(time.Duration(1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
assert.Equal(t, nil, err)
assert.False(t, refreshNeeded)
}
func TestAzureProviderRefreshWhenExpired(t *testing.T) {
email := "foo@example.com" email := "foo@example.com"
idToken := idTokenClaims{Email: email} idToken := idTokenClaims{Email: email}
idTokenString, err := newSignedTestIDToken(idToken) idTokenString, err := newSignedTestIDToken(idToken)
@ -373,9 +363,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
expires := time.Now().Add(time.Duration(-1) * time.Hour) expires := time.Now().Add(time.Duration(-1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
refreshed, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.True(t, refreshNeeded) assert.True(t, refreshed)
assert.NotEqual(t, session, nil) assert.NotEqual(t, session, nil)
assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, "new_some_access_token", session.AccessToken)
assert.Equal(t, "new_some_refresh_token", session.RefreshToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken)

View File

@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
return r.Email, nil return r.Email, nil
} }
// ValidateSessionState validates the AccessToken // ValidateSession validates the AccessToken
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
} }

View File

@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() {
} }
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
// RefreshToken to fetch a new ID token if required func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil return false, nil
} }
@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return true, nil return true, nil
} }
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return err
} }
c := oauth2.Config{ c := oauth2.Config{
@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
if err != nil { if err != nil {
return fmt.Errorf("unable to update session: %v", err) return fmt.Errorf("unable to update session: %v", err)
} }
s.AccessToken = newSession.AccessToken *s = *newSession
s.IDToken = newSession.IDToken
s.RefreshToken = newSession.RefreshToken return nil
s.CreatedAt = newSession.CreatedAt
s.ExpiresOn = newSession.ExpiresOn
s.Email = newSession.Email
return
} }
type gitlabUserInfo struct { type gitlabUserInfo struct {
@ -264,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token)
} }
} }
created := time.Now() ss := &sessions.SessionState{
return &sessions.SessionState{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
IDToken: getIDToken(token), IDToken: getIDToken(token),
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
CreatedAt: &created, }
ExpiresOn: &idToken.Expiry,
}, nil ss.CreatedAtNow()
ss.SetExpiresOn(idToken.Expiry)
return ss, nil
} }
// ValidateSession checks that the session's IDToken is still valid // ValidateSession checks that the session's IDToken is still valid

View File

@ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
return nil, err return nil, err
} }
created := time.Now() ss := &sessions.SessionState{
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken, RefreshToken: jsonResponse.RefreshToken,
Email: c.Email, Email: c.Email,
User: c.Subject, User: c.Subject,
}, nil }
ss.CreatedAtNow()
ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
return ss, nil
} }
// EnrichSession checks the listed Google Groups configured and adds any // EnrichSession checks the listed Google Groups configured and adds any
// that the user is a member of to session.Groups. // that the user is a member of to session.Groups.
func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error {
// TODO (@NickMeves) - Move to pure EnrichSession logic and stop // TODO (@NickMeves) - Move to pure EnrichSession logic and stop
// reusing legacy `groupValidator`. // reusing legacy `groupValidator`.
// //
@ -266,14 +265,13 @@ func userInGroup(service *admin.Service, group string, email string) bool {
return false return false
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
// RefreshToken to fetch a new ID token if required func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil return false, nil
} }
newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken) err := p.redeemRefreshToken(ctx, s)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -286,26 +284,20 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
} }
origExpiration := s.ExpiresOn
expires := time.Now().Add(duration).Truncate(time.Second)
s.AccessToken = newToken
s.IDToken = newIDToken
s.ExpiresOn = &expires
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil return true, nil
} }
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return err
} }
params := url.Values{} params := url.Values{}
params.Add("client_id", p.ClientID) params.Add("client_id", p.ClientID)
params.Add("client_secret", clientSecret) params.Add("client_secret", clientSecret)
params.Add("refresh_token", refreshToken) params.Add("refresh_token", s.RefreshToken)
params.Add("grant_type", "refresh_token") params.Add("grant_type", "refresh_token")
var data struct { var data struct {
@ -322,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
Do(). Do().
UnmarshalInto(&data) UnmarshalInto(&data)
if err != nil { if err != nil {
return "", "", 0, err return err
} }
token = data.AccessToken s.AccessToken = data.AccessToken
idToken = data.IDToken s.IDToken = data.IDToken
expires = time.Duration(data.ExpiresIn) * time.Second
return s.CreatedAtNow()
s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second)
return nil
} }

View File

@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
return email, nil return email, nil
} }
// ValidateSessionState validates the AccessToken // ValidateSession validates the AccessToken
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
} }

View File

@ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) {
if code == "" { if code == "" {
return nil, ErrMissingCode return nil, ErrMissingCode
} }
@ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
return nil, err return nil, err
} }
created := time.Now() session := &sessions.SessionState{
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
// Store the data that we found in the session state
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
Email: email, Email: email,
}, nil }
session.CreatedAtNow()
session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
return session, nil
} }
// GetLoginURL overrides GetLoginURL to add login.gov parameters // GetLoginURL overrides GetLoginURL to add login.gov parameters

View File

@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
return true return true
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
// RefreshToken to fetch a new Access Token (and optional ID token) if required func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil return false, nil
} }
@ -155,7 +154,6 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
return false, fmt.Errorf("unable to redeem refresh token: %v", err) return false, fmt.Errorf("unable to redeem refresh token: %v", err)
} }
logger.Printf("refreshed session: %s", s)
return true, nil return true, nil
} }
@ -227,7 +225,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
ss.AccessToken = token ss.AccessToken = token
ss.IDToken = token ss.IDToken = token
ss.RefreshToken = "" ss.RefreshToken = ""
ss.ExpiresOn = &idToken.Expiry
ss.CreatedAtNow()
ss.SetExpiresOn(idToken.Expiry)
return ss, nil return ss, nil
} }
@ -257,9 +257,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
ss.RefreshToken = token.RefreshToken ss.RefreshToken = token.RefreshToken
ss.IDToken = getIDToken(token) ss.IDToken = getIDToken(token)
created := time.Now() ss.CreatedAtNow()
ss.CreatedAt = &created ss.SetExpiresOn(token.Expiry)
ss.ExpiresOn = &token.Expiry
return ss, nil return ss, nil
} }

View File

@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
User: "11223344", User: "11223344",
} }
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, "janedoe@example.com", existingSession.Email) assert.Equal(t, "janedoe@example.com", existingSession.Email)
@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
Email: "changeit", Email: "changeit",
User: "changeit", User: "changeit",
} }
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, defaultIDToken.Email, existingSession.Email) assert.Equal(t, defaultIDToken.Email, existingSession.Email)

View File

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
@ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration
if token := values.Get("access_token"); token != "" { if token := values.Get("access_token"); token != "" {
created := time.Now() ss := &sessions.SessionState{
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil AccessToken: token,
}
ss.CreatedAtNow()
return ss, nil
} }
return nil, fmt.Errorf("no access token found %s", result.Body()) return nil, fmt.Errorf("no access token found %s", result.Body())
@ -126,10 +129,9 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
return validateToken(ctx, p, s.AccessToken, nil) return validateToken(ctx, p, s.AccessToken, nil)
} }
// RefreshSessionIfNeeded should refresh the user's session if required and // RefreshSession refreshes the user's session
// do nothing if a refresh is not required func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { return false, ErrNotImplemented
return false, nil
} }
// CreateSessionFromToken converts Bearer IDTokens into sessions // CreateSessionFromToken converts Bearer IDTokens into sessions

View File

@ -14,12 +14,20 @@ import (
func TestRefresh(t *testing.T) { func TestRefresh(t *testing.T) {
p := &ProviderData{} p := &ProviderData{}
expires := time.Now().Add(time.Duration(-11) * time.Minute) now := time.Unix(1234567890, 10)
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ expires := time.Unix(1234567890, 0)
ExpiresOn: &expires,
}) ss := &sessions.SessionState{}
assert.Equal(t, false, refreshed) ss.Clock.Set(now)
assert.Equal(t, nil, err) ss.SetExpiresOn(expires)
refreshed, err := p.RefreshSession(context.Background(), ss)
assert.False(t, refreshed)
assert.Equal(t, ErrNotImplemented, err)
refreshed, err = p.RefreshSession(context.Background(), nil)
assert.False(t, refreshed)
assert.Equal(t, ErrNotImplemented, err)
} }
func TestAcrValuesNotConfigured(t *testing.T) { func TestAcrValuesNotConfigured(t *testing.T) {

View File

@ -9,14 +9,14 @@ import (
// Provider represents an upstream identity provider implementation // Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetLoginURL(redirectURI, finalRedirect string, nonce string) string
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
// Deprecated: Migrate to EnrichSession // Deprecated: Migrate to EnrichSession
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
GetLoginURL(redirectURI, state, nonce string) string
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
EnrichSession(ctx context.Context, s *sessions.SessionState) error EnrichSession(ctx context.Context, s *sessions.SessionState) error
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSession(ctx context.Context, s *sessions.SessionState) bool ValidateSession(ctx context.Context, s *sessions.SessionState) bool
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
} }