1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-07-07 01:17:14 +02:00

Generalize and extend default CreateSessionFromToken

This commit is contained in:
Nick Meves
2020-11-15 18:57:48 -08:00
parent 44fa8316a1
commit 22f60e9b63
10 changed files with 148 additions and 209 deletions

View File

@ -13,7 +13,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/coreos/go-oidc"
"github.com/justinas/alice" "github.com/justinas/alice"
ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
@ -78,36 +77,34 @@ type OAuthProxy struct {
AuthOnlyPath string AuthOnlyPath string
UserInfoPath string UserInfoPath string
allowedRoutes []allowedRoute allowedRoutes []allowedRoute
redirectURL *url.URL // the url to receive requests at redirectURL *url.URL // the url to receive requests at
whitelistDomains []string whitelistDomains []string
provider providers.Provider provider providers.Provider
providerNameOverride string providerNameOverride string
sessionStore sessionsapi.SessionStore sessionStore sessionsapi.SessionStore
ProxyPrefix string ProxyPrefix string
SignInMessage string SignInMessage string
basicAuthValidator basic.Validator basicAuthValidator basic.Validator
displayHtpasswdForm bool displayHtpasswdForm bool
serveMux http.Handler serveMux http.Handler
SetXAuthRequest bool SetXAuthRequest bool
PassBasicAuth bool PassBasicAuth bool
SetBasicAuth bool SetBasicAuth bool
SkipProviderButton bool SkipProviderButton bool
PassUserHeaders bool PassUserHeaders bool
BasicAuthPassword string BasicAuthPassword string
PassAccessToken bool PassAccessToken bool
SetAuthorization bool SetAuthorization bool
PassAuthorization bool PassAuthorization bool
PreferEmailToUser bool PreferEmailToUser bool
skipAuthPreflight bool skipAuthPreflight bool
skipJwtBearerTokens bool skipJwtBearerTokens bool
mainJwtBearerVerifier *oidc.IDTokenVerifier templates *template.Template
extraJwtBearerVerifiers []*oidc.IDTokenVerifier realClientIPParser ipapi.RealClientIPParser
templates *template.Template trustedIPs *ip.NetSet
realClientIPParser ipapi.RealClientIPParser Banner string
trustedIPs *ip.NetSet Footer string
Banner string
Footer string
sessionChain alice.Chain sessionChain alice.Chain
headersChain alice.Chain headersChain alice.Chain
@ -202,25 +199,23 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix), AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix),
UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix),
ProxyPrefix: opts.ProxyPrefix, ProxyPrefix: opts.ProxyPrefix,
provider: opts.GetProvider(), provider: opts.GetProvider(),
providerNameOverride: opts.ProviderName, providerNameOverride: opts.ProviderName,
sessionStore: sessionStore, sessionStore: sessionStore,
serveMux: upstreamProxy, serveMux: upstreamProxy,
redirectURL: redirectURL, redirectURL: redirectURL,
allowedRoutes: allowedRoutes, allowedRoutes: allowedRoutes,
whitelistDomains: opts.WhitelistDomains, whitelistDomains: opts.WhitelistDomains,
skipAuthPreflight: opts.SkipAuthPreflight, skipAuthPreflight: opts.SkipAuthPreflight,
skipJwtBearerTokens: opts.SkipJwtBearerTokens, skipJwtBearerTokens: opts.SkipJwtBearerTokens,
mainJwtBearerVerifier: opts.GetOIDCVerifier(), realClientIPParser: opts.GetRealClientIPParser(),
extraJwtBearerVerifiers: opts.GetJWTBearerVerifiers(), SkipProviderButton: opts.SkipProviderButton,
realClientIPParser: opts.GetRealClientIPParser(), templates: templates,
SkipProviderButton: opts.SkipProviderButton, trustedIPs: trustedIPs,
templates: templates, Banner: opts.Banner,
trustedIPs: trustedIPs, Footer: opts.Footer,
Banner: opts.Banner, SignInMessage: buildSignInMessage(opts),
Footer: opts.Footer,
SignInMessage: buildSignInMessage(opts),
basicAuthValidator: basicAuthValidator, basicAuthValidator: basicAuthValidator,
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
@ -266,22 +261,13 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
chain := alice.New() chain := alice.New()
if opts.SkipJwtBearerTokens { if opts.SkipJwtBearerTokens {
sessionLoaders := []middlewareapi.TokenToSessionLoader{} sessionLoaders := []middlewareapi.TokenToSessionFunc{
if opts.GetOIDCVerifier() != nil { opts.GetProvider().CreateSessionFromToken,
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
Verifier: func(ctx context.Context, token string) (interface{}, error) {
return opts.GetOIDCVerifier().Verify(ctx, token)
},
TokenToSession: opts.GetProvider().CreateSessionFromToken,
})
} }
for _, verifier := range opts.GetJWTBearerVerifiers() { for _, verifier := range opts.GetJWTBearerVerifiers() {
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ sessionLoaders = append(sessionLoaders,
Verifier: func(ctx context.Context, token string) (interface{}, error) { middlewareapi.CreateTokenToSessionFunc(verifier.Verify))
return verifier.Verify(ctx, token)
},
})
} }
chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders))

