1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-08-08 22:46:33 +02:00

Implement graceful shutdown and propagate request context (#468)

* feature: Implement graceful shutdown

Propagate the request context to the Redis client.
It is possible to propagate a context cancel to Redis client if the connection is closed by the HTTP client.
The redis.Cmdable cannot use WithContext, so added the Client interface to handle redis.Client and redis.ClusterClient transparently.

Added handling of Unix signals to http server.

Upgrade go-redis/redis to v7.

* Update dependencies

- Upgrade golang/x/* and google-api-go
- Migrate fsnotify import from gopkg.in to github.com
- Replace bmizerany/assert with stretchr/testify/assert

* add doc for  wrapper interface

* Update CHANGELOG.md

* fix: upgrade fsnotify to v1.4.9

* fix: remove unnessary logging

* fix: wait until  all connections have been closed

* refactor: move chan to main for testing

* add assert to check if stop chan is empty

* add an idiomatic for sync.WaitGroup with timeout
This commit is contained in:
Mitsuo Heijo
2020-04-05 00:12:38 +09:00
committed by GitHub
parent bdc686103e
commit c7bfbdecef
11 changed files with 177 additions and 55 deletions

View File

@ -0,0 +1,59 @@
package redis
import (
"context"
"time"
"github.com/go-redis/redis/v7"
)
// Client is wrapper interface for redis.Client and redis.ClusterClient.
type Client interface {
Get(ctx context.Context, key string) ([]byte, error)
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
Del(ctx context.Context, key string) error
}
var _ Client = (*client)(nil)
type client struct {
*redis.Client
}
func newClient(c *redis.Client) Client {
return &client{Client: c}
}
func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
return c.WithContext(ctx).Get(key).Bytes()
}
func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
return c.WithContext(ctx).Set(key, value, expiration).Err()
}
func (c *client) Del(ctx context.Context, key string) error {
return c.WithContext(ctx).Del(key).Err()
}
var _ Client = (*clusterClient)(nil)
type clusterClient struct {
*redis.ClusterClient
}
func newClusterClient(c *redis.ClusterClient) Client {
return &clusterClient{ClusterClient: c}
}
func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) {
return c.WithContext(ctx).Get(key).Bytes()
}
func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
return c.WithContext(ctx).Set(key, value, expiration).Err()
}
func (c *clusterClient) Del(ctx context.Context, key string) error {
return c.WithContext(ctx).Del(key).Err()
}

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
@ -14,7 +15,7 @@ import (
"strings"
"time"
"github.com/go-redis/redis"
"github.com/go-redis/redis/v7"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
@ -33,19 +34,19 @@ type TicketData struct {
type SessionStore struct {
CookieCipher *encryption.Cipher
CookieOptions *options.CookieOptions
Cmdable redis.Cmdable
Client Client
}
// NewRedisSessionStore initialises a new instance of the SessionStore from
// the configuration given
func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
cmdable, err := newRedisCmdable(opts.RedisStoreOptions)
client, err := newRedisCmdable(opts.RedisStoreOptions)
if err != nil {
return nil, fmt.Errorf("error constructing redis client: %v", err)
}
rs := &SessionStore{
Cmdable: cmdable,
Client: client,
CookieCipher: opts.Cipher,
CookieOptions: cookieOpts,
}
@ -53,7 +54,7 @@ func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cook
}
func newRedisCmdable(opts options.RedisStoreOptions) (redis.Cmdable, error) {
func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) {
if opts.UseSentinel && opts.UseCluster {
return nil, fmt.Errorf("options redis-use-sentinel and redis-use-cluster are mutually exclusive")
}
@ -63,14 +64,14 @@ func newRedisCmdable(opts options.RedisStoreOptions) (redis.Cmdable, error) {
MasterName: opts.SentinelMasterName,
SentinelAddrs: opts.SentinelConnectionURLs,
})
return client, nil
return newClient(client), nil
}
if opts.UseCluster {
client := redis.NewClusterClient(&redis.ClusterOptions{
Addrs: opts.ClusterConnectionURLs,
})
return client, nil
return newClusterClient(client), nil
}
opt, err := redis.ParseURL(opts.RedisConnectionURL)
@ -104,7 +105,7 @@ func newRedisCmdable(opts options.RedisStoreOptions) (redis.Cmdable, error) {
}
client := redis.NewClient(opt)
return client, nil
return newClient(client), nil
}
// Save takes a sessions.SessionState and stores the information from it
@ -121,7 +122,8 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se
if err != nil {
return err
}
ticketString, err := store.storeValue(value, store.CookieOptions.CookieExpire, requestCookie)
ctx := req.Context()
ticketString, err := store.storeValue(ctx, value, store.CookieOptions.CookieExpire, requestCookie)
if err != nil {
return err
}
@ -149,7 +151,8 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro
if !ok {
return nil, fmt.Errorf("Cookie Signature not valid")
}
session, err := store.loadSessionFromString(val)
ctx := req.Context()
session, err := store.loadSessionFromString(ctx, val)
if err != nil {
return nil, fmt.Errorf("error loading session: %s", err)
}
@ -157,18 +160,17 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro
}
// loadSessionFromString loads the session based on the ticket value
func (store *SessionStore) loadSessionFromString(value string) (*sessions.SessionState, error) {
func (store *SessionStore) loadSessionFromString(ctx context.Context, value string) (*sessions.SessionState, error) {
ticket, err := decodeTicket(store.CookieOptions.CookieName, value)
if err != nil {
return nil, err
}
result, err := store.Cmdable.Get(ticket.asHandle(store.CookieOptions.CookieName)).Result()
resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.CookieOptions.CookieName))
if err != nil {
return nil, err
}
resultBytes := []byte(result)
block, err := aes.NewCipher(ticket.Secret)
if err != nil {
return nil, err
@ -214,7 +216,8 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
// If there's an issue decoding the ticket, ignore it
ticket, _ := decodeTicket(store.CookieOptions.CookieName, val)
if ticket != nil {
_, err := store.Cmdable.Del(ticket.asHandle(store.CookieOptions.CookieName)).Result()
ctx := req.Context()
err := store.Client.Del(ctx, ticket.asHandle(store.CookieOptions.CookieName))
if err != nil {
return fmt.Errorf("error clearing cookie from redis: %s", err)
}
@ -237,7 +240,7 @@ func (store *SessionStore) makeCookie(req *http.Request, value string, expires t
)
}
func (store *SessionStore) storeValue(value string, expiration time.Duration, requestCookie *http.Cookie) (string, error) {
func (store *SessionStore) storeValue(ctx context.Context, value string, expiration time.Duration, requestCookie *http.Cookie) (string, error) {
ticket, err := store.getTicket(requestCookie)
if err != nil {
return "", fmt.Errorf("error getting ticket: %v", err)
@ -254,7 +257,7 @@ func (store *SessionStore) storeValue(value string, expiration time.Duration, re
stream.XORKeyStream(ciphertext, []byte(value))
handle := ticket.asHandle(store.CookieOptions.CookieName)
err = store.Cmdable.Set(handle, ciphertext, expiration).Err()
err = store.Client.Set(ctx, handle, ciphertext, expiration)
if err != nil {
return "", err
}
@ -290,7 +293,7 @@ func newTicket() (*TicketData, error) {
return nil, fmt.Errorf("failed to create new ticket ID %s", err)
}
// ticketID is hex encoded
ticketID := fmt.Sprintf("%x", rawID)
ticketID := hex.EncodeToString(rawID)
secret := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(rand.Reader, secret); err != nil {