mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-05-23 22:40:31 +02:00
Merge pull request #1086 from oauth2-proxy/early-refresh
Convert RefreshSessionIfNeeded into RefreshSession
This commit is contained in:
commit
16a9893a19
@ -4,10 +4,16 @@
|
||||
|
||||
## Important Notes
|
||||
|
||||
- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session
|
||||
deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade
|
||||
to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret`
|
||||
value and force all sessions to reauthenticate.
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
## Changes since v7.1.3
|
||||
|
||||
- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves)
|
||||
- [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed)
|
||||
- [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed)
|
||||
- [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi)
|
||||
|
@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
|
||||
}
|
||||
|
||||
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||
SessionStore: sessionStore,
|
||||
RefreshPeriod: opts.Cookie.Refresh,
|
||||
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
|
||||
ValidateSessionState: opts.GetProvider().ValidateSession,
|
||||
SessionStore: sessionStore,
|
||||
RefreshPeriod: opts.Cookie.Refresh,
|
||||
RefreshSession: opts.GetProvider().RefreshSession,
|
||||
ValidateSession: opts.GetProvider().ValidateSession,
|
||||
}))
|
||||
|
||||
return chain
|
||||
@ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Force setting these in case the Provider didn't
|
||||
if s.CreatedAt == nil {
|
||||
s.CreatedAtNow()
|
||||
}
|
||||
if s.ExpiresOn == nil {
|
||||
s.ExpiresIn(p.CookieOptions.Expire)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -3,14 +3,12 @@ package sessions
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||
"github.com/pierrec/lz4"
|
||||
"github.com/vmihailenco/msgpack/v4"
|
||||
@ -32,7 +30,9 @@ type SessionState struct {
|
||||
Groups []string `msgpack:"g,omitempty"`
|
||||
PreferredUsername string `msgpack:"pu,omitempty"`
|
||||
|
||||
Lock Lock `msgpack:"-"`
|
||||
// Internal helpers, not serialized
|
||||
Clock clock.Clock `msgpack:"-"`
|
||||
Lock Lock `msgpack:"-"`
|
||||
}
|
||||
|
||||
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
|
||||
@ -63,9 +63,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
|
||||
return s.Lock.Peek(ctx)
|
||||
}
|
||||
|
||||
// CreatedAtNow sets a SessionState's CreatedAt to now
|
||||
func (s *SessionState) CreatedAtNow() {
|
||||
now := s.Clock.Now()
|
||||
s.CreatedAt = &now
|
||||
}
|
||||
|
||||
// SetExpiresOn sets an expiration
|
||||
func (s *SessionState) SetExpiresOn(exp time.Time) {
|
||||
s.ExpiresOn = &exp
|
||||
}
|
||||
|
||||
// ExpiresIn sets an expiration a certain duration from CreatedAt.
|
||||
// CreatedAt will be set to time.Now if it is unset.
|
||||
func (s *SessionState) ExpiresIn(d time.Duration) {
|
||||
if s.CreatedAt == nil {
|
||||
s.CreatedAtNow()
|
||||
}
|
||||
exp := s.CreatedAt.Add(d)
|
||||
s.ExpiresOn = &exp
|
||||
}
|
||||
|
||||
// IsExpired checks whether the session has expired
|
||||
func (s *SessionState) IsExpired() bool {
|
||||
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
|
||||
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@ -74,7 +95,7 @@ func (s *SessionState) IsExpired() bool {
|
||||
// Age returns the age of a session
|
||||
func (s *SessionState) Age() time.Duration {
|
||||
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
||||
return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
|
||||
return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@ -177,11 +198,6 @@ func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*Ses
|
||||
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
|
||||
}
|
||||
|
||||
err = ss.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ss, nil
|
||||
}
|
||||
|
||||
@ -235,35 +251,3 @@ func lz4Decompress(compressed []byte) ([]byte, error) {
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// validate ensures the decoded session is non-empty and contains valid data
|
||||
//
|
||||
// Non-empty check is needed due to ensure the non-authenticated AES-CFB
|
||||
// decryption doesn't result in garbage data that collides with a valid
|
||||
// MessagePack header bytes (which MessagePack will unmarshal to an empty
|
||||
// default SessionState). <1% chance, but observed with random test data.
|
||||
//
|
||||
// UTF-8 check ensures the strings are valid and not raw bytes overloaded
|
||||
// into Latin-1 encoding. The occurs when legacy unencrypted fields are
|
||||
// decrypted with AES-CFB which results in random bytes.
|
||||
func (s *SessionState) validate() error {
|
||||
for _, field := range []string{
|
||||
s.User,
|
||||
s.Email,
|
||||
s.PreferredUsername,
|
||||
s.AccessToken,
|
||||
s.IDToken,
|
||||
s.RefreshToken,
|
||||
} {
|
||||
if !utf8.ValidString(field) {
|
||||
return errors.New("invalid non-UTF8 field in session")
|
||||
}
|
||||
}
|
||||
|
||||
empty := new(SessionState)
|
||||
if reflect.DeepEqual(*s, *empty) {
|
||||
return errors.New("invalid empty session unmarshalled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -16,6 +16,30 @@ func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
func TestCreatedAtNow(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
ss := &SessionState{}
|
||||
|
||||
now := time.Unix(1234567890, 0)
|
||||
ss.Clock.Set(now)
|
||||
|
||||
ss.CreatedAtNow()
|
||||
g.Expect(*ss.CreatedAt).To(Equal(now))
|
||||
}
|
||||
|
||||
func TestExpiresIn(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
ss := &SessionState{}
|
||||
|
||||
now := time.Unix(1234567890, 0)
|
||||
ss.Clock.Set(now)
|
||||
|
||||
ttl := time.Duration(743) * time.Second
|
||||
ss.ExpiresIn(ttl)
|
||||
|
||||
g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl)))
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z")
|
||||
|
@ -63,13 +63,10 @@ func Reset() *clockapi.Mock {
|
||||
// package.
|
||||
type Clock struct {
|
||||
mock *clockapi.Mock
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// Set sets the Clock to a clock.Mock at the given time.Time
|
||||
func (c *Clock) Set(t time.Time) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.mock == nil {
|
||||
c.mock = clockapi.NewMock()
|
||||
}
|
||||
@ -79,8 +76,6 @@ func (c *Clock) Set(t time.Time) {
|
||||
// Add moves clock forward time.Duration if it is mocked. It will error
|
||||
// if the clock is not mocked.
|
||||
func (c *Clock) Add(d time.Duration) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.mock == nil {
|
||||
return errors.New("clock not mocked")
|
||||
}
|
||||
@ -91,8 +86,6 @@ func (c *Clock) Add(d time.Duration) error {
|
||||
// Reset removes local clock.Mock. Returns any existing Mock if set in case
|
||||
// lingering time operations are attached to it.
|
||||
func (c *Clock) Reset() *clockapi.Mock {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
existing := c.mock
|
||||
c.mock = nil
|
||||
return existing
|
||||
|
@ -11,25 +11,26 @@ import (
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||
)
|
||||
|
||||
// StoredSessionLoaderOptions cotnains all of the requirements to construct
|
||||
// StoredSessionLoaderOptions contains all of the requirements to construct
|
||||
// a stored session loader.
|
||||
// All options must be provided.
|
||||
type StoredSessionLoaderOptions struct {
|
||||
// Session storage basckend
|
||||
// Session storage backend
|
||||
SessionStore sessionsapi.SessionStore
|
||||
|
||||
// How often should sessions be refreshed
|
||||
RefreshPeriod time.Duration
|
||||
|
||||
// Provider based sesssion refreshing
|
||||
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
// Provider based session refreshing
|
||||
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
|
||||
// Provider based session validation.
|
||||
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
||||
// refresh it, we must re-validate using this validation.
|
||||
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||
}
|
||||
|
||||
// NewStoredSessionLoader creates a new storedSessionLoader which loads
|
||||
@ -38,10 +39,10 @@ type StoredSessionLoaderOptions struct {
|
||||
// If a session was loader by a previous handler, it will not be replaced.
|
||||
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
||||
ss := &storedSessionLoader{
|
||||
store: opts.SessionStore,
|
||||
refreshPeriod: opts.RefreshPeriod,
|
||||
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
|
||||
validateSessionState: opts.ValidateSessionState,
|
||||
store: opts.SessionStore,
|
||||
refreshPeriod: opts.RefreshPeriod,
|
||||
sessionRefresher: opts.RefreshSession,
|
||||
sessionValidator: opts.ValidateSession,
|
||||
}
|
||||
return ss.loadSession
|
||||
}
|
||||
@ -49,10 +50,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
|
||||
// storedSessionLoader is responsible for loading sessions from cookie
|
||||
// identified sessions in the session store.
|
||||
type storedSessionLoader struct {
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
|
||||
}
|
||||
|
||||
// loadSession attempts to load a session as identified by the request cookies.
|
||||
@ -108,49 +109,59 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
|
||||
|
||||
// refreshSessionIfNeeded will attempt to refresh a session if the session
|
||||
// is older than the refresh period.
|
||||
// It is assumed that if the provider refreshes the session, the session is now
|
||||
// valid.
|
||||
// If the session requires refreshing but the provider does not refresh it,
|
||||
// we must validate the session to ensure that the returned session is still
|
||||
// valid.
|
||||
// Success or fail, we will then validate the session.
|
||||
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
|
||||
// Refresh is disabled or the session is not old enough, do nothing
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
||||
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
|
||||
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
||||
err := s.refreshSession(rw, req, session)
|
||||
if err != nil {
|
||||
return err
|
||||
// If a preemptive refresh fails, we still keep the session
|
||||
// if validateSession succeeds.
|
||||
logger.Errorf("Unable to refresh session: %v", err)
|
||||
}
|
||||
|
||||
if !refreshed {
|
||||
// Session wasn't refreshed, so make sure it's still valid
|
||||
return s.validateSession(req.Context(), session)
|
||||
}
|
||||
return nil
|
||||
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
||||
return s.validateSession(req.Context(), session)
|
||||
}
|
||||
|
||||
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
|
||||
// refreshSession attempts to refresh the session with the provider
|
||||
// and will save the session if it was updated.
|
||||
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
||||
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error refreshing access token: %v", err)
|
||||
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||
refreshed, err := s.sessionRefresher(req.Context(), session)
|
||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||
return fmt.Errorf("error refreshing tokens: %v", err)
|
||||
}
|
||||
|
||||
if !refreshed {
|
||||
return false, nil
|
||||
// HACK:
|
||||
// Providers that don't implement `RefreshSession` use the default
|
||||
// implementation which returns `ErrNotImplemented`.
|
||||
// Pretend it refreshed to reset the refresh timer so that `ValidateSession`
|
||||
// isn't triggered every subsequent request and is only called once during
|
||||
// this request.
|
||||
if errors.Is(err, providers.ErrNotImplemented) {
|
||||
refreshed = true
|
||||
}
|
||||
|
||||
// Session not refreshed, nothing to persist.
|
||||
if !refreshed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
|
||||
// (In case underlying provider implementations forget)
|
||||
session.CreatedAtNow()
|
||||
|
||||
// Because the session was refreshed, make sure to save it
|
||||
err = s.store.Save(rw, req, session)
|
||||
if err != nil {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
||||
return false, fmt.Errorf("error saving session: %v", err)
|
||||
return fmt.Errorf("error saving session: %v", err)
|
||||
}
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateSession checks whether the session has expired and performs
|
||||
@ -161,7 +172,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
|
||||
return errors.New("session is expired")
|
||||
}
|
||||
|
||||
if !s.validateSessionState(ctx, session) {
|
||||
if !s.sessionValidator(ctx, session) {
|
||||
return errors.New("session is invalid")
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
@ -17,15 +19,17 @@ import (
|
||||
|
||||
var _ = Describe("Stored Session Suite", func() {
|
||||
const (
|
||||
refresh = "Refresh"
|
||||
noRefresh = "NoRefresh"
|
||||
refresh = "Refresh"
|
||||
noRefresh = "NoRefresh"
|
||||
notImplemented = "NotImplemented"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
Context("StoredSessionLoader", func() {
|
||||
createdPast := time.Now().Add(-5 * time.Minute)
|
||||
createdFuture := time.Now().Add(5 * time.Minute)
|
||||
now := time.Now()
|
||||
createdPast := now.Add(-5 * time.Minute)
|
||||
createdFuture := now.Add(5 * time.Minute)
|
||||
|
||||
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
switch ss.RefreshToken {
|
||||
@ -85,6 +89,14 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
},
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
clock.Set(now)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
clock.Reset()
|
||||
})
|
||||
|
||||
type storedSessionLoaderTableInput struct {
|
||||
requestHeaders http.Header
|
||||
existingSession *sessionsapi.SessionState
|
||||
@ -109,10 +121,10 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
opts := &StoredSessionLoaderOptions{
|
||||
SessionStore: in.store,
|
||||
RefreshPeriod: in.refreshPeriod,
|
||||
RefreshSessionIfNeeded: in.refreshSession,
|
||||
ValidateSessionState: in.validateSession,
|
||||
SessionStore: in.store,
|
||||
RefreshPeriod: in.refreshPeriod,
|
||||
RefreshSession: in.refreshSession,
|
||||
ValidateSession: in.validateSession,
|
||||
}
|
||||
|
||||
// Create the handler with a next handler that will capture the session
|
||||
@ -208,6 +220,21 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
existingSession: nil,
|
||||
expectedSession: &sessionsapi.SessionState{
|
||||
RefreshToken: "Refreshed",
|
||||
CreatedAt: &now,
|
||||
ExpiresOn: &createdFuture,
|
||||
},
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
||||
},
|
||||
existingSession: nil,
|
||||
expectedSession: &sessionsapi.SessionState{
|
||||
RefreshToken: "RefreshError",
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
},
|
||||
@ -216,7 +243,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
|
||||
Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
||||
},
|
||||
@ -225,7 +252,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false },
|
||||
}),
|
||||
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
@ -261,18 +288,20 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
s := &storedSessionLoader{
|
||||
refreshPeriod: in.refreshPeriod,
|
||||
store: &fakeSessionStore{},
|
||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
refreshed = true
|
||||
switch ss.RefreshToken {
|
||||
case refresh:
|
||||
return true, nil
|
||||
case noRefresh:
|
||||
return false, nil
|
||||
case notImplemented:
|
||||
return false, providers.ErrNotImplemented
|
||||
default:
|
||||
return false, errors.New("error refreshing session")
|
||||
}
|
||||
},
|
||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||
sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||
validated = true
|
||||
return ss.AccessToken != "Invalid"
|
||||
},
|
||||
@ -326,7 +355,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectRefreshed: true,
|
||||
expectValidated: false,
|
||||
expectValidated: true,
|
||||
}),
|
||||
Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
@ -339,15 +368,25 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
expectRefreshed: true,
|
||||
expectValidated: true,
|
||||
}),
|
||||
Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{
|
||||
Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: notImplemented,
|
||||
CreatedAt: &createdPast,
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectRefreshed: true,
|
||||
expectValidated: true,
|
||||
}),
|
||||
Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: "RefreshError",
|
||||
CreatedAt: &createdPast,
|
||||
},
|
||||
expectedErr: errors.New("error refreshing access token: error refreshing session"),
|
||||
expectedErr: nil,
|
||||
expectRefreshed: true,
|
||||
expectValidated: false,
|
||||
expectValidated: true,
|
||||
}),
|
||||
Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
@ -364,12 +403,11 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
)
|
||||
})
|
||||
|
||||
Context("refreshSessionWithProvider", func() {
|
||||
Context("refreshSession", func() {
|
||||
type refreshSessionWithProviderTableInput struct {
|
||||
session *sessionsapi.SessionState
|
||||
expectedErr error
|
||||
expectRefreshed bool
|
||||
expectSaved bool
|
||||
session *sessionsapi.SessionState
|
||||
expectedErr error
|
||||
expectSaved bool
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
@ -388,12 +426,14 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
switch ss.RefreshToken {
|
||||
case refresh:
|
||||
return true, nil
|
||||
case noRefresh:
|
||||
return false, nil
|
||||
case notImplemented:
|
||||
return false, providers.ErrNotImplemented
|
||||
default:
|
||||
return false, errors.New("error refreshing session")
|
||||
}
|
||||
@ -402,30 +442,34 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
|
||||
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session)
|
||||
err := s.refreshSession(nil, req, in.session)
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(in.expectedErr))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(refreshed).To(Equal(in.expectRefreshed))
|
||||
Expect(saved).To(Equal(in.expectSaved))
|
||||
},
|
||||
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: noRefresh,
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectRefreshed: false,
|
||||
expectSaved: false,
|
||||
expectedErr: nil,
|
||||
expectSaved: false,
|
||||
}),
|
||||
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: refresh,
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectRefreshed: true,
|
||||
expectSaved: true,
|
||||
expectedErr: nil,
|
||||
expectSaved: true,
|
||||
}),
|
||||
Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: notImplemented,
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectSaved: true,
|
||||
}),
|
||||
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
@ -433,18 +477,16 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
CreatedAt: &now,
|
||||
ExpiresOn: &now,
|
||||
},
|
||||
expectedErr: errors.New("error refreshing access token: error refreshing session"),
|
||||
expectRefreshed: false,
|
||||
expectSaved: false,
|
||||
expectedErr: errors.New("error refreshing tokens: error refreshing session"),
|
||||
expectSaved: false,
|
||||
}),
|
||||
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: refresh,
|
||||
AccessToken: "NoSave",
|
||||
},
|
||||
expectedErr: errors.New("error saving session: unable to save session"),
|
||||
expectRefreshed: false,
|
||||
expectSaved: true,
|
||||
expectedErr: errors.New("error saving session: unable to save session"),
|
||||
expectSaved: true,
|
||||
}),
|
||||
)
|
||||
})
|
||||
@ -454,7 +496,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
s = &storedSessionLoader{
|
||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||
sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||
return ss.AccessToken == "Valid"
|
||||
},
|
||||
}
|
||||
|
@ -36,8 +36,7 @@ type SessionStore struct {
|
||||
// within Cookies set on the HTTP response writer
|
||||
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
||||
if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
|
||||
now := time.Now()
|
||||
ss.CreatedAt = &now
|
||||
ss.CreatedAtNow()
|
||||
}
|
||||
value, err := s.cookieForSession(ss)
|
||||
if err != nil {
|
||||
|
@ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager {
|
||||
// from the persistent data store.
|
||||
func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||
if s.CreatedAt == nil || s.CreatedAt.IsZero() {
|
||||
now := time.Now()
|
||||
s.CreatedAt = &now
|
||||
s.CreatedAtNow()
|
||||
}
|
||||
|
||||
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||
|
@ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
||||
|
||||
session := &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
}
|
||||
session.CreatedAtNow()
|
||||
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
|
||||
|
||||
@ -239,28 +236,29 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
origExpiration := s.ExpiresOn
|
||||
|
||||
err := p.redeemRefreshToken(ctx, s)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("refresh_token", s.RefreshToken)
|
||||
params.Add("grant_type", "refresh_token")
|
||||
|
||||
@ -278,18 +276,16 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
||||
s.AccessToken = jsonResponse.AccessToken
|
||||
s.IDToken = jsonResponse.IDToken
|
||||
s.RefreshToken = jsonResponse.RefreshToken
|
||||
s.CreatedAt = &now
|
||||
s.ExpiresOn = &expires
|
||||
|
||||
s.CreatedAtNow()
|
||||
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
|
||||
|
||||
@ -312,7 +308,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeAzureHeader(accessToken string) http.Header {
|
||||
|
@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
|
||||
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
|
||||
}
|
||||
|
||||
func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
||||
p := testAzureProvider("")
|
||||
|
||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.False(t, refreshNeeded)
|
||||
}
|
||||
|
||||
func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
||||
func TestAzureProviderRefresh(t *testing.T) {
|
||||
email := "foo@example.com"
|
||||
idToken := idTokenClaims{Email: email}
|
||||
idTokenString, err := newSignedTestIDToken(idToken)
|
||||
@ -373,9 +363,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
||||
|
||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.True(t, refreshNeeded)
|
||||
assert.True(t, refreshed)
|
||||
assert.NotEqual(t, session, nil)
|
||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
||||
|
@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
return r.Email, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
// ValidateSession validates the AccessToken
|
||||
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
|
||||
}
|
||||
|
@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() {
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
c := oauth2.Config{
|
||||
@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to update session: %v", err)
|
||||
}
|
||||
s.AccessToken = newSession.AccessToken
|
||||
s.IDToken = newSession.IDToken
|
||||
s.RefreshToken = newSession.RefreshToken
|
||||
s.CreatedAt = newSession.CreatedAt
|
||||
s.ExpiresOn = newSession.ExpiresOn
|
||||
s.Email = newSession.Email
|
||||
return
|
||||
*s = *newSession
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type gitlabUserInfo struct {
|
||||
@ -264,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token)
|
||||
}
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
return &sessions.SessionState{
|
||||
ss := &sessions.SessionState{
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: getIDToken(token),
|
||||
RefreshToken: token.RefreshToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &idToken.Expiry,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ss.CreatedAtNow()
|
||||
ss.SetExpiresOn(idToken.Expiry)
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// ValidateSession checks that the session's IDToken is still valid
|
||||
|
@ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||
|
||||
return &sessions.SessionState{
|
||||
ss := &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
Email: c.Email,
|
||||
User: c.Subject,
|
||||
}, nil
|
||||
}
|
||||
ss.CreatedAtNow()
|
||||
ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// EnrichSession checks the listed Google Groups configured and adds any
|
||||
// that the user is a member of to session.Groups.
|
||||
func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||
func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error {
|
||||
// TODO (@NickMeves) - Move to pure EnrichSession logic and stop
|
||||
// reusing legacy `groupValidator`.
|
||||
//
|
||||
@ -266,14 +265,13 @@ func userInGroup(service *admin.Service, group string, email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken)
|
||||
err := p.redeemRefreshToken(ctx, s)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -286,26 +284,20 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
||||
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
||||
}
|
||||
|
||||
origExpiration := s.ExpiresOn
|
||||
expires := time.Now().Add(duration).Truncate(time.Second)
|
||||
s.AccessToken = newToken
|
||||
s.IDToken = newIDToken
|
||||
s.ExpiresOn = &expires
|
||||
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
||||
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("refresh_token", refreshToken)
|
||||
params.Add("refresh_token", s.RefreshToken)
|
||||
params.Add("grant_type", "refresh_token")
|
||||
|
||||
var data struct {
|
||||
@ -322,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
|
||||
Do().
|
||||
UnmarshalInto(&data)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return err
|
||||
}
|
||||
|
||||
token = data.AccessToken
|
||||
idToken = data.IDToken
|
||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||
return
|
||||
s.AccessToken = data.AccessToken
|
||||
s.IDToken = data.IDToken
|
||||
|
||||
s.CreatedAtNow()
|
||||
s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
// ValidateSession validates the AccessToken
|
||||
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
@ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||
|
||||
// Store the data that we found in the session state
|
||||
return &sessions.SessionState{
|
||||
session := &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
Email: email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
session.CreatedAtNow()
|
||||
session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
||||
|
@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
return true
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -155,7 +154,6 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("refreshed session: %s", s)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@ -227,7 +225,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
|
||||
ss.AccessToken = token
|
||||
ss.IDToken = token
|
||||
ss.RefreshToken = ""
|
||||
ss.ExpiresOn = &idToken.Expiry
|
||||
|
||||
ss.CreatedAtNow()
|
||||
ss.SetExpiresOn(idToken.Expiry)
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
@ -257,9 +257,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
|
||||
ss.RefreshToken = token.RefreshToken
|
||||
ss.IDToken = getIDToken(token)
|
||||
|
||||
created := time.Now()
|
||||
ss.CreatedAt = &created
|
||||
ss.ExpiresOn = &token.Expiry
|
||||
ss.CreatedAtNow()
|
||||
ss.SetExpiresOn(token.Expiry)
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||
User: "11223344",
|
||||
}
|
||||
|
||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
||||
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, refreshed, true)
|
||||
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
||||
@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
||||
Email: "changeit",
|
||||
User: "changeit",
|
||||
}
|
||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
||||
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, refreshed, true)
|
||||
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
@ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration
|
||||
if token := values.Get("access_token"); token != "" {
|
||||
created := time.Now()
|
||||
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
|
||||
ss := &sessions.SessionState{
|
||||
AccessToken: token,
|
||||
}
|
||||
ss.CreatedAtNow()
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no access token found %s", result.Body())
|
||||
@ -126,10 +129,9 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
return validateToken(ctx, p, s.AccessToken, nil)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||
// do nothing if a refresh is not required
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
||||
return false, nil
|
||||
// RefreshSession refreshes the user's session
|
||||
func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
||||
return false, ErrNotImplemented
|
||||
}
|
||||
|
||||
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||
|
@ -14,12 +14,20 @@ import (
|
||||
func TestRefresh(t *testing.T) {
|
||||
p := &ProviderData{}
|
||||
|
||||
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
||||
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
|
||||
ExpiresOn: &expires,
|
||||
})
|
||||
assert.Equal(t, false, refreshed)
|
||||
assert.Equal(t, nil, err)
|
||||
now := time.Unix(1234567890, 10)
|
||||
expires := time.Unix(1234567890, 0)
|
||||
|
||||
ss := &sessions.SessionState{}
|
||||
ss.Clock.Set(now)
|
||||
ss.SetExpiresOn(expires)
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), ss)
|
||||
assert.False(t, refreshed)
|
||||
assert.Equal(t, ErrNotImplemented, err)
|
||||
|
||||
refreshed, err = p.RefreshSession(context.Background(), nil)
|
||||
assert.False(t, refreshed)
|
||||
assert.Equal(t, ErrNotImplemented, err)
|
||||
}
|
||||
|
||||
func TestAcrValuesNotConfigured(t *testing.T) {
|
||||
|
@ -9,14 +9,14 @@ import (
|
||||
// Provider represents an upstream identity provider implementation
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
GetLoginURL(redirectURI, finalRedirect string, nonce string) string
|
||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||
// Deprecated: Migrate to EnrichSession
|
||||
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
||||
GetLoginURL(redirectURI, state, nonce string) string
|
||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||
EnrichSession(ctx context.Context, s *sessions.SessionState) error
|
||||
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user