2021-06-13 10:00:12 +02:00
|
|
|
package providers
|
|
|
|
|
|
|
|
import (
|
2021-06-22 18:50:47 -07:00
|
|
|
"context"
|
|
|
|
"fmt"
|
2021-06-13 10:00:12 +02:00
|
|
|
"net/url"
|
|
|
|
"strings"
|
2021-06-22 18:50:47 -07:00
|
|
|
|
2022-02-15 11:18:32 +00:00
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
2021-06-22 18:50:47 -07:00
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
2021-06-13 10:00:12 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
// ADFSProvider represents an ADFS based Identity Provider
|
|
|
|
type ADFSProvider struct {
|
|
|
|
*OIDCProvider
|
2021-06-22 18:50:47 -07:00
|
|
|
|
2021-06-19 16:06:58 -07:00
|
|
|
skipScope bool
|
2021-06-22 18:50:47 -07:00
|
|
|
// Expose for unit testing
|
|
|
|
oidcEnrichFunc func(context.Context, *sessions.SessionState) error
|
|
|
|
oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error)
|
2021-06-13 10:00:12 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
var _ Provider = (*ADFSProvider)(nil)
|
|
|
|
|
|
|
|
const (
|
2021-06-19 16:06:58 -07:00
|
|
|
adfsProviderName = "ADFS"
|
|
|
|
adfsDefaultScope = "openid email profile"
|
2021-06-22 18:50:47 -07:00
|
|
|
adfsUPNClaim = "upn"
|
2021-06-13 10:00:12 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
// NewADFSProvider initiates a new ADFSProvider
|
2022-02-15 11:18:32 +00:00
|
|
|
func NewADFSProvider(p *ProviderData, opts options.ADFSOptions) *ADFSProvider {
|
2021-06-13 10:00:12 +02:00
|
|
|
p.setProviderDefaults(providerDefaults{
|
2021-06-19 16:06:58 -07:00
|
|
|
name: adfsProviderName,
|
|
|
|
scope: adfsDefaultScope,
|
2021-06-13 10:00:12 +02:00
|
|
|
})
|
|
|
|
|
|
|
|
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
|
|
|
|
resource := p.ProtectedResource.String()
|
|
|
|
if !strings.HasSuffix(resource, "/") {
|
|
|
|
resource += "/"
|
|
|
|
}
|
|
|
|
|
|
|
|
if p.Scope != "" && !strings.HasPrefix(p.Scope, resource) {
|
|
|
|
p.Scope = resource + p.Scope
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-06-22 18:50:47 -07:00
|
|
|
oidcProvider := &OIDCProvider{
|
|
|
|
ProviderData: p,
|
2021-06-22 18:57:21 -07:00
|
|
|
SkipNonce: false,
|
2021-06-22 18:50:47 -07:00
|
|
|
}
|
|
|
|
|
2021-06-13 10:00:12 +02:00
|
|
|
return &ADFSProvider{
|
2021-06-22 18:50:47 -07:00
|
|
|
OIDCProvider: oidcProvider,
|
2022-02-15 11:18:32 +00:00
|
|
|
skipScope: opts.SkipScope,
|
2021-06-22 18:50:47 -07:00
|
|
|
oidcEnrichFunc: oidcProvider.EnrichSession,
|
|
|
|
oidcRefreshFunc: oidcProvider.RefreshSession,
|
2021-06-13 10:00:12 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// GetLoginURL Override to double encode the state parameter. If not query params are lost
|
|
|
|
// More info here: https://docs.microsoft.com/en-us/powerapps/maker/portals/configure/configure-saml2-settings
|
|
|
|
func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string {
|
|
|
|
extraParams := url.Values{}
|
|
|
|
if !p.SkipNonce {
|
|
|
|
extraParams.Add("nonce", nonce)
|
|
|
|
}
|
|
|
|
loginURL := makeLoginURL(p.Data(), redirectURI, url.QueryEscape(state), extraParams)
|
2021-06-19 16:06:58 -07:00
|
|
|
if p.skipScope {
|
2021-06-13 10:00:12 +02:00
|
|
|
q := loginURL.Query()
|
|
|
|
q.Del("scope")
|
|
|
|
loginURL.RawQuery = q.Encode()
|
|
|
|
}
|
|
|
|
return loginURL.String()
|
|
|
|
}
|
2021-06-22 18:50:47 -07:00
|
|
|
|
|
|
|
// EnrichSession calls the OIDC ProfileURL to backfill any fields missing
|
|
|
|
// from the claims. If Email is missing, falls back to ADFS `upn` claim.
|
|
|
|
func (p *ADFSProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
|
|
|
err := p.oidcEnrichFunc(ctx, s)
|
2021-07-03 13:40:34 -07:00
|
|
|
if err != nil || s.Email == "" {
|
|
|
|
// OIDC only errors if email is missing
|
2021-06-22 18:50:47 -07:00
|
|
|
return p.fallbackUPN(ctx, s)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// RefreshSession refreshes via the OIDC implementation. If email is missing,
|
|
|
|
// falls back to ADFS `upn` claim.
|
|
|
|
func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
|
|
|
refreshed, err := p.oidcRefreshFunc(ctx, s)
|
|
|
|
if err != nil || s.Email != "" {
|
|
|
|
return refreshed, err
|
|
|
|
}
|
|
|
|
err = p.fallbackUPN(ctx, s)
|
|
|
|
return refreshed, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error {
|
2021-06-26 11:49:08 +01:00
|
|
|
claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken)
|
2021-06-22 18:50:47 -07:00
|
|
|
if err != nil {
|
2021-06-26 11:49:08 +01:00
|
|
|
return fmt.Errorf("could not extract claims: %v", err)
|
2021-06-22 18:50:47 -07:00
|
|
|
}
|
2021-06-26 11:49:08 +01:00
|
|
|
|
|
|
|
upn, found, err := claims.GetClaim(adfsUPNClaim)
|
2021-06-22 18:50:47 -07:00
|
|
|
if err != nil {
|
2021-06-26 11:49:08 +01:00
|
|
|
return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err)
|
2021-06-22 18:50:47 -07:00
|
|
|
}
|
2021-06-26 11:49:08 +01:00
|
|
|
|
|
|
|
if found && fmt.Sprint(upn) != "" {
|
2021-06-22 18:50:47 -07:00
|
|
|
s.Email = fmt.Sprint(upn)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|