mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-17 21:17:53 +02:00
Split RefreshSessionIfNeeded in two methods and use Redis lock
This commit is contained in:
parent
b942eb1582
commit
69d6fc8a08
1
go.mod
1
go.mod
@ -22,6 +22,7 @@ require (
|
||||
github.com/onsi/gomega v1.10.2
|
||||
github.com/pierrec/lz4 v2.5.2+incompatible
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/prometheus/client_golang v0.9.3
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.6.3
|
||||
github.com/stretchr/testify v1.6.1
|
||||
|
6
go.sum
6
go.sum
@ -24,6 +24,7 @@ github.com/alicebob/miniredis/v2 v2.13.0 h1:QPosMaxm+r6Qs+YcCtL2Z2a2RSdC9VfXJLpd
|
||||
github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k=
|
||||
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
|
||||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
|
||||
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
|
||||
@ -168,6 +169,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP
|
||||
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
|
||||
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo=
|
||||
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s=
|
||||
@ -208,13 +210,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8=
|
||||
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
|
||||
github.com/prometheus/common v0.4.0 h1:7etb9YClo3a6HjLzfl6rIQaU+FDfi0VSX39io3aQ+DM=
|
||||
github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084 h1:sofwID9zm4tzrgykg80hfFph1mryUeLRsUfoocVVmRY=
|
||||
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
|
@ -276,10 +276,11 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
|
||||
}
|
||||
|
||||
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||
SessionStore: sessionStore,
|
||||
RefreshPeriod: opts.Cookie.Refresh,
|
||||
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
|
||||
ValidateSessionState: opts.GetProvider().ValidateSession,
|
||||
SessionStore: sessionStore,
|
||||
RefreshPeriod: opts.Cookie.Refresh,
|
||||
RefreshSession: opts.GetProvider().RefreshSession,
|
||||
IsRefreshNeeded: opts.GetProvider().IsRefreshNeeded,
|
||||
ValidateSessionState: opts.GetProvider().ValidateSession,
|
||||
}))
|
||||
|
||||
return chain
|
||||
|
@ -2,13 +2,13 @@ package sessions
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionStore is an interface to storing user sessions in the proxy
|
||||
type SessionStore interface {
|
||||
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
|
||||
Load(req *http.Request) (*SessionState, error)
|
||||
Lock(req *http.Request, expiration time.Duration) error
|
||||
LoadWithLock(req *http.Request) (*SessionState, error)
|
||||
ReleaseLock(req *http.Request) error
|
||||
Clear(rw http.ResponseWriter, req *http.Request) error
|
||||
}
|
||||
|
@ -24,10 +24,13 @@ type StoredSessionLoaderOptions struct {
|
||||
RefreshPeriod time.Duration
|
||||
|
||||
// Provider based sesssion refreshing
|
||||
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
|
||||
// Provider based session refresh check
|
||||
IsRefreshNeeded func(*sessionsapi.SessionState) bool
|
||||
|
||||
// Provider based session validation.
|
||||
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
||||
// If the session is older than `RefreshPeriod` but the provider doesn't
|
||||
// refresh it, we must re-validate using this validation.
|
||||
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||
}
|
||||
@ -38,10 +41,11 @@ type StoredSessionLoaderOptions struct {
|
||||
// If a session was loader by a previous handler, it will not be replaced.
|
||||
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
||||
ss := &storedSessionLoader{
|
||||
store: opts.SessionStore,
|
||||
refreshPeriod: opts.RefreshPeriod,
|
||||
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
|
||||
validateSessionState: opts.ValidateSessionState,
|
||||
store: opts.SessionStore,
|
||||
refreshPeriod: opts.RefreshPeriod,
|
||||
refreshSessionWithProvider: opts.RefreshSession,
|
||||
isRefreshNeededWithProvider: opts.IsRefreshNeeded,
|
||||
validateSessionState: opts.ValidateSessionState,
|
||||
}
|
||||
return ss.loadSession
|
||||
}
|
||||
@ -49,10 +53,11 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
|
||||
// storedSessionLoader is responsible for loading sessions from cookie
|
||||
// identified sessions in the session store.
|
||||
type storedSessionLoader struct {
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
refreshSessionWithProvider func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
isRefreshNeededWithProvider func(*sessionsapi.SessionState) bool
|
||||
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||
}
|
||||
|
||||
// loadSession attempts to load a session as identified by the request cookies.
|
||||
@ -98,44 +103,54 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err = s.refreshSessionIfNeeded(rw, req, session)
|
||||
if !s.isSessionRefreshNeeded(session) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
session, err = s.store.LoadWithLock(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if session == nil {
|
||||
// No session was found in the storage, nothing more to do
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !s.isSessionRefreshNeeded(session) {
|
||||
_ = s.store.ReleaseLock(req)
|
||||
return session, nil
|
||||
}
|
||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
||||
refreshed, err := s.refreshSession(rw, req, session)
|
||||
_ = s.store.ReleaseLock(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
|
||||
}
|
||||
|
||||
if refreshed {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Session wasn't refreshed, so make sure it's still valid
|
||||
err = s.validateSession(req.Context(), session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// refreshSessionIfNeeded will attempt to refresh a session if the session
|
||||
// is older than the refresh period.
|
||||
// It is assumed that if the provider refreshes the session, the session is now
|
||||
// valid.
|
||||
// If the session requires refreshing but the provider does not refresh it,
|
||||
// we must validate the session to ensure that the returned session is still
|
||||
// valid.
|
||||
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
|
||||
// Refresh is disabled or the session is not old enough, do nothing
|
||||
return nil
|
||||
// isSessionRefreshNeeded will check if the session need to be refreshed
|
||||
func (s *storedSessionLoader) isSessionRefreshNeeded(session *sessionsapi.SessionState) bool {
|
||||
if s.refreshPeriod > time.Duration(0) && session.Age() >= s.refreshPeriod {
|
||||
return s.isRefreshNeededWithProvider(session)
|
||||
}
|
||||
|
||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
||||
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !refreshed {
|
||||
// Session wasn't refreshed, so make sure it's still valid
|
||||
return s.validateSession(req.Context(), session)
|
||||
}
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
|
||||
// refreshSession attempts to refresh the sessinon with the provider
|
||||
// and will save the session if it was updated.
|
||||
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
||||
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
|
||||
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
||||
refreshed, err := s.refreshSessionWithProvider(req.Context(), session)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error refreshing access token: %v", err)
|
||||
}
|
||||
|
@ -39,6 +39,17 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
}
|
||||
}
|
||||
|
||||
var defaultIsRefreshNeededFunc = func(ss *sessionsapi.SessionState) bool {
|
||||
switch ss.RefreshToken {
|
||||
case refresh:
|
||||
return true
|
||||
case noRefresh:
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||
return ss.AccessToken != "Invalid"
|
||||
}
|
||||
@ -86,13 +97,14 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
}
|
||||
|
||||
type storedSessionLoaderTableInput struct {
|
||||
requestHeaders http.Header
|
||||
existingSession *sessionsapi.SessionState
|
||||
expectedSession *sessionsapi.SessionState
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
validateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||
requestHeaders http.Header
|
||||
existingSession *sessionsapi.SessionState
|
||||
expectedSession *sessionsapi.SessionState
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
isRefreshSessionNeeded func(*sessionsapi.SessionState) bool
|
||||
validateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||
}
|
||||
|
||||
DescribeTable("when serving a request",
|
||||
@ -109,10 +121,11 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
opts := &StoredSessionLoaderOptions{
|
||||
SessionStore: in.store,
|
||||
RefreshPeriod: in.refreshPeriod,
|
||||
RefreshSessionIfNeeded: in.refreshSession,
|
||||
ValidateSessionState: in.validateSession,
|
||||
SessionStore: in.store,
|
||||
RefreshPeriod: in.refreshPeriod,
|
||||
RefreshSession: in.refreshSession,
|
||||
IsRefreshNeeded: in.isRefreshSessionNeeded,
|
||||
ValidateSessionState: in.validateSession,
|
||||
}
|
||||
|
||||
// Create the handler with a next handler that will capture the session
|
||||
@ -126,24 +139,26 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
Expect(gotSession).To(Equal(in.expectedSession))
|
||||
},
|
||||
Entry("with no cookie", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{},
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
requestHeaders: http.Header{},
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("with an invalid cookie", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
"Cookie": []string{"_oauth2_proxy=NonExistent"},
|
||||
},
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("with an existing session", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
@ -155,10 +170,11 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
expectedSession: &sessionsapi.SessionState{
|
||||
RefreshToken: "Existing",
|
||||
},
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("with a session that has not expired", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
@ -170,21 +186,23 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
},
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
|
||||
},
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
@ -196,10 +214,11 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
},
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 10 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 10 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
@ -211,37 +230,40 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
},
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
||||
},
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{
|
||||
"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
|
||||
},
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
existingSession: nil,
|
||||
expectedSession: nil,
|
||||
store: defaultSessionStore,
|
||||
refreshPeriod: 1 * time.Minute,
|
||||
refreshSession: defaultRefreshFunc,
|
||||
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||
validateSession: defaultValidateFunc,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
Context("refreshSessionIfNeeded", func() {
|
||||
Context("isSessionRefreshNeeded", func() {
|
||||
type refreshSessionIfNeededTableInput struct {
|
||||
refreshPeriod time.Duration
|
||||
session *sessionsapi.SessionState
|
||||
@ -261,7 +283,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
s := &storedSessionLoader{
|
||||
refreshPeriod: in.refreshPeriod,
|
||||
store: &fakeSessionStore{},
|
||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
refreshed = true
|
||||
switch ss.RefreshToken {
|
||||
case refresh:
|
||||
@ -272,6 +294,17 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
return false, errors.New("error refreshing session")
|
||||
}
|
||||
},
|
||||
isRefreshNeededWithProvider: func(ss *sessionsapi.SessionState) bool {
|
||||
refreshed = true
|
||||
switch ss.RefreshToken {
|
||||
case refresh:
|
||||
return true
|
||||
case noRefresh:
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
},
|
||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||
validated = true
|
||||
return ss.AccessToken != "Invalid"
|
||||
@ -279,7 +312,10 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
err := s.refreshSessionIfNeeded(nil, req, in.session)
|
||||
var err error
|
||||
if s.isSessionRefreshNeeded(in.session) {
|
||||
refreshed, err = s.refreshSession(nil, req, in.session)
|
||||
}
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(in.expectedErr))
|
||||
} else {
|
||||
@ -364,7 +400,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
)
|
||||
})
|
||||
|
||||
Context("refreshSessionWithProvider", func() {
|
||||
Context("refreshSession", func() {
|
||||
type refreshSessionWithProviderTableInput struct {
|
||||
session *sessionsapi.SessionState
|
||||
expectedErr error
|
||||
@ -388,7 +424,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
switch ss.RefreshToken {
|
||||
case refresh:
|
||||
return true, nil
|
||||
@ -401,7 +437,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session)
|
||||
refreshed, err := s.refreshSession(nil, req, in.session)
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(in.expectedErr))
|
||||
} else {
|
||||
@ -515,6 +551,17 @@ func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, e
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionStore) LoadWithLock(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
if f.LoadFunc != nil {
|
||||
return f.LoadFunc(req)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionStore) ReleaseLock(req *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||
if f.ClearFunc != nil {
|
||||
return f.ClearFunc(rw, req)
|
||||
|
@ -66,7 +66,14 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *SessionStore) Lock(req *http.Request, expirationTime time.Duration) error {
|
||||
// Load reads sessions.SessionState information from Cookies within the
|
||||
// HTTP request object
|
||||
func (s *SessionStore) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
|
||||
return s.Load(req)
|
||||
}
|
||||
|
||||
// Release the session lock
|
||||
func (s *SessionStore) ReleaseLock(req *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
type Store interface {
|
||||
Save(context.Context, string, []byte, time.Duration) error
|
||||
Load(context.Context, string) ([]byte, error)
|
||||
Lock(context.Context, string, time.Duration) error
|
||||
LoadWithLock(context.Context, string) ([]byte, error)
|
||||
ReleaseLock(context.Context, string) error
|
||||
Clear(context.Context, string) error
|
||||
}
|
||||
|
@ -65,16 +65,28 @@ func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||
})
|
||||
}
|
||||
|
||||
// Lock reads sessions.SessionState in a session store. It will
|
||||
// Load reads sessions.SessionState information from a session store. It will
|
||||
// use the session ticket from the http.Request's cookie.
|
||||
func (m *Manager) Lock(req *http.Request, expiration time.Duration) error {
|
||||
func (m *Manager) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
|
||||
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tckt.loadSession(func(key string) ([]byte, error) {
|
||||
return m.Store.LoadWithLock(req.Context(), key)
|
||||
})
|
||||
}
|
||||
|
||||
// Release the session lock
|
||||
func (m *Manager) ReleaseLock(req *http.Request) error {
|
||||
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tckt.lockSession(func(key string) error {
|
||||
return m.Store.Lock(req.Context(), key, expiration)
|
||||
return tckt.releaseSession(func(key string) error {
|
||||
return m.Store.ReleaseLock(req.Context(), key)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -139,13 +139,11 @@ func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) {
|
||||
return sessions.DecodeSessionState(ciphertext, c, false)
|
||||
}
|
||||
|
||||
// lockSession loads a session from the disk store via the passed loadFunc
|
||||
// using the ticket.id as the key. It then decodes the SessionState using
|
||||
// ticket.secret to make the AES-GCM cipher.
|
||||
func (t *ticket) lockSession(loader lockFunc) error {
|
||||
// releaseSession releases a potential locked session
|
||||
func (t *ticket) releaseSession(loader lockFunc) error {
|
||||
err := loader(t.id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lock the session state with the ticket: %v", err)
|
||||
return fmt.Errorf("failed to release session state with the ticket: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bsm/redislock"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
type Client interface {
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Lock(ctx context.Context, key string, expiration time.Duration) error
|
||||
Unlock(ctx context.Context, key string) error
|
||||
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
|
||||
Del(ctx context.Context, key string) error
|
||||
}
|
||||
@ -21,20 +23,21 @@ var _ Client = (*client)(nil)
|
||||
type client struct {
|
||||
*redis.Client
|
||||
locker *redislock.Client
|
||||
lock *redislock.Lock
|
||||
locks map[string]*redislock.Lock
|
||||
}
|
||||
|
||||
func newClient(c *redis.Client) Client {
|
||||
return &client{
|
||||
Client: c,
|
||||
locker: redislock.New(c),
|
||||
locks: map[string]*redislock.Lock{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
if c.lock != nil {
|
||||
if c.locks[key] != nil {
|
||||
for {
|
||||
ttl, err := c.lock.TTL(ctx)
|
||||
ttl, err := c.locks[key].TTL(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -47,27 +50,37 @@ func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
||||
if c.locks[key] != nil {
|
||||
return fmt.Errorf("locks for key %s already exists", key)
|
||||
}
|
||||
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.lock = lock
|
||||
c.locks[key] = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) Unlock(ctx context.Context, key string) error {
|
||||
if c.locks[key] == nil {
|
||||
return nil
|
||||
}
|
||||
return c.locks[key].Release(ctx)
|
||||
}
|
||||
|
||||
func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
|
||||
err := c.Client.Set(ctx, key, value, expiration).Err()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.lock == nil {
|
||||
if c.locks[key] == nil {
|
||||
return nil
|
||||
}
|
||||
err = c.lock.Release(ctx)
|
||||
err = c.locks[key].Release(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.lock = nil
|
||||
c.locks = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -80,13 +93,14 @@ var _ Client = (*clusterClient)(nil)
|
||||
type clusterClient struct {
|
||||
*redis.ClusterClient
|
||||
locker *redislock.Client
|
||||
lock *redislock.Lock
|
||||
locks map[string]*redislock.Lock
|
||||
}
|
||||
|
||||
func newClusterClient(c *redis.ClusterClient) Client {
|
||||
return &clusterClient{
|
||||
ClusterClient: c,
|
||||
locker: redislock.New(c),
|
||||
locks: map[string]*redislock.Lock{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,14 +109,24 @@ func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
||||
if c.locks[key] != nil {
|
||||
return fmt.Errorf("locks for key %s already exists", key)
|
||||
}
|
||||
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.lock = lock
|
||||
c.locks[key] = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clusterClient) Unlock(ctx context.Context, key string) error {
|
||||
if c.locks[key] == nil {
|
||||
return nil
|
||||
}
|
||||
return c.locks[key].Release(ctx)
|
||||
}
|
||||
|
||||
func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
|
||||
return c.ClusterClient.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
@ -54,11 +54,25 @@ func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error)
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Lock sessions.SessionState information from a persistence
|
||||
func (store *SessionStore) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
||||
err := store.Client.Lock(ctx, key, expiration)
|
||||
// ReleaseLock sessions.SessionState information from a persistence
|
||||
func (store *SessionStore) LoadWithLock(ctx context.Context, key string) ([]byte, error) {
|
||||
value, err := store.Client.Get(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting redis lock: %v", err)
|
||||
return nil, fmt.Errorf("error loading redis session: %v", err)
|
||||
}
|
||||
|
||||
err = store.Client.Lock(ctx, key, 200*time.Millisecond)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error setting redis locks: %v", err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ReleaseLock sessions.SessionState information from a persistence
|
||||
func (store *SessionStore) ReleaseLock(ctx context.Context, key string) error {
|
||||
err := store.Client.Unlock(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error releasing redis locks: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -160,8 +160,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -176,6 +176,11 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IsRefreshNeeded checks if the session has expired
|
||||
func (p *AzureProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||
return !(s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "")
|
||||
}
|
||||
|
||||
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
|
@ -244,7 +244,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
||||
|
||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||
refreshNeeded, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.False(t, refreshNeeded)
|
||||
}
|
||||
@ -258,7 +258,7 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
||||
|
||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
_, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||
_, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, session, nil)
|
||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||
|
@ -123,7 +123,7 @@ func (p *GitLabProvider) SetProjectScope() {
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
@ -139,6 +139,11 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IsRefreshNeeded checks if the session has expired
|
||||
func (p *GitLabProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
|
||||
}
|
||||
|
||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
|
@ -268,8 +268,8 @@ func userInGroup(service *admin.Service, group string, email string) bool {
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -295,6 +295,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IsRefreshNeeded checks if the session has expired
|
||||
func (p *GoogleProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
|
@ -115,8 +115,8 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -129,6 +129,11 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IsRefreshNeeded checks if the session has expired
|
||||
func (p *OIDCProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
|
||||
}
|
||||
|
||||
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
|
||||
// Access Token and (probably) the ID Token.
|
||||
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
|
@ -467,7 +467,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||
User: "11223344",
|
||||
}
|
||||
|
||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
||||
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, refreshed, true)
|
||||
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
||||
@ -500,7 +500,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
||||
Email: "changeit",
|
||||
User: "changeit",
|
||||
}
|
||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
||||
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, refreshed, true)
|
||||
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
||||
|
@ -128,10 +128,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
|
||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||
// do nothing if a refresh is not required
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
||||
func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// IsRefreshNeeded should return true if the user's session need to be refreshed
|
||||
func (p *ProviderData) IsRefreshNeeded(_ *sessions.SessionState) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||
func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
||||
if p.Verifier != nil {
|
||||
|
@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
|
||||
p := &ProviderData{}
|
||||
|
||||
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
||||
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
|
||||
refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
|
||||
ExpiresOn: &expires,
|
||||
})
|
||||
assert.Equal(t, false, refreshed)
|
||||
|
@ -16,7 +16,8 @@ type Provider interface {
|
||||
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
||||
GetLoginURL(redirectURI, finalRedirect string) string
|
||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
IsRefreshNeeded(s *sessions.SessionState) bool
|
||||
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user