1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-07-15 01:44:22 +02:00

Split RefreshSessionIfNeeded in two methods and use Redis lock

This commit is contained in:
Kevin Kreitner
2021-02-22 08:33:53 +01:00
parent b942eb1582
commit 69d6fc8a08
21 changed files with 297 additions and 145 deletions

1
go.mod
View File

@ -22,6 +22,7 @@ require (
github.com/onsi/gomega v1.10.2 github.com/onsi/gomega v1.10.2
github.com/pierrec/lz4 v2.5.2+incompatible github.com/pierrec/lz4 v2.5.2+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v0.9.3
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.6.3 github.com/spf13/viper v1.6.3
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.6.1

6
go.sum
View File

@ -24,6 +24,7 @@ github.com/alicebob/miniredis/v2 v2.13.0 h1:QPosMaxm+r6Qs+YcCtL2Z2a2RSdC9VfXJLpd
github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k= github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y= github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
@ -168,6 +169,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A= github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA= github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo= github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo=
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s= github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s=
@ -208,13 +210,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8=
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/common v0.4.0 h1:7etb9YClo3a6HjLzfl6rIQaU+FDfi0VSX39io3aQ+DM=
github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084 h1:sofwID9zm4tzrgykg80hfFph1mryUeLRsUfoocVVmRY=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=

View File

@ -276,10 +276,11 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
} }
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore, SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh, RefreshPeriod: opts.Cookie.Refresh,
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, RefreshSession: opts.GetProvider().RefreshSession,
ValidateSessionState: opts.GetProvider().ValidateSession, IsRefreshNeeded: opts.GetProvider().IsRefreshNeeded,
ValidateSessionState: opts.GetProvider().ValidateSession,
})) }))
return chain return chain

View File

@ -2,13 +2,13 @@ package sessions
import ( import (
"net/http" "net/http"
"time"
) )
// SessionStore is an interface to storing user sessions in the proxy // SessionStore is an interface to storing user sessions in the proxy
type SessionStore interface { type SessionStore interface {
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
Load(req *http.Request) (*SessionState, error) Load(req *http.Request) (*SessionState, error)
Lock(req *http.Request, expiration time.Duration) error LoadWithLock(req *http.Request) (*SessionState, error)
ReleaseLock(req *http.Request) error
Clear(rw http.ResponseWriter, req *http.Request) error Clear(rw http.ResponseWriter, req *http.Request) error
} }

View File

