1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-08 04:03:58 +02:00

Use upn claim as a fallback in Enrich & Refresh

Only when `email` claim is missing, fallback to `upn` claim which may have it.
This commit is contained in:
Nick Meves 2021-06-22 18:50:47 -07:00
parent a53198725e
commit 4980f6af7d
2 changed files with 138 additions and 10 deletions

View File

@ -1,14 +1,22 @@
package providers
import (
"context"
"fmt"
"net/url"
"strings"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
)
// ADFSProvider represents an ADFS based Identity Provider
type ADFSProvider struct {
*OIDCProvider
skipScope bool
// Expose for unit testing
oidcEnrichFunc func(context.Context, *sessions.SessionState) error
oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error)
}
var _ Provider = (*ADFSProvider)(nil)
@ -17,7 +25,7 @@ const (
adfsProviderName = "ADFS"
adfsDefaultScope = "openid email profile"
adfsSkipScope = false
adfsEmailClaim = "upn"
adfsUPNClaim = "upn"
)
// NewADFSProvider initiates a new ADFSProvider
@ -26,7 +34,6 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider {
name: adfsProviderName,
scope: adfsDefaultScope,
})
p.EmailClaim = adfsEmailClaim
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
resource := p.ProtectedResource.String()
@ -39,12 +46,16 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider {
}
}
oidcProvider := &OIDCProvider{
ProviderData: p,
SkipNonce: true,
}
return &ADFSProvider{
OIDCProvider: &OIDCProvider{
ProviderData: p,
SkipNonce: true,
},
skipScope: adfsSkipScope,
OIDCProvider: oidcProvider,
skipScope: adfsSkipScope,
oidcEnrichFunc: oidcProvider.EnrichSession,
oidcRefreshFunc: oidcProvider.RefreshSession,
}
}
@ -68,3 +79,44 @@ func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string {
}
return loginURL.String()
}
// EnrichSession calls the OIDC ProfileURL to backfill any fields missing
// from the claims. If Email is missing, falls back to ADFS `upn` claim.
func (p *ADFSProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
err := p.oidcEnrichFunc(ctx, s)
if err != nil {
return err
}
if s.Email == "" {
return p.fallbackUPN(ctx, s)
}
return nil
}
// RefreshSession refreshes via the OIDC implementation. If email is missing,
// falls back to ADFS `upn` claim.
func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
refreshed, err := p.oidcRefreshFunc(ctx, s)
if err != nil || s.Email != "" {
return refreshed, err
}
err = p.fallbackUPN(ctx, s)
return refreshed, err
}
func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error {
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil {
return err
}
claims, err := p.getClaims(idToken)
if err != nil {
return fmt.Errorf("couldn't extract claims from id_token (%v)", err)
}
upn := claims.raw[adfsUPNClaim]
if upn != nil {
s.Email = fmt.Sprint(upn)
}
return nil
}

View File

@ -2,6 +2,8 @@ package providers
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"net/http"
"net/http/httptest"
@ -9,6 +11,7 @@ import (
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
@ -25,8 +28,18 @@ func (fakeADFSJwks) VerifySignature(_ context.Context, jwt string) (payload []by
return decodeString, nil
}
func testADFSProvider(hostname string) *ADFSProvider {
type adfsClaims struct {
UPN string `json:"upn,omitempty"`
idTokenClaims
}
func newSignedTestADFSToken(tokenClaims adfsClaims) (string, error) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
return standardClaims.SignedString(key)
}
func testADFSProvider(hostname string) *ADFSProvider {
o := oidc.NewVerifier(
"https://issuer.example.com",
fakeADFSJwks{},
@ -41,6 +54,7 @@ func testADFSProvider(hostname string) *ADFSProvider {
ValidateURL: &url.URL{},
Scope: "",
Verifier: o,
EmailClaim: OIDCEmailClaim,
})
if hostname != "" {
@ -54,7 +68,6 @@ func testADFSProvider(hostname string) *ADFSProvider {
}
func testADFSBackend() *httptest.Server {
authResponse := `
{
"access_token": "my_access_token",
@ -129,7 +142,6 @@ var _ = Describe("ADFS Provider Tests", func() {
Context("with valid token", func() {
It("should not throw an error", func() {
p.EmailClaim = "email"
rawIDToken, _ := newSignedTestIDToken(defaultIDToken)
idToken, err := p.Verifier.Verify(context.Background(), rawIDToken)
Expect(err).To(BeNil())
@ -202,4 +214,68 @@ var _ = Describe("ADFS Provider Tests", func() {
}),
)
})
Context("UPN Fallback", func() {
var idToken string
var session *sessions.SessionState
BeforeEach(func() {
var err error
idToken, err = newSignedTestADFSToken(adfsClaims{
UPN: "upn@company.com",
idTokenClaims: minimalIDToken,
})
Expect(err).ToNot(HaveOccurred())
session = &sessions.SessionState{
IDToken: idToken,
}
})
Describe("EnrichSession", func() {
It("uses email claim if present", func() {
p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error {
s.Email = "person@company.com"
return nil
}
err := p.EnrichSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("person@company.com"))
})
It("falls back to UPN claim if Email is missing", func() {
p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error {
return nil
}
err := p.EnrichSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("upn@company.com"))
})
})
Describe("RefreshSession", func() {
It("uses email claim if present", func() {
p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) {
s.Email = "person@company.com"
return true, nil
}
_, err := p.RefreshSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("person@company.com"))
})
It("falls back to UPN claim if Email is missing", func() {
p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) {
return true, nil
}
_, err := p.RefreshSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("upn@company.com"))
})
})
})
})