1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-04-19 12:12:39 +02:00

Decouple TokenToSession from OIDC & add a generic VerifyFunc

This commit is contained in:
Nick Meves 2020-10-23 23:34:06 -07:00
parent e9f787957e
commit 3e9717d489
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
8 changed files with 102 additions and 55 deletions

View File

@ -269,14 +269,18 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
sessionLoaders := []middlewareapi.TokenToSessionLoader{} sessionLoaders := []middlewareapi.TokenToSessionLoader{}
if opts.GetOIDCVerifier() != nil { if opts.GetOIDCVerifier() != nil {
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
Verifier: opts.GetOIDCVerifier(), Verifier: func(ctx context.Context, token string) (interface{}, error) {
TokenToSession: opts.GetProvider().CreateSessionFromBearer, return opts.GetOIDCVerifier().Verify(ctx, token)
},
TokenToSession: opts.GetProvider().CreateSessionFromToken,
}) })
} }
for _, verifier := range opts.GetJWTBearerVerifiers() { for _, verifier := range opts.GetJWTBearerVerifiers() {
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
Verifier: verifier, Verifier: func(ctx context.Context, token string) (interface{}, error) {
return verifier.Verify(ctx, token)
},
}) })
} }

View File

@ -3,22 +3,24 @@ package middleware
import ( import (
"context" "context"
"github.com/coreos/go-oidc"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
) )
// TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a // TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a
// SessionState. // SessionState.
type TokenToSessionFunc func(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) type TokenToSessionFunc func(ctx context.Context, token string, verify VerifyFunc) (*sessionsapi.SessionState, error)
// VerifyFunc takes a raw bearer token and verifies it
type VerifyFunc func(ctx context.Context, token string) (interface{}, error)
// TokenToSessionLoader pairs a token verifier with the correct converter function // TokenToSessionLoader pairs a token verifier with the correct converter function
// to convert the ID Token to a SessionState. // to convert the ID Token to a SessionState.
type TokenToSessionLoader struct { type TokenToSessionLoader struct {
// Verfier is used to verify that the ID Token was signed by the claimed issuer // Verifier is used to verify that the ID Token was signed by the claimed issuer
// and that the token has not been tampered with. // and that the token has not been tampered with.
Verifier *oidc.IDTokenVerifier Verifier VerifyFunc
// TokenToSession converts a rawIDToken and an idToken to a SessionState. // TokenToSession converts a raw bearer token to a SessionState.
// (Optional) If not set a default basic implementation is used. // (Optional) If not set a default basic implementation is used.
TokenToSession TokenToSessionFunc TokenToSession TokenToSessionFunc
} }

View File

