1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-17 21:17:53 +02:00

Split RefreshSessionIfNeeded in two methods and use Redis lock

This commit is contained in:
Kevin Kreitner 2021-02-22 08:33:53 +01:00
parent b942eb1582
commit 69d6fc8a08
21 changed files with 297 additions and 145 deletions

1
go.mod
View File

@ -22,6 +22,7 @@ require (
github.com/onsi/gomega v1.10.2
github.com/pierrec/lz4 v2.5.2+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v0.9.3
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.6.3
github.com/stretchr/testify v1.6.1

6
go.sum
View File

@ -24,6 +24,7 @@ github.com/alicebob/miniredis/v2 v2.13.0 h1:QPosMaxm+r6Qs+YcCtL2Z2a2RSdC9VfXJLpd
github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
@ -168,6 +169,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo=
github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s=
@ -208,13 +210,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8=
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/common v0.4.0 h1:7etb9YClo3a6HjLzfl6rIQaU+FDfi0VSX39io3aQ+DM=
github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084 h1:sofwID9zm4tzrgykg80hfFph1mryUeLRsUfoocVVmRY=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=

View File

@ -276,10 +276,11 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
}
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
ValidateSessionState: opts.GetProvider().ValidateSession,
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: opts.GetProvider().RefreshSession,
IsRefreshNeeded: opts.GetProvider().IsRefreshNeeded,
ValidateSessionState: opts.GetProvider().ValidateSession,
}))
return chain

View File

@ -2,13 +2,13 @@ package sessions
import (
"net/http"
"time"
)
// SessionStore is an interface to storing user sessions in the proxy
type SessionStore interface {
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
Load(req *http.Request) (*SessionState, error)
Lock(req *http.Request, expiration time.Duration) error
LoadWithLock(req *http.Request) (*SessionState, error)
ReleaseLock(req *http.Request) error
Clear(rw http.ResponseWriter, req *http.Request) error
}

View File

@ -24,10 +24,13 @@ type StoredSessionLoaderOptions struct {
RefreshPeriod time.Duration
// Provider based sesssion refreshing
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
// Provider based session refresh check
IsRefreshNeeded func(*sessionsapi.SessionState) bool
// Provider based session validation.
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
// If the session is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation.
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
}
@ -38,10 +41,11 @@ type StoredSessionLoaderOptions struct {
// If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
ss := &storedSessionLoader{
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
validateSessionState: opts.ValidateSessionState,
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
refreshSessionWithProvider: opts.RefreshSession,
isRefreshNeededWithProvider: opts.IsRefreshNeeded,
validateSessionState: opts.ValidateSessionState,
}
return ss.loadSession
}
@ -49,10 +53,11 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
// storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store.
type storedSessionLoader struct {
store sessionsapi.SessionStore
refreshPeriod time.Duration
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
store sessionsapi.SessionStore
refreshPeriod time.Duration
refreshSessionWithProvider func(context.Context, *sessionsapi.SessionState) (bool, error)
isRefreshNeededWithProvider func(*sessionsapi.SessionState) bool
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
}
// loadSession attempts to load a session as identified by the request cookies.
@ -98,44 +103,54 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
return nil, nil
}
err = s.refreshSessionIfNeeded(rw, req, session)
if !s.isSessionRefreshNeeded(session) {
return session, nil
}
session, err = s.store.LoadWithLock(req)
if err != nil {
return nil, err
}
if session == nil {
// No session was found in the storage, nothing more to do
return nil, nil
}
if !s.isSessionRefreshNeeded(session) {
_ = s.store.ReleaseLock(req)
return session, nil
}
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
refreshed, err := s.refreshSession(rw, req, session)
_ = s.store.ReleaseLock(req)
if err != nil {
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
}
if refreshed {
return session, nil
}
// Session wasn't refreshed, so make sure it's still valid
err = s.validateSession(req.Context(), session)
if err != nil {
return nil, err
}
return session, nil
}
// refreshSessionIfNeeded will attempt to refresh a session if the session
// is older than the refresh period.
// It is assumed that if the provider refreshes the session, the session is now
// valid.
// If the session requires refreshing but the provider does not refresh it,
// we must validate the session to ensure that the returned session is still
// valid.
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
// Refresh is disabled or the session is not old enough, do nothing
return nil
// isSessionRefreshNeeded will check if the session need to be refreshed
func (s *storedSessionLoader) isSessionRefreshNeeded(session *sessionsapi.SessionState) bool {
if s.refreshPeriod > time.Duration(0) && session.Age() >= s.refreshPeriod {
return s.isRefreshNeededWithProvider(session)
}
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
if err != nil {
return err
}
if !refreshed {
// Session wasn't refreshed, so make sure it's still valid
return s.validateSession(req.Context(), session)
}
return nil
return false
}
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
// refreshSession attempts to refresh the sessinon with the provider
// and will save the session if it was updated.
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
refreshed, err := s.refreshSessionWithProvider(req.Context(), session)
if err != nil {
return false, fmt.Errorf("error refreshing access token: %v", err)
}

