You've already forked oauth2-proxy
							
							
				mirror of
				https://github.com/oauth2-proxy/oauth2-proxy.git
				synced 2025-10-30 23:47:52 +02:00 
			
		
		
		
	feat: readiness check (#1839)
* feat: readiness check * fix: no need for query param * docs: add a note * chore: move the readyness check to its own endpoint * docs(cr): add godoc Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
This commit is contained in:
		| @@ -12,8 +12,10 @@ | ||||
| - [#1882](https://github.com/oauth2-proxy/oauth2-proxy/pull/1882) Make `htpasswd.GetUsers` racecondition safe | ||||
| - [#1883](https://github.com/oauth2-proxy/oauth2-proxy/pull/1883) Ensure v8 manifest variant is set on docker images | ||||
| - [#1906](https://github.com/oauth2-proxy/oauth2-proxy/pull/1906) Fix PKCE code verifier generation to never use UTF-8 characters | ||||
| - [#1839](https://github.com/oauth2-proxy/oauth2-proxy/pull/1839) Add readiness checks for deeper health checks (@kobim) | ||||
| - [#1927](https://github.com/oauth2-proxy/oauth2-proxy/pull/1927) Fix default scope settings for none oidc providers | ||||
|  | ||||
|  | ||||
| # V7.4.0 | ||||
|  | ||||
| ## Release Highlights | ||||
|   | ||||
| @@ -449,11 +449,11 @@ spec: | ||||
|           timeoutSeconds: 1 | ||||
|         readinessProbe: | ||||
|           httpGet: | ||||
|             path: /ping | ||||
|             path: /ready | ||||
|             port: http | ||||
|             scheme: HTTP | ||||
|           initialDelaySeconds: 0 | ||||
|           timeoutSeconds: 1 | ||||
|           timeoutSeconds: 5 | ||||
|           successThreshold: 1 | ||||
|           periodSeconds: 10 | ||||
|         resources: | ||||
|   | ||||
| @@ -24,7 +24,7 @@ _oauth2_proxy() { | ||||
| 			COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) ) | ||||
| 			return 0 | ||||
| 			;; | ||||
| 		--@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors)) | ||||
| 		--@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|ready-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors)) | ||||
| 			return 0 | ||||
| 			;; | ||||
| 	esac | ||||
|   | ||||
| @@ -155,6 +155,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | ||||
| | `--provider-display-name` | string | Override the provider's name with the given string; used for the sign-in page | (depends on provider) | | ||||
| | `--ping-path` | string | the ping endpoint that can be used for basic health checks | `"/ping"` | | ||||
| | `--ping-user-agent` | string | a User-Agent that can be used for basic health checks | `""` (don't check user agent) | | ||||
| | `--ready-path` | string | the ready endpoint that can be used for deep health checks | `"/ready"` | | ||||
| | `--metrics-address` | string | the address prometheus metrics will be scraped from | `""` | | ||||
| | `--proxy-prefix` | string | the url root path that this proxy should be nested under (e.g. /`<oauth2>/sign_in`) | `"/oauth2"` | | ||||
| | `--proxy-websockets` | bool | enables WebSocket proxying | true | | ||||
| @@ -184,7 +185,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | ||||
| | `--set-basic-auth` | bool | set HTTP Basic Auth information in response (useful in Nginx auth_request mode) | false | | ||||
| | `--show-debug-on-error` | bool | show detailed error information on error pages (WARNING: this may contain sensitive information - do not use in production) | false | | ||||
| | `--signature-key` | string | GAP-Signature request signature key (algorithm:secretkey) | | | ||||
| | `--silence-ping-logging` | bool | disable logging of requests to ping endpoint | false | | ||||
| | `--silence-ping-logging` | bool | disable logging of requests to ping & ready endpoints | false | | ||||
| | `--skip-auth-preflight` | bool | will skip authentication for OPTIONS requests | false | | ||||
| | `--skip-auth-regex` | string \| list | (DEPRECATED for `--skip-auth-route`) bypass authentication for requests paths that match (may be given multiple times) | | | ||||
| | `--skip-auth-route` | string \| list | bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex  | | | ||||
| @@ -246,7 +247,7 @@ There are three different types of logging: standard, authentication, and HTTP r | ||||
|  | ||||
| Each type of logging has its own configurable format and variables. By default these formats are similar to the Apache Combined Log. | ||||
|  | ||||
| Logging of requests to the `/ping` endpoint (or using `--ping-user-agent`) can be disabled with `--silence-ping-logging` reducing log volume. This flag appends the `--ping-path` to `--exclude-logging-paths`. | ||||
| Logging of requests to the `/ping` endpoint (or using `--ping-user-agent`) and the `/ready` endpoint can be disabled with `--silence-ping-logging` reducing log volume. | ||||
|  | ||||
| ### Auth Log Format | ||||
| Authentication logs are logs which are guaranteed to contain a username or email address of a user attempting to authenticate. These logs are output by default in the below format: | ||||
|   | ||||
| @@ -7,6 +7,7 @@ OAuth2 Proxy responds directly to the following endpoints. All other endpoints w | ||||
|  | ||||
| - /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info | ||||
| - /ping - returns a 200 OK response, which is intended for use with health checks | ||||
| - /ready - returns a 200 OK response if all the underlying connections (e.g., Redis store) are connected | ||||
| - /metrics - Metrics endpoint for Prometheus to scrape, serve on the address specified by `--metrics-address`, disabled by default | ||||
| - /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) | ||||
| - /oauth2/sign_out - this URL is used to clear the session cookie | ||||
|   | ||||
| @@ -185,7 +185,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	preAuthChain, err := buildPreAuthChain(opts) | ||||
| 	preAuthChain, err := buildPreAuthChain(opts, sessionStore) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not build pre-auth chain: %v", err) | ||||
| 	} | ||||
| @@ -327,7 +327,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { | ||||
| // buildPreAuthChain constructs a chain that should process every request before | ||||
| // the OAuth2 Proxy authentication logic kicks in. | ||||
| // For example forcing HTTPS or health checks. | ||||
| func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||
| func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) { | ||||
| 	chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) | ||||
|  | ||||
| 	if opts.ForceHTTPS { | ||||
| @@ -351,12 +351,14 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||
| 	if opts.Logging.SilencePing { | ||||
| 		chain = chain.Append( | ||||
| 			middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), | ||||
| 			middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), | ||||
| 			middleware.NewRequestLogger(), | ||||
| 		) | ||||
| 	} else { | ||||
| 		chain = chain.Append( | ||||
| 			middleware.NewRequestLogger(), | ||||
| 			middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), | ||||
| 			middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -49,6 +49,7 @@ var _ = Describe("Load", func() { | ||||
| 		Options: Options{ | ||||
| 			ProxyPrefix:        "/oauth2", | ||||
| 			PingPath:           "/ping", | ||||
| 			ReadyPath:          "/ready", | ||||
| 			RealClientIPHeader: "X-Real-IP", | ||||
| 			ForceHTTPS:         false, | ||||
| 			Cookie:             cookieDefaults(), | ||||
|   | ||||
| @@ -43,7 +43,7 @@ func loggingFlagSet() *pflag.FlagSet { | ||||
|  | ||||
| 	flagSet.StringSlice("exclude-logging-path", []string{}, "Exclude logging requests to paths (eg: '/path1,/path2,/path3')") | ||||
| 	flagSet.Bool("logging-local-time", true, "If the time in log files and backup filenames are local or UTC time") | ||||
| 	flagSet.Bool("silence-ping-logging", false, "Disable logging of requests to ping endpoint") | ||||
| 	flagSet.Bool("silence-ping-logging", false, "Disable logging of requests to ping & ready endpoints") | ||||
| 	flagSet.String("request-id-header", "X-Request-Id", "Request header to use as the request ID") | ||||
|  | ||||
| 	flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") | ||||
|   | ||||
| @@ -21,6 +21,7 @@ type Options struct { | ||||
| 	ProxyPrefix        string   `flag:"proxy-prefix" cfg:"proxy_prefix"` | ||||
| 	PingPath           string   `flag:"ping-path" cfg:"ping_path"` | ||||
| 	PingUserAgent      string   `flag:"ping-user-agent" cfg:"ping_user_agent"` | ||||
| 	ReadyPath          string   `flag:"ready-path" cfg:"ready_path"` | ||||
| 	ReverseProxy       bool     `flag:"reverse-proxy" cfg:"reverse_proxy"` | ||||
| 	RealClientIPHeader string   `flag:"real-client-ip-header" cfg:"real_client_ip_header"` | ||||
| 	TrustedIPs         []string `flag:"trusted-ip" cfg:"trusted_ips"` | ||||
| @@ -96,6 +97,7 @@ func NewOptions() *Options { | ||||
| 		ProxyPrefix:        "/oauth2", | ||||
| 		Providers:          providerDefaults(), | ||||
| 		PingPath:           "/ping", | ||||
| 		ReadyPath:          "/ready", | ||||
| 		RealClientIPHeader: "X-Real-IP", | ||||
| 		ForceHTTPS:         false, | ||||
| 		Cookie:             cookieDefaults(), | ||||
| @@ -133,6 +135,7 @@ func NewFlagSet() *pflag.FlagSet { | ||||
| 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | ||||
| 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | ||||
| 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | ||||
| 	flagSet.String("ready-path", "/ready", "the ready endpoint that can be used for deep health checks") | ||||
| 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | ||||
| 	flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") | ||||
| 	flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") | ||||
|   | ||||
| @@ -12,6 +12,7 @@ type SessionStore interface { | ||||
| 	Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error | ||||
| 	Load(req *http.Request) (*SessionState, error) | ||||
| 	Clear(rw http.ResponseWriter, req *http.Request) error | ||||
| 	VerifyConnection(ctx context.Context) error | ||||
| } | ||||
|  | ||||
| var ErrLockNotObtained = errors.New("lock: not obtained") | ||||
|   | ||||
							
								
								
									
										40
									
								
								pkg/middleware/readynesscheck.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								pkg/middleware/readynesscheck.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/justinas/alice" | ||||
| ) | ||||
|  | ||||
| // Verifiable an interface for an object that has a connection to external | ||||
| // data source and exports a function to validate that connection | ||||
| type Verifiable interface { | ||||
| 	VerifyConnection(context.Context) error | ||||
| } | ||||
|  | ||||
| // NewReadynessCheck returns a middleware that performs deep health checks | ||||
| // (verifies the connection to any underlying store) on a specific `path` | ||||
| func NewReadynessCheck(path string, verifiable Verifiable) alice.Constructor { | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return readynessCheck(path, verifiable, next) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func readynessCheck(path string, verifiable Verifiable, next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		if path != "" && req.URL.EscapedPath() == path { | ||||
| 			if err := verifiable.VerifyConnection(req.Context()); err != nil { | ||||
| 				rw.WriteHeader(http.StatusInternalServerError) | ||||
| 				fmt.Fprintf(rw, "error: %v", err) | ||||
| 				return | ||||
| 			} | ||||
| 			rw.WriteHeader(http.StatusOK) | ||||
| 			fmt.Fprintf(rw, "OK") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										84
									
								
								pkg/middleware/readynesscheck_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								pkg/middleware/readynesscheck_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
|  | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
|  | ||||
| var _ = Describe("ReadynessCheck suite", func() { | ||||
| 	type requestTableInput struct { | ||||
| 		readyPath        string | ||||
| 		healthVerifiable Verifiable | ||||
| 		requestString    string | ||||
| 		expectedStatus   int | ||||
| 		expectedBody     string | ||||
| 	} | ||||
|  | ||||
| 	DescribeTable("when serving a request", | ||||
| 		func(in *requestTableInput) { | ||||
| 			req := httptest.NewRequest("", in.requestString, nil) | ||||
|  | ||||
| 			rw := httptest.NewRecorder() | ||||
|  | ||||
| 			handler := NewReadynessCheck(in.readyPath, in.healthVerifiable)(http.NotFoundHandler()) | ||||
| 			handler.ServeHTTP(rw, req) | ||||
|  | ||||
| 			Expect(rw.Code).To(Equal(in.expectedStatus)) | ||||
| 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | ||||
| 		}, | ||||
| 		Entry("when requesting the readyness check path", &requestTableInput{ | ||||
| 			readyPath:        "/ready", | ||||
| 			healthVerifiable: &fakeVerifiable{nil}, | ||||
| 			requestString:    "http://example.com/ready", | ||||
| 			expectedStatus:   200, | ||||
| 			expectedBody:     "OK", | ||||
| 		}), | ||||
| 		Entry("when requesting a different path", &requestTableInput{ | ||||
| 			readyPath:        "/ready", | ||||
| 			healthVerifiable: &fakeVerifiable{nil}, | ||||
| 			requestString:    "http://example.com/different", | ||||
| 			expectedStatus:   404, | ||||
| 			expectedBody:     "404 page not found\n", | ||||
| 		}), | ||||
| 		Entry("when a blank string is configured as a readyness check path and the request has no specific path", &requestTableInput{ | ||||
| 			readyPath:        "", | ||||
| 			healthVerifiable: &fakeVerifiable{nil}, | ||||
| 			requestString:    "http://example.com", | ||||
| 			expectedStatus:   404, | ||||
| 			expectedBody:     "404 page not found\n", | ||||
| 		}), | ||||
| 		Entry("with full health check and without an underlying error", &requestTableInput{ | ||||
| 			readyPath:        "/ready", | ||||
| 			healthVerifiable: &fakeVerifiable{nil}, | ||||
| 			requestString:    "http://example.com/ready", | ||||
| 			expectedStatus:   200, | ||||
| 			expectedBody:     "OK", | ||||
| 		}), | ||||
| 		Entry("with full health check and with an underlying error", &requestTableInput{ | ||||
| 			readyPath:        "/ready", | ||||
| 			healthVerifiable: &fakeVerifiable{func(ctx context.Context) error { return errors.New("failed to check") }}, | ||||
| 			requestString:    "http://example.com/ready", | ||||
| 			expectedStatus:   500, | ||||
| 			expectedBody:     "error: failed to check", | ||||
| 		}), | ||||
| 	) | ||||
| }) | ||||
|  | ||||
| type fakeVerifiable struct { | ||||
| 	mock func(context.Context) error | ||||
| } | ||||
|  | ||||
| func (v *fakeVerifiable) VerifyConnection(ctx context.Context) error { | ||||
| 	if v.mock != nil { | ||||
| 		return v.mock(ctx) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| var _ Verifiable = (*fakeVerifiable)(nil) | ||||
| @@ -793,3 +793,7 @@ func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (f *fakeSessionStore) VerifyConnection(_ context.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package cookie | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -82,6 +83,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // VerifyConnection always return no-error, as there's no connection | ||||
| // in this store | ||||
| func (s *SessionStore) VerifyConnection(_ context.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // cookieForSession serializes a session state for storage in a cookie | ||||
| func (s *SessionStore) cookieForSession(ss *sessions.SessionState) ([]byte, error) { | ||||
| 	if s.Minimal && (ss.AccessToken != "" || ss.IDToken != "" || ss.RefreshToken != "") { | ||||
|   | ||||
| @@ -15,4 +15,5 @@ type Store interface { | ||||
| 	Load(context.Context, string) ([]byte, error) | ||||
| 	Clear(context.Context, string) error | ||||
| 	Lock(key string) sessions.Lock | ||||
| 	VerifyConnection(context.Context) error | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package persistence | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| @@ -90,3 +91,8 @@ func (m *Manager) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| 		return m.Store.Clear(req.Context(), key) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // VerifyConnection validates the underlying store is ready and connected | ||||
| func (m *Manager) VerifyConnection(ctx context.Context) error { | ||||
| 	return m.Store.VerifyConnection(ctx) | ||||
| } | ||||
|   | ||||
| @@ -14,6 +14,7 @@ type Client interface { | ||||
| 	Lock(key string) sessions.Lock | ||||
| 	Set(ctx context.Context, key string, value []byte, expiration time.Duration) error | ||||
| 	Del(ctx context.Context, key string) error | ||||
| 	Ping(ctx context.Context) error | ||||
| } | ||||
|  | ||||
| var _ Client = (*client)(nil) | ||||
| @@ -44,6 +45,10 @@ func (c *client) Lock(key string) sessions.Lock { | ||||
| 	return NewLock(c.Client, key) | ||||
| } | ||||
|  | ||||
| func (c *client) Ping(ctx context.Context) error { | ||||
| 	return c.Client.Ping(ctx).Err() | ||||
| } | ||||
|  | ||||
| var _ Client = (*clusterClient)(nil) | ||||
|  | ||||
| type clusterClient struct { | ||||
| @@ -71,3 +76,7 @@ func (c *clusterClient) Del(ctx context.Context, key string) error { | ||||
| func (c *clusterClient) Lock(key string) sessions.Lock { | ||||
| 	return NewLock(c.ClusterClient, key) | ||||
| } | ||||
|  | ||||
| func (c *clusterClient) Ping(ctx context.Context) error { | ||||
| 	return c.ClusterClient.Ping(ctx).Err() | ||||
| } | ||||
|   | ||||
							
								
								
									
										159
									
								
								pkg/sessions/redis/client_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								pkg/sessions/redis/client_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| package redis_test | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/base64" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/alicebob/miniredis/v2" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/redis" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
|  | ||||
| var _ = Describe("Redis Client Tests", func() { | ||||
| 	Context("with basic client", func() { | ||||
| 		RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions { | ||||
| 			return options.RedisStoreOptions{ | ||||
| 				ConnectionURL: "redis://" + mr.Addr(), | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	Context("with cluster client", func() { | ||||
| 		RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions { | ||||
| 			return options.RedisStoreOptions{ | ||||
| 				ClusterConnectionURLs: []string{"redis://" + mr.Addr()}, | ||||
| 				UseCluster:            true, | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|  | ||||
| type getOptsFunc func(mr *miniredis.Miniredis) options.RedisStoreOptions | ||||
|  | ||||
| func RunClientTests(getOptsFunc getOptsFunc) { | ||||
| 	var mr *miniredis.Miniredis | ||||
| 	var client redis.Client | ||||
| 	var err error | ||||
| 	var key string | ||||
| 	var ctx context.Context | ||||
|  | ||||
| 	BeforeEach(func() { | ||||
| 		mr, err = miniredis.Run() | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
|  | ||||
| 		client, err = redis.NewRedisClient(getOptsFunc(mr)) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
|  | ||||
| 		nonce, err := encryption.Nonce(32) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 		key = base64.RawURLEncoding.EncodeToString(nonce) | ||||
|  | ||||
| 		ctx = context.Background() | ||||
| 	}) | ||||
|  | ||||
| 	AfterEach(func() { | ||||
| 		if mr != nil { | ||||
| 			mr.Close() | ||||
| 			mr = nil | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	Context("when Get is called", func() { | ||||
| 		expectedValue := []byte("value") | ||||
|  | ||||
| 		BeforeEach(func() { | ||||
| 			client.Set(context.Background(), key, expectedValue, time.Duration(1*time.Minute)) | ||||
| 		}) | ||||
|  | ||||
| 		It("returns the saved value", func() { | ||||
| 			value, err := client.Get(ctx, key) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(value).To(Equal(value)) | ||||
| 		}) | ||||
|  | ||||
| 		It("does not return expired values", func() { | ||||
| 			mr.FastForward(5 * time.Minute) | ||||
|  | ||||
| 			_, err = client.Get(ctx, key) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 		}) | ||||
|  | ||||
| 		It("returns an error if value does not exist", func() { | ||||
| 			_, err = client.Get(ctx, "does-not-exists") | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	Context("using Lock", func() { | ||||
| 		It("maintains the lock", func() { | ||||
| 			lock := client.Lock(key) | ||||
|  | ||||
| 			err = lock.Obtain(ctx, 1*time.Minute) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
|  | ||||
| 			isLocked, err := lock.Peek(ctx) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(isLocked).To(BeTrue()) | ||||
|  | ||||
| 			err = lock.Release(ctx) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 		}) | ||||
|  | ||||
| 		It("reflects non-locked instance", func() { | ||||
| 			lock := client.Lock(key) | ||||
|  | ||||
| 			isLocked, err := lock.Peek(ctx) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(isLocked).To(BeFalse()) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	Context("when Set is called", func() { | ||||
| 		expectedValue := []byte("value") | ||||
|  | ||||
| 		It("sets the expected value", func() { | ||||
| 			err = client.Set(ctx, key, expectedValue, 1*time.Minute) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
|  | ||||
| 			value, err := client.Get(ctx, key) | ||||
| 			Expect(value).To(Equal(expectedValue)) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	Context("when Del is called", func() { | ||||
| 		It("does not return an error when key exists", func() { | ||||
| 			err = client.Set(ctx, key, []byte("dummy"), 1*time.Minute) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
|  | ||||
| 			err = client.Del(ctx, key) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
|  | ||||
| 			_, err = client.Get(ctx, key) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	Context("when Ping is called", func() { | ||||
| 		Context("when redis is up", func() { | ||||
| 			It("does not return an error", func() { | ||||
| 				err = client.Ping(ctx) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			}) | ||||
| 		}) | ||||
|  | ||||
| 		Context("when redis is down", func() { | ||||
| 			It("returns an error", func() { | ||||
| 				mr.Close() | ||||
| 				mr = nil | ||||
|  | ||||
| 				err = client.Ping(ctx) | ||||
| 				Expect(err).To(HaveOccurred()) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
| @@ -70,6 +70,12 @@ func (store *SessionStore) Lock(key string) sessions.Lock { | ||||
| 	return store.Client.Lock(key) | ||||
| } | ||||
|  | ||||
| // VerifyConnection verifies the redis connection is valid and the | ||||
| // server is responsive | ||||
| func (store *SessionStore) VerifyConnection(ctx context.Context) error { | ||||
| 	return store.Client.Ping(ctx) | ||||
| } | ||||
|  | ||||
| // NewRedisClient makes a redis.Client (either standalone, sentinel aware, or | ||||
| // redis cluster) | ||||
| func NewRedisClient(opts options.RedisStoreOptions) (Client, error) { | ||||
| @@ -205,3 +211,5 @@ func parseRedisURLs(urls []string) ([]string, *redis.Options, error) { | ||||
| 	} | ||||
| 	return addrs, redisOptions, nil | ||||
| } | ||||
|  | ||||
| var _ persistence.Store = (*SessionStore)(nil) | ||||
|   | ||||
| @@ -2,20 +2,15 @@ package redis | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/pem" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/Bose/minisentinel" | ||||
| 	"github.com/alicebob/miniredis/v2" | ||||
| 	"github.com/go-redis/redis/v9" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/persistence" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/tests" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" | ||||
| @@ -25,28 +20,6 @@ import ( | ||||
|  | ||||
| const redisPassword = "0123456789abcdefghijklmnopqrstuv" | ||||
|  | ||||
| // wrappedRedisLogger wraps a logger so that we can coerce the logger to | ||||
| // fit the expected signature for go-redis logging | ||||
| type wrappedRedisLogger struct { | ||||
| 	*log.Logger | ||||
| } | ||||
|  | ||||
| func (l *wrappedRedisLogger) Printf(_ context.Context, format string, v ...interface{}) { | ||||
| 	l.Logger.Printf(format, v...) | ||||
| } | ||||
|  | ||||
| func TestSessionStore(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 	logger.SetErrOutput(GinkgoWriter) | ||||
|  | ||||
| 	redisLogger := &wrappedRedisLogger{Logger: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)} | ||||
| 	redisLogger.SetOutput(GinkgoWriter) | ||||
| 	redis.SetLogger(redisLogger) | ||||
|  | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "Redis SessionStore") | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	cert   tls.Certificate | ||||
| 	caPath string | ||||
|   | ||||
							
								
								
									
										35
									
								
								pkg/sessions/redis/redis_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								pkg/sessions/redis/redis_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| package redis_test | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v9" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
|  | ||||
| // wrappedRedisLogger wraps a logger so that we can coerce the logger to | ||||
| // fit the expected signature for go-redis logging | ||||
| type wrappedRedisLogger struct { | ||||
| 	*log.Logger | ||||
| } | ||||
|  | ||||
| func (l *wrappedRedisLogger) Printf(_ context.Context, format string, v ...interface{}) { | ||||
| 	l.Logger.Printf(format, v...) | ||||
| } | ||||
|  | ||||
| func TestRedis(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 	logger.SetErrOutput(GinkgoWriter) | ||||
|  | ||||
| 	redisLogger := &wrappedRedisLogger{Logger: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)} | ||||
| 	redisLogger.SetOutput(GinkgoWriter) | ||||
| 	redis.SetLogger(redisLogger) | ||||
|  | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "Redis") | ||||
| } | ||||
| @@ -65,6 +65,10 @@ func (s *MockStore) Lock(key string) sessions.Lock { | ||||
| 	return lock | ||||
| } | ||||
|  | ||||
| func (s *MockStore) VerifyConnection(_ context.Context) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // FastForward simulates the flow of time to test expirations | ||||
| func (s *MockStore) FastForward(duration time.Duration) { | ||||
| 	for _, mockLock := range s.lockCache { | ||||
|   | ||||
| @@ -485,6 +485,12 @@ func SessionStoreInterfaceTests(in *testInput) { | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	Context("when VerifyConnection is called", func() { | ||||
| 		It("should return without an error", func() { | ||||
| 			Expect(in.ss().VerifyConnection(in.request.Context())).ToNot(HaveOccurred()) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func LoadSessionTests(in *testInput) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user