View File

@ -2,25 +2,57 @@ package middleware
import ( import (
"context" "context"
"fmt"
"github.com/coreos/go-oidc"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
) )
// TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a // TokenToSessionFunc takes a raw ID Token and converts it into a SessionState.
// SessionState. type TokenToSessionFunc func(ctx context.Context, token string) (*sessionsapi.SessionState, error)
type TokenToSessionFunc func(ctx context.Context, token string, verify VerifyFunc) (*sessionsapi.SessionState, error)
// VerifyFunc takes a raw bearer token and verifies it // VerifyFunc takes a raw bearer token and verifies it returning the converted
type VerifyFunc func(ctx context.Context, token string) (interface{}, error) // oidc.IDToken representation of the token.
type VerifyFunc func(ctx context.Context, token string) (*oidc.IDToken, error)
// TokenToSessionLoader pairs a token verifier with the correct converter function // CreateTokenToSessionFunc provides a handler that is a default implementation
// to convert the ID Token to a SessionState. // for converting a JWT into a session.
type TokenToSessionLoader struct { func CreateTokenToSessionFunc(verify VerifyFunc) TokenToSessionFunc {
// Verifier is used to verify that the ID Token was signed by the claimed issuer return func(ctx context.Context, token string) (*sessionsapi.SessionState, error) {
// and that the token has not been tampered with. var claims struct {
Verifier VerifyFunc Subject string `json:"sub"`
Email string `json:"email"`
Verified *bool `json:"email_verified"`
PreferredUsername string `json:"preferred_username"`
}
// TokenToSession converts a raw bearer token to a SessionState. idToken, err := verify(ctx, token)
// (Optional) If not set a default basic implementation is used. if err != nil {
TokenToSession TokenToSessionFunc return nil, err
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
}
if claims.Email == "" {
claims.Email = claims.Subject
}
if claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
newSession := &sessionsapi.SessionState{
Email: claims.Email,
User: claims.Subject,
PreferredUsername: claims.PreferredUsername,
AccessToken: token,
IDToken: token,
RefreshToken: "",
ExpiresOn: &idToken.Expiry,
}
return newSession, nil
}
} }

View File

