1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-05-31 23:19:50 +02:00

feat: readiness check (#1839)

* feat: readiness check

* fix: no need for query param

* docs: add a note

* chore: move the readyness check to its own endpoint

* docs(cr): add godoc

Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
This commit is contained in:
Kobi Meirson 2022-12-23 11:08:12 +02:00 committed by GitHub
parent 8b77c97009
commit f753ec1ca5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 382 additions and 35 deletions

View File

@ -12,8 +12,10 @@
- [#1882](https://github.com/oauth2-proxy/oauth2-proxy/pull/1882) Make `htpasswd.GetUsers` racecondition safe - [#1882](https://github.com/oauth2-proxy/oauth2-proxy/pull/1882) Make `htpasswd.GetUsers` racecondition safe
- [#1883](https://github.com/oauth2-proxy/oauth2-proxy/pull/1883) Ensure v8 manifest variant is set on docker images - [#1883](https://github.com/oauth2-proxy/oauth2-proxy/pull/1883) Ensure v8 manifest variant is set on docker images
- [#1906](https://github.com/oauth2-proxy/oauth2-proxy/pull/1906) Fix PKCE code verifier generation to never use UTF-8 characters - [#1906](https://github.com/oauth2-proxy/oauth2-proxy/pull/1906) Fix PKCE code verifier generation to never use UTF-8 characters
- [#1839](https://github.com/oauth2-proxy/oauth2-proxy/pull/1839) Add readiness checks for deeper health checks (@kobim)
- [#1927](https://github.com/oauth2-proxy/oauth2-proxy/pull/1927) Fix default scope settings for none oidc providers - [#1927](https://github.com/oauth2-proxy/oauth2-proxy/pull/1927) Fix default scope settings for none oidc providers
# V7.4.0 # V7.4.0
## Release Highlights ## Release Highlights

View File

@ -449,11 +449,11 @@ spec:
timeoutSeconds: 1 timeoutSeconds: 1
readinessProbe: readinessProbe:
httpGet: httpGet:
path: /ping path: /ready
port: http port: http
scheme: HTTP scheme: HTTP
initialDelaySeconds: 0 initialDelaySeconds: 0
timeoutSeconds: 1 timeoutSeconds: 5
successThreshold: 1 successThreshold: 1
periodSeconds: 10 periodSeconds: 10
resources: resources:

View File

@ -24,7 +24,7 @@ _oauth2_proxy() {
COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) ) COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) )
return 0 return 0
;; ;;
--@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors)) --@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|ready-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors))
return 0 return 0
;; ;;
esac esac

View File