@ -24,10 +24,13 @@ type StoredSessionLoaderOptions struct {
RefreshPeriod time.Duration RefreshPeriod time.Duration
// Provider based sesssion refreshing // Provider based sesssion refreshing
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
// Provider based session refresh check
IsRefreshNeeded func(*sessionsapi.SessionState) bool
// Provider based session validation. // Provider based session validation.
// If the sesssion is older than `RefreshPeriod` but the provider doesn't // If the session is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation. // refresh it, we must re-validate using this validation.
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
} }
@ -38,10 +41,11 @@ type StoredSessionLoaderOptions struct {
// If a session was loader by a previous handler, it will not be replaced. // If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
ss := &storedSessionLoader{ ss := &storedSessionLoader{
store: opts.SessionStore, store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod, refreshPeriod: opts.RefreshPeriod,
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, refreshSessionWithProvider: opts.RefreshSession,
validateSessionState: opts.ValidateSessionState, isRefreshNeededWithProvider: opts.IsRefreshNeeded,
validateSessionState: opts.ValidateSessionState,
} }
return ss.loadSession return ss.loadSession
} }
@ -49,10 +53,11 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
// storedSessionLoader is responsible for loading sessions from cookie // storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store. // identified sessions in the session store.
type storedSessionLoader struct { type storedSessionLoader struct {
store sessionsapi.SessionStore store sessionsapi.SessionStore
refreshPeriod time.Duration refreshPeriod time.Duration
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) refreshSessionWithProvider func(context.Context, *sessionsapi.SessionState) (bool, error)
validateSessionState func(context.Context, *sessionsapi.SessionState) bool isRefreshNeededWithProvider func(*sessionsapi.SessionState) bool
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
} }
// loadSession attempts to load a session as identified by the request cookies. // loadSession attempts to load a session as identified by the request cookies.
@ -98,44 +103,54 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
return nil, nil return nil, nil
} }
err = s.refreshSessionIfNeeded(rw, req, session) if !s.isSessionRefreshNeeded(session) {
return session, nil
}
session, err = s.store.LoadWithLock(req)
if err != nil {
return nil, err
}
if session == nil {
// No session was found in the storage, nothing more to do
return nil, nil
}
if !s.isSessionRefreshNeeded(session) {
_ = s.store.ReleaseLock(req)
return session, nil
}
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
refreshed, err := s.refreshSession(rw, req, session)
_ = s.store.ReleaseLock(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
} }
if refreshed {
return session, nil
}
// Session wasn't refreshed, so make sure it's still valid
err = s.validateSession(req.Context(), session)
if err != nil {
return nil, err
}
return session, nil return session, nil
} }
// refreshSessionIfNeeded will attempt to refresh a session if the session // isSessionRefreshNeeded will check if the session need to be refreshed
// is older than the refresh period. func (s *storedSessionLoader) isSessionRefreshNeeded(session *sessionsapi.SessionState) bool {
// It is assumed that if the provider refreshes the session, the session is now if s.refreshPeriod > time.Duration(0) && session.Age() >= s.refreshPeriod {
// valid. return s.isRefreshNeededWithProvider(session)
// If the session requires refreshing but the provider does not refresh it,
// we must validate the session to ensure that the returned session is still
// valid.
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
// Refresh is disabled or the session is not old enough, do nothing
return nil
} }
return false
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
if err != nil {
return err
}
if !refreshed {
// Session wasn't refreshed, so make sure it's still valid
return s.validateSession(req.Context(), session)
}
return nil
} }
// refreshSessionWithProvider attempts to refresh the sessinon with the provider // refreshSession attempts to refresh the sessinon with the provider
// and will save the session if it was updated. // and will save the session if it was updated.
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) refreshed, err := s.refreshSessionWithProvider(req.Context(), session)
if err != nil { if err != nil {
return false, fmt.Errorf("error refreshing access token: %v", err) return false, fmt.Errorf("error refreshing access token: %v", err)
} }

View File

