1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-12-17 23:48:13 +02:00

Use ErrNotImplemented in default refresh implementation

This commit is contained in:
Nick Meves
2021-06-12 11:41:03 -07:00
parent baf6cf3816
commit ff914d7e17
5 changed files with 54 additions and 29 deletions

View File

@@ -11,19 +11,20 @@ import (
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
)
// StoredSessionLoaderOptions cotnains all of the requirements to construct
// StoredSessionLoaderOptions contains all of the requirements to construct
// a stored session loader.
// All options must be provided.
type StoredSessionLoaderOptions struct {
// Session storage basckend
// Session storage backend
SessionStore sessionsapi.SessionStore
// How often should sessions be refreshed
RefreshPeriod time.Duration
// Provider based sesssion refreshing
// Provider based session refreshing
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
// Provider based session validation.
@@ -115,7 +116,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
return nil
}
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
err := s.refreshSession(rw, req, session)
if err != nil {
// If a preemptive refresh fails, we still keep the session
@@ -131,21 +132,27 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
// and will save the session if it was updated.
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
refreshed, err := s.sessionRefresher(req.Context(), session)
if err != nil {
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
return fmt.Errorf("error refreshing tokens: %v", err)
}
// HACK:
// Providers that don't implement `RefreshSession` use the default
// implementation which returns `ErrNotImplemented`.
// Pretend it refreshed to reset the refresh timer so that `ValidateSession`
// isn't triggered every subsequent request and is only called once during
// this request.
if errors.Is(err, providers.ErrNotImplemented) {
refreshed = true
}
// Session not refreshed, nothing to persist.
if !refreshed {
return nil
}
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
//
// HACK:
// Providers that don't implement `RefreshSession` use the default
// implementation. It always returns `refreshed == true`, so the
// `session.CreatedAt` is updated and doesn't trigger `ValidateSession`
// every subsequent request.
// (In case underlying provider implementations forget)
session.CreatedAtNow()
// Because the session was refreshed, make sure to save it