@ -155,6 +155,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
| `--provider-display-name` | string | Override the provider's name with the given string; used for the sign-in page | (depends on provider) | | `--provider-display-name` | string | Override the provider's name with the given string; used for the sign-in page | (depends on provider) |
| `--ping-path` | string | the ping endpoint that can be used for basic health checks | `"/ping"` | | `--ping-path` | string | the ping endpoint that can be used for basic health checks | `"/ping"` |
| `--ping-user-agent` | string | a User-Agent that can be used for basic health checks | `""` (don't check user agent) | | `--ping-user-agent` | string | a User-Agent that can be used for basic health checks | `""` (don't check user agent) |
| `--ready-path` | string | the ready endpoint that can be used for deep health checks | `"/ready"` |
| `--metrics-address` | string | the address prometheus metrics will be scraped from | `""` | | `--metrics-address` | string | the address prometheus metrics will be scraped from | `""` |
| `--proxy-prefix` | string | the url root path that this proxy should be nested under (e.g. /`<oauth2>/sign_in`) | `"/oauth2"` | | `--proxy-prefix` | string | the url root path that this proxy should be nested under (e.g. /`<oauth2>/sign_in`) | `"/oauth2"` |
| `--proxy-websockets` | bool | enables WebSocket proxying | true | | `--proxy-websockets` | bool | enables WebSocket proxying | true |
@ -184,7 +185,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
| `--set-basic-auth` | bool | set HTTP Basic Auth information in response (useful in Nginx auth_request mode) | false | | `--set-basic-auth` | bool | set HTTP Basic Auth information in response (useful in Nginx auth_request mode) | false |
| `--show-debug-on-error` | bool | show detailed error information on error pages (WARNING: this may contain sensitive information - do not use in production) | false | | `--show-debug-on-error` | bool | show detailed error information on error pages (WARNING: this may contain sensitive information - do not use in production) | false |
| `--signature-key` | string | GAP-Signature request signature key (algorithm:secretkey) | | | `--signature-key` | string | GAP-Signature request signature key (algorithm:secretkey) | |
| `--silence-ping-logging` | bool | disable logging of requests to ping endpoint | false | | `--silence-ping-logging` | bool | disable logging of requests to ping & ready endpoints | false |
| `--skip-auth-preflight` | bool | will skip authentication for OPTIONS requests | false | | `--skip-auth-preflight` | bool | will skip authentication for OPTIONS requests | false |
| `--skip-auth-regex` | string \| list | (DEPRECATED for `--skip-auth-route`) bypass authentication for requests paths that match (may be given multiple times) | | | `--skip-auth-regex` | string \| list | (DEPRECATED for `--skip-auth-route`) bypass authentication for requests paths that match (may be given multiple times) | |
| `--skip-auth-route` | string \| list | bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex | | | `--skip-auth-route` | string \| list | bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex | |
@ -246,7 +247,7 @@ There are three different types of logging: standard, authentication, and HTTP r
Each type of logging has its own configurable format and variables. By default these formats are similar to the Apache Combined Log. Each type of logging has its own configurable format and variables. By default these formats are similar to the Apache Combined Log.
Logging of requests to the `/ping` endpoint (or using `--ping-user-agent`) can be disabled with `--silence-ping-logging` reducing log volume. This flag appends the `--ping-path` to `--exclude-logging-paths`. Logging of requests to the `/ping` endpoint (or using `--ping-user-agent`) and the `/ready` endpoint can be disabled with `--silence-ping-logging` reducing log volume.
### Auth Log Format ### Auth Log Format
Authentication logs are logs which are guaranteed to contain a username or email address of a user attempting to authenticate. These logs are output by default in the below format: Authentication logs are logs which are guaranteed to contain a username or email address of a user attempting to authenticate. These logs are output by default in the below format:

View File

@ -7,6 +7,7 @@ OAuth2 Proxy responds directly to the following endpoints. All other endpoints w
- /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info - /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info
- /ping - returns a 200 OK response, which is intended for use with health checks - /ping - returns a 200 OK response, which is intended for use with health checks
- /ready - returns a 200 OK response if all the underlying connections (e.g., Redis store) are connected
- /metrics - Metrics endpoint for Prometheus to scrape, serve on the address specified by `--metrics-address`, disabled by default - /metrics - Metrics endpoint for Prometheus to scrape, serve on the address specified by `--metrics-address`, disabled by default
- /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) - /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies)
- /oauth2/sign_out - this URL is used to clear the session cookie - /oauth2/sign_out - this URL is used to clear the session cookie

View File