@ -39,6 +39,17 @@ var _ = Describe("Stored Session Suite", func() {
} }
} }
var defaultIsRefreshNeededFunc = func(ss *sessionsapi.SessionState) bool {
switch ss.RefreshToken {
case refresh:
return true
case noRefresh:
return false
default:
return false
}
}
var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool {
return ss.AccessToken != "Invalid" return ss.AccessToken != "Invalid"
} }
@ -86,13 +97,14 @@ var _ = Describe("Stored Session Suite", func() {
} }
type storedSessionLoaderTableInput struct { type storedSessionLoaderTableInput struct {
requestHeaders http.Header requestHeaders http.Header
existingSession *sessionsapi.SessionState existingSession *sessionsapi.SessionState
expectedSession *sessionsapi.SessionState expectedSession *sessionsapi.SessionState
store sessionsapi.SessionStore store sessionsapi.SessionStore
refreshPeriod time.Duration refreshPeriod time.Duration
refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
validateSession func(context.Context, *sessionsapi.SessionState) bool isRefreshSessionNeeded func(*sessionsapi.SessionState) bool
validateSession func(context.Context, *sessionsapi.SessionState) bool
} }
DescribeTable("when serving a request", DescribeTable("when serving a request",
@ -109,10 +121,11 @@ var _ = Describe("Stored Session Suite", func() {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
opts := &StoredSessionLoaderOptions{ opts := &StoredSessionLoaderOptions{
SessionStore: in.store, SessionStore: in.store,
RefreshPeriod: in.refreshPeriod, RefreshPeriod: in.refreshPeriod,
RefreshSessionIfNeeded: in.refreshSession, RefreshSession: in.refreshSession,
ValidateSessionState: in.validateSession, IsRefreshNeeded: in.isRefreshSessionNeeded,
ValidateSessionState: in.validateSession,
} }
// Create the handler with a next handler that will capture the session // Create the handler with a next handler that will capture the session
@ -126,24 +139,26 @@ var _ = Describe("Stored Session Suite", func() {
Expect(gotSession).To(Equal(in.expectedSession)) Expect(gotSession).To(Equal(in.expectedSession))
}, },
Entry("with no cookie", storedSessionLoaderTableInput{ Entry("with no cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{}, requestHeaders: http.Header{},
existingSession: nil, existingSession: nil,
expectedSession: nil, expectedSession: nil,
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("with an invalid cookie", storedSessionLoaderTableInput{ Entry("with an invalid cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NonExistent"}, "Cookie": []string{"_oauth2_proxy=NonExistent"},
}, },
existingSession: nil, existingSession: nil,
expectedSession: nil, expectedSession: nil,
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("with an existing session", storedSessionLoaderTableInput{ Entry("with an existing session", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
@ -155,10 +170,11 @@ var _ = Describe("Stored Session Suite", func() {
expectedSession: &sessionsapi.SessionState{ expectedSession: &sessionsapi.SessionState{
RefreshToken: "Existing", RefreshToken: "Existing",
}, },
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("with a session that has not expired", storedSessionLoaderTableInput{ Entry("with a session that has not expired", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
@ -170,21 +186,23 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &createdPast, CreatedAt: &createdPast,
ExpiresOn: &createdFuture, ExpiresOn: &createdFuture,
}, },
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, "Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
}, },
existingSession: nil, existingSession: nil,
expectedSession: nil, expectedSession: nil,
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
@ -196,10 +214,11 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &createdPast, CreatedAt: &createdPast,
ExpiresOn: &createdFuture, ExpiresOn: &createdFuture,
}, },
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 10 * time.Minute, refreshPeriod: 10 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
@ -211,37 +230,40 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &createdPast, CreatedAt: &createdPast,
ExpiresOn: &createdFuture, ExpiresOn: &createdFuture,
}, },
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("when the provider refresh fails", storedSessionLoaderTableInput{ Entry("when the provider refresh fails", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"}, "Cookie": []string{"_oauth2_proxy=RefreshError"},
}, },
existingSession: nil, existingSession: nil,
expectedSession: nil, expectedSession: nil,
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
requestHeaders: http.Header{ requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, "Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
}, },
existingSession: nil, existingSession: nil,
expectedSession: nil, expectedSession: nil,
store: defaultSessionStore, store: defaultSessionStore,
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc, refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc, isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}), }),
) )
}) })
Context("refreshSessionIfNeeded", func() { Context("isSessionRefreshNeeded", func() {
type refreshSessionIfNeededTableInput struct { type refreshSessionIfNeededTableInput struct {
refreshPeriod time.Duration refreshPeriod time.Duration
session *sessionsapi.SessionState session *sessionsapi.SessionState
@ -261,7 +283,7 @@ var _ = Describe("Stored Session Suite", func() {
s := &storedSessionLoader{ s := &storedSessionLoader{
refreshPeriod: in.refreshPeriod, refreshPeriod: in.refreshPeriod,
store: &fakeSessionStore{}, store: &fakeSessionStore{},
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
refreshed = true refreshed = true
switch ss.RefreshToken { switch ss.RefreshToken {
case refresh: case refresh:
@ -272,6 +294,17 @@ var _ = Describe("Stored Session Suite", func() {
return false, errors.New("error refreshing session") return false, errors.New("error refreshing session")
} }
}, },
isRefreshNeededWithProvider: func(ss *sessionsapi.SessionState) bool {
refreshed = true
switch ss.RefreshToken {
case refresh:
return true
case noRefresh:
return false
default:
return false
}
},
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
validated = true validated = true
return ss.AccessToken != "Invalid" return ss.AccessToken != "Invalid"
@ -279,7 +312,10 @@ var _ = Describe("Stored Session Suite", func() {
} }
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
err := s.refreshSessionIfNeeded(nil, req, in.session) var err error
if s.isSessionRefreshNeeded(in.session) {
refreshed, err = s.refreshSession(nil, req, in.session)
}
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
} else { } else {
@ -364,7 +400,7 @@ var _ = Describe("Stored Session Suite", func() {
) )
}) })
Context("refreshSessionWithProvider", func() { Context("refreshSession", func() {
type refreshSessionWithProviderTableInput struct { type refreshSessionWithProviderTableInput struct {
session *sessionsapi.SessionState session *sessionsapi.SessionState
expectedErr error expectedErr error
@ -388,7 +424,7 @@ var _ = Describe("Stored Session Suite", func() {
return nil return nil
}, },
}, },
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken { switch ss.RefreshToken {
case refresh: case refresh:
return true, nil return true, nil
@ -401,7 +437,7 @@ var _ = Describe("Stored Session Suite", func() {
} }
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) refreshed, err := s.refreshSession(nil, req, in.session)
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
} else { } else {
@ -515,6 +551,17 @@ func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, e
return nil, nil return nil, nil
} }
func (f *fakeSessionStore) LoadWithLock(req *http.Request) (*sessionsapi.SessionState, error) {
if f.LoadFunc != nil {
return f.LoadFunc(req)
}
return nil, nil
}
func (f *fakeSessionStore) ReleaseLock(req *http.Request) error {
return nil
}
func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
if f.ClearFunc != nil { if f.ClearFunc != nil {
return f.ClearFunc(rw, req) return f.ClearFunc(rw, req)

View File

@ -66,7 +66,14 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
return session, nil return session, nil
} }
func (s *SessionStore) Lock(req *http.Request, expirationTime time.Duration) error { // Load reads sessions.SessionState information from Cookies within the
// HTTP request object
func (s *SessionStore) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
return s.Load(req)
}
// Release the session lock
func (s *SessionStore) ReleaseLock(req *http.Request) error {
return nil return nil
} }

View File

@ -11,6 +11,7 @@ import (
type Store interface { type Store interface {
Save(context.Context, string, []byte, time.Duration) error Save(context.Context, string, []byte, time.Duration) error
Load(context.Context, string) ([]byte, error) Load(context.Context, string) ([]byte, error)
Lock(context.Context, string, time.Duration) error LoadWithLock(context.Context, string) ([]byte, error)
ReleaseLock(context.Context, string) error
Clear(context.Context, string) error Clear(context.Context, string) error
} }

View File

@ -65,16 +65,28 @@ func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) {
}) })
} }
// Lock reads sessions.SessionState in a session store. It will // Load reads sessions.SessionState information from a session store. It will
// use the session ticket from the http.Request's cookie. // use the session ticket from the http.Request's cookie.
func (m *Manager) Lock(req *http.Request, expiration time.Duration) error { func (m *Manager) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
tckt, err := decodeTicketFromRequest(req, m.Options)
if err != nil {
return nil, err
}
return tckt.loadSession(func(key string) ([]byte, error) {
return m.Store.LoadWithLock(req.Context(), key)
})
}
// Release the session lock
func (m *Manager) ReleaseLock(req *http.Request) error {
tckt, err := decodeTicketFromRequest(req, m.Options) tckt, err := decodeTicketFromRequest(req, m.Options)
if err != nil { if err != nil {
return err return err
} }
return tckt.lockSession(func(key string) error { return tckt.releaseSession(func(key string) error {
return m.Store.Lock(req.Context(), key, expiration) return m.Store.ReleaseLock(req.Context(), key)
}) })
} }

View File

@ -139,13 +139,11 @@ func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) {
return sessions.DecodeSessionState(ciphertext, c, false) return sessions.DecodeSessionState(ciphertext, c, false)
} }
// lockSession loads a session from the disk store via the passed loadFunc // releaseSession releases a potential locked session
// using the ticket.id as the key. It then decodes the SessionState using func (t *ticket) releaseSession(loader lockFunc) error {
// ticket.secret to make the AES-GCM cipher.
func (t *ticket) lockSession(loader lockFunc) error {
err := loader(t.id) err := loader(t.id)
if err != nil { if err != nil {
return fmt.Errorf("failed to lock the session state with the ticket: %v", err) return fmt.Errorf("failed to release session state with the ticket: %v", err)
} }
return nil return nil
} }

