You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-07-15 01:44:22 +02:00
Split RefreshSessionIfNeeded in two methods and use Redis lock
This commit is contained in:
1
go.mod
1
go.mod
@ -22,6 +22,7 @@ require (
|
|||||||
github.com/onsi/gomega v1.10.2
|
github.com/onsi/gomega v1.10.2
|
||||||
github.com/pierrec/lz4 v2.5.2+incompatible
|
github.com/pierrec/lz4 v2.5.2+incompatible
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||||
|
github.com/prometheus/client_golang v0.9.3
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/spf13/viper v1.6.3
|
github.com/spf13/viper v1.6.3
|
||||||
github.com/stretchr/testify v1.6.1
|
github.com/stretchr/testify v1.6.1
|
||||||
|
6
go.sum
6
go.sum
@ -24,6 +24,7 @@ github.com/alicebob/miniredis/v2 v2.13.0 h1:QPosMaxm+r6Qs+YcCtL2Z2a2RSdC9VfXJLpd
|
|||||||
github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k=
|
github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k=
|
||||||
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
||||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||||
|
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
|
||||||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||||
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
|
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
|
||||||
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
|
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
|
||||||
@ -168,6 +169,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP
|
|||||||
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||||
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
|
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
|
||||||
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
|
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
|
||||||
|
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo=
|
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo=
|
||||||
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s=
|
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s=
|
||||||
@ -208,13 +210,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
|||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||||
|
github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8=
|
||||||
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
|
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
|
||||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
|
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=
|
||||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
|
github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
|
||||||
|
github.com/prometheus/common v0.4.0 h1:7etb9YClo3a6HjLzfl6rIQaU+FDfi0VSX39io3aQ+DM=
|
||||||
github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||||
|
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084 h1:sofwID9zm4tzrgykg80hfFph1mryUeLRsUfoocVVmRY=
|
||||||
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
|
||||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||||
|
@ -276,10 +276,11 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||||
SessionStore: sessionStore,
|
SessionStore: sessionStore,
|
||||||
RefreshPeriod: opts.Cookie.Refresh,
|
RefreshPeriod: opts.Cookie.Refresh,
|
||||||
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
|
RefreshSession: opts.GetProvider().RefreshSession,
|
||||||
ValidateSessionState: opts.GetProvider().ValidateSession,
|
IsRefreshNeeded: opts.GetProvider().IsRefreshNeeded,
|
||||||
|
ValidateSessionState: opts.GetProvider().ValidateSession,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
|
@ -2,13 +2,13 @@ package sessions
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionStore is an interface to storing user sessions in the proxy
|
// SessionStore is an interface to storing user sessions in the proxy
|
||||||
type SessionStore interface {
|
type SessionStore interface {
|
||||||
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
|
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
|
||||||
Load(req *http.Request) (*SessionState, error)
|
Load(req *http.Request) (*SessionState, error)
|
||||||
Lock(req *http.Request, expiration time.Duration) error
|
LoadWithLock(req *http.Request) (*SessionState, error)
|
||||||
|
ReleaseLock(req *http.Request) error
|
||||||
Clear(rw http.ResponseWriter, req *http.Request) error
|
Clear(rw http.ResponseWriter, req *http.Request) error
|
||||||
}
|
}
|
||||||
|
@ -24,10 +24,13 @@ type StoredSessionLoaderOptions struct {
|
|||||||
RefreshPeriod time.Duration
|
RefreshPeriod time.Duration
|
||||||
|
|
||||||
// Provider based sesssion refreshing
|
// Provider based sesssion refreshing
|
||||||
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||||
|
|
||||||
|
// Provider based session refresh check
|
||||||
|
IsRefreshNeeded func(*sessionsapi.SessionState) bool
|
||||||
|
|
||||||
// Provider based session validation.
|
// Provider based session validation.
|
||||||
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
// If the session is older than `RefreshPeriod` but the provider doesn't
|
||||||
// refresh it, we must re-validate using this validation.
|
// refresh it, we must re-validate using this validation.
|
||||||
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||||
}
|
}
|
||||||
@ -38,10 +41,11 @@ type StoredSessionLoaderOptions struct {
|
|||||||
// If a session was loader by a previous handler, it will not be replaced.
|
// If a session was loader by a previous handler, it will not be replaced.
|
||||||
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
||||||
ss := &storedSessionLoader{
|
ss := &storedSessionLoader{
|
||||||
store: opts.SessionStore,
|
store: opts.SessionStore,
|
||||||
refreshPeriod: opts.RefreshPeriod,
|
refreshPeriod: opts.RefreshPeriod,
|
||||||
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
|
refreshSessionWithProvider: opts.RefreshSession,
|
||||||
validateSessionState: opts.ValidateSessionState,
|
isRefreshNeededWithProvider: opts.IsRefreshNeeded,
|
||||||
|
validateSessionState: opts.ValidateSessionState,
|
||||||
}
|
}
|
||||||
return ss.loadSession
|
return ss.loadSession
|
||||||
}
|
}
|
||||||
@ -49,10 +53,11 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
|
|||||||
// storedSessionLoader is responsible for loading sessions from cookie
|
// storedSessionLoader is responsible for loading sessions from cookie
|
||||||
// identified sessions in the session store.
|
// identified sessions in the session store.
|
||||||
type storedSessionLoader struct {
|
type storedSessionLoader struct {
|
||||||
store sessionsapi.SessionStore
|
store sessionsapi.SessionStore
|
||||||
refreshPeriod time.Duration
|
refreshPeriod time.Duration
|
||||||
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
refreshSessionWithProvider func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||||
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
isRefreshNeededWithProvider func(*sessionsapi.SessionState) bool
|
||||||
|
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadSession attempts to load a session as identified by the request cookies.
|
// loadSession attempts to load a session as identified by the request cookies.
|
||||||
@ -98,44 +103,54 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.refreshSessionIfNeeded(rw, req, session)
|
if !s.isSessionRefreshNeeded(session) {
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err = s.store.LoadWithLock(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
// No session was found in the storage, nothing more to do
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.isSessionRefreshNeeded(session) {
|
||||||
|
_ = s.store.ReleaseLock(req)
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
||||||
|
refreshed, err := s.refreshSession(rw, req, session)
|
||||||
|
_ = s.store.ReleaseLock(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
|
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if refreshed {
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session wasn't refreshed, so make sure it's still valid
|
||||||
|
err = s.validateSession(req.Context(), session)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshSessionIfNeeded will attempt to refresh a session if the session
|
// isSessionRefreshNeeded will check if the session need to be refreshed
|
||||||
// is older than the refresh period.
|
func (s *storedSessionLoader) isSessionRefreshNeeded(session *sessionsapi.SessionState) bool {
|
||||||
// It is assumed that if the provider refreshes the session, the session is now
|
if s.refreshPeriod > time.Duration(0) && session.Age() >= s.refreshPeriod {
|
||||||
// valid.
|
return s.isRefreshNeededWithProvider(session)
|
||||||
// If the session requires refreshing but the provider does not refresh it,
|
|
||||||
// we must validate the session to ensure that the returned session is still
|
|
||||||
// valid.
|
|
||||||
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
|
||||||
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
|
|
||||||
// Refresh is disabled or the session is not old enough, do nothing
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
|
||||||
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !refreshed {
|
|
||||||
// Session wasn't refreshed, so make sure it's still valid
|
|
||||||
return s.validateSession(req.Context(), session)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
|
// refreshSession attempts to refresh the sessinon with the provider
|
||||||
// and will save the session if it was updated.
|
// and will save the session if it was updated.
|
||||||
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
||||||
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
|
refreshed, err := s.refreshSessionWithProvider(req.Context(), session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("error refreshing access token: %v", err)
|
return false, fmt.Errorf("error refreshing access token: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,17 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultIsRefreshNeededFunc = func(ss *sessionsapi.SessionState) bool {
|
||||||
|
switch ss.RefreshToken {
|
||||||
|
case refresh:
|
||||||
|
return true
|
||||||
|
case noRefresh:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||||
return ss.AccessToken != "Invalid"
|
return ss.AccessToken != "Invalid"
|
||||||
}
|
}
|
||||||
@ -86,13 +97,14 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type storedSessionLoaderTableInput struct {
|
type storedSessionLoaderTableInput struct {
|
||||||
requestHeaders http.Header
|
requestHeaders http.Header
|
||||||
existingSession *sessionsapi.SessionState
|
existingSession *sessionsapi.SessionState
|
||||||
expectedSession *sessionsapi.SessionState
|
expectedSession *sessionsapi.SessionState
|
||||||
store sessionsapi.SessionStore
|
store sessionsapi.SessionStore
|
||||||
refreshPeriod time.Duration
|
refreshPeriod time.Duration
|
||||||
refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||||
validateSession func(context.Context, *sessionsapi.SessionState) bool
|
isRefreshSessionNeeded func(*sessionsapi.SessionState) bool
|
||||||
|
validateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
DescribeTable("when serving a request",
|
DescribeTable("when serving a request",
|
||||||
@ -109,10 +121,11 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
opts := &StoredSessionLoaderOptions{
|
opts := &StoredSessionLoaderOptions{
|
||||||
SessionStore: in.store,
|
SessionStore: in.store,
|
||||||
RefreshPeriod: in.refreshPeriod,
|
RefreshPeriod: in.refreshPeriod,
|
||||||
RefreshSessionIfNeeded: in.refreshSession,
|
RefreshSession: in.refreshSession,
|
||||||
ValidateSessionState: in.validateSession,
|
IsRefreshNeeded: in.isRefreshSessionNeeded,
|
||||||
|
ValidateSessionState: in.validateSession,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the handler with a next handler that will capture the session
|
// Create the handler with a next handler that will capture the session
|
||||||
@ -126,24 +139,26 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
Expect(gotSession).To(Equal(in.expectedSession))
|
Expect(gotSession).To(Equal(in.expectedSession))
|
||||||
},
|
},
|
||||||
Entry("with no cookie", storedSessionLoaderTableInput{
|
Entry("with no cookie", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{},
|
requestHeaders: http.Header{},
|
||||||
existingSession: nil,
|
existingSession: nil,
|
||||||
expectedSession: nil,
|
expectedSession: nil,
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("with an invalid cookie", storedSessionLoaderTableInput{
|
Entry("with an invalid cookie", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
"Cookie": []string{"_oauth2_proxy=NonExistent"},
|
"Cookie": []string{"_oauth2_proxy=NonExistent"},
|
||||||
},
|
},
|
||||||
existingSession: nil,
|
existingSession: nil,
|
||||||
expectedSession: nil,
|
expectedSession: nil,
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("with an existing session", storedSessionLoaderTableInput{
|
Entry("with an existing session", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
@ -155,10 +170,11 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
expectedSession: &sessionsapi.SessionState{
|
expectedSession: &sessionsapi.SessionState{
|
||||||
RefreshToken: "Existing",
|
RefreshToken: "Existing",
|
||||||
},
|
},
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("with a session that has not expired", storedSessionLoaderTableInput{
|
Entry("with a session that has not expired", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
@ -170,21 +186,23 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
},
|
},
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
|
Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
|
"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
|
||||||
},
|
},
|
||||||
existingSession: nil,
|
existingSession: nil,
|
||||||
expectedSession: nil,
|
expectedSession: nil,
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
|
Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
@ -196,10 +214,11 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
},
|
},
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 10 * time.Minute,
|
refreshPeriod: 10 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
|
Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
@ -211,37 +230,40 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
},
|
},
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
|
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
||||||
},
|
},
|
||||||
existingSession: nil,
|
existingSession: nil,
|
||||||
expectedSession: nil,
|
expectedSession: nil,
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
|
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
|
"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
|
||||||
},
|
},
|
||||||
existingSession: nil,
|
existingSession: nil,
|
||||||
expectedSession: nil,
|
expectedSession: nil,
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("refreshSessionIfNeeded", func() {
|
Context("isSessionRefreshNeeded", func() {
|
||||||
type refreshSessionIfNeededTableInput struct {
|
type refreshSessionIfNeededTableInput struct {
|
||||||
refreshPeriod time.Duration
|
refreshPeriod time.Duration
|
||||||
session *sessionsapi.SessionState
|
session *sessionsapi.SessionState
|
||||||
@ -261,7 +283,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
s := &storedSessionLoader{
|
s := &storedSessionLoader{
|
||||||
refreshPeriod: in.refreshPeriod,
|
refreshPeriod: in.refreshPeriod,
|
||||||
store: &fakeSessionStore{},
|
store: &fakeSessionStore{},
|
||||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
refreshed = true
|
refreshed = true
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
case refresh:
|
case refresh:
|
||||||
@ -272,6 +294,17 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
return false, errors.New("error refreshing session")
|
return false, errors.New("error refreshing session")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
isRefreshNeededWithProvider: func(ss *sessionsapi.SessionState) bool {
|
||||||
|
refreshed = true
|
||||||
|
switch ss.RefreshToken {
|
||||||
|
case refresh:
|
||||||
|
return true
|
||||||
|
case noRefresh:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
},
|
||||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||||
validated = true
|
validated = true
|
||||||
return ss.AccessToken != "Invalid"
|
return ss.AccessToken != "Invalid"
|
||||||
@ -279,7 +312,10 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
err := s.refreshSessionIfNeeded(nil, req, in.session)
|
var err error
|
||||||
|
if s.isSessionRefreshNeeded(in.session) {
|
||||||
|
refreshed, err = s.refreshSession(nil, req, in.session)
|
||||||
|
}
|
||||||
if in.expectedErr != nil {
|
if in.expectedErr != nil {
|
||||||
Expect(err).To(MatchError(in.expectedErr))
|
Expect(err).To(MatchError(in.expectedErr))
|
||||||
} else {
|
} else {
|
||||||
@ -364,7 +400,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("refreshSessionWithProvider", func() {
|
Context("refreshSession", func() {
|
||||||
type refreshSessionWithProviderTableInput struct {
|
type refreshSessionWithProviderTableInput struct {
|
||||||
session *sessionsapi.SessionState
|
session *sessionsapi.SessionState
|
||||||
expectedErr error
|
expectedErr error
|
||||||
@ -388,7 +424,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
case refresh:
|
case refresh:
|
||||||
return true, nil
|
return true, nil
|
||||||
@ -401,7 +437,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session)
|
refreshed, err := s.refreshSession(nil, req, in.session)
|
||||||
if in.expectedErr != nil {
|
if in.expectedErr != nil {
|
||||||
Expect(err).To(MatchError(in.expectedErr))
|
Expect(err).To(MatchError(in.expectedErr))
|
||||||
} else {
|
} else {
|
||||||
@ -515,6 +551,17 @@ func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, e
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fakeSessionStore) LoadWithLock(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
|
if f.LoadFunc != nil {
|
||||||
|
return f.LoadFunc(req)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSessionStore) ReleaseLock(req *http.Request) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||||
if f.ClearFunc != nil {
|
if f.ClearFunc != nil {
|
||||||
return f.ClearFunc(rw, req)
|
return f.ClearFunc(rw, req)
|
||||||
|
@ -66,7 +66,14 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionStore) Lock(req *http.Request, expirationTime time.Duration) error {
|
// Load reads sessions.SessionState information from Cookies within the
|
||||||
|
// HTTP request object
|
||||||
|
func (s *SessionStore) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
|
||||||
|
return s.Load(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the session lock
|
||||||
|
func (s *SessionStore) ReleaseLock(req *http.Request) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
type Store interface {
|
type Store interface {
|
||||||
Save(context.Context, string, []byte, time.Duration) error
|
Save(context.Context, string, []byte, time.Duration) error
|
||||||
Load(context.Context, string) ([]byte, error)
|
Load(context.Context, string) ([]byte, error)
|
||||||
Lock(context.Context, string, time.Duration) error
|
LoadWithLock(context.Context, string) ([]byte, error)
|
||||||
|
ReleaseLock(context.Context, string) error
|
||||||
Clear(context.Context, string) error
|
Clear(context.Context, string) error
|
||||||
}
|
}
|
||||||
|
@ -65,16 +65,28 @@ func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock reads sessions.SessionState in a session store. It will
|
// Load reads sessions.SessionState information from a session store. It will
|
||||||
// use the session ticket from the http.Request's cookie.
|
// use the session ticket from the http.Request's cookie.
|
||||||
func (m *Manager) Lock(req *http.Request, expiration time.Duration) error {
|
func (m *Manager) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
|
||||||
|
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tckt.loadSession(func(key string) ([]byte, error) {
|
||||||
|
return m.Store.LoadWithLock(req.Context(), key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the session lock
|
||||||
|
func (m *Manager) ReleaseLock(req *http.Request) error {
|
||||||
tckt, err := decodeTicketFromRequest(req, m.Options)
|
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return tckt.lockSession(func(key string) error {
|
return tckt.releaseSession(func(key string) error {
|
||||||
return m.Store.Lock(req.Context(), key, expiration)
|
return m.Store.ReleaseLock(req.Context(), key)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,13 +139,11 @@ func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) {
|
|||||||
return sessions.DecodeSessionState(ciphertext, c, false)
|
return sessions.DecodeSessionState(ciphertext, c, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// lockSession loads a session from the disk store via the passed loadFunc
|
// releaseSession releases a potential locked session
|
||||||
// using the ticket.id as the key. It then decodes the SessionState using
|
func (t *ticket) releaseSession(loader lockFunc) error {
|
||||||
// ticket.secret to make the AES-GCM cipher.
|
|
||||||
func (t *ticket) lockSession(loader lockFunc) error {
|
|
||||||
err := loader(t.id)
|
err := loader(t.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to lock the session state with the ticket: %v", err)
|
return fmt.Errorf("failed to release session state with the ticket: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package redis
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bsm/redislock"
|
"github.com/bsm/redislock"
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
type Client interface {
|
type Client interface {
|
||||||
Get(ctx context.Context, key string) ([]byte, error)
|
Get(ctx context.Context, key string) ([]byte, error)
|
||||||
Lock(ctx context.Context, key string, expiration time.Duration) error
|
Lock(ctx context.Context, key string, expiration time.Duration) error
|
||||||
|
Unlock(ctx context.Context, key string) error
|
||||||
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
|
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
|
||||||
Del(ctx context.Context, key string) error
|
Del(ctx context.Context, key string) error
|
||||||
}
|
}
|
||||||
@ -21,20 +23,21 @@ var _ Client = (*client)(nil)
|
|||||||
type client struct {
|
type client struct {
|
||||||
*redis.Client
|
*redis.Client
|
||||||
locker *redislock.Client
|
locker *redislock.Client
|
||||||
lock *redislock.Lock
|
locks map[string]*redislock.Lock
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(c *redis.Client) Client {
|
func newClient(c *redis.Client) Client {
|
||||||
return &client{
|
return &client{
|
||||||
Client: c,
|
Client: c,
|
||||||
locker: redislock.New(c),
|
locker: redislock.New(c),
|
||||||
|
locks: map[string]*redislock.Lock{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
|
func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
|
||||||
if c.lock != nil {
|
if c.locks[key] != nil {
|
||||||
for {
|
for {
|
||||||
ttl, err := c.lock.TTL(ctx)
|
ttl, err := c.locks[key].TTL(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -47,27 +50,37 @@ func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
||||||
|
if c.locks[key] != nil {
|
||||||
|
return fmt.Errorf("locks for key %s already exists", key)
|
||||||
|
}
|
||||||
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
|
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.lock = lock
|
c.locks[key] = lock
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *client) Unlock(ctx context.Context, key string) error {
|
||||||
|
if c.locks[key] == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.locks[key].Release(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
|
func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
|
||||||
err := c.Client.Set(ctx, key, value, expiration).Err()
|
err := c.Client.Set(ctx, key, value, expiration).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.lock == nil {
|
if c.locks[key] == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.lock.Release(ctx)
|
err = c.locks[key].Release(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.lock = nil
|
c.locks = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,13 +93,14 @@ var _ Client = (*clusterClient)(nil)
|
|||||||
type clusterClient struct {
|
type clusterClient struct {
|
||||||
*redis.ClusterClient
|
*redis.ClusterClient
|
||||||
locker *redislock.Client
|
locker *redislock.Client
|
||||||
lock *redislock.Lock
|
locks map[string]*redislock.Lock
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClusterClient(c *redis.ClusterClient) Client {
|
func newClusterClient(c *redis.ClusterClient) Client {
|
||||||
return &clusterClient{
|
return &clusterClient{
|
||||||
ClusterClient: c,
|
ClusterClient: c,
|
||||||
locker: redislock.New(c),
|
locker: redislock.New(c),
|
||||||
|
locks: map[string]*redislock.Lock{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,14 +109,24 @@ func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
||||||
|
if c.locks[key] != nil {
|
||||||
|
return fmt.Errorf("locks for key %s already exists", key)
|
||||||
|
}
|
||||||
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
|
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.lock = lock
|
c.locks[key] = lock
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *clusterClient) Unlock(ctx context.Context, key string) error {
|
||||||
|
if c.locks[key] == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.locks[key].Release(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
|
func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
|
||||||
return c.ClusterClient.Set(ctx, key, value, expiration).Err()
|
return c.ClusterClient.Set(ctx, key, value, expiration).Err()
|
||||||
}
|
}
|
||||||
|
@ -54,11 +54,25 @@ func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error)
|
|||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock sessions.SessionState information from a persistence
|
// ReleaseLock sessions.SessionState information from a persistence
|
||||||
func (store *SessionStore) Lock(ctx context.Context, key string, expiration time.Duration) error {
|
func (store *SessionStore) LoadWithLock(ctx context.Context, key string) ([]byte, error) {
|
||||||
err := store.Client.Lock(ctx, key, expiration)
|
value, err := store.Client.Get(ctx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error setting redis lock: %v", err)
|
return nil, fmt.Errorf("error loading redis session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.Client.Lock(ctx, key, 200*time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error setting redis locks: %v", err)
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReleaseLock sessions.SessionState information from a persistence
|
||||||
|
func (store *SessionStore) ReleaseLock(ctx context.Context, key string) error {
|
||||||
|
err := store.Client.Unlock(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error releasing redis locks: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -160,8 +160,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
if s == nil {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,6 +176,11 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRefreshNeeded checks if the session has expired
|
||||||
|
func (p *AzureProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||||
|
return !(s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "")
|
||||||
|
}
|
||||||
|
|
||||||
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("client_id", p.ClientID)
|
params.Add("client_id", p.ClientID)
|
||||||
|
@ -244,7 +244,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
|||||||
|
|
||||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
refreshNeeded, err := p.RefreshSession(context.Background(), session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.False(t, refreshNeeded)
|
assert.False(t, refreshNeeded)
|
||||||
}
|
}
|
||||||
@ -258,7 +258,7 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
|||||||
|
|
||||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||||
_, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
_, err := p.RefreshSession(context.Background(), session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, session, nil)
|
assert.NotEqual(t, session, nil)
|
||||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||||
|
@ -123,7 +123,7 @@ func (p *GitLabProvider) SetProjectScope() {
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@ -139,6 +139,11 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRefreshNeeded checks if the session has expired
|
||||||
|
func (p *GitLabProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||||
|
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
|
||||||
|
}
|
||||||
|
|
||||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -268,8 +268,8 @@ func userInGroup(service *admin.Service, group string, email string) bool {
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
if s == nil {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,6 +295,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRefreshNeeded checks if the session has expired
|
||||||
|
func (p *GoogleProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||||
|
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
|
||||||
|
}
|
||||||
|
|
||||||
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
||||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
|
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
|
@ -115,8 +115,8 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
||||||
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
if s == nil {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,6 +129,11 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRefreshNeeded checks if the session has expired
|
||||||
|
func (p *OIDCProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
|
||||||
|
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
|
||||||
|
}
|
||||||
|
|
||||||
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
|
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
|
||||||
// Access Token and (probably) the ID Token.
|
// Access Token and (probably) the ID Token.
|
||||||
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
@ -467,7 +467,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
|||||||
User: "11223344",
|
User: "11223344",
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, refreshed, true)
|
assert.Equal(t, refreshed, true)
|
||||||
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
||||||
@ -500,7 +500,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
|||||||
Email: "changeit",
|
Email: "changeit",
|
||||||
User: "changeit",
|
User: "changeit",
|
||||||
}
|
}
|
||||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, refreshed, true)
|
assert.Equal(t, refreshed, true)
|
||||||
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
||||||
|
@ -128,10 +128,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||||
// do nothing if a refresh is not required
|
// do nothing if a refresh is not required
|
||||||
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRefreshNeeded should return true if the user's session need to be refreshed
|
||||||
|
func (p *ProviderData) IsRefreshNeeded(_ *sessions.SessionState) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||||
func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
||||||
if p.Verifier != nil {
|
if p.Verifier != nil {
|
||||||
|
@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
|
|||||||
p := &ProviderData{}
|
p := &ProviderData{}
|
||||||
|
|
||||||
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
||||||
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
|
refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
|
||||||
ExpiresOn: &expires,
|
ExpiresOn: &expires,
|
||||||
})
|
})
|
||||||
assert.Equal(t, false, refreshed)
|
assert.Equal(t, false, refreshed)
|
||||||
|
@ -16,7 +16,8 @@ type Provider interface {
|
|||||||
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
||||||
GetLoginURL(redirectURI, finalRedirect string) string
|
GetLoginURL(redirectURI, finalRedirect string) string
|
||||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
|
IsRefreshNeeded(s *sessions.SessionState) bool
|
||||||
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user