1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-07-07 01:17:14 +02:00

Generalize and extend default CreateSessionFromToken

This commit is contained in:
Nick Meves
2020-11-15 18:57:48 -08:00
parent 44fa8316a1
commit 22f60e9b63
10 changed files with 148 additions and 209 deletions

View File

@ -13,7 +13,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/coreos/go-oidc"
"github.com/justinas/alice" "github.com/justinas/alice"
ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
@ -101,8 +100,6 @@ type OAuthProxy struct {
PreferEmailToUser bool PreferEmailToUser bool
skipAuthPreflight bool skipAuthPreflight bool
skipJwtBearerTokens bool skipJwtBearerTokens bool
mainJwtBearerVerifier *oidc.IDTokenVerifier
extraJwtBearerVerifiers []*oidc.IDTokenVerifier
templates *template.Template templates *template.Template
realClientIPParser ipapi.RealClientIPParser realClientIPParser ipapi.RealClientIPParser
trustedIPs *ip.NetSet trustedIPs *ip.NetSet
@ -212,8 +209,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
whitelistDomains: opts.WhitelistDomains, whitelistDomains: opts.WhitelistDomains,
skipAuthPreflight: opts.SkipAuthPreflight, skipAuthPreflight: opts.SkipAuthPreflight,
skipJwtBearerTokens: opts.SkipJwtBearerTokens, skipJwtBearerTokens: opts.SkipJwtBearerTokens,
mainJwtBearerVerifier: opts.GetOIDCVerifier(),
extraJwtBearerVerifiers: opts.GetJWTBearerVerifiers(),
realClientIPParser: opts.GetRealClientIPParser(), realClientIPParser: opts.GetRealClientIPParser(),
SkipProviderButton: opts.SkipProviderButton, SkipProviderButton: opts.SkipProviderButton,
templates: templates, templates: templates,
@ -266,22 +261,13 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
chain := alice.New() chain := alice.New()
if opts.SkipJwtBearerTokens { if opts.SkipJwtBearerTokens {
sessionLoaders := []middlewareapi.TokenToSessionLoader{} sessionLoaders := []middlewareapi.TokenToSessionFunc{
if opts.GetOIDCVerifier() != nil { opts.GetProvider().CreateSessionFromToken,
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
Verifier: func(ctx context.Context, token string) (interface{}, error) {
return opts.GetOIDCVerifier().Verify(ctx, token)
},
TokenToSession: opts.GetProvider().CreateSessionFromToken,
})
} }
for _, verifier := range opts.GetJWTBearerVerifiers() { for _, verifier := range opts.GetJWTBearerVerifiers() {
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ sessionLoaders = append(sessionLoaders,
Verifier: func(ctx context.Context, token string) (interface{}, error) { middlewareapi.CreateTokenToSessionFunc(verifier.Verify))
return verifier.Verify(ctx, token)
},
})
} }
chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders))

View File

@ -2,25 +2,57 @@ package middleware
import ( import (
"context" "context"
"fmt"
"github.com/coreos/go-oidc"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
) )
// TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a // TokenToSessionFunc takes a raw ID Token and converts it into a SessionState.
// SessionState. type TokenToSessionFunc func(ctx context.Context, token string) (*sessionsapi.SessionState, error)
type TokenToSessionFunc func(ctx context.Context, token string, verify VerifyFunc) (*sessionsapi.SessionState, error)
// VerifyFunc takes a raw bearer token and verifies it // VerifyFunc takes a raw bearer token and verifies it returning the converted
type VerifyFunc func(ctx context.Context, token string) (interface{}, error) // oidc.IDToken representation of the token.
type VerifyFunc func(ctx context.Context, token string) (*oidc.IDToken, error)
// TokenToSessionLoader pairs a token verifier with the correct converter function // CreateTokenToSessionFunc provides a handler that is a default implementation
// to convert the ID Token to a SessionState. // for converting a JWT into a session.
type TokenToSessionLoader struct { func CreateTokenToSessionFunc(verify VerifyFunc) TokenToSessionFunc {
// Verifier is used to verify that the ID Token was signed by the claimed issuer return func(ctx context.Context, token string) (*sessionsapi.SessionState, error) {
// and that the token has not been tampered with. var claims struct {
Verifier VerifyFunc Subject string `json:"sub"`
Email string `json:"email"`
Verified *bool `json:"email_verified"`
PreferredUsername string `json:"preferred_username"`
}
// TokenToSession converts a raw bearer token to a SessionState. idToken, err := verify(ctx, token)
// (Optional) If not set a default basic implementation is used. if err != nil {
TokenToSession TokenToSessionFunc return nil, err
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
}
if claims.Email == "" {
claims.Email = claims.Subject
}
if claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
newSession := &sessionsapi.SessionState{
Email: claims.Email,
User: claims.Subject,
PreferredUsername: claims.PreferredUsername,
AccessToken: token,
IDToken: token,
RefreshToken: "",
ExpiresOn: &idToken.Expiry,
}
return newSession, nil
}
} }