View File

@ -2,6 +2,7 @@ package redis
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/bsm/redislock" "github.com/bsm/redislock"
@ -12,6 +13,7 @@ import (
type Client interface { type Client interface {
Get(ctx context.Context, key string) ([]byte, error) Get(ctx context.Context, key string) ([]byte, error)
Lock(ctx context.Context, key string, expiration time.Duration) error Lock(ctx context.Context, key string, expiration time.Duration) error
Unlock(ctx context.Context, key string) error
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
Del(ctx context.Context, key string) error Del(ctx context.Context, key string) error
} }
@ -21,20 +23,21 @@ var _ Client = (*client)(nil)
type client struct { type client struct {
*redis.Client *redis.Client
locker *redislock.Client locker *redislock.Client
lock *redislock.Lock locks map[string]*redislock.Lock
} }
func newClient(c *redis.Client) Client { func newClient(c *redis.Client) Client {
return &client{ return &client{
Client: c, Client: c,
locker: redislock.New(c), locker: redislock.New(c),
locks: map[string]*redislock.Lock{},
} }
} }
func (c *client) Get(ctx context.Context, key string) ([]byte, error) { func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
if c.lock != nil { if c.locks[key] != nil {
for { for {
ttl, err := c.lock.TTL(ctx) ttl, err := c.locks[key].TTL(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -47,27 +50,37 @@ func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
} }
func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error { func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error {
if c.locks[key] != nil {
return fmt.Errorf("locks for key %s already exists", key)
}
lock, err := c.locker.Obtain(ctx, key, expiration, nil) lock, err := c.locker.Obtain(ctx, key, expiration, nil)
if err != nil { if err != nil {
return err return err
} }
c.lock = lock c.locks[key] = lock
return nil return nil
} }
func (c *client) Unlock(ctx context.Context, key string) error {
if c.locks[key] == nil {
return nil
}
return c.locks[key].Release(ctx)
}
func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error { func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
err := c.Client.Set(ctx, key, value, expiration).Err() err := c.Client.Set(ctx, key, value, expiration).Err()
if err != nil { if err != nil {
return err return err
} }
if c.lock == nil { if c.locks[key] == nil {
return nil return nil
} }
err = c.lock.Release(ctx) err = c.locks[key].Release(ctx)
if err != nil { if err != nil {
return err return err
} }
c.lock = nil c.locks = nil
return nil return nil
} }
@ -80,13 +93,14 @@ var _ Client = (*clusterClient)(nil)
type clusterClient struct { type clusterClient struct {
*redis.ClusterClient *redis.ClusterClient
locker *redislock.Client locker *redislock.Client
lock *redislock.Lock locks map[string]*redislock.Lock
} }
func newClusterClient(c *redis.ClusterClient) Client { func newClusterClient(c *redis.ClusterClient) Client {
return &clusterClient{ return &clusterClient{
ClusterClient: c, ClusterClient: c,
locker: redislock.New(c), locker: redislock.New(c),
locks: map[string]*redislock.Lock{},
} }
} }
@ -95,14 +109,24 @@ func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) {
} }
func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error { func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error {
if c.locks[key] != nil {
return fmt.Errorf("locks for key %s already exists", key)
}
lock, err := c.locker.Obtain(ctx, key, expiration, nil) lock, err := c.locker.Obtain(ctx, key, expiration, nil)
if err != nil { if err != nil {
return err return err
} }
c.lock = lock c.locks[key] = lock
return nil return nil
} }
func (c *clusterClient) Unlock(ctx context.Context, key string) error {
if c.locks[key] == nil {
return nil
}
return c.locks[key].Release(ctx)
}
func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error { func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
return c.ClusterClient.Set(ctx, key, value, expiration).Err() return c.ClusterClient.Set(ctx, key, value, expiration).Err()
} }

View File

@ -54,11 +54,25 @@ func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error)
return value, nil return value, nil
} }
// Lock sessions.SessionState information from a persistence // ReleaseLock sessions.SessionState information from a persistence
func (store *SessionStore) Lock(ctx context.Context, key string, expiration time.Duration) error { func (store *SessionStore) LoadWithLock(ctx context.Context, key string) ([]byte, error) {
err := store.Client.Lock(ctx, key, expiration) value, err := store.Client.Get(ctx, key)
if err != nil { if err != nil {
return fmt.Errorf("error setting redis lock: %v", err) return nil, fmt.Errorf("error loading redis session: %v", err)
}
err = store.Client.Lock(ctx, key, 200*time.Millisecond)
if err != nil {
return nil, fmt.Errorf("error setting redis locks: %v", err)
}
return value, nil
}
// ReleaseLock sessions.SessionState information from a persistence
func (store *SessionStore) ReleaseLock(ctx context.Context, key string) error {
err := store.Client.Unlock(ctx, key)
if err != nil {
return fmt.Errorf("error releasing redis locks: %v", err)
} }
return nil return nil
} }