@ -13,14 +13,14 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
) )
const jwtRegexFormat = `^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`
func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor {
for i, loader := range sessionLoaders { for i, loader := range sessionLoaders {
if loader.TokenToSession == nil { if loader.TokenToSession == nil {
sessionLoaders[i] = middlewareapi.TokenToSessionLoader{ sessionLoaders[i] = middlewareapi.TokenToSessionLoader{
Verifier: loader.Verifier, Verifier: loader.Verifier,
TokenToSession: createSessionStateFromBearerToken, TokenToSession: createSessionFromToken,
} }
} }
} }
@ -75,24 +75,24 @@ func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.Sessio
return nil, nil return nil, nil
} }
rawBearerToken, err := j.findBearerTokenFromHeader(auth) token, err := j.findTokenFromHeader(auth)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, loader := range j.sessionLoaders { for _, loader := range j.sessionLoaders {
bearerToken, err := loader.Verifier.Verify(req.Context(), rawBearerToken) session, err := loader.TokenToSession(req.Context(), token, loader.Verifier)
if err == nil { if err == nil {
// The token was verified, convert it to a session return session, nil
return loader.TokenToSession(req.Context(), rawBearerToken, bearerToken)
} }
} }
// TODO (@NickMeves) Aggregate error logs in the chain
return nil, fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization")) return nil, fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))
} }
// findBearerTokenFromHeader finds a valid JWT token from the Authorization header of a given request. // findTokenFromHeader finds a valid JWT token from the Authorization header of a given request.
func (j *jwtSessionLoader) findBearerTokenFromHeader(header string) (string, error) { func (j *jwtSessionLoader) findTokenFromHeader(header string) (string, error) {
tokenType, token, err := splitAuthHeader(header) tokenType, token, err := splitAuthHeader(header)
if err != nil { if err != nil {
return "", err return "", err
@ -133,9 +133,9 @@ func (j *jwtSessionLoader) getBasicToken(token string) (string, error) {
return "", fmt.Errorf("invalid basic auth token found in authorization header") return "", fmt.Errorf("invalid basic auth token found in authorization header")
} }
// createSessionStateFromBearerToken is a default implementation for converting // createSessionFromToken is a default implementation for converting
// a JWT into a session state. // a JWT into a session state.
func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) { func createSessionFromToken(ctx context.Context, token string, verify middlewareapi.VerifyFunc) (*sessionsapi.SessionState, error) {
var claims struct { var claims struct {
Subject string `json:"sub"` Subject string `json:"sub"`
Email string `json:"email"` Email string `json:"email"`
@ -143,6 +143,16 @@ func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, i
PreferredUsername string `json:"preferred_username"` PreferredUsername string `json:"preferred_username"`
} }
verifiedToken, err := verify(ctx, token)
if err != nil {
return nil, err
}
idToken, ok := verifiedToken.(*oidc.IDToken)
if !ok {
return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token)
}
if err := idToken.Claims(&claims); err != nil { if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
} }
@ -159,8 +169,8 @@ func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, i
Email: claims.Email, Email: claims.Email,
User: claims.Subject, User: claims.Subject,
PreferredUsername: claims.PreferredUsername, PreferredUsername: claims.PreferredUsername,
AccessToken: rawIDToken, AccessToken: token,
IDToken: rawIDToken, IDToken: token,
RefreshToken: "", RefreshToken: "",
ExpiresOn: &idToken.Expiry, ExpiresOn: &idToken.Expiry,
} }

View File