@ -185,7 +185,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, err return nil, err
} }
preAuthChain, err := buildPreAuthChain(opts) preAuthChain, err := buildPreAuthChain(opts, sessionStore)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err) return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
} }
@ -327,7 +327,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {
// buildPreAuthChain constructs a chain that should process every request before // buildPreAuthChain constructs a chain that should process every request before
// the OAuth2 Proxy authentication logic kicks in. // the OAuth2 Proxy authentication logic kicks in.
// For example forcing HTTPS or health checks. // For example forcing HTTPS or health checks.
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader))
if opts.ForceHTTPS { if opts.ForceHTTPS {
@ -351,12 +351,14 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
if opts.Logging.SilencePing { if opts.Logging.SilencePing {
chain = chain.Append( chain = chain.Append(
middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
middleware.NewReadynessCheck(opts.ReadyPath, sessionStore),
middleware.NewRequestLogger(), middleware.NewRequestLogger(),
) )
} else { } else {
chain = chain.Append( chain = chain.Append(
middleware.NewRequestLogger(), middleware.NewRequestLogger(),
middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
middleware.NewReadynessCheck(opts.ReadyPath, sessionStore),
) )
} }

View File

@ -49,6 +49,7 @@ var _ = Describe("Load", func() {
Options: Options{ Options: Options{
ProxyPrefix: "/oauth2", ProxyPrefix: "/oauth2",
PingPath: "/ping", PingPath: "/ping",
ReadyPath: "/ready",
RealClientIPHeader: "X-Real-IP", RealClientIPHeader: "X-Real-IP",
ForceHTTPS: false, ForceHTTPS: false,
Cookie: cookieDefaults(), Cookie: cookieDefaults(),

View File

@ -43,7 +43,7 @@ func loggingFlagSet() *pflag.FlagSet {
flagSet.StringSlice("exclude-logging-path", []string{}, "Exclude logging requests to paths (eg: '/path1,/path2,/path3')") flagSet.StringSlice("exclude-logging-path", []string{}, "Exclude logging requests to paths (eg: '/path1,/path2,/path3')")
flagSet.Bool("logging-local-time", true, "If the time in log files and backup filenames are local or UTC time") flagSet.Bool("logging-local-time", true, "If the time in log files and backup filenames are local or UTC time")
flagSet.Bool("silence-ping-logging", false, "Disable logging of requests to ping endpoint") flagSet.Bool("silence-ping-logging", false, "Disable logging of requests to ping & ready endpoints")
flagSet.String("request-id-header", "X-Request-Id", "Request header to use as the request ID") flagSet.String("request-id-header", "X-Request-Id", "Request header to use as the request ID")
flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") flagSet.String("logging-filename", "", "File to log requests to, empty for stdout")

View File

@ -21,6 +21,7 @@ type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"` ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"`
PingPath string `flag:"ping-path" cfg:"ping_path"` PingPath string `flag:"ping-path" cfg:"ping_path"`
PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"` PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"`
ReadyPath string `flag:"ready-path" cfg:"ready_path"`
ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"` ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"`
RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"` RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"`
TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"` TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"`
@ -96,6 +97,7 @@ func NewOptions() *Options {
ProxyPrefix: "/oauth2", ProxyPrefix: "/oauth2",
Providers: providerDefaults(), Providers: providerDefaults(),
PingPath: "/ping", PingPath: "/ping",
ReadyPath: "/ready",
RealClientIPHeader: "X-Real-IP", RealClientIPHeader: "X-Real-IP",
ForceHTTPS: false, ForceHTTPS: false,
Cookie: cookieDefaults(), Cookie: cookieDefaults(),
@ -133,6 +135,7 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)")
flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks")
flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks")
flagSet.String("ready-path", "/ready", "the ready endpoint that can be used for deep health checks")
flagSet.String("session-store-type", "cookie", "the session storage provider to use") flagSet.String("session-store-type", "cookie", "the session storage provider to use")
flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)")
flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])")

View File

@ -12,6 +12,7 @@ type SessionStore interface {
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
Load(req *http.Request) (*SessionState, error) Load(req *http.Request) (*SessionState, error)
Clear(rw http.ResponseWriter, req *http.Request) error Clear(rw http.ResponseWriter, req *http.Request) error
VerifyConnection(ctx context.Context) error
} }
var ErrLockNotObtained = errors.New("lock: not obtained") var ErrLockNotObtained = errors.New("lock: not obtained")

View File

@ -0,0 +1,40 @@
package middleware
import (
"context"
"fmt"
"net/http"
"github.com/justinas/alice"
)
// Verifiable an interface for an object that has a connection to external
// data source and exports a function to validate that connection
type Verifiable interface {
VerifyConnection(context.Context) error
}
// NewReadynessCheck returns a middleware that performs deep health checks
// (verifies the connection to any underlying store) on a specific `path`
func NewReadynessCheck(path string, verifiable Verifiable) alice.Constructor {
return func(next http.Handler) http.Handler {
return readynessCheck(path, verifiable, next)
}
}
func readynessCheck(path string, verifiable Verifiable, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if path != "" && req.URL.EscapedPath() == path {
if err := verifiable.VerifyConnection(req.Context()); err != nil {
rw.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(rw, "error: %v", err)
return
}
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK")
return
}
next.ServeHTTP(rw, req)
})
}

View File

@ -0,0 +1,84 @@
package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
var _ = Describe("ReadynessCheck suite", func() {
type requestTableInput struct {
readyPath string
healthVerifiable Verifiable
requestString string
expectedStatus int
expectedBody string
}
DescribeTable("when serving a request",
func(in *requestTableInput) {
req := httptest.NewRequest("", in.requestString, nil)
rw := httptest.NewRecorder()
handler := NewReadynessCheck(in.readyPath, in.healthVerifiable)(http.NotFoundHandler())
handler.ServeHTTP(rw, req)
Expect(rw.Code).To(Equal(in.expectedStatus))
Expect(rw.Body.String()).To(Equal(in.expectedBody))
},
Entry("when requesting the readyness check path", &requestTableInput{
readyPath: "/ready",
healthVerifiable: &fakeVerifiable{nil},
requestString: "http://example.com/ready",
expectedStatus: 200,
expectedBody: "OK",
}),
Entry("when requesting a different path", &requestTableInput{
readyPath: "/ready",
healthVerifiable: &fakeVerifiable{nil},
requestString: "http://example.com/different",
expectedStatus: 404,
expectedBody: "404 page not found\n",
}),
Entry("when a blank string is configured as a readyness check path and the request has no specific path", &requestTableInput{
readyPath: "",
healthVerifiable: &fakeVerifiable{nil},
requestString: "http://example.com",
expectedStatus: 404,
expectedBody: "404 page not found\n",
}),
Entry("with full health check and without an underlying error", &requestTableInput{
readyPath: "/ready",
healthVerifiable: &fakeVerifiable{nil},
requestString: "http://example.com/ready",
expectedStatus: 200,
expectedBody: "OK",
}),
Entry("with full health check and with an underlying error", &requestTableInput{
readyPath: "/ready",
healthVerifiable: &fakeVerifiable{func(ctx context.Context) error { return errors.New("failed to check") }},
requestString: "http://example.com/ready",
expectedStatus: 500,
expectedBody: "error: failed to check",
}),
)
})
type fakeVerifiable struct {
mock func(context.Context) error
}
func (v *fakeVerifiable) VerifyConnection(ctx context.Context) error {
if v.mock != nil {
return v.mock(ctx)
}
return nil
}
var _ Verifiable = (*fakeVerifiable)(nil)

View File

@ -793,3 +793,7 @@ func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
} }
return nil return nil
} }
func (f *fakeSessionStore) VerifyConnection(_ context.Context) error {
return nil
}

View File

@ -1,6 +1,7 @@
package cookie package cookie
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -82,6 +83,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
return nil return nil
} }
// VerifyConnection always return no-error, as there's no connection
// in this store
func (s *SessionStore) VerifyConnection(_ context.Context) error {
return nil
}
// cookieForSession serializes a session state for storage in a cookie // cookieForSession serializes a session state for storage in a cookie
func (s *SessionStore) cookieForSession(ss *sessions.SessionState) ([]byte, error) { func (s *SessionStore) cookieForSession(ss *sessions.SessionState) ([]byte, error) {
if s.Minimal && (ss.AccessToken != "" || ss.IDToken != "" || ss.RefreshToken != "") { if s.Minimal && (ss.AccessToken != "" || ss.IDToken != "" || ss.RefreshToken != "") {

View File

@ -15,4 +15,5 @@ type Store interface {
Load(context.Context, string) ([]byte, error) Load(context.Context, string) ([]byte, error)
Clear(context.Context, string) error Clear(context.Context, string) error
Lock(key string) sessions.Lock Lock(key string) sessions.Lock
VerifyConnection(context.Context) error
} }

View File

@ -1,6 +1,7 @@
package persistence package persistence
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
@ -90,3 +91,8 @@ func (m *Manager) Clear(rw http.ResponseWriter, req *http.Request) error {
return m.Store.Clear(req.Context(), key) return m.Store.Clear(req.Context(), key)
}) })
} }
// VerifyConnection validates the underlying store is ready and connected
func (m *Manager) VerifyConnection(ctx context.Context) error {
return m.Store.VerifyConnection(ctx)
}

View File

@ -14,6 +14,7 @@ type Client interface {
Lock(key string) sessions.Lock Lock(key string) sessions.Lock
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
Del(ctx context.Context, key string) error Del(ctx context.Context, key string) error
Ping(ctx context.Context) error
} }
var _ Client = (*client)(nil) var _ Client = (*client)(nil)
@ -44,6 +45,10 @@ func (c *client) Lock(key string) sessions.Lock {
return NewLock(c.Client, key) return NewLock(c.Client, key)
} }
func (c *client) Ping(ctx context.Context) error {
return c.Client.Ping(ctx).Err()
}
var _ Client = (*clusterClient)(nil) var _ Client = (*clusterClient)(nil)
type clusterClient struct { type clusterClient struct {
@ -71,3 +76,7 @@ func (c *clusterClient) Del(ctx context.Context, key string) error {
func (c *clusterClient) Lock(key string) sessions.Lock { func (c *clusterClient) Lock(key string) sessions.Lock {
return NewLock(c.ClusterClient, key) return NewLock(c.ClusterClient, key)
} }
func (c *clusterClient) Ping(ctx context.Context) error {
return c.ClusterClient.Ping(ctx).Err()
}

View File

@ -0,0 +1,159 @@
package redis_test
import (
"context"
"encoding/base64"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/redis"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Redis Client Tests", func() {
Context("with basic client", func() {
RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions {
return options.RedisStoreOptions{
ConnectionURL: "redis://" + mr.Addr(),
}
})
})
Context("with cluster client", func() {
RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions {
return options.RedisStoreOptions{
ClusterConnectionURLs: []string{"redis://" + mr.Addr()},
UseCluster: true,
}
})
})
})
type getOptsFunc func(mr *miniredis.Miniredis) options.RedisStoreOptions
func RunClientTests(getOptsFunc getOptsFunc) {
var mr *miniredis.Miniredis
var client redis.Client
var err error
var key string
var ctx context.Context
BeforeEach(func() {
mr, err = miniredis.Run()
Expect(err).ToNot(HaveOccurred())
client, err = redis.NewRedisClient(getOptsFunc(mr))
Expect(err).ToNot(HaveOccurred())
nonce, err := encryption.Nonce(32)
Expect(err).ToNot(HaveOccurred())
key = base64.RawURLEncoding.EncodeToString(nonce)
ctx = context.Background()
})
AfterEach(func() {
if mr != nil {
mr.Close()
mr = nil
}
})
Context("when Get is called", func() {
expectedValue := []byte("value")
BeforeEach(func() {
client.Set(context.Background(), key, expectedValue, time.Duration(1*time.Minute))
})
It("returns the saved value", func() {
value, err := client.Get(ctx, key)
Expect(err).ToNot(HaveOccurred())
Expect(value).To(Equal(value))
})
It("does not return expired values", func() {
mr.FastForward(5 * time.Minute)
_, err = client.Get(ctx, key)
Expect(err).To(HaveOccurred())
})
It("returns an error if value does not exist", func() {
_, err = client.Get(ctx, "does-not-exists")
Expect(err).To(HaveOccurred())
})
})
Context("using Lock", func() {
It("maintains the lock", func() {
lock := client.Lock(key)
err = lock.Obtain(ctx, 1*time.Minute)
Expect(err).ToNot(HaveOccurred())
isLocked, err := lock.Peek(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(isLocked).To(BeTrue())
err = lock.Release(ctx)
Expect(err).ToNot(HaveOccurred())
})
It("reflects non-locked instance", func() {
lock := client.Lock(key)
isLocked, err := lock.Peek(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(isLocked).To(BeFalse())
})
})
Context("when Set is called", func() {
expectedValue := []byte("value")
It("sets the expected value", func() {
err = client.Set(ctx, key, expectedValue, 1*time.Minute)
Expect(err).ToNot(HaveOccurred())
value, err := client.Get(ctx, key)
Expect(value).To(Equal(expectedValue))
Expect(err).ToNot(HaveOccurred())
})
})
Context("when Del is called", func() {
It("does not return an error when key exists", func() {
err = client.Set(ctx, key, []byte("dummy"), 1*time.Minute)
Expect(err).ToNot(HaveOccurred())
err = client.Del(ctx, key)
Expect(err).ToNot(HaveOccurred())
_, err = client.Get(ctx, key)
Expect(err).To(HaveOccurred())
})
})
Context("when Ping is called", func() {
Context("when redis is up", func() {
It("does not return an error", func() {
err = client.Ping(ctx)
Expect(err).ToNot(HaveOccurred())
})
})
Context("when redis is down", func() {
It("returns an error", func() {
mr.Close()
mr = nil
err = client.Ping(ctx)
Expect(err).To(HaveOccurred())
})
})
})
}

View File

@ -70,6 +70,12 @@ func (store *SessionStore) Lock(key string) sessions.Lock {
return store.Client.Lock(key) return store.Client.Lock(key)
} }
// VerifyConnection verifies the redis connection is valid and the
// server is responsive
func (store *SessionStore) VerifyConnection(ctx context.Context) error {
return store.Client.Ping(ctx)
}
// NewRedisClient makes a redis.Client (either standalone, sentinel aware, or // NewRedisClient makes a redis.Client (either standalone, sentinel aware, or
// redis cluster) // redis cluster)
func NewRedisClient(opts options.RedisStoreOptions) (Client, error) { func NewRedisClient(opts options.RedisStoreOptions) (Client, error) {
@ -205,3 +211,5 @@ func parseRedisURLs(urls []string) ([]string, *redis.Options, error) {
} }
return addrs, redisOptions, nil return addrs, redisOptions, nil
} }
var _ persistence.Store = (*SessionStore)(nil)

View File

@ -2,20 +2,15 @@ package redis
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/pem" "encoding/pem"
"log"
"os" "os"
"testing"
"time" "time"
"github.com/Bose/minisentinel" "github.com/Bose/minisentinel"
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v9"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/persistence" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/persistence"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/tests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/tests"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
@ -25,28 +20,6 @@ import (
const redisPassword = "0123456789abcdefghijklmnopqrstuv" const redisPassword = "0123456789abcdefghijklmnopqrstuv"
// wrappedRedisLogger wraps a logger so that we can coerce the logger to
// fit the expected signature for go-redis logging
type wrappedRedisLogger struct {
*log.Logger
}
func (l *wrappedRedisLogger) Printf(_ context.Context, format string, v ...interface{}) {
l.Logger.Printf(format, v...)
}
func TestSessionStore(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
redisLogger := &wrappedRedisLogger{Logger: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)}
redisLogger.SetOutput(GinkgoWriter)
redis.SetLogger(redisLogger)
RegisterFailHandler(Fail)
RunSpecs(t, "Redis SessionStore")
}
var ( var (
cert tls.Certificate cert tls.Certificate
caPath string caPath string

View File

@ -0,0 +1,35 @@
package redis_test
import (
"context"
"log"
"os"
"testing"
"github.com/go-redis/redis/v9"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// wrappedRedisLogger wraps a logger so that we can coerce the logger to
// fit the expected signature for go-redis logging
type wrappedRedisLogger struct {
*log.Logger
}
func (l *wrappedRedisLogger) Printf(_ context.Context, format string, v ...interface{}) {
l.Logger.Printf(format, v...)
}
func TestRedis(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
redisLogger := &wrappedRedisLogger{Logger: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)}
redisLogger.SetOutput(GinkgoWriter)
redis.SetLogger(redisLogger)
RegisterFailHandler(Fail)
RunSpecs(t, "Redis")
}

View File

@ -65,6 +65,10 @@ func (s *MockStore) Lock(key string) sessions.Lock {
return lock return lock
} }
func (s *MockStore) VerifyConnection(_ context.Context) error {
return nil
}
// FastForward simulates the flow of time to test expirations // FastForward simulates the flow of time to test expirations
func (s *MockStore) FastForward(duration time.Duration) { func (s *MockStore) FastForward(duration time.Duration) {
for _, mockLock := range s.lockCache { for _, mockLock := range s.lockCache {

View File

@ -485,6 +485,12 @@ func SessionStoreInterfaceTests(in *testInput) {
}) })
}) })
}) })
Context("when VerifyConnection is called", func() {
It("should return without an error", func() {
Expect(in.ss().VerifyConnection(in.request.Context())).ToNot(HaveOccurred())
})
})
} }
func LoadSessionTests(in *testInput) { func LoadSessionTests(in *testInput) {