diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index b5abc1fd..e036f693 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -108,6 +108,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h } session, err = s.store.LoadWithLock(req) + defer s.store.ReleaseLock(req) if err != nil { return nil, err } @@ -117,12 +118,10 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h } if !s.isSessionRefreshNeeded(session) { - _ = s.store.ReleaseLock(req) return session, nil } logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) refreshed, err := s.refreshSession(rw, req, session) - _ = s.store.ReleaseLock(req) if err != nil { return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) } diff --git a/pkg/sessions/redis/client.go b/pkg/sessions/redis/client.go index 2510dbbb..9e598026 100644 --- a/pkg/sessions/redis/client.go +++ b/pkg/sessions/redis/client.go @@ -51,7 +51,7 @@ func (c *client) Get(ctx context.Context, key string) ([]byte, error) { func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error { if c.locks[key] != nil { - return fmt.Errorf("locks for key %s already exists", key) + return fmt.Errorf("lock for key %s already exists", key) } lock, err := c.locker.Obtain(ctx, key, expiration, nil) if err != nil { @@ -65,23 +65,13 @@ func (c *client) Unlock(ctx context.Context, key string) error { if c.locks[key] == nil { return nil } - return c.locks[key].Release(ctx) + err := c.locks[key].Release(ctx) + delete(c.locks, key) + return err } func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error { - err := c.Client.Set(ctx, key, value, expiration).Err() - if err != nil { - return err - } - if c.locks[key] == nil { - return nil - } - err = c.locks[key].Release(ctx) - if err != nil { - return err - } - c.locks = nil - return nil + return c.Client.Set(ctx, key, value, expiration).Err() } func (c *client) Del(ctx context.Context, key string) error { @@ -105,12 +95,23 @@ func newClusterClient(c *redis.ClusterClient) Client { } func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) { + if c.locks[key] != nil { + for { + ttl, err := c.locks[key].TTL(ctx) + if err != nil { + return nil, err + } + if ttl <= 0 { + break + } + } + } return c.ClusterClient.Get(ctx, key).Bytes() } func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error { if c.locks[key] != nil { - return fmt.Errorf("locks for key %s already exists", key) + return fmt.Errorf("lock for key %s already exists", key) } lock, err := c.locker.Obtain(ctx, key, expiration, nil) if err != nil { @@ -124,7 +125,9 @@ func (c *clusterClient) Unlock(ctx context.Context, key string) error { if c.locks[key] == nil { return nil } - return c.locks[key].Release(ctx) + err := c.locks[key].Release(ctx) + delete(c.locks, key) + return err } func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error { diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index b5fe662f..84394029 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -54,16 +54,17 @@ func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error) return value, nil } -// ReleaseLock sessions.SessionState information from a persistence +// Load reads sessions.SessionState information from a persistence +// cookie within the HTTP request object and locks it on redis func (store *SessionStore) LoadWithLock(ctx context.Context, key string) ([]byte, error) { value, err := store.Client.Get(ctx, key) if err != nil { return nil, fmt.Errorf("error loading redis session: %v", err) } - err = store.Client.Lock(ctx, key, 200*time.Millisecond) + err = store.Client.Lock(ctx, key, 300*time.Millisecond) if err != nil { - return nil, fmt.Errorf("error setting redis locks: %v", err) + return nil, fmt.Errorf("error setting redis lock: %v", err) } return value, nil }