diff --git a/oauthproxy.go b/oauthproxy.go index 034fe6a3..a2053ad1 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -15,7 +15,9 @@ import ( "time" "github.com/coreos/go-oidc" + "github.com/justinas/alice" ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" @@ -23,6 +25,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/providers" @@ -98,6 +101,8 @@ type OAuthProxy struct { trustedIPs *ip.NetSet Banner string Footer string + + sessionChain alice.Chain } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided @@ -156,6 +161,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr } } + sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) + return &OAuthProxy{ CookieName: opts.Cookie.Name, CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), @@ -209,9 +216,45 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil, + sessionChain: sessionChain, }, nil } +func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { + chain := alice.New(middleware.NewScope()) + + if opts.SkipJwtBearerTokens { + sessionLoaders := []middlewareapi.TokenToSessionLoader{} + if opts.GetOIDCVerifier() != nil { + sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ + Verifier: opts.GetOIDCVerifier(), + TokenToSession: opts.GetProvider().CreateSessionStateFromBearerToken, + }) + } + + for _, verifier := range opts.GetJWTBearerVerifiers() { + sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ + Verifier: verifier, + }) + } + + chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) + } + + if validator != nil { + chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator)) + } + + chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, + ValidateSessionState: opts.GetProvider().ValidateSessionState, + })) + + return chain +} + // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will // redirect clients to once authenticated func (p *OAuthProxy) GetRedirectURI(host string) string { @@ -780,86 +823,20 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // Set-Cookie headers may be set on the response as a side-effect of calling this method. func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { var session *sessionsapi.SessionState - var err error - var saveSession, clearSession, revalidated bool - if p.skipJwtBearerTokens && req.Header.Get("Authorization") != "" { - session, err = p.GetJwtSession(req) - if err != nil { - logger.Printf("Error retrieving session from token in Authorization header: %s", err) - } - if session != nil { - saveSession = false - } - } + getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + session = middleware.GetRequestScope(req).Session + })) + getSession.ServeHTTP(rw, req) - remoteAddr := ip.GetClientString(p.realClientIPParser, req, true) if session == nil { - session, err = p.LoadCookiedSession(req) - if err != nil { - logger.Printf("Error loading cookied session: %s", err) - } - - if session != nil { - if session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { - logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) - saveSession = true - } - - if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil { - logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) - clearSession = true - session = nil - } else if ok { - saveSession = true - revalidated = true - } - } - } - - if session != nil && session.IsExpired() { - logger.Printf("Removing session: token expired %s", session) - session = nil - saveSession = false - clearSession = true - } - - if saveSession && !revalidated && session != nil && session.AccessToken != "" { - if !p.provider.ValidateSessionState(req.Context(), session) { - logger.Printf("Removing session: error validating %s", session) - saveSession = false - session = nil - clearSession = true - } + return nil, ErrNeedsLogin } if session != nil && session.Email != "" && !p.Validator(session.Email) { logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) - session = nil - saveSession = false - clearSession = true - } - - if saveSession && session != nil { - err = p.SaveSession(rw, req, session) - if err != nil { - logger.PrintAuthf(session.Email, req, logger.AuthError, "Save session error %s", err) - return nil, err - } - } - - if clearSession { + // Invalid session, clear it p.ClearSessionCookie(rw, req) - } - - if session == nil { - session, err = p.CheckBasicAuth(req) - if err != nil { - logger.Printf("Error during basic auth validation: %s", err) - } - } - - if session == nil { return nil, ErrNeedsLogin } @@ -997,36 +974,6 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { } } -// CheckBasicAuth checks the requests Authorization header for basic auth -// credentials and authenticates these against the proxies HtpasswdFile -func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { - if p.basicAuthValidator == nil { - return nil, nil - } - auth := req.Header.Get("Authorization") - if auth == "" { - return nil, nil - } - s := strings.SplitN(auth, " ", 2) - if len(s) != 2 || s[0] != "Basic" { - return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) - } - b, err := b64.StdEncoding.DecodeString(s[1]) - if err != nil { - return nil, err - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return nil, fmt.Errorf("invalid format %s", b) - } - if p.basicAuthValidator.Validate(pair[0], pair[1]) { - logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") - return &sessionsapi.SessionState{User: pair[0]}, nil - } - logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") - return nil, nil -} - // isAjax checks if a request is an ajax request func isAjax(req *http.Request) bool { acceptValues := req.Header.Values("Accept") @@ -1044,74 +991,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } - -// GetJwtSession loads a session based on a JWT token in the authorization header. -// (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers) -func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState, error) { - rawBearerToken, err := p.findBearerToken(req) - if err != nil { - return nil, err - } - - // If we are using an oidc provider, go ahead and try that provider first with its Verifier - // and Bearer Token -> Session converter - if p.mainJwtBearerVerifier != nil { - bearerToken, err := p.mainJwtBearerVerifier.Verify(req.Context(), rawBearerToken) - if err == nil { - return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken) - } - } - - // Otherwise, attempt to verify against the extra JWT issuers and use a more generic - // Bearer Token -> Session converter - for _, verifier := range p.extraJwtBearerVerifiers { - bearerToken, err := verifier.Verify(req.Context(), rawBearerToken) - if err != nil { - continue - } - - return (*providers.ProviderData)(nil).CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken) - } - return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization")) -} - -// findBearerToken finds a valid JWT token from the Authorization header of a given request. -func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) { - auth := req.Header.Get("Authorization") - s := strings.SplitN(auth, " ", 2) - if len(s) != 2 { - return "", fmt.Errorf("invalid authorization header %s", auth) - } - jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`) - var rawBearerToken string - if s[0] == "Bearer" && jwtRegex.MatchString(s[1]) { - rawBearerToken = s[1] - } else if s[0] == "Basic" { - // Check if we have a Bearer token masquerading in Basic - b, err := b64.StdEncoding.DecodeString(s[1]) - if err != nil { - return "", err - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return "", fmt.Errorf("invalid format %s", b) - } - user, password := pair[0], pair[1] - - // check user, user+password, or just password for a token - if jwtRegex.MatchString(user) { - // Support blank passwords or magic `x-oauth-basic` passwords - nothing else - if password == "" || password == "x-oauth-basic" { - rawBearerToken = user - } - } else if jwtRegex.MatchString(password) { - // support passwords and ignore user - rawBearerToken = password - } - } - if rawBearerToken == "" { - return "", fmt.Errorf("no valid bearer token found in authorization header") - } - - return rawBearerToken, nil -} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 00c33ff6..cb14c717 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1889,7 +1889,7 @@ func TestGetJwtSession(t *testing.T) { // Bearer expires := time.Unix(1912151821, 0) - session, err := test.proxy.GetJwtSession(test.req) + session, err := test.proxy.getAuthenticatedSession(test.rw, test.req) assert.NoError(t, err) assert.Equal(t, session.User, "1234567890") assert.Equal(t, session.Email, "john@example.com") @@ -1912,70 +1912,6 @@ func TestGetJwtSession(t *testing.T) { assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") } -func TestFindJwtBearerToken(t *testing.T) { - p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}} - getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}} - - validToken := "eyJfoobar.eyJfoobar.12345asdf" - var token string - - // Bearer - getReq.Header = map[string][]string{ - "Authorization": {fmt.Sprintf("Bearer %s", validToken)}, - } - - token, err := p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - // Basic - no password - getReq.SetBasicAuth(token, "") - token, err = p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - // Basic - sentinel password - getReq.SetBasicAuth(token, "x-oauth-basic") - token, err = p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - // Basic - any username, password matching jwt pattern - getReq.SetBasicAuth("any-username-you-could-wish-for", token) - token, err = p.findBearerToken(getReq) - assert.NoError(t, err) - assert.Equal(t, validToken, token) - - failures := []string{ - // Too many parts - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - // Not enough parts - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA", - // Invalid encrypted key - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA", - // Invalid IV - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA", - // Invalid ciphertext - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA", - // Invalid tag - "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////", - // Invalid header - "W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - // Invalid header - "######.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - // Missing alc/enc params - "e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA", - } - - for _, failure := range failures { - getReq.Header = map[string][]string{ - "Authorization": {fmt.Sprintf("Bearer %s", failure)}, - } - _, err := p.findBearerToken(getReq) - assert.Error(t, err) - } -} - func Test_prepareNoCache(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { prepareNoCache(w) diff --git a/providers/provider_default.go b/providers/provider_default.go index 598b91e8..58e1b4fc 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -126,35 +126,8 @@ func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S return false, nil } +// CreateSessionStateFromBearerToken should be implemented to allow providers +// to convert ID tokens into sessions func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { - var claims struct { - Subject string `json:"sub"` - Email string `json:"email"` - Verified *bool `json:"email_verified"` - PreferredUsername string `json:"preferred_username"` - } - - if err := idToken.Claims(&claims); err != nil { - return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) - } - - if claims.Email == "" { - claims.Email = claims.Subject - } - - if claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) - } - - newSession := &sessions.SessionState{ - Email: claims.Email, - User: claims.Subject, - PreferredUsername: claims.PreferredUsername, - AccessToken: rawIDToken, - IDToken: rawIDToken, - RefreshToken: "", - ExpiresOn: &idToken.Expiry, - } - - return newSession, nil + return nil, errors.New("not implemented") } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 745bf2d4..74d7096f 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -2,15 +2,10 @@ package providers import ( "context" - "crypto/rand" - "crypto/rsa" "net/url" "testing" "time" - "github.com/coreos/go-oidc" - "github.com/dgrijalva/jwt-go" - "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -52,39 +47,3 @@ func TestAcrValuesConfigured(t *testing.T) { result := p.GetLoginURL("https://my.test.app/oauth", "") assert.Contains(t, result, "acr_values=testValue") } - -func TestCreateSessionStateFromBearerToken(t *testing.T) { - minimalIDToken := jwt.StandardClaims{ - Audience: "asdf1234", - ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), - Id: "id-some-id", - IssuedAt: time.Now().Unix(), - Issuer: "https://issuer.example.com", - NotBefore: 0, - Subject: "123456789", - } - // From oidc_test.go - verifier := oidc.NewVerifier( - "https://issuer.example.com", - fakeKeySetStub{}, - &oidc.Config{ClientID: "asdf1234"}, - ) - - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, minimalIDToken).SignedString(key) - assert.NoError(t, err) - // Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below - idToken, err := verifier.Verify(context.Background(), rawIDToken) - assert.NoError(t, err) - - session, err := (*ProviderData)(nil).CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken) - assert.NoError(t, err) - - assert.Equal(t, rawIDToken, session.AccessToken) - assert.Equal(t, rawIDToken, session.IDToken) - assert.Equal(t, "123456789", session.Email) - assert.Equal(t, "123456789", session.User) - assert.Empty(t, session.RefreshToken) - assert.Empty(t, session.PreferredUsername) -}