@ -1,12 +1,10 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"regexp" "regexp"
"github.com/coreos/go-oidc"
"github.com/justinas/alice" "github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
@ -16,16 +14,7 @@ import (
const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`
func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc) alice.Constructor {
for i, loader := range sessionLoaders {
if loader.TokenToSession == nil {
sessionLoaders[i] = middlewareapi.TokenToSessionLoader{
Verifier: loader.Verifier,
TokenToSession: createSessionFromToken,
}
}
}
js := &jwtSessionLoader{ js := &jwtSessionLoader{
jwtRegex: regexp.MustCompile(jwtRegexFormat), jwtRegex: regexp.MustCompile(jwtRegexFormat),
sessionLoaders: sessionLoaders, sessionLoaders: sessionLoaders,
@ -37,7 +26,7 @@ func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) al
// Authorization headers. // Authorization headers.
type jwtSessionLoader struct { type jwtSessionLoader struct {
jwtRegex *regexp.Regexp jwtRegex *regexp.Regexp
sessionLoaders []middlewareapi.TokenToSessionLoader sessionLoaders []middlewareapi.TokenToSessionFunc
} }
// loadSession attempts to load a session from a JWT stored in an Authorization // loadSession attempts to load a session from a JWT stored in an Authorization
@ -83,7 +72,7 @@ func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.Sessio
errs := []error{fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))} errs := []error{fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))}
for _, loader := range j.sessionLoaders { for _, loader := range j.sessionLoaders {
session, err := loader.TokenToSession(req.Context(), token, loader.Verifier) session, err := loader(req.Context(), token)
if err == nil { if err == nil {
return session, nil return session, nil
} else { } else {
@ -135,48 +124,3 @@ func (j *jwtSessionLoader) getBasicToken(token string) (string, error) {
return "", fmt.Errorf("invalid basic auth token found in authorization header") return "", fmt.Errorf("invalid basic auth token found in authorization header")
} }
// createSessionFromToken is a default implementation for converting
// a JWT into a session state.
func createSessionFromToken(ctx context.Context, token string, verify middlewareapi.VerifyFunc) (*sessionsapi.SessionState, error) {
var claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
Verified *bool `json:"email_verified"`
PreferredUsername string `json:"preferred_username"`
}
verifiedToken, err := verify(ctx, token)
if err != nil {
return nil, err
}
idToken, ok := verifiedToken.(*oidc.IDToken)
if !ok {
return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token)
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
}
if claims.Email == "" {
claims.Email = claims.Subject
}
if claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
newSession := &sessionsapi.SessionState{
Email: claims.Email,
User: claims.Subject,
PreferredUsername: claims.PreferredUsername,
AccessToken: token,
IDToken: token,
RefreshToken: "",
ExpiresOn: &idToken.Expiry,
}
return newSession, nil
}

View File

@ -26,7 +26,7 @@ import (
type noOpKeySet struct { type noOpKeySet struct {
} }
func (noOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { func (noOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
splitStrings := strings.Split(jwt, ".") splitStrings := strings.Split(jwt, ".")
payloadString := splitStrings[1] payloadString := splitStrings[1]
return base64.RawURLEncoding.DecodeString(payloadString) return base64.RawURLEncoding.DecodeString(payloadString)
@ -78,16 +78,14 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
const nonVerifiedToken = validToken const nonVerifiedToken = validToken
BeforeEach(func() { BeforeEach(func() {
verifier = func(ctx context.Context, token string) (interface{}, error) { verifier = oidc.NewVerifier(
return oidc.NewVerifier( "https://issuer.example.com",
"https://issuer.example.com", noOpKeySet{},
noOpKeySet{}, &oidc.Config{
&oidc.Config{ ClientID: "https://test.myapp.com",
ClientID: "https://test.myapp.com", SkipExpiryCheck: true,
SkipExpiryCheck: true, },
}, ).Verify
).Verify(ctx, token)
}
}) })
type jwtSessionLoaderTableInput struct { type jwtSessionLoaderTableInput struct {
@ -110,10 +108,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
sessionLoaders := []middlewareapi.TokenToSessionLoader{ sessionLoaders := []middlewareapi.TokenToSessionFunc{
{ middlewareapi.CreateTokenToSessionFunc(verifier),
Verifier: verifier,
},
} }
// Create the handler with a next handler that will capture the session // Create the handler with a next handler that will capture the session
@ -175,24 +171,19 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
const nonVerifiedToken = validToken const nonVerifiedToken = validToken
BeforeEach(func() { BeforeEach(func() {
verifier := func(ctx context.Context, token string) (interface{}, error) { verifier := oidc.NewVerifier(
return oidc.NewVerifier( "https://issuer.example.com",
"https://issuer.example.com", noOpKeySet{},
noOpKeySet{}, &oidc.Config{
&oidc.Config{ ClientID: "https://test.myapp.com",
ClientID: "https://test.myapp.com", SkipExpiryCheck: true,
SkipExpiryCheck: true, },
}, ).Verify
).Verify(ctx, token)
}
j = &jwtSessionLoader{ j = &jwtSessionLoader{
jwtRegex: regexp.MustCompile(jwtRegexFormat), jwtRegex: regexp.MustCompile(jwtRegexFormat),
sessionLoaders: []middlewareapi.TokenToSessionLoader{ sessionLoaders: []middlewareapi.TokenToSessionFunc{
{ middlewareapi.CreateTokenToSessionFunc(verifier),
Verifier: verifier,
TokenToSession: createSessionFromToken,
},
}, },
} }
}) })
@ -402,7 +393,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
) )
}) })
Context("createSessionFromToken", func() { Context("CreateTokenToSessionFunc", func() {
ctx := context.Background() ctx := context.Background()
expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) expiresFuture := time.Now().Add(time.Duration(5) * time.Minute)
verified := true verified := true
@ -414,7 +405,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
jwt.StandardClaims jwt.StandardClaims
} }
type createSessionStateTableInput struct { type tokenToSessionTableInput struct {
idToken idTokenClaims idToken idTokenClaims
expectedErr error expectedErr error
expectedUser string expectedUser string
@ -423,8 +414,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
} }
DescribeTable("when creating a session from an IDToken", DescribeTable("when creating a session from an IDToken",
func(in createSessionStateTableInput) { func(in tokenToSessionTableInput) {
verifier := func(ctx context.Context, token string) (interface{}, error) { verifier := func(ctx context.Context, token string) (*oidc.IDToken, error) {
oidcVerifier := oidc.NewVerifier( oidcVerifier := oidc.NewVerifier(
"https://issuer.example.com", "https://issuer.example.com",
noOpKeySet{}, noOpKeySet{},
@ -443,7 +434,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
session, err := createSessionFromToken(ctx, rawIDToken, verifier) session, err := middlewareapi.CreateTokenToSessionFunc(verifier)(ctx, rawIDToken)
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
Expect(session).To(BeNil()) Expect(session).To(BeNil())
@ -459,7 +450,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
Expect(session.RefreshToken).To(BeEmpty()) Expect(session.RefreshToken).To(BeEmpty())
Expect(session.PreferredUsername).To(BeEmpty()) Expect(session.PreferredUsername).To(BeEmpty())
}, },
Entry("with no email", createSessionStateTableInput{ Entry("with no email", tokenToSessionTableInput{
idToken: idTokenClaims{ idToken: idTokenClaims{
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Audience: "asdf1234", Audience: "asdf1234",
@ -476,7 +467,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
expectedEmail: "123456789", expectedEmail: "123456789",
expectedExpires: &expiresFuture, expectedExpires: &expiresFuture,
}), }),
Entry("with a verified email", createSessionStateTableInput{ Entry("with a verified email", tokenToSessionTableInput{
idToken: idTokenClaims{ idToken: idTokenClaims{
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Audience: "asdf1234", Audience: "asdf1234",
@ -495,7 +486,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
expectedEmail: "foo@example.com", expectedEmail: "foo@example.com",
expectedExpires: &expiresFuture, expectedExpires: &expiresFuture,
}), }),
Entry("with a non-verified email", createSessionStateTableInput{ Entry("with a non-verified email", tokenToSessionTableInput{
idToken: idTokenClaims{ idToken: idTokenClaims{
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Audience: "asdf1234", Audience: "asdf1234",

View File

@ -233,6 +233,9 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
// Make the OIDC Verifier accessible to all providers that can support it
p.Verifier = o.GetOIDCVerifier()
p.SetAllowedGroups(o.AllowedGroups) p.SetAllowedGroups(o.AllowedGroups)
provider := providers.New(o.ProviderType, p) provider := providers.New(o.ProviderType, p)
@ -273,18 +276,14 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.UserIDClaim = o.UserIDClaim p.UserIDClaim = o.UserIDClaim
p.GroupsClaim = o.OIDCGroupsClaim p.GroupsClaim = o.OIDCGroupsClaim
if o.GetOIDCVerifier() == nil { if p.Verifier == nil {
msgs = append(msgs, "oidc provider requires an oidc issuer URL") msgs = append(msgs, "oidc provider requires an oidc issuer URL")
} else {
p.Verifier = o.GetOIDCVerifier()
} }
case *providers.GitLabProvider: case *providers.GitLabProvider:
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.Groups = o.GitLabGroup p.Groups = o.GitLabGroup
if o.GetOIDCVerifier() != nil { if p.Verifier == nil {
p.Verifier = o.GetOIDCVerifier()
} else {
// Initialize with default verifier for gitlab.com // Initialize with default verifier for gitlab.com
ctx := context.Background() ctx := context.Background()

View File

@ -11,7 +11,6 @@ import (
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
@ -23,7 +22,6 @@ const emailClaim = "email"
type OIDCProvider struct { type OIDCProvider struct {
*ProviderData *ProviderData
Verifier *oidc.IDTokenVerifier
AllowUnverifiedEmail bool AllowUnverifiedEmail bool
UserIDClaim string UserIDClaim string
GroupsClaim string GroupsClaim string
@ -176,17 +174,12 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
return newSession, nil return newSession, nil
} }
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) { func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
verifiedToken, err := verify(ctx, token) idToken, err := p.Verifier.Verify(ctx, token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
idToken, ok := verifiedToken.(*oidc.IDToken)
if !ok {
return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token)
}
newSession, err := p.createSessionStateInternal(ctx, idToken, nil) newSession, err := p.createSessionStateInternal(ctx, idToken, nil)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -144,16 +144,17 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
Scheme: serverURL.Scheme, Scheme: serverURL.Scheme,
Host: serverURL.Host, Host: serverURL.Host,
Path: "/api"}, Path: "/api"},
Scope: "openid profile offline_access"} Scope: "openid profile offline_access",
p := &OIDCProvider{
ProviderData: providerData,
Verifier: oidc.NewVerifier( Verifier: oidc.NewVerifier(
"https://issuer.example.com", "https://issuer.example.com",
fakeKeySetStub{}, fakeKeySetStub{},
&oidc.Config{ClientID: clientID}, &oidc.Config{ClientID: clientID},
), ),
UserIDClaim: "email", }
p := &OIDCProvider{
ProviderData: providerData,
UserIDClaim: "email",
} }
return p return p
@ -347,18 +348,7 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) {
rawIDToken, err := newSignedTestIDToken(tc.IDToken) rawIDToken, err := newSignedTestIDToken(tc.IDToken)
assert.NoError(t, err) assert.NoError(t, err)
verifyFunc := func(ctx context.Context, token string) (interface{}, error) { ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken)
keyset := fakeKeySetStub{}
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
idToken, err := verifier.Verify(ctx, token)
assert.NoError(t, err)
return idToken, nil
}
ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken, verifyFunc)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.ExpectedUser, ss.User) assert.Equal(t, tc.ExpectedUser, ss.User)

View File

@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"net/url" "net/url"
"github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
) )
@ -26,6 +27,7 @@ type ProviderData struct {
ClientSecretFile string ClientSecretFile string
Scope string Scope string
Prompt string Prompt string
Verifier *oidc.IDTokenVerifier
// Universal Group authorization data structure // Universal Group authorization data structure
// any provider can set to consume // any provider can set to consume

View File

@ -126,6 +126,9 @@ func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.Ses
// CreateSessionStateFromBearerToken should be implemented to allow providers // CreateSessionStateFromBearerToken should be implemented to allow providers
// to convert ID tokens into sessions // to convert ID tokens into sessions
func (p *ProviderData) CreateSessionFromToken(_ context.Context, _ string, _ middleware.VerifyFunc) (*sessions.SessionState, error) { func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
if p.Verifier != nil {
return middleware.CreateTokenToSessionFunc(p.Verifier.Verify)(ctx, token)
}
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }

View File

@ -3,7 +3,6 @@ package providers
import ( import (
"context" "context"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
) )
@ -18,7 +17,7 @@ type Provider interface {
ValidateSession(ctx context.Context, s *sessions.SessionState) bool ValidateSession(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
} }
// New provides a new Provider based on the configured provider string // New provides a new Provider based on the configured provider string