View File

@ -160,8 +160,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil {
return false, nil return false, nil
} }
@ -176,6 +176,11 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.
return true, nil return true, nil
} }
// IsRefreshNeeded checks if the session has expired
func (p *AzureProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "")
}
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
params := url.Values{} params := url.Values{}
params.Add("client_id", p.ClientID) params.Add("client_id", p.ClientID)

View File

@ -244,7 +244,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
expires := time.Now().Add(time.Duration(1) * time.Hour) expires := time.Now().Add(time.Duration(1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) refreshNeeded, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.False(t, refreshNeeded) assert.False(t, refreshNeeded)
} }
@ -258,7 +258,7 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
expires := time.Now().Add(time.Duration(-1) * time.Hour) expires := time.Now().Add(time.Duration(-1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
_, err := p.RefreshSessionIfNeeded(context.Background(), session) _, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil) assert.NotEqual(t, session, nil)
assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, "new_some_access_token", session.AccessToken)

View File

@ -123,7 +123,7 @@ func (p *GitLabProvider) SetProjectScope() {
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil return false, nil
} }
@ -139,6 +139,11 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return true, nil return true, nil
} }
// IsRefreshNeeded checks if the session has expired
func (p *GitLabProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
}
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {

View File

@ -268,8 +268,8 @@ func userInGroup(service *admin.Service, group string, email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { if s == nil {
return false, nil return false, nil
} }
@ -295,6 +295,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return true, nil return true, nil
} }
// IsRefreshNeeded checks if the session has expired
func (p *GoogleProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
}
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()