@ -73,13 +73,20 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
const validToken = "eyJfoobar.eyJfoobar.12345asdf" const validToken = "eyJfoobar.eyJfoobar.12345asdf"
Context("JwtSessionLoader", func() { Context("JwtSessionLoader", func() {
var verifier *oidc.IDTokenVerifier var verifier middlewareapi.VerifyFunc
const nonVerifiedToken = validToken const nonVerifiedToken = validToken
BeforeEach(func() { BeforeEach(func() {
keyset := noOpKeySet{} verifier = func(ctx context.Context, token string) (interface{}, error) {
verifier = oidc.NewVerifier("https://issuer.example.com", keyset, return oidc.NewVerifier(
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) "https://issuer.example.com",
noOpKeySet{},
&oidc.Config{
ClientID: "https://test.myapp.com",
SkipExpiryCheck: true,
},
).Verify(ctx, token)
}
}) })
type jwtSessionLoaderTableInput struct { type jwtSessionLoaderTableInput struct {
@ -167,16 +174,23 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
const nonVerifiedToken = validToken const nonVerifiedToken = validToken
BeforeEach(func() { BeforeEach(func() {
keyset := noOpKeySet{} verifier := func(ctx context.Context, token string) (interface{}, error) {
verifier := oidc.NewVerifier("https://issuer.example.com", keyset, return oidc.NewVerifier(
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) "https://issuer.example.com",
noOpKeySet{},
&oidc.Config{
ClientID: "https://test.myapp.com",
SkipExpiryCheck: true,
},
).Verify(ctx, token)
}
j = &jwtSessionLoader{ j = &jwtSessionLoader{
jwtRegex: regexp.MustCompile(jwtRegexFormat), jwtRegex: regexp.MustCompile(jwtRegexFormat),
sessionLoaders: []middlewareapi.TokenToSessionLoader{ sessionLoaders: []middlewareapi.TokenToSessionLoader{
{ {
Verifier: verifier, Verifier: verifier,
TokenToSession: createSessionStateFromBearerToken, TokenToSession: createSessionFromToken,
}, },
}, },
} }
@ -239,7 +253,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
) )
}) })
Context("findBearerTokenFromHeader", func() { Context("findTokenFromHeader", func() {
var j *jwtSessionLoader var j *jwtSessionLoader
BeforeEach(func() { BeforeEach(func() {
@ -256,7 +270,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
DescribeTable("with a header", DescribeTable("with a header",
func(in findBearerTokenFromHeaderTableInput) { func(in findBearerTokenFromHeaderTableInput) {
token, err := j.findBearerTokenFromHeader(in.header) token, err := j.findTokenFromHeader(in.header)
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
} else { } else {
@ -381,7 +395,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
) )
}) })
Context("createSessionStateFromBearerToken", func() { Context("createSessionFromToken", func() {
ctx := context.Background() ctx := context.Background()
expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) expiresFuture := time.Now().Add(time.Duration(5) * time.Minute)
verified := true verified := true
@ -403,23 +417,26 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
DescribeTable("when creating a session from an IDToken", DescribeTable("when creating a session from an IDToken",
func(in createSessionStateTableInput) { func(in createSessionStateTableInput) {
verifier := oidc.NewVerifier( verifier := func(ctx context.Context, token string) (interface{}, error) {
oidcVerifier := oidc.NewVerifier(
"https://issuer.example.com", "https://issuer.example.com",
noOpKeySet{}, noOpKeySet{},
&oidc.Config{ClientID: "asdf1234"}, &oidc.Config{ClientID: "asdf1234"},
) )
idToken, err := oidcVerifier.Verify(ctx, token)
Expect(err).ToNot(HaveOccurred())
return idToken, nil
}
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below session, err := createSessionFromToken(ctx, rawIDToken, verifier)
idToken, err := verifier.Verify(context.Background(), rawIDToken)
Expect(err).ToNot(HaveOccurred())
session, err := createSessionStateFromBearerToken(ctx, rawIDToken, idToken)
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
Expect(session).To(BeNil()) Expect(session).To(BeNil())

View File

@ -11,6 +11,7 @@ import (
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
@ -175,14 +176,24 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
return newSession, nil return newSession, nil
} }
func (p *OIDCProvider) CreateSessionFromBearer(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) {
verifiedToken, err := verify(ctx, token)
if err != nil {
return nil, err
}
idToken, ok := verifiedToken.(*oidc.IDToken)
if !ok {
return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token)
}
newSession, err := p.createSessionStateInternal(ctx, idToken, nil) newSession, err := p.createSessionStateInternal(ctx, idToken, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
newSession.AccessToken = rawIDToken newSession.AccessToken = token
newSession.IDToken = rawIDToken newSession.IDToken = token
newSession.RefreshToken = "" newSession.RefreshToken = ""
newSession.ExpiresOn = &idToken.Expiry newSession.ExpiresOn = &idToken.Expiry

View File

@ -347,14 +347,18 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) {
rawIDToken, err := newSignedTestIDToken(tc.IDToken) rawIDToken, err := newSignedTestIDToken(tc.IDToken)
assert.NoError(t, err) assert.NoError(t, err)
verifyFunc := func(ctx context.Context, token string) (interface{}, error) {
keyset := fakeKeySetStub{} keyset := fakeKeySetStub{}
verifier := oidc.NewVerifier("https://issuer.example.com", keyset, verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(ctx, token)
assert.NoError(t, err) assert.NoError(t, err)
ss, err := provider.CreateSessionFromBearer(context.Background(), rawIDToken, idToken) return idToken, nil
}
ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken, verifyFunc)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.ExpectedUser, ss.User) assert.Equal(t, tc.ExpectedUser, ss.User)

View File

@ -8,8 +8,7 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
) )
@ -127,6 +126,6 @@ func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.Ses
// CreateSessionStateFromBearerToken should be implemented to allow providers // CreateSessionStateFromBearerToken should be implemented to allow providers
// to convert ID tokens into sessions // to convert ID tokens into sessions
func (p *ProviderData) CreateSessionFromBearer(_ context.Context, _ string, _ *oidc.IDToken) (*sessions.SessionState, error) { func (p *ProviderData) CreateSessionFromToken(_ context.Context, _ string, _ middleware.VerifyFunc) (*sessions.SessionState, error) {
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }

View File

@ -3,7 +3,7 @@ package providers
import ( import (
"context" "context"
"github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
) )
@ -18,7 +18,7 @@ type Provider interface {
ValidateSession(ctx context.Context, s *sessions.SessionState) bool ValidateSession(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
CreateSessionFromBearer(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error)
} }
// New provides a new Provider based on the configured provider string // New provides a new Provider based on the configured provider string