View File

@ -1,12 +1,10 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"regexp" "regexp"
"github.com/coreos/go-oidc"
"github.com/justinas/alice" "github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
@ -16,16 +14,7 @@ import (
const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`
func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc) alice.Constructor {
for i, loader := range sessionLoaders {
if loader.TokenToSession == nil {
sessionLoaders[i] = middlewareapi.TokenToSessionLoader{
Verifier: loader.Verifier,
TokenToSession: createSessionFromToken,
}
}
}
js := &jwtSessionLoader{ js := &jwtSessionLoader{
jwtRegex: regexp.MustCompile(jwtRegexFormat), jwtRegex: regexp.MustCompile(jwtRegexFormat),
sessionLoaders: sessionLoaders, sessionLoaders: sessionLoaders,
@ -37,7 +26,7 @@ func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) al
// Authorization headers. // Authorization headers.
type jwtSessionLoader struct { type jwtSessionLoader struct {
jwtRegex *regexp.Regexp jwtRegex *regexp.Regexp
sessionLoaders []middlewareapi.TokenToSessionLoader sessionLoaders []middlewareapi.TokenToSessionFunc
} }
// loadSession attempts to load a session from a JWT stored in an Authorization // loadSession attempts to load a session from a JWT stored in an Authorization
@ -83,7 +72,7 @@ func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.Sessio
errs := []error{fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))} errs := []error{fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))}
for _, loader := range j.sessionLoaders { for _, loader := range j.sessionLoaders {
session, err := loader.TokenToSession(req.Context(), token, loader.Verifier) session, err := loader(req.Context(), token)
if err == nil { if err == nil {
return session, nil return session, nil
} else { } else {
@ -135,48 +124,3 @@ func (j *jwtSessionLoader) getBasicToken(token string) (string, error) {
return "", fmt.Errorf("invalid basic auth token found in authorization header") return "", fmt.Errorf("invalid basic auth token found in authorization header")
} }
// createSessionFromToken is a default implementation for converting
// a JWT into a session state.
func createSessionFromToken(ctx context.Context, token string, verify middlewareapi.VerifyFunc) (*sessionsapi.SessionState, error) {
var claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
Verified *bool `json:"email_verified"`
PreferredUsername string `json:"preferred_username"`
}
verifiedToken, err := verify(ctx, token)
if err != nil {
return nil, err
}
idToken, ok := verifiedToken.(*oidc.IDToken)
if !ok {
return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token)
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
}
if claims.Email == "" {
claims.Email = claims.Subject
}
if claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
newSession := &sessionsapi.SessionState{
Email: claims.Email,
User: claims.Subject,
PreferredUsername: claims.PreferredUsername,
AccessToken: token,
IDToken: token,
RefreshToken: "",
ExpiresOn: &idToken.Expiry,
}
return newSession, nil
}

View File

@ -26,7 +26,7 @@ import (
type noOpKeySet struct { type noOpKeySet struct {
} }
func (noOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { func (noOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
splitStrings := strings.Split(jwt, ".") splitStrings := strings.Split(jwt, ".")
payloadString := splitStrings[1] payloadString := splitStrings[1]
return base64.RawURLEncoding.DecodeString(payloadString) return base64.RawURLEncoding.DecodeString(payloadString)
@ -78,16 +78,14 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
const nonVerifiedToken = validToken const nonVerifiedToken = validToken
BeforeEach(func() { BeforeEach(func() {
verifier = func(ctx context.Context, token string) (interface{}, error) { verifier = oidc.NewVerifier(
return oidc.NewVerifier(
"https://issuer.example.com", "https://issuer.example.com",
noOpKeySet{}, noOpKeySet{},
&oidc.Config{ &oidc.Config{
ClientID: "https://test.myapp.com", ClientID: "https://test.myapp.com",
SkipExpiryCheck: true, SkipExpiryCheck: true,
}, },
).Verify(ctx, token) ).Verify
}
}) })
type jwtSessionLoaderTableInput struct { type jwtSessionLoaderTableInput struct {
@ -110,10 +108,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
sessionLoaders := []middlewareapi.TokenToSessionLoader{ sessionLoaders := []middlewareapi.TokenToSessionFunc{
{ middlewareapi.CreateTokenToSessionFunc(verifier),
Verifier: verifier,
},
} }
// Create the handler with a next handler that will capture the session // Create the handler with a next handler that will capture the session
@ -175,24 +171,19 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
const nonVerifiedToken = validToken const nonVerifiedToken = validToken
BeforeEach(func() { BeforeEach(func() {
verifier := func(ctx context.Context, token string) (interface{}, error) { verifier := oidc.NewVerifier(
return oidc.NewVerifier(
"https://issuer.example.com", "https://issuer.example.com",
noOpKeySet{}, noOpKeySet{},
&oidc.Config{ &oidc.Config{
ClientID: "https://test.myapp.com", ClientID: "https://test.myapp.com",
SkipExpiryCheck: true, SkipExpiryCheck: true,
}, },
).Verify(ctx, token) ).Verify
}
j = &jwtSessionLoader{ j = &jwtSessionLoader{
jwtRegex: regexp.MustCompile(jwtRegexFormat), jwtRegex: regexp.MustCompile(jwtRegexFormat),
sessionLoaders: []middlewareapi.TokenToSessionLoader{ sessionLoaders: []middlewareapi.TokenToSessionFunc{
{ middlewareapi.CreateTokenToSessionFunc(verifier),
Verifier: verifier,
TokenToSession: createSessionFromToken,
},
}, },
} }
}) })
@ -402,7 +393,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
) )
}) })
Context("createSessionFromToken", func() { Context("CreateTokenToSessionFunc", func() {
ctx := context.Background() ctx := context.Background()
expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) expiresFuture := time.Now().Add(time.Duration(5) * time.Minute)
verified := true verified := true
@ -414,7 +405,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
jwt.StandardClaims jwt.StandardClaims
} }
type createSessionStateTableInput struct { type tokenToSessionTableInput struct {
idToken idTokenClaims idToken idTokenClaims
expectedErr error expectedErr error
expectedUser string expectedUser string
@ -423,8 +414,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
} }
DescribeTable("when creating a session from an IDToken", DescribeTable("when creating a session from an IDToken",
func(in createSessionStateTableInput) { func(in tokenToSessionTableInput) {
verifier := func(ctx context.Context, token string) (interface{}, error) { verifier := func(ctx context.Context, token string) (*oidc.IDToken, error) {
oidcVerifier := oidc.NewVerifier( oidcVerifier := oidc.NewVerifier(
"https://issuer.example.com", "https://issuer.example.com",
noOpKeySet{}, noOpKeySet{},
@ -443,7 +434,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
session, err := createSessionFromToken(ctx, rawIDToken, verifier) session, err := middlewareapi.CreateTokenToSessionFunc(verifier)(ctx, rawIDToken)
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
Expect(session).To(BeNil()) Expect(session).To(BeNil())
@ -459,7 +450,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
Expect(session.RefreshToken).To(BeEmpty()) Expect(session.RefreshToken).To(BeEmpty())
Expect(session.PreferredUsername).To(BeEmpty()) Expect(session.PreferredUsername).To(BeEmpty())
}, },
Entry("with no email", createSessionStateTableInput{ Entry("with no email", tokenToSessionTableInput{
idToken: idTokenClaims{ idToken: idTokenClaims{
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Audience: "asdf1234", Audience: "asdf1234",
@ -476,7 +467,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
expectedEmail: "123456789", expectedEmail: "123456789",
expectedExpires: &expiresFuture, expectedExpires: &expiresFuture,
}), }),
Entry("with a verified email", createSessionStateTableInput{ Entry("with a verified email", tokenToSessionTableInput{
idToken: idTokenClaims{ idToken: idTokenClaims{
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Audience: "asdf1234", Audience: "asdf1234",
@ -495,7 +486,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
expectedEmail: "foo@example.com", expectedEmail: "foo@example.com",
expectedExpires: &expiresFuture, expectedExpires: &expiresFuture,
}), }),
Entry("with a non-verified email", createSessionStateTableInput{ Entry("with a non-verified email", tokenToSessionTableInput{
idToken: idTokenClaims{ idToken: idTokenClaims{
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Audience: "asdf1234", Audience: "asdf1234",

View File

@ -233,6 +233,9 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
// Make the OIDC Verifier accessible to all providers that can support it
p.Verifier = o.GetOIDCVerifier()
p.SetAllowedGroups(o.AllowedGroups) p.SetAllowedGroups(o.AllowedGroups)
provider := providers.New(o.ProviderType, p) provider := providers.New(o.ProviderType, p)
@ -273,18 +276,14 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.UserIDClaim = o.UserIDClaim p.UserIDClaim = o.UserIDClaim
p.GroupsClaim = o.OIDCGroupsClaim p.GroupsClaim = o.OIDCGroupsClaim
if o.GetOIDCVerifier() == nil { if p.Verifier == nil {
msgs = append(msgs, "oidc provider requires an oidc issuer URL") msgs = append(msgs, "oidc provider requires an oidc issuer URL")
} else {
p.Verifier = o.GetOIDCVerifier()
} }
case *providers.GitLabProvider: case *providers.GitLabProvider:
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.Groups = o.GitLabGroup p.Groups = o.GitLabGroup
if o.GetOIDCVerifier() != nil { if p.Verifier == nil {
p.Verifier = o.GetOIDCVerifier()
} else {
// Initialize with default verifier for gitlab.com // Initialize with default verifier for gitlab.com
ctx := context.Background() ctx := context.Background()

View File

@ -11,7 +11,6 @@ import (
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
@ -23,7 +22,6 @@ const emailClaim = "email"
type OIDCProvider struct { type OIDCProvider struct {
*ProviderData *ProviderData
Verifier *oidc.IDTokenVerifier
AllowUnverifiedEmail bool AllowUnverifiedEmail bool
UserIDClaim string UserIDClaim string
GroupsClaim string GroupsClaim string
@ -176,17 +174,12 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
return newSession, nil return newSession, nil
} }
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) { func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
verifiedToken, err := verify(ctx, token) idToken, err := p.Verifier.Verify(ctx, token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
idToken, ok := verifiedToken.(*oidc.IDToken)
if !ok {
return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token)
}
newSession, err := p.createSessionStateInternal(ctx, idToken, nil) newSession, err := p.createSessionStateInternal(ctx, idToken, nil)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -144,15 +144,16 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
Scheme: serverURL.Scheme, Scheme: serverURL.Scheme,
Host: serverURL.Host, Host: serverURL.Host,
Path: "/api"}, Path: "/api"},
Scope: "openid profile offline_access"} Scope: "openid profile offline_access",
p := &OIDCProvider{
ProviderData: providerData,
Verifier: oidc.NewVerifier( Verifier: oidc.NewVerifier(
"https://issuer.example.com", "https://issuer.example.com",
fakeKeySetStub{}, fakeKeySetStub{},
&oidc.Config{ClientID: clientID}, &oidc.Config{ClientID: clientID},
), ),
}
p := &OIDCProvider{
ProviderData: providerData,
UserIDClaim: "email", UserIDClaim: "email",
} }
@ -347,18 +348,7 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) {
rawIDToken, err := newSignedTestIDToken(tc.IDToken) rawIDToken, err := newSignedTestIDToken(tc.IDToken)
assert.NoError(t, err) assert.NoError(t, err)
verifyFunc := func(ctx context.Context, token string) (interface{}, error) { ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken)
keyset := fakeKeySetStub{}
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
idToken, err := verifier.Verify(ctx, token)
assert.NoError(t, err)
return idToken, nil
}
ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken, verifyFunc)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.ExpectedUser, ss.User) assert.Equal(t, tc.ExpectedUser, ss.User)

View File

@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"net/url" "net/url"
"github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
) )
@ -26,6 +27,7 @@ type ProviderData struct {
ClientSecretFile string ClientSecretFile string
Scope string Scope string
Prompt string Prompt string
Verifier *oidc.IDTokenVerifier
// Universal Group authorization data structure // Universal Group authorization data structure
// any provider can set to consume // any provider can set to consume

View File

@ -126,6 +126,9 @@ func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.Ses
// CreateSessionStateFromBearerToken should be implemented to allow providers // CreateSessionStateFromBearerToken should be implemented to allow providers
// to convert ID tokens into sessions // to convert ID tokens into sessions
func (p *ProviderData) CreateSessionFromToken(_ context.Context, _ string, _ middleware.VerifyFunc) (*sessions.SessionState, error) { func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
if p.Verifier != nil {
return middleware.CreateTokenToSessionFunc(p.Verifier.Verify)(ctx, token)
}
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }

View File

@ -3,7 +3,6 @@ package providers
import ( import (
"context" "context"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
) )
@ -18,7 +17,7 @@ type Provider interface {
ValidateSession(ctx context.Context, s *sessions.SessionState) bool ValidateSession(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
} }
// New provides a new Provider based on the configured provider string // New provides a new Provider based on the configured provider string