View File

@ -115,8 +115,8 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access Token (and optional ID token) if required // RefreshToken to fetch a new Access Token (and optional ID token) if required
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { if s == nil {
return false, nil return false, nil
} }
@ -129,6 +129,11 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
return true, nil return true, nil
} }
// IsRefreshNeeded checks if the session has expired
func (p *OIDCProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
}
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the // redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
// Access Token and (probably) the ID Token. // Access Token and (probably) the ID Token.
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {

View File

@ -467,7 +467,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
User: "11223344", User: "11223344",
} }
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, "janedoe@example.com", existingSession.Email) assert.Equal(t, "janedoe@example.com", existingSession.Email)
@ -500,7 +500,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
Email: "changeit", Email: "changeit",
User: "changeit", User: "changeit",
} }
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, defaultIDToken.Email, existingSession.Email) assert.Equal(t, defaultIDToken.Email, existingSession.Email)

View File

@ -128,10 +128,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
// RefreshSessionIfNeeded should refresh the user's session if required and // RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required // do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
return false, nil return false, nil
} }
// IsRefreshNeeded should return true if the user's session need to be refreshed
func (p *ProviderData) IsRefreshNeeded(_ *sessions.SessionState) bool {
return false
}
// CreateSessionFromToken converts Bearer IDTokens into sessions // CreateSessionFromToken converts Bearer IDTokens into sessions
func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
if p.Verifier != nil { if p.Verifier != nil {

View File

@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
p := &ProviderData{} p := &ProviderData{}
expires := time.Now().Add(time.Duration(-11) * time.Minute) expires := time.Now().Add(time.Duration(-11) * time.Minute)
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
ExpiresOn: &expires, ExpiresOn: &expires,
}) })
assert.Equal(t, false, refreshed) assert.Equal(t, false, refreshed)

View File

@ -16,7 +16,8 @@ type Provider interface {
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSession(ctx context.Context, s *sessions.SessionState) bool ValidateSession(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
IsRefreshNeeded(s *sessions.SessionState) bool
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
} }