diff --git a/CHANGELOG.md b/CHANGELOG.md index 25f29079..341b20d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ## Changes since v6.0.0 +- [#688](https://github.com/oauth2-proxy/oauth2-proxy/pull/688) Refactor session loading to make use of middleware pattern (@JoelSpeed) - [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed) - [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed) - [#624](https://github.com/oauth2-proxy/oauth2-proxy/pull/624) Allow stripping authentication headers from whitelisted requests with `--skip-auth-strip-headers` (@NickMeves) diff --git a/oauthproxy.go b/oauthproxy.go index 034fe6a3..a2053ad1 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -15,7 +15,9 @@ import ( "time" "github.com/coreos/go-oidc" + "github.com/justinas/alice" ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" @@ -23,6 +25,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/providers" @@ -98,6 +101,8 @@ type OAuthProxy struct { trustedIPs *ip.NetSet Banner string Footer string + + sessionChain alice.Chain } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided @@ -156,6 +161,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr } } + sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) + return &OAuthProxy{ CookieName: opts.Cookie.Name, CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), @@ -209,9 +216,45 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil, + sessionChain: sessionChain, }, nil } +func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { + chain := alice.New(middleware.NewScope()) + + if opts.SkipJwtBearerTokens { + sessionLoaders := []middlewareapi.TokenToSessionLoader{} + if opts.GetOIDCVerifier() != nil { + sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ + Verifier: opts.GetOIDCVerifier(), + TokenToSession: opts.GetProvider().CreateSessionStateFromBearerToken, + }) + } + + for _, verifier := range opts.GetJWTBearerVerifiers() { + sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ + Verifier: verifier, + }) + } + + chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) + } + + if validator != nil { + chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator)) + } + + chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, + ValidateSessionState: opts.GetProvider().ValidateSessionState, + })) + + return chain +} + // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will // redirect clients to once authenticated func (p *OAuthProxy) GetRedirectURI(host string) string { @@ -780,86 +823,20 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // Set-Cookie headers may be set on the response as a side-effect of calling this method. func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { var session *sessionsapi.SessionState - var err error - var saveSession, clearSession, revalidated bool - if p.skipJwtBearerTokens && req.Header.Get("Authorization") != "" { - session, err = p.GetJwtSession(req) - if err != nil { - logger.Printf("Error retrieving session from token in Authorization header: %s", err) - } - if session != nil { - saveSession = false - } - } + getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + session = middleware.GetRequestScope(req).Session + })) + getSession.ServeHTTP(rw, req) - remoteAddr := ip.GetClientString(p.realClientIPParser, req, true) if session == nil { - session, err = p.LoadCookiedSession(req) - if err != nil { - logger.Printf("Error loading cookied session: %s", err) - } - - if session != nil { - if session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { - logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) - saveSession = true - } - - if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil { - logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) - clearSession = true - session = nil - } else if ok { - saveSession = true - revalidated = true - } - } - } - - if session != nil && session.IsExpired() { - logger.Printf("Removing session: token expired %s", session) - session = nil - saveSession = false - clearSession = true - } - - if saveSession && !revalidated && session != nil && session.AccessToken != "" { - if !p.provider.ValidateSessionState(req.Context(), session) { - logger.Printf("Removing session: error validating %s", session) - saveSession = false - session = nil - clearSession = true - } + return nil, ErrNeedsLogin } if session != nil && session.Email != "" && !p.Validator(session.Email) { logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) - session = nil - saveSession = false - clearSession = true - } - - if saveSession && session != nil { - err = p.SaveSession(rw, req, session) - if err != nil { - logger.PrintAuthf(session.Email, req, logger.AuthError, "Save session error %s", err) - return nil, err - } - } - - if clearSession { + // Invalid session, clear it p.ClearSessionCookie(rw, req) - } - - if session == nil { - session, err = p.CheckBasicAuth(req) - if err != nil { - logger.Printf("Error during basic auth validation: %s", err) - } - } - - if session == nil { return nil, ErrNeedsLogin } @@ -997,36 +974,6 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { } } -// CheckBasicAuth checks the requests Authorization header for basic auth -// credentials and authenticates these against the proxies HtpasswdFile -func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { - if p.basicAuthValidator == nil { - return nil, nil - } - auth := req.Header.Get("Authorization") - if auth == "" { - return nil, nil - } - s := strings.SplitN(auth, " ", 2) - if len(s) != 2 || s[0] != "Basic" { - return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) - } - b, err := b64.StdEncoding.DecodeString(s[1]) - if err != nil { - return nil, err - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return nil, fmt.Errorf("invalid format %s", b) - } - if p.basicAuthValidator.Validate(pair[0], pair[1]) { - logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") - return &sessionsapi.SessionState{User: pair[0]}, nil - } - logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") - return nil, nil -} - // isAjax checks if a request is an ajax request func isAjax(req *http.Request) bool { acceptValues := req.Header.Values("Accept") @@ -1044,74 +991,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } - -// GetJwtSession loads a session based on a JWT token in the authorization header. -// (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers) -func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState, error) { - rawBearerToken, err := p.findBearerToken(req) - if err != nil { - return nil, err - } - - // If we are using an oidc provider, go ahead and try that provider first with its Verifier - // and Bearer Token -> Session converter - if p.mainJwtBearerVerifier != nil { - bearerToken, err := p.mainJwtBearerVerifier.Verify(req.Context(), rawBearerToken) - if err == nil { - return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken) - } - } - - // Otherwise, attempt to verify against the extra JWT issuers and use a more generic - // Bearer Token -> Session converter - for _, verifier := range p.extraJwtBearerVerifiers { - bearerToken, err := verifier.Verify(req.Context(), rawBearerToken) - if err != nil { - continue - } - - return (*providers.ProviderData)(nil).CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken) - } - return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization")) -} - -// findBearerToken finds a valid JWT token from the Authorization header of a given request. -func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) { - auth := req.Header.Get("Authorization") - s := strings.SplitN(auth, " ", 2) - if len(s) != 2 { - return "", fmt.Errorf("invalid authorization header %s", auth) - } - jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`) - var rawBearerToken string - if s[0] == "Bearer" && jwtRegex.MatchString(s[1]) { - rawBearerToken = s[1] - } else if s[0] == "Basic" { - // Check if we have a Bearer token masquerading in Basic - b, err := b64.StdEncoding.DecodeString(s[1]) - if err != nil { - return "", err - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return "", fmt.Errorf("invalid format %s", b) - } - user, password := pair[0], pair[1] - - // check user, user+password, or just password for a token - if jwtRegex.MatchString(user) { - // Support blank passwords or magic `x-oauth-basic` passwords - nothing else - if password == "" || password == "x-oauth-basic" { - rawBearerToken = user - } - } else if jwtRegex.MatchString(password) { - // support passwords and ignore user - rawBearerToken = password - } - } - if rawBearerToken == "" { - return "", fmt.Errorf("no valid bearer token found in authorization header") - } - - return rawBearerToken, nil -} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 00c33ff6..cb14c717 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1889,7 +1889,7 @@ func TestGetJwtSession(t *testing.T) { // Bearer expires := time.Unix(1912151821, 0) - session, err := test.proxy.GetJwtSession(test.req) + session, err := test.proxy.getAuthenticatedSession(test.rw, test.req) assert.NoError(t, err) assert.Equal(t, session.User, "1234567890") assert.Equal(t, session.Email, "john@example.com") @@ -1912,70 +1912,6 @@ func TestGetJwtSession(t *testing.T) { assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") } -func TestFindJwtBearerToken(t *testing.T) { - p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}} - getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}} - - validToken := "eyJfoobar.eyJfoobar.12345asdf" - var token string - - // Bearer - getReq.Header = map[string][]string{ - "Authorization": {fmt.Sprintf("Bearer %s", validToken)}, - } - - token, err := p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - // Basic - no password - getReq.SetBasicAuth(token, "") - token, err = p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - // Basic - sentinel password - getReq.SetBasicAuth(token, "x-oauth-basic") - token, err = p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - // Basic - any username, password matching jwt pattern - getReq.SetBasicAuth("any-username-you-could-wish-for", token) - token, err = p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - failures := []string{ - // Too many parts - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - // Not enough parts - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA", - // Invalid encrypted key - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA", - // Invalid IV - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA", - // Invalid ciphertext - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA", - // Invalid tag - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////", - // Invalid header - "W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - // Invalid header - "######.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - // Missing alc/enc params - "e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - } - - for _, failure := range failures { - getReq.Header = map[string][]string{ - "Authorization": {fmt.Sprintf("Bearer %s", failure)}, - } - _, err := p.findBearerToken(getReq) - assert.Error(t, err) - } -} - func Test_prepareNoCache(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { prepareNoCache(w) diff --git a/pkg/apis/middleware/scope.go b/pkg/apis/middleware/scope.go new file mode 100644 index 00000000..c8153d1a --- /dev/null +++ b/pkg/apis/middleware/scope.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" +) + +// RequestScope contains information regarding the request that is being made. +// The RequestScope is used to pass information between different middlewares +// within the chain. +type RequestScope struct { + // Session details the authenticated users information (if it exists). + Session *sessions.SessionState + + // SaveSession indicates whether the session storage should attempt to save + // the session or not. + SaveSession bool + + // ClearSession indicates whether the user should be logged out or not. + ClearSession bool + + // SessionRevalidated indicates whether the session has been revalidated since + // it was loaded or not. + SessionRevalidated bool +} diff --git a/pkg/apis/middleware/session.go b/pkg/apis/middleware/session.go new file mode 100644 index 00000000..344ba31e --- /dev/null +++ b/pkg/apis/middleware/session.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "context" + + "github.com/coreos/go-oidc" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" +) + +// TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a +// SessionState. +type TokenToSessionFunc func(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) + +// TokenToSessionLoader pairs a token verifier with the correct converter function +// to convert the ID Token to a SessionState. +type TokenToSessionLoader struct { + // Verfier is used to verify that the ID Token was signed by the claimed issuer + // and that the token has not been tampered with. + Verifier *oidc.IDTokenVerifier + + // TokenToSession converts a rawIDToken and an idToken to a SessionState. + // (Optional) If not set a default basic implementation is used. + TokenToSession TokenToSessionFunc +} diff --git a/pkg/middleware/basic_session.go b/pkg/middleware/basic_session.go new file mode 100644 index 00000000..734f4d83 --- /dev/null +++ b/pkg/middleware/basic_session.go @@ -0,0 +1,88 @@ +package middleware + +import ( + "fmt" + "net/http" + + "github.com/justinas/alice" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor { + return func(next http.Handler) http.Handler { + return loadBasicAuthSession(validator, next) + } +} + +// loadBasicAuthSession attmepts to load a session from basic auth credentials +// stored in an Authorization header within the request. +// If no authorization header is found, or the header is invalid, no session +// will be loaded and the request will be passed to the next handler. +// If a session was loaded by a previous handler, it will not be replaced. +func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + if scope.Session != nil { + // The session was already loaded, pass to the next handler + next.ServeHTTP(rw, req) + return + } + + session, err := getBasicSession(validator, req) + if err != nil { + logger.Printf("Error retrieving session from token in Authorization header: %v", err) + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +// getBasicSession attempts to load a basic session from the request. +// If the credentials in the request exist within the htpasswdMap, +// a new session will be created. +func getBasicSession(validator basic.Validator, req *http.Request) (*sessionsapi.SessionState, error) { + auth := req.Header.Get("Authorization") + if auth == "" { + // No auth header provided, so don't attempt to load a session + return nil, nil + } + + user, password, err := findBasicCredentialsFromHeader(auth) + if err != nil { + return nil, err + } + + if validator.Validate(user, password) { + logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") + return &sessionsapi.SessionState{User: user}, nil + } + + logger.PrintAuthf(user, req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") + return nil, nil +} + +// findBasicCredentialsFromHeader finds basic auth credneitals from the +// Authorization header of a given request. +func findBasicCredentialsFromHeader(header string) (string, string, error) { + tokenType, token, err := splitAuthHeader(header) + if err != nil { + return "", "", err + } + + if tokenType != "Basic" { + return "", "", fmt.Errorf("invalid Authorization header: %q", header) + } + + user, password, err := getBasicAuthCredentials(token) + if err != nil { + return "", "", fmt.Errorf("error decoding basic auth credentials: %v", err) + } + + return user, password, nil +} diff --git a/pkg/middleware/basic_session_test.go b/pkg/middleware/basic_session_test.go new file mode 100644 index 00000000..109b09a5 --- /dev/null +++ b/pkg/middleware/basic_session_test.go @@ -0,0 +1,132 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +const ( + adminUser = "admin" + adminPassword = "Adm1n1str$t0r" + user1 = "user1" + user1Password = "UsErOn3P455" + user2 = "user2" + user2Password = "us3r2P455W0Rd!" +) + +var _ = Describe("Basic Auth Session Suite", func() { + Context("BasicAuthSessionLoader", func() { + + type basicAuthSessionLoaderTableInput struct { + authorizationHeader string + existingSession *sessionsapi.SessionState + expectedSession *sessionsapi.SessionState + } + + DescribeTable("with an authorization header", + func(in basicAuthSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the authorization header and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header.Set("Authorization", in.authorizationHeader) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + + validator := fakeBasicValidator{ + users: map[string]string{ + adminUser: adminPassword, + user1: user1Password, + user2: user2Password, + }, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + })) + handler.ServeHTTP(rw, req) + + Expect(gotSession).To(Equal(in.expectedSession)) + }, + Entry("", basicAuthSessionLoaderTableInput{ + authorizationHeader: "", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef", basicAuthSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef (with existing session)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Bearer ", basicAuthSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", adminPassword), + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic ", basicAuthSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Basic %s", adminPassword), + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic Base64(:) (with existing session)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic OlVzRXJPbjNQNDU1", + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Basic Base64(user1:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic dXNlcjE6VXNFck9uM1A0NTU=", + existingSession: nil, + expectedSession: &sessionsapi.SessionState{User: "user1"}, + }), + Entry("Basic Base64(user2:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic dXNlcjI6VXNFck9uM1A0NTU=", + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic Base64(user2:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic dXNlcjI6dXMzcjJQNDU1VzBSZCE=", + existingSession: nil, + expectedSession: &sessionsapi.SessionState{User: "user2"}, + }), + Entry("Basic Base64(admin:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic YWRtaW46QWRtMW4xc3RyJHQwcg==", + existingSession: nil, + expectedSession: &sessionsapi.SessionState{User: "admin"}, + }), + ) + }) +}) + +type fakeBasicValidator struct { + users map[string]string +} + +func (f fakeBasicValidator) Validate(user, password string) bool { + if f.users == nil { + return false + } + if realPassword, ok := f.users[user]; ok { + return realPassword == password + } + return false +} diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go new file mode 100644 index 00000000..3a0b65fc --- /dev/null +++ b/pkg/middleware/jwt_session.go @@ -0,0 +1,168 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "regexp" + + "github.com/coreos/go-oidc" + "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +const jwtRegexFormat = `^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` + +func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { + for i, loader := range sessionLoaders { + if loader.TokenToSession == nil { + sessionLoaders[i] = middlewareapi.TokenToSessionLoader{ + Verifier: loader.Verifier, + TokenToSession: createSessionStateFromBearerToken, + } + } + } + + js := &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + sessionLoaders: sessionLoaders, + } + return js.loadSession +} + +// jwtSessionLoader is responsible for loading sessions from JWTs in +// Authorization headers. +type jwtSessionLoader struct { + jwtRegex *regexp.Regexp + sessionLoaders []middlewareapi.TokenToSessionLoader +} + +// loadSession attempts to load a session from a JWT stored in an Authorization +// header within the request. +// If no authorization header is found, or the header is invalid, no session +// will be loaded and the request will be passed to the next handler. +// If a session was loaded by a previous handler, it will not be replaced. +func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + if scope.Session != nil { + // The session was already loaded, pass to the next handler + next.ServeHTTP(rw, req) + return + } + + session, err := j.getJwtSession(req) + if err != nil { + logger.Printf("Error retrieving session from token in Authorization header: %v", err) + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +// getJwtSession loads a session based on a JWT token in the authorization header. +// (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers) +func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.SessionState, error) { + auth := req.Header.Get("Authorization") + if auth == "" { + // No auth header provided, so don't attempt to load a session + return nil, nil + } + + rawBearerToken, err := j.findBearerTokenFromHeader(auth) + if err != nil { + return nil, err + } + + for _, loader := range j.sessionLoaders { + bearerToken, err := loader.Verifier.Verify(req.Context(), rawBearerToken) + if err == nil { + // The token was verified, convert it to a session + return loader.TokenToSession(req.Context(), rawBearerToken, bearerToken) + } + } + + return nil, fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization")) +} + +// findBearerTokenFromHeader finds a valid JWT token from the Authorization header of a given request. +func (j *jwtSessionLoader) findBearerTokenFromHeader(header string) (string, error) { + tokenType, token, err := splitAuthHeader(header) + if err != nil { + return "", err + } + + if tokenType == "Bearer" && j.jwtRegex.MatchString(token) { + // Found a JWT as a bearer token + return token, nil + } + + if tokenType == "Basic" { + // Check if we have a Bearer token masquerading in Basic + return j.getBasicToken(token) + } + + return "", fmt.Errorf("no valid bearer token found in authorization header") +} + +// getBasicToken tries to extract a token from the basic value provided. +func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { + user, password, err := getBasicAuthCredentials(token) + if err != nil { + return "", err + } + + // check user, user+password, or just password for a token + if j.jwtRegex.MatchString(user) { + // Support blank passwords or magic `x-oauth-basic` passwords - nothing else + if password == "" || password == "x-oauth-basic" { + return user, nil + } + } else if j.jwtRegex.MatchString(password) { + // support passwords and ignore user + return password, nil + } + + return "", fmt.Errorf("invalid basic auth token found in authorization header") +} + +// createSessionStateFromBearerToken is a default implementation for converting +// a JWT into a session state. +func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) { + var claims struct { + Subject string `json:"sub"` + Email string `json:"email"` + Verified *bool `json:"email_verified"` + PreferredUsername string `json:"preferred_username"` + } + + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) + } + + if claims.Email == "" { + claims.Email = claims.Subject + } + + if claims.Verified != nil && !*claims.Verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + } + + newSession := &sessionsapi.SessionState{ + Email: claims.Email, + User: claims.Subject, + PreferredUsername: claims.PreferredUsername, + AccessToken: rawIDToken, + IDToken: rawIDToken, + RefreshToken: "", + ExpiresOn: &idToken.Expiry, + } + + return newSession, nil +} diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go new file mode 100644 index 00000000..5148ad28 --- /dev/null +++ b/pkg/middleware/jwt_session_test.go @@ -0,0 +1,492 @@ +package middleware + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "time" + + "github.com/coreos/go-oidc" + "github.com/dgrijalva/jwt-go" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +type noOpKeySet struct { +} + +func (noOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { + splitStrings := strings.Split(jwt, ".") + payloadString := splitStrings[1] + return base64.RawURLEncoding.DecodeString(payloadString) +} + +var _ = Describe("JWT Session Suite", func() { + /* token payload: + { + "sub": "1234567890", + "aud": "https://test.myapp.com", + "name": "John Doe", + "email": "john@example.com", + "iss": "https://issuer.example.com", + "iat": 1553691215, + "exp": 1912151821 + } + */ + const verifiedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + + "eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" + + "WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" + + "E1LCJleHAiOjE5MTIxNTE4MjF9." + + "rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" + + "OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8" + + const verifiedTokenXOAuthBasicBase64 = `ZXlKaGJHY2lPaUpTVXpJMU5pSXNJblI1Y0NJNklrcFhWQ0o5LmV5SnpkV0lpT2lJeE1qTTBOVFkz +T0Rrd0lpd2lZWFZrSWpvaWFIUjBjSE02THk5MFpYTjBMbTE1WVhCd0xtTnZiU0lzSW01aGJXVWlP +aUpLYjJodUlFUnZaU0lzSW1WdFlXbHNJam9pYW05b2JrQmxlR0Z0Y0d4bExtTnZiU0lzSW1semN5 +STZJbWgwZEhCek9pOHZhWE56ZFdWeUxtVjRZVzF3YkdVdVkyOXRJaXdpYVdGMElqb3hOVFV6Tmpr +eE1qRTFMQ0psZUhBaU9qRTVNVEl4TlRFNE1qRjkuckxWeXpPbkVsZFVxX3BOa2ZhLVdpVjhUVkpZ +V3laQ2FNMkFtX3VvOEZHZzExekQ3bC1xbXozeDFzZVR2cXBINlkwVHkwMGZtdjZkSm5HbkM4V01u +UFhRaW9kUlRmaEJTZU9LWk11MEhrTUQyc2c1MnpsS2tiZkxUTzZpYzVWbmJWZ3dqanJCOGFtX1Rh +Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` + + var verifiedSessionExpiry = time.Unix(1912151821, 0) + var verifiedSession = &sessionsapi.SessionState{ + AccessToken: verifiedToken, + IDToken: verifiedToken, + Email: "john@example.com", + User: "1234567890", + ExpiresOn: &verifiedSessionExpiry, + } + + // validToken will pass the token regex so can be used to check token fetching + // is valid. It will not pass the OIDC Verifier however. + const validToken = "eyJfoobar.eyJfoobar.12345asdf" + + Context("JwtSessionLoader", func() { + var verifier *oidc.IDTokenVerifier + const nonVerifiedToken = validToken + + BeforeEach(func() { + keyset := noOpKeySet{} + verifier = oidc.NewVerifier("https://issuer.example.com", keyset, + &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + }) + + type jwtSessionLoaderTableInput struct { + authorizationHeader string + existingSession *sessionsapi.SessionState + expectedSession *sessionsapi.SessionState + } + + DescribeTable("with an authorization header", + func(in jwtSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the authorization header and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header.Set("Authorization", in.authorizationHeader) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + + sessionLoaders := []middlewareapi.TokenToSessionLoader{ + { + Verifier: verifier, + }, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + })) + handler.ServeHTTP(rw, req) + + Expect(gotSession).To(Equal(in.expectedSession)) + }, + Entry("", jwtSessionLoaderTableInput{ + authorizationHeader: "", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef", jwtSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef (with existing session)", jwtSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Bearer ", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), + existingSession: nil, + expectedSession: verifiedSession, + }), + Entry("Bearer ", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", nonVerifiedToken), + existingSession: nil, + expectedSession: nil, + }), + Entry("Bearer (with existing session)", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Basic Base64(:) (No password)", jwtSessionLoaderTableInput{ + authorizationHeader: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic Base64(:x-oauth-basic) (Sentinel password)", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Basic %s", verifiedTokenXOAuthBasicBase64), + existingSession: nil, + expectedSession: verifiedSession, + }), + ) + + }) + + Context("getJWTSession", func() { + var j *jwtSessionLoader + const nonVerifiedToken = validToken + + BeforeEach(func() { + keyset := noOpKeySet{} + verifier := oidc.NewVerifier("https://issuer.example.com", keyset, + &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + + j = &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + sessionLoaders: []middlewareapi.TokenToSessionLoader{ + { + Verifier: verifier, + TokenToSession: createSessionStateFromBearerToken, + }, + }, + } + }) + + type getJWTSessionTableInput struct { + authorizationHeader string + expectedErr error + expectedSession *sessionsapi.SessionState + } + + DescribeTable("with an authorization header", + func(in getJWTSessionTableInput) { + req := httptest.NewRequest("", "/", nil) + req.Header.Set("Authorization", in.authorizationHeader) + + session, err := j.getJwtSession(req) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(session).To(Equal(in.expectedSession)) + }, + Entry("", getJWTSessionTableInput{ + authorizationHeader: "", + expectedErr: nil, + expectedSession: nil, + }), + Entry("abcdef", getJWTSessionTableInput{ + authorizationHeader: "abcdef", + expectedErr: errors.New("invalid authorization header: \"abcdef\""), + expectedSession: nil, + }), + Entry("Bearer abcdef", getJWTSessionTableInput{ + authorizationHeader: "Bearer abcdef", + expectedErr: errors.New("no valid bearer token found in authorization header"), + expectedSession: nil, + }), + Entry("Bearer ", getJWTSessionTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", nonVerifiedToken), + expectedErr: errors.New("unable to verify jwt token: \"Bearer eyJfoobar.eyJfoobar.12345asdf\""), + expectedSession: nil, + }), + Entry("Bearer ", getJWTSessionTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), + expectedErr: nil, + expectedSession: verifiedSession, + }), + Entry("Basic Base64(:) (No password)", getJWTSessionTableInput{ + authorizationHeader: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + expectedErr: errors.New("unable to verify jwt token: \"Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6\""), + expectedSession: nil, + }), + Entry("Basic Base64(:x-oauth-basic) (Sentinel password)", getJWTSessionTableInput{ + authorizationHeader: fmt.Sprintf("Basic %s", verifiedTokenXOAuthBasicBase64), + expectedErr: nil, + expectedSession: verifiedSession, + }), + ) + }) + + Context("findBearerTokenFromHeader", func() { + var j *jwtSessionLoader + + BeforeEach(func() { + j = &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + } + }) + + type findBearerTokenFromHeaderTableInput struct { + header string + expectedErr error + expectedToken string + } + + DescribeTable("with a header", + func(in findBearerTokenFromHeaderTableInput) { + token, err := j.findBearerTokenFromHeader(in.header) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(token).To(Equal(in.expectedToken)) + }, + Entry("Bearer", findBearerTokenFromHeaderTableInput{ + header: "Bearer", + expectedErr: errors.New("invalid authorization header: \"Bearer\""), + expectedToken: "", + }), + Entry("Bearer abc def", findBearerTokenFromHeaderTableInput{ + header: "Bearer abc def", + expectedErr: errors.New("invalid authorization header: \"Bearer abc def\""), + expectedToken: "", + }), + Entry("Bearer abcdef", findBearerTokenFromHeaderTableInput{ + header: "Bearer abcdef", + expectedErr: errors.New("no valid bearer token found in authorization header"), + expectedToken: "", + }), + Entry("Bearer ", findBearerTokenFromHeaderTableInput{ + header: fmt.Sprintf("Bearer %s", validToken), + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic invalid-base64", findBearerTokenFromHeaderTableInput{ + header: "Basic invalid-base64", + expectedErr: errors.New("invalid basic auth token: illegal base64 data at input byte 7"), + expectedToken: "", + }), + Entry("Basic Base64(:) (No password)", findBearerTokenFromHeaderTableInput{ + header: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic Base64(:x-oauth-basic) (Sentinel password)", findBearerTokenFromHeaderTableInput{ + header: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6eC1vYXV0aC1iYXNpYw==", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic Base64(any-user:) (Matching password)", findBearerTokenFromHeaderTableInput{ + header: "Basic YW55LXVzZXI6ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY=", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic Base64(any-user:any-password) (No matches)", findBearerTokenFromHeaderTableInput{ + header: "Basic YW55LXVzZXI6YW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid basic auth token found in authorization header"), + expectedToken: "", + }), + Entry("Basic Base64(any-user any-password) (Invalid format)", findBearerTokenFromHeaderTableInput{ + header: "Basic YW55LXVzZXIgYW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid format: \"any-user any-password\""), + expectedToken: "", + }), + Entry("Something ", findBearerTokenFromHeaderTableInput{ + header: fmt.Sprintf("Something %s", validToken), + expectedErr: errors.New("no valid bearer token found in authorization header"), + expectedToken: "", + }), + ) + + }) + + Context("getBasicToken", func() { + var j *jwtSessionLoader + + BeforeEach(func() { + j = &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + } + }) + + type getBasicTokenTableInput struct { + token string + expectedErr error + expectedToken string + } + + DescribeTable("with a token", + func(in getBasicTokenTableInput) { + token, err := j.getBasicToken(in.token) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(token).To(Equal(in.expectedToken)) + }, + Entry("invalid-base64", getBasicTokenTableInput{ + token: "invalid-base64", + expectedErr: errors.New("invalid basic auth token: illegal base64 data at input byte 7"), + expectedToken: "", + }), + Entry("Base64(:) (No password)", getBasicTokenTableInput{ + token: "ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Base64(:x-oauth-basic) (Sentinel password)", getBasicTokenTableInput{ + token: "ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6eC1vYXV0aC1iYXNpYw==", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Base64(any-user:) (Matching password)", getBasicTokenTableInput{ + token: "YW55LXVzZXI6ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY=", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Base64(any-user:any-password) (No matches)", getBasicTokenTableInput{ + token: "YW55LXVzZXI6YW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid basic auth token found in authorization header"), + expectedToken: "", + }), + Entry("Base64(any-user any-password) (Invalid format)", getBasicTokenTableInput{ + token: "YW55LXVzZXIgYW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid format: \"any-user any-password\""), + expectedToken: "", + }), + ) + }) + + Context("createSessionStateFromBearerToken", func() { + ctx := context.Background() + expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) + verified := true + notVerified := false + + type idTokenClaims struct { + Email string `json:"email,omitempty"` + Verified *bool `json:"email_verified,omitempty"` + jwt.StandardClaims + } + + type createSessionStateTableInput struct { + idToken idTokenClaims + expectedErr error + expectedUser string + expectedEmail string + expectedExpires *time.Time + } + + DescribeTable("when creating a session from an IDToken", + func(in createSessionStateTableInput) { + verifier := oidc.NewVerifier( + "https://issuer.example.com", + noOpKeySet{}, + &oidc.Config{ClientID: "asdf1234"}, + ) + + key, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + + rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) + Expect(err).ToNot(HaveOccurred()) + + // Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below + idToken, err := verifier.Verify(context.Background(), rawIDToken) + Expect(err).ToNot(HaveOccurred()) + + session, err := createSessionStateFromBearerToken(ctx, rawIDToken, idToken) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + Expect(session).To(BeNil()) + return + } + + Expect(err).ToNot(HaveOccurred()) + Expect(session.AccessToken).To(Equal(rawIDToken)) + Expect(session.IDToken).To(Equal(rawIDToken)) + Expect(session.User).To(Equal(in.expectedUser)) + Expect(session.Email).To(Equal(in.expectedEmail)) + Expect(session.ExpiresOn.Unix()).To(Equal(in.expectedExpires.Unix())) + Expect(session.RefreshToken).To(BeEmpty()) + Expect(session.PreferredUsername).To(BeEmpty()) + }, + Entry("with no email", createSessionStateTableInput{ + idToken: idTokenClaims{ + StandardClaims: jwt.StandardClaims{ + Audience: "asdf1234", + ExpiresAt: expiresFuture.Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: "https://issuer.example.com", + NotBefore: 0, + Subject: "123456789", + }, + }, + expectedErr: nil, + expectedUser: "123456789", + expectedEmail: "123456789", + expectedExpires: &expiresFuture, + }), + Entry("with a verified email", createSessionStateTableInput{ + idToken: idTokenClaims{ + StandardClaims: jwt.StandardClaims{ + Audience: "asdf1234", + ExpiresAt: expiresFuture.Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: "https://issuer.example.com", + NotBefore: 0, + Subject: "123456789", + }, + Email: "foo@example.com", + Verified: &verified, + }, + expectedErr: nil, + expectedUser: "123456789", + expectedEmail: "foo@example.com", + expectedExpires: &expiresFuture, + }), + Entry("with a non-verified email", createSessionStateTableInput{ + idToken: idTokenClaims{ + StandardClaims: jwt.StandardClaims{ + Audience: "asdf1234", + ExpiresAt: expiresFuture.Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: "https://issuer.example.com", + NotBefore: 0, + Subject: "123456789", + }, + Email: "foo@example.com", + Verified: ¬Verified, + }, + expectedErr: errors.New("email in id_token (foo@example.com) isn't verified"), + }), + ) + }) +}) diff --git a/pkg/middleware/scope.go b/pkg/middleware/scope.go new file mode 100644 index 00000000..d5925ad4 --- /dev/null +++ b/pkg/middleware/scope.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" +) + +type scopeKey string + +// requestScopeKey uses a typed string to reduce likelihood of clasing +// with other context keys +const requestScopeKey scopeKey = "request-scope" + +func NewScope() alice.Constructor { + return addScope +} + +// addScope injects a new request scope into the request context. +func addScope(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := &middlewareapi.RequestScope{} + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + requestWithScope := req.WithContext(contextWithScope) + next.ServeHTTP(rw, requestWithScope) + }) +} + +// GetRequestScope returns the current request scope from the given request +func GetRequestScope(req *http.Request) *middlewareapi.RequestScope { + scope := req.Context().Value(requestScopeKey) + if scope == nil { + return nil + } + + return scope.(*middlewareapi.RequestScope) +} diff --git a/pkg/middleware/scope_test.go b/pkg/middleware/scope_test.go new file mode 100644 index 00000000..5a998bb0 --- /dev/null +++ b/pkg/middleware/scope_test.go @@ -0,0 +1,94 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Scope Suite", func() { + Context("NewScope", func() { + var request, nextRequest *http.Request + var rw http.ResponseWriter + + BeforeEach(func() { + var err error + request, err = http.NewRequest("", "http://127.0.0.1/", nil) + Expect(err).ToNot(HaveOccurred()) + + rw = httptest.NewRecorder() + + handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) + }) + + It("does not add a scope to the original request", func() { + Expect(request.Context().Value(requestScopeKey)).To(BeNil()) + }) + + It("cannot load a scope from the original request using GetRequestScope", func() { + Expect(GetRequestScope(request)).To(BeNil()) + }) + + It("adds a scope to the request for the next handler", func() { + Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil()) + }) + + It("can load a scope from the next handler's request using GetRequestScope", func() { + Expect(GetRequestScope(nextRequest)).ToNot(BeNil()) + }) + }) + + Context("GetRequestScope", func() { + var request *http.Request + + BeforeEach(func() { + var err error + request, err = http.NewRequest("", "http://127.0.0.1/", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("with a scope", func() { + var scope *middlewareapi.RequestScope + + BeforeEach(func() { + scope = &middlewareapi.RequestScope{} + contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope) + request = request.WithContext(contextWithScope) + }) + + It("returns the scope", func() { + s := GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + }) + + Context("if the scope is then modified", func() { + BeforeEach(func() { + Expect(scope.SaveSession).To(BeFalse()) + scope.SaveSession = true + }) + + It("returns the updated session", func() { + s := GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + Expect(s.SaveSession).To(BeTrue()) + }) + }) + }) + + Context("without a scope", func() { + It("returns nil", func() { + Expect(GetRequestScope(request)).To(BeNil()) + }) + }) + }) +}) diff --git a/pkg/middleware/session_utils.go b/pkg/middleware/session_utils.go new file mode 100644 index 00000000..00277d4f --- /dev/null +++ b/pkg/middleware/session_utils.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "encoding/base64" + "fmt" + "strings" +) + +// splitAuthHeader takes the auth header value and splits it into the token type +// and the token value. +func splitAuthHeader(header string) (string, string, error) { + s := strings.Split(header, " ") + if len(s) != 2 { + return "", "", fmt.Errorf("invalid authorization header: %q", header) + } + return s[0], s[1], nil +} + +// getBasicAuthCredentials decodes a basic auth token and extracts the user +// and password pair. +func getBasicAuthCredentials(token string) (string, string, error) { + b, err := base64.StdEncoding.DecodeString(token) + if err != nil { + return "", "", fmt.Errorf("invalid basic auth token: %v", err) + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return "", "", fmt.Errorf("invalid format: %q", b) + } + // user, password + return pair[0], pair[1], nil +} diff --git a/pkg/middleware/session_utils_test.go b/pkg/middleware/session_utils_test.go new file mode 100644 index 00000000..26eb84bb --- /dev/null +++ b/pkg/middleware/session_utils_test.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "errors" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Session utilities suite", func() { + Context("splitAuthHeader", func() { + type splitAuthTableInput struct { + header string + expectedErr error + expectedTokenType string + expectedTokenValue string + } + + DescribeTable("with a header value", + func(in splitAuthTableInput) { + tt, tv, err := splitAuthHeader(in.header) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(tt).To(Equal(in.expectedTokenType)) + Expect(tv).To(Equal(in.expectedTokenValue)) + }, + Entry("Bearer abcdef", splitAuthTableInput{ + header: "Bearer abcdef", + expectedErr: nil, + expectedTokenType: "Bearer", + expectedTokenValue: "abcdef", + }), + Entry("Bearer", splitAuthTableInput{ + header: "Bearer", + expectedErr: errors.New("invalid authorization header: \"Bearer\""), + expectedTokenType: "", + expectedTokenValue: "", + }), + Entry("Bearer abc def", splitAuthTableInput{ + header: "Bearer abc def", + expectedErr: errors.New("invalid authorization header: \"Bearer abc def\""), + expectedTokenType: "", + expectedTokenValue: "", + }), + ) + }) + + Context("getBasicAuthCredentials", func() { + type getBasicAuthCredentialsTableInput struct { + token string + expectedErr error + expectedUser string + expectedPassword string + } + + DescribeTable("from token", + func(in getBasicAuthCredentialsTableInput) { + user, password, err := getBasicAuthCredentials(in.token) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(user).To(Equal(in.expectedUser)) + Expect(password).To(Equal(in.expectedPassword)) + }, + Entry("", getBasicAuthCredentialsTableInput{ + token: "", + expectedErr: errors.New("invalid format: \"\""), + expectedUser: "", + expectedPassword: "", + }), + Entry("invalid-base64", getBasicAuthCredentialsTableInput{ + token: "invalid-base64", + expectedErr: errors.New("invalid basic auth token: illegal base64 data at input byte 7"), + expectedUser: "", + expectedPassword: "", + }), + Entry("Base64(some-user:some-password)", getBasicAuthCredentialsTableInput{ + token: "c29tZS11c2VyOnNvbWUtcGFzc3dvcmQ=", + expectedErr: nil, + expectedUser: "some-user", + expectedPassword: "some-password", + }), + Entry("Base64(no-password:)", getBasicAuthCredentialsTableInput{ + token: "bm8tcGFzc3dvcmQ6", + expectedErr: nil, + expectedUser: "no-password", + expectedPassword: "", + }), + Entry("Base64(:no-user)", getBasicAuthCredentialsTableInput{ + token: "Om5vLXVzZXI=", + expectedErr: nil, + expectedUser: "", + expectedPassword: "no-user", + }), + ) + }) +}) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go new file mode 100644 index 00000000..dd4d1405 --- /dev/null +++ b/pkg/middleware/stored_session.go @@ -0,0 +1,165 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/justinas/alice" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +// StoredSessionLoaderOptions cotnains all of the requirements to construct +// a stored session loader. +// All options must be provided. +type StoredSessionLoaderOptions struct { + // Session storage basckend + SessionStore sessionsapi.SessionStore + + // How often should sessions be refreshed + RefreshPeriod time.Duration + + // Provider based sesssion refreshing + RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + + // Provider based session validation. + // If the sesssion is older than `RefreshPeriod` but the provider doesn't + // refresh it, we must re-validate using this validation. + ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool +} + +// NewStoredSessionLoader creates a new storedSessionLoader which loads +// sessions from the session store. +// If no session is found, the request will be passed to the nex handler. +// If a session was loader by a previous handler, it will not be replaced. +func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { + ss := &storedSessionLoader{ + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, + validateSessionState: opts.ValidateSessionState, + } + return ss.loadSession +} + +// storedSessionLoader is responsible for loading sessions from cookie +// identified sessions in the session store. +type storedSessionLoader struct { + store sessionsapi.SessionStore + refreshPeriod time.Duration + refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + validateSessionState func(context.Context, *sessionsapi.SessionState) bool +} + +// loadSession attempts to load a session as identified by the request cookies. +// If no session is found, the request will be passed to the nex handler. +// If a session was loader by a previous handler, it will not be replaced. +func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + if scope.Session != nil { + // The session was already loaded, pass to the next handler + next.ServeHTTP(rw, req) + return + } + + session, err := s.getValidatedSession(rw, req) + if err != nil { + // In the case when there was an error loading the session, + // we should clear the session + logger.Printf("Error loading cookied session: %v, removing session", err) + s.store.Clear(rw, req) + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +// getValidatedSession is responsible for loading a session and making sure +// that is is valid. +func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { + session, err := s.store.Load(req) + if err != nil { + return nil, err + } + if session == nil { + // No session was found in the storage, nothing more to do + return nil, nil + } + + err = s.refreshSessionIfNeeded(rw, req, session) + if err != nil { + return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) + } + + return session, nil +} + +// refreshSessionIfNeeded will attempt to refresh a session if the session +// is older than the refresh period. +// It is assumed that if the provider refreshes the session, the session is now +// valid. +// If the session requires refreshing but the provider does not refresh it, +// we must validate the session to ensure that the returned session is still +// valid. +func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { + // Refresh is disabled or the session is not old enough, do nothing + return nil + } + + logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) + refreshed, err := s.refreshSessionWithProvider(rw, req, session) + if err != nil { + return err + } + + if !refreshed { + // Session wasn't refreshed, so make sure it's still valid + return s.validateSession(req.Context(), session) + } + return nil +} + +// refreshSessionWithProvider attempts to refresh the sessinon with the provider +// and will save the session if it was updated. +func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { + refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) + if err != nil { + return false, fmt.Errorf("error refreshing access token: %v", err) + } + + if !refreshed { + return false, nil + } + + // Because the session was refreshed, make sure to save it + err = s.store.Save(rw, req, session) + if err != nil { + logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) + return false, fmt.Errorf("error saving session: %v", err) + } + return true, nil +} + +// validateSession checks whether the session has expired and performs +// provider validation on the session. +// An error implies the session is not longer valid. +func (s *storedSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error { + if session.IsExpired() { + return errors.New("session is expired") + } + + if !s.validateSessionState(ctx, session) { + return errors.New("session is invalid") + } + + return nil +} diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go new file mode 100644 index 00000000..1721b309 --- /dev/null +++ b/pkg/middleware/stored_session_test.go @@ -0,0 +1,524 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "time" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stored Session Suite", func() { + const ( + refresh = "Refresh" + noRefresh = "NoRefresh" + ) + + var ctx = context.Background() + + Context("StoredSessionLoader", func() { + createdPast := time.Now().Add(-5 * time.Minute) + createdFuture := time.Now().Add(5 * time.Minute) + + var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + switch ss.RefreshToken { + case refresh: + ss.RefreshToken = "Refreshed" + return true, nil + case noRefresh: + return false, nil + default: + return false, errors.New("error refreshing session") + } + } + + var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { + return ss.AccessToken != "Invalid" + } + + var defaultSessionStore = &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + switch req.Header.Get("Cookie") { + case "_oauth2_proxy=NoRefreshSession": + return &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=InvalidNoRefreshSession": + return &sessionsapi.SessionState{ + AccessToken: "Invalid", + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=ExpiredNoRefreshSession": + return &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdPast, + }, nil + case "_oauth2_proxy=RefreshSession": + return &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=RefreshError": + return &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=NonExistent": + return nil, fmt.Errorf("invalid cookie") + default: + return nil, nil + } + }, + } + + type storedSessionLoaderTableInput struct { + requestHeaders http.Header + existingSession *sessionsapi.SessionState + expectedSession *sessionsapi.SessionState + store sessionsapi.SessionStore + refreshPeriod time.Duration + refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) + validateSession func(context.Context, *sessionsapi.SessionState) bool + } + + DescribeTable("when serving a request", + func(in storedSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the request headesr and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header = in.requestHeaders + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + + opts := &StoredSessionLoaderOptions{ + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSessionIfNeeded: in.refreshSession, + ValidateSessionState: in.validateSession, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + })) + handler.ServeHTTP(rw, req) + + Expect(gotSession).To(Equal(in.expectedSession)) + }, + Entry("with no cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{}, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an invalid cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NonExistent"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an existing session", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshSession"}, + }, + existingSession: &sessionsapi.SessionState{ + RefreshToken: "Existing", + }, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Existing", + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that has not expired", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 10 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + ) + }) + + Context("refreshSessionIfNeeded", func() { + type refreshSessionIfNeededTableInput struct { + refreshPeriod time.Duration + session *sessionsapi.SessionState + expectedErr error + expectRefreshed bool + expectValidated bool + } + + createdPast := time.Now().Add(-5 * time.Minute) + createdFuture := time.Now().Add(5 * time.Minute) + + DescribeTable("with a session", + func(in refreshSessionIfNeededTableInput) { + refreshed := false + validated := false + + s := &storedSessionLoader{ + refreshPeriod: in.refreshPeriod, + store: &fakeSessionStore{}, + refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + refreshed = true + switch ss.RefreshToken { + case refresh: + return true, nil + case noRefresh: + return false, nil + default: + return false, errors.New("error refreshing session") + } + }, + validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + validated = true + return ss.AccessToken != "Invalid" + }, + } + + req := httptest.NewRequest("", "/", nil) + err := s.refreshSessionIfNeeded(nil, req, in.session) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(refreshed).To(Equal(in.expectRefreshed)) + Expect(validated).To(Equal(in.expectValidated)) + }, + Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ + refreshPeriod: time.Duration(0), + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdFuture, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + }), + Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ + refreshPeriod: time.Duration(0), + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + }), + Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdFuture, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + }), + Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: false, + }), + Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + }), + Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &createdPast, + }, + expectedErr: errors.New("error refreshing access token: error refreshing session"), + expectRefreshed: true, + expectValidated: false, + }), + Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + AccessToken: "Invalid", + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + expectedErr: errors.New("session is invalid"), + expectRefreshed: true, + expectValidated: true, + }), + ) + }) + + Context("refreshSessionWithProvider", func() { + type refreshSessionWithProviderTableInput struct { + session *sessionsapi.SessionState + expectedErr error + expectRefreshed bool + expectSaved bool + } + + now := time.Now() + + DescribeTable("when refreshing with the provider", + func(in refreshSessionWithProviderTableInput) { + saved := false + + s := &storedSessionLoader{ + store: &fakeSessionStore{ + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, ss *sessionsapi.SessionState) error { + saved = true + if ss.AccessToken == "NoSave" { + return errors.New("unable to save session") + } + return nil + }, + }, + refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + switch ss.RefreshToken { + case refresh: + return true, nil + case noRefresh: + return false, nil + default: + return false, errors.New("error refreshing session") + } + }, + } + + req := httptest.NewRequest("", "/", nil) + refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(refreshed).To(Equal(in.expectRefreshed)) + Expect(saved).To(Equal(in.expectSaved)) + }, + Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + }, + expectedErr: nil, + expectRefreshed: false, + expectSaved: false, + }), + Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + }, + expectedErr: nil, + expectRefreshed: true, + expectSaved: true, + }), + Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &now, + ExpiresOn: &now, + }, + expectedErr: errors.New("error refreshing access token: error refreshing session"), + expectRefreshed: false, + expectSaved: false, + }), + Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + AccessToken: "NoSave", + }, + expectedErr: errors.New("error saving session: unable to save session"), + expectRefreshed: false, + expectSaved: true, + }), + ) + }) + + Context("validateSession", func() { + var s *storedSessionLoader + + BeforeEach(func() { + s = &storedSessionLoader{ + validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + return ss.AccessToken == "Valid" + }, + } + }) + + Context("with a valid session", func() { + It("does not return an error", func() { + expires := time.Now().Add(1 * time.Minute) + session := &sessionsapi.SessionState{ + AccessToken: "Valid", + ExpiresOn: &expires, + } + Expect(s.validateSession(ctx, session)).To(Succeed()) + }) + }) + + Context("with an expired session", func() { + It("returns an error", func() { + created := time.Now().Add(-5 * time.Minute) + expires := time.Now().Add(-1 * time.Minute) + session := &sessionsapi.SessionState{ + AccessToken: "Valid", + CreatedAt: &created, + ExpiresOn: &expires, + } + Expect(s.validateSession(ctx, session)).To(MatchError("session is expired")) + }) + }) + + Context("with an invalid session", func() { + It("returns an error", func() { + expires := time.Now().Add(1 * time.Minute) + session := &sessionsapi.SessionState{ + AccessToken: "Invalid", + ExpiresOn: &expires, + } + Expect(s.validateSession(ctx, session)).To(MatchError("session is invalid")) + }) + }) + }) +}) + +type fakeSessionStore struct { + SaveFunc func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error + LoadFunc func(req *http.Request) (*sessionsapi.SessionState, error) + ClearFunc func(rw http.ResponseWriter, req *http.Request) error +} + +func (f *fakeSessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { + if f.SaveFunc != nil { + return f.SaveFunc(rw, req, s) + } + return nil +} +func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, error) { + if f.LoadFunc != nil { + return f.LoadFunc(req) + } + return nil, nil +} + +func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { + if f.ClearFunc != nil { + return f.ClearFunc(rw, req) + } + return nil +} diff --git a/providers/provider_default.go b/providers/provider_default.go index 598b91e8..58e1b4fc 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -126,35 +126,8 @@ func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S return false, nil } +// CreateSessionStateFromBearerToken should be implemented to allow providers +// to convert ID tokens into sessions func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { - var claims struct { - Subject string `json:"sub"` - Email string `json:"email"` - Verified *bool `json:"email_verified"` - PreferredUsername string `json:"preferred_username"` - } - - if err := idToken.Claims(&claims); err != nil { - return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) - } - - if claims.Email == "" { - claims.Email = claims.Subject - } - - if claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) - } - - newSession := &sessions.SessionState{ - Email: claims.Email, - User: claims.Subject, - PreferredUsername: claims.PreferredUsername, - AccessToken: rawIDToken, - IDToken: rawIDToken, - RefreshToken: "", - ExpiresOn: &idToken.Expiry, - } - - return newSession, nil + return nil, errors.New("not implemented") } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 745bf2d4..74d7096f 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -2,15 +2,10 @@ package providers import ( "context" - "crypto/rand" - "crypto/rsa" "net/url" "testing" "time" - "github.com/coreos/go-oidc" - "github.com/dgrijalva/jwt-go" - "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -52,39 +47,3 @@ func TestAcrValuesConfigured(t *testing.T) { result := p.GetLoginURL("https://my.test.app/oauth", "") assert.Contains(t, result, "acr_values=testValue") } - -func TestCreateSessionStateFromBearerToken(t *testing.T) { - minimalIDToken := jwt.StandardClaims{ - Audience: "asdf1234", - ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), - Id: "id-some-id", - IssuedAt: time.Now().Unix(), - Issuer: "https://issuer.example.com", - NotBefore: 0, - Subject: "123456789", - } - // From oidc_test.go - verifier := oidc.NewVerifier( - "https://issuer.example.com", - fakeKeySetStub{}, - &oidc.Config{ClientID: "asdf1234"}, - ) - - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, minimalIDToken).SignedString(key) - assert.NoError(t, err) - // Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below - idToken, err := verifier.Verify(context.Background(), rawIDToken) - assert.NoError(t, err) - - session, err := (*ProviderData)(nil).CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken) - assert.NoError(t, err) - - assert.Equal(t, rawIDToken, session.AccessToken) - assert.Equal(t, rawIDToken, session.IDToken) - assert.Equal(t, "123456789", session.Email) - assert.Equal(t, "123456789", session.User) - assert.Empty(t, session.RefreshToken) - assert.Empty(t, session.PreferredUsername) -}