View File

@ -39,6 +39,17 @@ var _ = Describe("Stored Session Suite", func() {
}
}
var defaultIsRefreshNeededFunc = func(ss *sessionsapi.SessionState) bool {
switch ss.RefreshToken {
case refresh:
return true
case noRefresh:
return false
default:
return false
}
}
var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool {
return ss.AccessToken != "Invalid"
}
@ -86,13 +97,14 @@ var _ = Describe("Stored Session Suite", func() {
}
type storedSessionLoaderTableInput struct {
requestHeaders http.Header
existingSession *sessionsapi.SessionState
expectedSession *sessionsapi.SessionState
store sessionsapi.SessionStore
refreshPeriod time.Duration
refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
validateSession func(context.Context, *sessionsapi.SessionState) bool
requestHeaders http.Header
existingSession *sessionsapi.SessionState
expectedSession *sessionsapi.SessionState
store sessionsapi.SessionStore
refreshPeriod time.Duration
refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
isRefreshSessionNeeded func(*sessionsapi.SessionState) bool
validateSession func(context.Context, *sessionsapi.SessionState) bool
}
DescribeTable("when serving a request",
@ -109,10 +121,11 @@ var _ = Describe("Stored Session Suite", func() {
rw := httptest.NewRecorder()
opts := &StoredSessionLoaderOptions{
SessionStore: in.store,
RefreshPeriod: in.refreshPeriod,
RefreshSessionIfNeeded: in.refreshSession,
ValidateSessionState: in.validateSession,
SessionStore: in.store,
RefreshPeriod: in.refreshPeriod,
RefreshSession: in.refreshSession,
IsRefreshNeeded: in.isRefreshSessionNeeded,
ValidateSessionState: in.validateSession,
}
// Create the handler with a next handler that will capture the session
@ -126,24 +139,26 @@ var _ = Describe("Stored Session Suite", func() {
Expect(gotSession).To(Equal(in.expectedSession))
},
Entry("with no cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
requestHeaders: http.Header{},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("with an invalid cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NonExistent"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("with an existing session", storedSessionLoaderTableInput{
requestHeaders: http.Header{
@ -155,10 +170,11 @@ var _ = Describe("Stored Session Suite", func() {
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Existing",
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that has not expired", storedSessionLoaderTableInput{
requestHeaders: http.Header{
@ -170,21 +186,23 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{
@ -196,10 +214,11 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
},
store: defaultSessionStore,
refreshPeriod: 10 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
store: defaultSessionStore,
refreshPeriod: 10 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{
@ -211,37 +230,40 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
validateSession: defaultValidateFunc,
}),
)
})
Context("refreshSessionIfNeeded", func() {
Context("isSessionRefreshNeeded", func() {
type refreshSessionIfNeededTableInput struct {
refreshPeriod time.Duration
session *sessionsapi.SessionState
@ -261,7 +283,7 @@ var _ = Describe("Stored Session Suite", func() {
s := &storedSessionLoader{
refreshPeriod: in.refreshPeriod,
store: &fakeSessionStore{},
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
refreshed = true
switch ss.RefreshToken {
case refresh:
@ -272,6 +294,17 @@ var _ = Describe("Stored Session Suite", func() {
return false, errors.New("error refreshing session")
}
},
isRefreshNeededWithProvider: func(ss *sessionsapi.SessionState) bool {
refreshed = true
switch ss.RefreshToken {
case refresh:
return true
case noRefresh:
return false
default:
return false
}
},
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
validated = true
return ss.AccessToken != "Invalid"
@ -279,7 +312,10 @@ var _ = Describe("Stored Session Suite", func() {
}
req := httptest.NewRequest("", "/", nil)
err := s.refreshSessionIfNeeded(nil, req, in.session)
var err error
if s.isSessionRefreshNeeded(in.session) {
refreshed, err = s.refreshSession(nil, req, in.session)
}
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr))
} else {
@ -364,7 +400,7 @@ var _ = Describe("Stored Session Suite", func() {
)
})
Context("refreshSessionWithProvider", func() {
Context("refreshSession", func() {
type refreshSessionWithProviderTableInput struct {
session *sessionsapi.SessionState
expectedErr error
@ -388,7 +424,7 @@ var _ = Describe("Stored Session Suite", func() {
return nil
},
},
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken {
case refresh:
return true, nil
@ -401,7 +437,7 @@ var _ = Describe("Stored Session Suite", func() {
}
req := httptest.NewRequest("", "/", nil)
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session)
refreshed, err := s.refreshSession(nil, req, in.session)
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr))
} else {
@ -515,6 +551,17 @@ func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, e
return nil, nil
}
func (f *fakeSessionStore) LoadWithLock(req *http.Request) (*sessionsapi.SessionState, error) {
if f.LoadFunc != nil {
return f.LoadFunc(req)
}
return nil, nil
}
func (f *fakeSessionStore) ReleaseLock(req *http.Request) error {
return nil
}
func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
if f.ClearFunc != nil {
return f.ClearFunc(rw, req)

View File

@ -66,7 +66,14 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
return session, nil
}
func (s *SessionStore) Lock(req *http.Request, expirationTime time.Duration) error {
// Load reads sessions.SessionState information from Cookies within the
// HTTP request object
func (s *SessionStore) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
return s.Load(req)
}
// Release the session lock
func (s *SessionStore) ReleaseLock(req *http.Request) error {
return nil
}

View File

@ -11,6 +11,7 @@ import (
type Store interface {
Save(context.Context, string, []byte, time.Duration) error
Load(context.Context, string) ([]byte, error)
Lock(context.Context, string, time.Duration) error
LoadWithLock(context.Context, string) ([]byte, error)
ReleaseLock(context.Context, string) error
Clear(context.Context, string) error
}

View File

@ -65,16 +65,28 @@ func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) {
})
}
// Lock reads sessions.SessionState in a session store. It will
// Load reads sessions.SessionState information from a session store. It will
// use the session ticket from the http.Request's cookie.
func (m *Manager) Lock(req *http.Request, expiration time.Duration) error {
func (m *Manager) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
tckt, err := decodeTicketFromRequest(req, m.Options)
if err != nil {
return nil, err
}
return tckt.loadSession(func(key string) ([]byte, error) {
return m.Store.LoadWithLock(req.Context(), key)
})
}
// Release the session lock
func (m *Manager) ReleaseLock(req *http.Request) error {
tckt, err := decodeTicketFromRequest(req, m.Options)
if err != nil {
return err
}
return tckt.lockSession(func(key string) error {
return m.Store.Lock(req.Context(), key, expiration)
return tckt.releaseSession(func(key string) error {
return m.Store.ReleaseLock(req.Context(), key)
})
}

View File

@ -139,13 +139,11 @@ func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) {
return sessions.DecodeSessionState(ciphertext, c, false)
}
// lockSession loads a session from the disk store via the passed loadFunc
// using the ticket.id as the key. It then decodes the SessionState using
// ticket.secret to make the AES-GCM cipher.
func (t *ticket) lockSession(loader lockFunc) error {
// releaseSession releases a potential locked session
func (t *ticket) releaseSession(loader lockFunc) error {
err := loader(t.id)
if err != nil {
return fmt.Errorf("failed to lock the session state with the ticket: %v", err)
return fmt.Errorf("failed to release session state with the ticket: %v", err)
}
return nil
}

View File

@ -2,6 +2,7 @@ package redis
import (
"context"
"fmt"
"time"
"github.com/bsm/redislock"
@ -12,6 +13,7 @@ import (
type Client interface {
Get(ctx context.Context, key string) ([]byte, error)
Lock(ctx context.Context, key string, expiration time.Duration) error
Unlock(ctx context.Context, key string) error
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
Del(ctx context.Context, key string) error
}
@ -21,20 +23,21 @@ var _ Client = (*client)(nil)
type client struct {
*redis.Client
locker *redislock.Client
lock *redislock.Lock
locks map[string]*redislock.Lock
}
func newClient(c *redis.Client) Client {
return &client{
Client: c,
locker: redislock.New(c),
locks: map[string]*redislock.Lock{},
}
}
func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
if c.lock != nil {
if c.locks[key] != nil {
for {
ttl, err := c.lock.TTL(ctx)
ttl, err := c.locks[key].TTL(ctx)
if err != nil {
return nil, err
}
@ -47,27 +50,37 @@ func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
}
func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error {
if c.locks[key] != nil {
return fmt.Errorf("locks for key %s already exists", key)
}
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
if err != nil {
return err
}
c.lock = lock
c.locks[key] = lock
return nil
}
func (c *client) Unlock(ctx context.Context, key string) error {
if c.locks[key] == nil {
return nil
}
return c.locks[key].Release(ctx)
}
func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
err := c.Client.Set(ctx, key, value, expiration).Err()
if err != nil {
return err
}
if c.lock == nil {
if c.locks[key] == nil {
return nil
}
err = c.lock.Release(ctx)
err = c.locks[key].Release(ctx)
if err != nil {
return err
}
c.lock = nil
c.locks = nil
return nil
}
@ -80,13 +93,14 @@ var _ Client = (*clusterClient)(nil)
type clusterClient struct {
*redis.ClusterClient
locker *redislock.Client
lock *redislock.Lock
locks map[string]*redislock.Lock
}
func newClusterClient(c *redis.ClusterClient) Client {
return &clusterClient{
ClusterClient: c,
locker: redislock.New(c),
locks: map[string]*redislock.Lock{},
}
}
@ -95,14 +109,24 @@ func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) {
}
func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error {
if c.locks[key] != nil {
return fmt.Errorf("locks for key %s already exists", key)
}
lock, err := c.locker.Obtain(ctx, key, expiration, nil)
if err != nil {
return err
}
c.lock = lock
c.locks[key] = lock
return nil
}
func (c *clusterClient) Unlock(ctx context.Context, key string) error {
if c.locks[key] == nil {
return nil
}
return c.locks[key].Release(ctx)
}
func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
return c.ClusterClient.Set(ctx, key, value, expiration).Err()
}

View File

@ -54,11 +54,25 @@ func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error)
return value, nil
}
// Lock sessions.SessionState information from a persistence
func (store *SessionStore) Lock(ctx context.Context, key string, expiration time.Duration) error {
err := store.Client.Lock(ctx, key, expiration)
// ReleaseLock sessions.SessionState information from a persistence
func (store *SessionStore) LoadWithLock(ctx context.Context, key string) ([]byte, error) {
value, err := store.Client.Get(ctx, key)
if err != nil {
return fmt.Errorf("error setting redis lock: %v", err)
return nil, fmt.Errorf("error loading redis session: %v", err)
}
err = store.Client.Lock(ctx, key, 200*time.Millisecond)
if err != nil {
return nil, fmt.Errorf("error setting redis locks: %v", err)
}
return value, nil
}
// ReleaseLock sessions.SessionState information from a persistence
func (store *SessionStore) ReleaseLock(ctx context.Context, key string) error {
err := store.Client.Unlock(ctx, key)
if err != nil {
return fmt.Errorf("error releasing redis locks: %v", err)
}
return nil
}

View File

@ -160,8 +160,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil {
return false, nil
}
@ -176,6 +176,11 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.
return true, nil
}
// IsRefreshNeeded checks if the session has expired
func (p *AzureProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "")
}
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
params := url.Values{}
params.Add("client_id", p.ClientID)

View File

@ -244,7 +244,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
expires := time.Now().Add(time.Duration(1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
refreshNeeded, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err)
assert.False(t, refreshNeeded)
}
@ -258,7 +258,7 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
expires := time.Now().Add(time.Duration(-1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
_, err := p.RefreshSessionIfNeeded(context.Background(), session)
_, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil)
assert.Equal(t, "new_some_access_token", session.AccessToken)

View File

@ -123,7 +123,7 @@ func (p *GitLabProvider) SetProjectScope() {
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil
}
@ -139,6 +139,11 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return true, nil
}
// IsRefreshNeeded checks if the session has expired
func (p *GitLabProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
}
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
clientSecret, err := p.GetClientSecret()
if err != nil {

View File

@ -268,8 +268,8 @@ func userInGroup(service *admin.Service, group string, email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil {
return false, nil
}
@ -295,6 +295,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return true, nil
}
// IsRefreshNeeded checks if the session has expired
func (p *GoogleProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
}
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
clientSecret, err := p.GetClientSecret()

View File

@ -115,8 +115,8 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access Token (and optional ID token) if required
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil {
return false, nil
}
@ -129,6 +129,11 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
return true, nil
}
// IsRefreshNeeded checks if the session has expired
func (p *OIDCProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
}
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
// Access Token and (probably) the ID Token.
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {

View File

@ -467,7 +467,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
User: "11223344",
}
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true)
assert.Equal(t, "janedoe@example.com", existingSession.Email)
@ -500,7 +500,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
Email: "changeit",
User: "changeit",
}
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true)
assert.Equal(t, defaultIDToken.Email, existingSession.Email)

View File

@ -128,10 +128,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
// RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
return false, nil
}
// IsRefreshNeeded should return true if the user's session need to be refreshed
func (p *ProviderData) IsRefreshNeeded(_ *sessions.SessionState) bool {
return false
}
// CreateSessionFromToken converts Bearer IDTokens into sessions
func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
if p.Verifier != nil {

View File

@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
p := &ProviderData{}
expires := time.Now().Add(time.Duration(-11) * time.Minute)
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
ExpiresOn: &expires,
})
assert.Equal(t, false, refreshed)

View File

@ -16,7 +16,8 @@ type Provider interface {
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
IsRefreshNeeded(s *sessions.SessionState) bool
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
}