mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-05-29 23:17:38 +02:00
Merge pull request #1394 from oauth2-proxy/claim-extractor
Add generic claim extractor to get claims from ID Tokens
This commit is contained in:
commit
9832844c8a
@ -8,6 +8,7 @@
|
||||
|
||||
## Changes since v7.2.1
|
||||
|
||||
- [#1394](https://github.com/oauth2-proxy/oauth2-proxy/pull/1394) Add generic claim extractor to get claims from ID Tokens (@JoelSpeed)
|
||||
- [#1468](https://github.com/oauth2-proxy/oauth2-proxy/pull/1468) Implement session locking with session state lock (@JoelSpeed, @Bibob7)
|
||||
- [#1489](https://github.com/oauth2-proxy/oauth2-proxy/pull/1489) Fix Docker Buildx push to include build version (@JoelSpeed)
|
||||
- [#1477](https://github.com/oauth2-proxy/oauth2-proxy/pull/1477) Remove provider documentation for `Microsoft Azure AD` (@omBratteng)
|
||||
|
1
go.mod
1
go.mod
@ -23,6 +23,7 @@ require (
|
||||
github.com/onsi/gomega v1.10.2
|
||||
github.com/pierrec/lz4 v2.5.2+incompatible
|
||||
github.com/prometheus/client_golang v1.9.0
|
||||
github.com/spf13/cast v1.3.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.6.3
|
||||
github.com/stretchr/testify v1.6.1
|
||||
|
210
pkg/providers/util/claim_extractor.go
Normal file
210
pkg/providers/util/claim_extractor.go
Normal file
@ -0,0 +1,210 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
|
||||
// present, from the profile URL.
|
||||
type ClaimExtractor interface {
|
||||
// GetClaim fetches a named claim and returns the value.
|
||||
GetClaim(claim string) (interface{}, bool, error)
|
||||
|
||||
// GetClaimInto fetches a named claim and puts the value into the destination.
|
||||
GetClaimInto(claim string, dst interface{}) (bool, error)
|
||||
}
|
||||
|
||||
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
|
||||
// If needed, it will use the profile URL to look up a claim if it isn't present
|
||||
// within the ID Token.
|
||||
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
|
||||
payload, err := parseJWT(idToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ID Token: %v", err)
|
||||
}
|
||||
|
||||
tokenClaims, err := simplejson.NewJson(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
|
||||
}
|
||||
|
||||
return &claimExtractor{
|
||||
ctx: ctx,
|
||||
profileURL: profileURL,
|
||||
requestHeaders: profileRequestHeaders,
|
||||
tokenClaims: tokenClaims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// claimExtractor implements the ClaimExtractor interface
|
||||
type claimExtractor struct {
|
||||
profileURL *url.URL
|
||||
ctx context.Context
|
||||
requestHeaders map[string][]string
|
||||
tokenClaims *simplejson.Json
|
||||
profileClaims *simplejson.Json
|
||||
}
|
||||
|
||||
// GetClaim will return the value claim if it exists.
|
||||
// It will only return an error if the profile URL needs to be fetched due to
|
||||
// the claim not being present in the ID Token.
|
||||
func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
|
||||
if claim == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if value := getClaimFrom(claim, c.tokenClaims); value != nil {
|
||||
return value, true, nil
|
||||
}
|
||||
|
||||
if c.profileClaims == nil {
|
||||
profileClaims, err := c.loadProfileClaims()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
|
||||
}
|
||||
|
||||
c.profileClaims = profileClaims
|
||||
}
|
||||
|
||||
if value := getClaimFrom(claim, c.profileClaims); value != nil {
|
||||
return value, true, nil
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// loadProfileClaims will fetch the profileURL using the provided headers as
|
||||
// authentication.
|
||||
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
|
||||
if c.profileURL == nil || c.requestHeaders == nil {
|
||||
// When no profileURL is set, we return a non-empty map so that
|
||||
// we don't attempt to populate the profile claims again.
|
||||
// If there are no headers, the request would be unauthorized so we also skip
|
||||
// in this case too.
|
||||
return simplejson.New(), nil
|
||||
}
|
||||
|
||||
claims, err := requests.New(c.profileURL.String()).
|
||||
WithContext(c.ctx).
|
||||
WithHeaders(c.requestHeaders).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request to profile URL: %v", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// GetClaimInto loads a claim and places it into the destination interface.
|
||||
// This will attempt to coerce the claim into the specified type.
|
||||
// If it cannot be coerced, an error may be returned.
|
||||
func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
|
||||
value, exists, err := c.GetClaim(claim)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("could not get claim %q: %v", claim, err)
|
||||
}
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
if err := coerceClaim(value, dst); err != nil {
|
||||
return false, fmt.Errorf("could no coerce claim: %v", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
|
||||
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
|
||||
func parseJWT(p string) ([]byte, error) {
|
||||
parts := strings.Split(p, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// getClaimFrom gets a claim from a Json object.
|
||||
// It can accept either a single claim name or a json path.
|
||||
// Paths with indexes are not supported.
|
||||
func getClaimFrom(claim string, src *simplejson.Json) interface{} {
|
||||
claimParts := strings.Split(claim, ".")
|
||||
return src.GetPath(claimParts...).Interface()
|
||||
}
|
||||
|
||||
// coerceClaim tries to convert the value into the destination interface type.
|
||||
// If it can convert the value, it will then store the value in the destination
|
||||
// interface.
|
||||
func coerceClaim(value, dst interface{}) error {
|
||||
switch d := dst.(type) {
|
||||
case *string:
|
||||
str, err := toString(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string: %v", err)
|
||||
}
|
||||
*d = str
|
||||
case *[]string:
|
||||
strSlice, err := toStringSlice(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string slice: %v", err)
|
||||
}
|
||||
*d = strSlice
|
||||
case *bool:
|
||||
*d = cast.ToBool(value)
|
||||
default:
|
||||
return fmt.Errorf("unknown type for destination: %T", dst)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// toStringSlice converts an interface (either a slice or single value) into
|
||||
// a slice of strings.
|
||||
func toStringSlice(value interface{}) ([]string, error) {
|
||||
var sliceValues []interface{}
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
sliceValues = v
|
||||
case interface{}:
|
||||
sliceValues = []interface{}{v}
|
||||
default:
|
||||
sliceValues = cast.ToSlice(value)
|
||||
}
|
||||
|
||||
out := []string{}
|
||||
for _, v := range sliceValues {
|
||||
str, err := toString(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
|
||||
}
|
||||
out = append(out, str)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// toString coerces a value into a string.
|
||||
// If it is non-string, marshal it into JSON.
|
||||
func toString(value interface{}) (string, error) {
|
||||
if str, err := cast.ToStringE(value); err == nil {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
jsonStr, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(jsonStr), nil
|
||||
}
|
530
pkg/providers/util/claim_extractor_test.go
Normal file
530
pkg/providers/util/claim_extractor_test.go
Normal file
@ -0,0 +1,530 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
emptyJSON = "{}"
|
||||
profilePath = "/userinfo"
|
||||
authorizedAccessToken = "valid_access_token"
|
||||
basicIDTokenPayload = `{
|
||||
"user": "idTokenUser",
|
||||
"email": "idTokenEmail",
|
||||
"groups": [
|
||||
"idTokenGroup1",
|
||||
"idTokenGroup2"
|
||||
]
|
||||
}`
|
||||
basicProfileURLPayload = `{
|
||||
"user": "profileUser",
|
||||
"email": "profileEmail",
|
||||
"groups": [
|
||||
"profileGroup1",
|
||||
"profileGroup2"
|
||||
]
|
||||
}`
|
||||
nestedClaimPayload = `{
|
||||
"auth": {
|
||||
"user": {
|
||||
"username": "nestedUser"
|
||||
}
|
||||
}
|
||||
}`
|
||||
complexGroupsPayload = `{
|
||||
"groups": [
|
||||
{
|
||||
"groupID": "group1",
|
||||
"roles": ["admin"]
|
||||
},
|
||||
{
|
||||
"groupID": "group2",
|
||||
"roles": ["user", "employee"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
)
|
||||
|
||||
var _ = Describe("Claim Extractor Suite", func() {
|
||||
Context("Claim Extractor", func() {
|
||||
type newClaimExtractorTableInput struct {
|
||||
idToken string
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("NewClaimExtractor",
|
||||
func(in newClaimExtractorTableInput) {
|
||||
_, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
},
|
||||
Entry("with a valid JWT", newClaimExtractorTableInput{
|
||||
idToken: createJWTFromPayload(basicIDTokenPayload),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{
|
||||
idToken: createJWTFromPayload("this is not JSON"),
|
||||
expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"),
|
||||
}),
|
||||
Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{
|
||||
idToken: "eyJeyJ",
|
||||
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"),
|
||||
}),
|
||||
Entry("with an non-base64 IDToken", newClaimExtractorTableInput{
|
||||
idToken: "{metadata}.{payload}.{signature}",
|
||||
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"),
|
||||
}),
|
||||
)
|
||||
|
||||
type getClaimTableInput struct {
|
||||
testClaimExtractorOpts
|
||||
claim string
|
||||
expectedValue interface{}
|
||||
expectExists bool
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("GetClaim",
|
||||
func(in getClaimTableInput) {
|
||||
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if serverClose != nil {
|
||||
defer serverClose()
|
||||
}
|
||||
|
||||
value, exists, err := claimExtractor.GetClaim(in.claim)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if in.expectedValue != nil {
|
||||
Expect(value).To(Equal(in.expectedValue))
|
||||
} else {
|
||||
Expect(value).To(BeNil())
|
||||
}
|
||||
|
||||
Expect(exists).To(Equal(in.expectExists))
|
||||
},
|
||||
Entry("retrieves a string claim from ID Token when present", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
expectExists: true,
|
||||
expectedValue: "idTokenUser",
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
expectExists: true,
|
||||
expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"},
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the empty string", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
},
|
||||
claim: "",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
},
|
||||
claim: "not_found",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "not_found",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: nil,
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "not_found",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the profile URL is unauthorized", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: emptyJSON,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: make(http.Header),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
|
||||
}),
|
||||
Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: emptyJSON,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
expectExists: true,
|
||||
expectedValue: "profileUser",
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string claim from a nested path", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: nestedClaimPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "auth.user.username",
|
||||
expectExists: true,
|
||||
expectedValue: "nestedUser",
|
||||
expectedError: nil,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
It("GetClaim should only call the profile URL once", func() {
|
||||
var counter int32
|
||||
countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
rw.Write([]byte(basicProfileURLPayload))
|
||||
}
|
||||
|
||||
claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{
|
||||
idTokenPayload: "{}",
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: countRequestsHandler,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if serverClose != nil {
|
||||
defer serverClose()
|
||||
}
|
||||
|
||||
value, exists, err := claimExtractor.GetClaim("user")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(value).To(Equal("profileUser"))
|
||||
Expect(counter).To(BeEquivalentTo(1))
|
||||
|
||||
// Check a different claim, but expect the count not to increase
|
||||
value, exists, err = claimExtractor.GetClaim("email")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(value).To(Equal("profileEmail"))
|
||||
Expect(counter).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
type getClaimIntoTableInput struct {
|
||||
testClaimExtractorOpts
|
||||
into interface{}
|
||||
claim string
|
||||
expectedValue interface{}
|
||||
expectExists bool
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("GetClaimInto",
|
||||
func(in getClaimIntoTableInput) {
|
||||
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if serverClose != nil {
|
||||
defer serverClose()
|
||||
}
|
||||
|
||||
exists, err := claimExtractor.GetClaimInto(in.claim, in.into)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if in.expectedValue != nil {
|
||||
Expect(in.into).To(Equal(in.expectedValue))
|
||||
} else {
|
||||
Expect(in.into).To(BeEmpty())
|
||||
}
|
||||
|
||||
Expect(exists).To(Equal(in.expectExists))
|
||||
},
|
||||
Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: stringPointer(""),
|
||||
expectExists: true,
|
||||
expectedValue: stringPointer("idTokenUser"),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: stringSlicePointer([]string{}),
|
||||
expectExists: true,
|
||||
expectedValue: stringSlicePointer([]string{"idTokenUser"}),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
into: stringSlicePointer([]string{}),
|
||||
expectExists: true,
|
||||
expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
into: stringPointer(""),
|
||||
expectExists: true,
|
||||
expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: "",
|
||||
expectExists: false,
|
||||
expectedValue: "",
|
||||
expectedError: errors.New("could no coerce claim: unknown type for destination: string"),
|
||||
}),
|
||||
Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: complexGroupsPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
into: stringSlicePointer([]string{}),
|
||||
expectExists: true,
|
||||
expectedValue: stringSlicePointer([]string{
|
||||
"{\"groupID\":\"group1\",\"roles\":[\"admin\"]}",
|
||||
"{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}",
|
||||
}),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "not_found",
|
||||
into: stringPointer(""),
|
||||
expectExists: false,
|
||||
expectedValue: stringPointer(""),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: emptyJSON,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: make(http.Header),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: stringPointer(""),
|
||||
expectExists: false,
|
||||
expectedValue: stringPointer(""),
|
||||
expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
|
||||
}),
|
||||
)
|
||||
|
||||
type coerceClaimTableInput struct {
|
||||
value interface{}
|
||||
dst interface{}
|
||||
expectedDst interface{}
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("coerceClaim",
|
||||
func(in coerceClaimTableInput) {
|
||||
err := coerceClaim(in.value, in.dst)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(in.dst).To(Equal(in.expectedDst))
|
||||
},
|
||||
Entry("coerces a string to a string", coerceClaimTableInput{
|
||||
value: "some_string",
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("some_string"),
|
||||
}),
|
||||
Entry("coerces a slice to a string slice", coerceClaimTableInput{
|
||||
value: []interface{}{"a", "b"},
|
||||
dst: stringSlicePointer([]string{}),
|
||||
expectedDst: stringSlicePointer([]string{"a", "b"}),
|
||||
}),
|
||||
Entry("coerces a bool to a bool", coerceClaimTableInput{
|
||||
value: true,
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
}),
|
||||
Entry("coerces a string to a bool", coerceClaimTableInput{
|
||||
value: "true",
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
}),
|
||||
Entry("coerces a map to a string", coerceClaimTableInput{
|
||||
value: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "baz"},
|
||||
},
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
// ******************************************
|
||||
// Helpers for setting up the claim extractor
|
||||
// ******************************************
|
||||
|
||||
type testClaimExtractorOpts struct {
|
||||
idTokenPayload string
|
||||
setProfileURL bool
|
||||
profileRequestHeaders http.Header
|
||||
profileRequestHandler http.HandlerFunc
|
||||
}
|
||||
|
||||
func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) {
|
||||
var profileURL *url.URL
|
||||
var closeServer func()
|
||||
if in.setProfileURL {
|
||||
server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler))
|
||||
closeServer = server.Close
|
||||
|
||||
var err error
|
||||
profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
rawIDToken := createJWTFromPayload(in.idTokenPayload)
|
||||
|
||||
claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders)
|
||||
return claimExtractor, closeServer, err
|
||||
}
|
||||
|
||||
func createJWTFromPayload(payload string) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON))
|
||||
payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header)
|
||||
}
|
||||
|
||||
func newAuthorizedHeader() http.Header {
|
||||
headers := make(http.Header)
|
||||
headers.Add("Authorization", "Bearer "+authorizedAccessToken)
|
||||
return headers
|
||||
}
|
||||
|
||||
func hasAuthorizedHeader(headers http.Header) bool {
|
||||
return headers.Get("Authorization") == "Bearer "+authorizedAccessToken
|
||||
}
|
||||
|
||||
// ***********************
|
||||
// Typed Pointer Functions
|
||||
// ***********************
|
||||
|
||||
func stringPointer(in string) *string {
|
||||
return &in
|
||||
}
|
||||
|
||||
func stringSlicePointer(in []string) *[]string {
|
||||
return &in
|
||||
}
|
||||
|
||||
func boolPointer(in bool) *bool {
|
||||
return &in
|
||||
}
|
||||
|
||||
// ******************************
|
||||
// Different profile URL handlers
|
||||
// ******************************
|
||||
|
||||
func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
Expect(true).To(BeFalse(), "Unexpected request to profile URL")
|
||||
}
|
||||
|
||||
func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
if !hasAuthorizedHeader(req.Header) {
|
||||
rw.WriteHeader(403)
|
||||
rw.Write([]byte("Unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
rw.Write([]byte(basicProfileURLPayload))
|
||||
}
|
17
pkg/providers/util/util_suite_test.go
Normal file
17
pkg/providers/util/util_suite_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestProviderUtilSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
logger.SetErrOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Provider Utils")
|
||||
}
|
@ -103,16 +103,17 @@ func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt
|
||||
}
|
||||
|
||||
func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error {
|
||||
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("could not extract claims: %v", err)
|
||||
}
|
||||
claims, err := p.getClaims(idToken)
|
||||
|
||||
upn, found, err := claims.GetClaim(adfsUPNClaim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||
return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err)
|
||||
}
|
||||
upn := claims.raw[adfsUPNClaim]
|
||||
if upn != nil {
|
||||
|
||||
if found && fmt.Sprint(upn) != "" {
|
||||
s.Email = fmt.Sprint(upn)
|
||||
}
|
||||
return nil
|
||||
|
@ -79,7 +79,7 @@ func testADFSBackend() *httptest.Server {
|
||||
{
|
||||
"access_token": "my_access_token",
|
||||
"id_token": "my_id_token",
|
||||
"refresh_token": "my_refresh_token"
|
||||
"refresh_token": "my_refresh_token"
|
||||
}
|
||||
`
|
||||
userInfo := `
|
||||
@ -150,9 +150,7 @@ var _ = Describe("ADFS Provider Tests", func() {
|
||||
Context("with valid token", func() {
|
||||
It("should not throw an error", func() {
|
||||
rawIDToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||
idToken, err := p.Verifier.Verify(context.Background(), rawIDToken)
|
||||
Expect(err).To(BeNil())
|
||||
session, err := p.buildSessionFromClaims(idToken)
|
||||
session, err := p.buildSessionFromClaims(rawIDToken, "")
|
||||
Expect(err).To(BeNil())
|
||||
session.IDToken = rawIDToken
|
||||
err = p.EnrichSession(context.Background(), session)
|
||||
|
@ -15,9 +15,20 @@ func CreateAuthorizedSession() *sessions.SessionState {
|
||||
}
|
||||
|
||||
func IsAuthorizedInHeader(reqHeader http.Header) bool {
|
||||
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", authorizedAccessToken)
|
||||
return IsAuthorizedInHeaderWithToken(reqHeader, authorizedAccessToken)
|
||||
}
|
||||
|
||||
func IsAuthorizedInHeaderWithToken(reqHeader http.Header, token string) bool {
|
||||
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", token)
|
||||
}
|
||||
|
||||
func IsAuthorizedInURL(reqURL *url.URL) bool {
|
||||
return reqURL.Query().Get("access_token") == authorizedAccessToken
|
||||
}
|
||||
|
||||
func isAuthorizedRefreshInURLWithToken(reqURL *url.URL, token string) bool {
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
return reqURL.Query().Get("refresh_token") == token
|
||||
}
|
||||
|
@ -78,6 +78,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
|
||||
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
|
||||
p.ValidateURL = p.ProfileURL
|
||||
}
|
||||
p.getAuthorizationHeaderFunc = makeAzureHeader
|
||||
|
||||
return &AzureProvider{
|
||||
ProviderData: p,
|
||||
@ -150,7 +151,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
||||
session.CreatedAtNow()
|
||||
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken, session.AccessToken)
|
||||
|
||||
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
|
||||
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
|
||||
@ -163,7 +164,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
||||
}
|
||||
|
||||
if session.Email == "" {
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken)
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken, session.AccessToken)
|
||||
if err == nil && email != "" {
|
||||
session.Email = email
|
||||
} else {
|
||||
@ -215,16 +216,16 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err
|
||||
|
||||
// verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token
|
||||
// when oidc verifier is configured
|
||||
func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) {
|
||||
func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, rawIDToken string, accessToken string) (string, error) {
|
||||
email := ""
|
||||
|
||||
if token != "" && p.Verifier != nil {
|
||||
token, err := p.Verifier.Verify(ctx, token)
|
||||
if rawIDToken != "" && p.Verifier != nil {
|
||||
_, err := p.Verifier.Verify(ctx, rawIDToken)
|
||||
// due to issues mentioned above, id_token may not be signed by AAD
|
||||
if err == nil {
|
||||
claims, err := p.getClaims(token)
|
||||
s, err := p.buildSessionFromClaims(rawIDToken, accessToken)
|
||||
if err == nil {
|
||||
email = claims.Email
|
||||
email = s.Email
|
||||
} else {
|
||||
logger.Printf("unable to get claims from token: %v", err)
|
||||
}
|
||||
@ -287,7 +288,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
||||
s.CreatedAtNow()
|
||||
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken, s.AccessToken)
|
||||
|
||||
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
|
||||
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
|
||||
@ -300,7 +301,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
||||
}
|
||||
|
||||
if s.Email == "" {
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken)
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken, s.AccessToken)
|
||||
if err == nil && email != "" {
|
||||
s.Email = email
|
||||
} else {
|
||||
|
@ -13,9 +13,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
|
||||
|
||||
@ -145,11 +144,11 @@ func TestAzureSetTenant(t *testing.T) {
|
||||
assert.Equal(t, "openid", p.Data().Scope)
|
||||
}
|
||||
|
||||
func testAzureBackend(payload string) *httptest.Server {
|
||||
return testAzureBackendWithError(payload, false)
|
||||
func testAzureBackend(payload string, accessToken, refreshToken string) *httptest.Server {
|
||||
return testAzureBackendWithError(payload, accessToken, refreshToken, false)
|
||||
}
|
||||
|
||||
func testAzureBackendWithError(payload string, injectError bool) *httptest.Server {
|
||||
func testAzureBackendWithError(payload string, accessToken, refreshToken string, injectError bool) *httptest.Server {
|
||||
path := "/v1.0/me"
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
@ -163,7 +162,8 @@ func testAzureBackendWithError(payload string, injectError bool) *httptest.Serve
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
w.Write([]byte(payload))
|
||||
} else if !IsAuthorizedInHeader(r.Header) {
|
||||
} else if !IsAuthorizedInHeaderWithToken(r.Header, accessToken) &&
|
||||
!isAuthorizedRefreshInURLWithToken(r.URL, refreshToken) {
|
||||
w.WriteHeader(403)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
@ -224,7 +224,7 @@ func TestAzureProviderEnrichSession(t *testing.T) {
|
||||
host string
|
||||
)
|
||||
if testCase.PayloadFromAzureBackend != "" {
|
||||
b = testAzureBackend(testCase.PayloadFromAzureBackend)
|
||||
b = testAzureBackend(testCase.PayloadFromAzureBackend, authorizedAccessToken, "")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
@ -319,7 +319,7 @@ func TestAzureProviderRedeem(t *testing.T) {
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError)
|
||||
b := testAzureBackendWithError(string(payloadBytes), accessTokenString, testCase.RefreshToken, testCase.InjectRedeemURLError)
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
@ -353,35 +353,44 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
|
||||
|
||||
func TestAzureProviderRefresh(t *testing.T) {
|
||||
email := "foo@example.com"
|
||||
subject := "foo"
|
||||
idToken := idTokenClaims{
|
||||
StandardClaims: jwt.StandardClaims{Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532"},
|
||||
Email: email}
|
||||
Email: email,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532",
|
||||
Subject: subject,
|
||||
},
|
||||
}
|
||||
idTokenString, err := newSignedTestIDToken(idToken)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z")
|
||||
assert.NoError(t, err)
|
||||
|
||||
newAccessToken := "new_some_access_token"
|
||||
payload := azureOAuthPayload{
|
||||
IDToken: idTokenString,
|
||||
RefreshToken: "new_some_refresh_token",
|
||||
AccessToken: "new_some_access_token",
|
||||
AccessToken: newAccessToken,
|
||||
ExpiresOn: timestamp.Unix(),
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
assert.NoError(t, err)
|
||||
b := testAzureBackend(string(payloadBytes))
|
||||
|
||||
refreshToken := "some_refresh_token"
|
||||
b := testAzureBackend(string(payloadBytes), newAccessToken, refreshToken)
|
||||
defer b.Close()
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: refreshToken, IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.True(t, refreshed)
|
||||
assert.NotEqual(t, session, nil)
|
||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||
assert.Equal(t, newAccessToken, session.AccessToken)
|
||||
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
||||
assert.Equal(t, idTokenString, session.IDToken)
|
||||
assert.Equal(t, email, session.Email)
|
||||
|
@ -57,6 +57,8 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider {
|
||||
validateURL: digitalOceanDefaultProfileURL,
|
||||
scope: digitalOceanDefaultScope,
|
||||
})
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
|
||||
return &DigitalOceanProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
|
@ -58,6 +58,7 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
|
||||
validateURL: facebookDefaultProfileURL,
|
||||
scope: facebookDefaultScope,
|
||||
})
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
return &FacebookProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
|
@ -65,6 +65,8 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
|
||||
validateURL: linkedinDefaultValidateURL,
|
||||
scope: linkedinDefaultScope,
|
||||
})
|
||||
p.getAuthorizationHeaderFunc = makeLinkedInHeader
|
||||
|
||||
return &LinkedInProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,5 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
)
|
||||
|
||||
// NextcloudProvider represents an Nextcloud based Identity Provider
|
||||
type NextcloudProvider struct {
|
||||
*ProviderData
|
||||
@ -20,20 +12,11 @@ const nextCloudProviderName = "Nextcloud"
|
||||
// NewNextcloudProvider initiates a new NextcloudProvider
|
||||
func NewNextcloudProvider(p *ProviderData) *NextcloudProvider {
|
||||
p.ProviderName = nextCloudProviderName
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
if p.EmailClaim == OIDCEmailClaim {
|
||||
// This implies the email claim has not been overridden, we should set a default
|
||||
// for this provider
|
||||
p.EmailClaim = "ocs.data.email"
|
||||
}
|
||||
return &NextcloudProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(makeOIDCHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error making request: %v", err)
|
||||
}
|
||||
|
||||
email, err := json.Get("ocs").Get("data").Get("email").String()
|
||||
return email, err
|
||||
}
|
||||
|
@ -1,18 +1,13 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const formatJSON = "format=json"
|
||||
const userPath = "/ocs/v2.php/cloud/user"
|
||||
|
||||
func testNextcloudProvider(hostname string) *NextcloudProvider {
|
||||
p := NewNextcloudProvider(
|
||||
@ -32,23 +27,6 @@ func testNextcloudProvider(hostname string) *NextcloudProvider {
|
||||
return p
|
||||
}
|
||||
|
||||
func testNextcloudBackend(payload string) *httptest.Server {
|
||||
path := userPath
|
||||
query := formatJSON
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != path || r.URL.RawQuery != query {
|
||||
w.WriteHeader(404)
|
||||
} else if !IsAuthorizedInHeader(r.Header) {
|
||||
w.WriteHeader(403)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(payload))
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func TestNextcloudProviderDefaults(t *testing.T) {
|
||||
p := testNextcloudProvider("")
|
||||
assert.NotEqual(t, nil, p)
|
||||
@ -87,53 +65,3 @@ func TestNextcloudProviderOverrides(t *testing.T) {
|
||||
assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON,
|
||||
p.Data().ValidateURL.String())
|
||||
}
|
||||
|
||||
func TestNextcloudProviderGetEmailAddress(t *testing.T) {
|
||||
b := testNextcloudBackend("{\"ocs\": {\"data\": { \"email\": \"michael.bland@gsa.gov\"}}}")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testNextcloudProvider(bURL.Host)
|
||||
p.ValidateURL.Path = userPath
|
||||
p.ValidateURL.RawQuery = formatJSON
|
||||
|
||||
session := CreateAuthorizedSession()
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
}
|
||||
|
||||
// Note that trying to trigger the "failed building request" case is not
|
||||
// practical, since the only way it can fail is if the URL fails to parse.
|
||||
func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
b := testNextcloudBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testNextcloudProvider(bURL.Host)
|
||||
p.ValidateURL.Path = userPath
|
||||
p.ValidateURL.RawQuery = formatJSON
|
||||
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
}
|
||||
|
||||
func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testNextcloudBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testNextcloudProvider(bURL.Host)
|
||||
p.ValidateURL.Path = userPath
|
||||
p.ValidateURL.RawQuery = formatJSON
|
||||
|
||||
session := CreateAuthorizedSession()
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
}
|
||||
|
@ -5,12 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -24,6 +22,8 @@ type OIDCProvider struct {
|
||||
// NewOIDCProvider initiates a new OIDCProvider
|
||||
func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
||||
p.ProviderName = "OpenID Connect"
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
|
||||
return &OIDCProvider{
|
||||
ProviderData: p,
|
||||
SkipNonce: true,
|
||||
@ -68,21 +68,6 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s
|
||||
// EnrichSession is called after Redeem to allow providers to enrich session fields
|
||||
// such as User, Email, Groups with provider specific API calls.
|
||||
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||
if p.ProfileURL.String() == "" {
|
||||
if s.Email == "" {
|
||||
return errors.New("id_token did not contain an email and profileURL is not defined")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to get missing emails or groups from a profileURL
|
||||
if s.Email == "" || s.Groups == nil {
|
||||
err := p.enrichFromProfileURL(ctx, s)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: Profile URL request failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If a mandatory email wasn't set, error at this point.
|
||||
if s.Email == "" {
|
||||
return errors.New("neither the id_token nor the profileURL set an email")
|
||||
@ -90,42 +75,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
|
||||
return nil
|
||||
}
|
||||
|
||||
// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
|
||||
// an OIDC profile URL
|
||||
func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
|
||||
respJSON, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(makeOIDCHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
email, err := respJSON.Get(p.EmailClaim).String()
|
||||
if err == nil && s.Email == "" {
|
||||
s.Email = email
|
||||
}
|
||||
|
||||
if len(s.Groups) > 0 {
|
||||
return nil
|
||||
}
|
||||
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
|
||||
formatted, err := formatGroup(group)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
reflect.TypeOf(group), err)
|
||||
continue
|
||||
}
|
||||
s.Groups = append(s.Groups, formatted)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSession checks that the session's IDToken is still valid
|
||||
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
if err != nil {
|
||||
logger.Errorf("id_token verification failed: %v", err)
|
||||
return false
|
||||
@ -134,7 +86,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
if p.SkipNonce {
|
||||
return true
|
||||
}
|
||||
err = p.checkNonce(s, idToken)
|
||||
err = p.checkNonce(s)
|
||||
if err != nil {
|
||||
logger.Errorf("nonce verification failed: %v", err)
|
||||
return false
|
||||
@ -212,7 +164,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss, err := p.buildSessionFromClaims(idToken)
|
||||
ss, err := p.buildSessionFromClaims(token, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -235,7 +187,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
|
||||
// createSession takes an oauth2.Token and creates a SessionState from it.
|
||||
// It alters behavior if called from Redeem vs Refresh
|
||||
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
|
||||
idToken, err := p.verifyIDToken(ctx, token)
|
||||
_, err := p.verifyIDToken(ctx, token)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrMissingIDToken:
|
||||
@ -248,14 +200,15 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
|
||||
}
|
||||
}
|
||||
|
||||
ss, err := p.buildSessionFromClaims(idToken)
|
||||
rawIDToken := getIDToken(token)
|
||||
ss, err := p.buildSessionFromClaims(rawIDToken, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss.AccessToken = token.AccessToken
|
||||
ss.RefreshToken = token.RefreshToken
|
||||
ss.IDToken = getIDToken(token)
|
||||
ss.IDToken = rawIDToken
|
||||
|
||||
ss.CreatedAtNow()
|
||||
ss.SetExpiresOn(token.Expiry)
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -54,6 +53,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
||||
Scope: "openid profile offline_access",
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
@ -142,333 +142,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
||||
assert.Equal(t, defaultIDToken.Phone, session.Email)
|
||||
}
|
||||
|
||||
func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ExistingSession *sessions.SessionState
|
||||
EmailClaim string
|
||||
GroupsClaim string
|
||||
ProfileJSON map[string]interface{}
|
||||
ExpectedError error
|
||||
ExpectedSession *sessions.SessionState
|
||||
}{
|
||||
"Already Populated": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "found@email.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "found@email.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
|
||||
"Missing Email Only in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "found@email.com",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "found@email.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email with Custom Claim": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "weird",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"weird": "weird@claim.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "weird@claim.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email not in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"new", "thing"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Complex Groups in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []map[string]interface{}{
|
||||
{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Singleton Complex Group in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Empty Groups Claims": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Custom Claim": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "roles",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"roles": []string{"new", "thing", "roles"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"new", "thing", "roles"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups String Profile URL Response": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": "singleton",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"singleton"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups in both Claims and Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
jsonResp, err := json.Marshal(tc.ProfileJSON)
|
||||
assert.NoError(t, err)
|
||||
|
||||
server, provider := newTestOIDCSetup(jsonResp)
|
||||
provider.ProfileURL, err = url.Parse(server.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
provider.EmailClaim = tc.EmailClaim
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
defer server.Close()
|
||||
|
||||
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
|
||||
assert.Equal(t, tc.ExpectedError, err)
|
||||
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||
|
||||
idToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||
@ -565,11 +238,15 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||
ExpectedGroups: []string{"test:c", "test:d"},
|
||||
},
|
||||
"Complex Groups Claim": {
|
||||
IDToken: complexGroupsIDToken,
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "complex@claims.com",
|
||||
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: complexGroupsIDToken,
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "complex@claims.com",
|
||||
ExpectedGroups: []string{
|
||||
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||
"12345",
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
|
@ -5,20 +5,23 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/util"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
OIDCEmailClaim = "email"
|
||||
OIDCGroupsClaim = "groups"
|
||||
// This is not exported as it's not currently user configurable
|
||||
oidcUserClaim = "sub"
|
||||
)
|
||||
|
||||
var OIDCAudienceClaims = []string{"aud"}
|
||||
@ -52,6 +55,8 @@ type ProviderData struct {
|
||||
// Universal Group authorization data structure
|
||||
// any provider can set to consume
|
||||
AllowedGroups map[string]struct{}
|
||||
|
||||
getAuthorizationHeaderFunc func(string) http.Header
|
||||
}
|
||||
|
||||
// Data returns the ProviderData
|
||||
@ -99,6 +104,10 @@ func (p *ProviderData) setProviderDefaults(defaults providerDefaults) {
|
||||
if p.Scope == "" {
|
||||
p.Scope = defaults.scope
|
||||
}
|
||||
|
||||
if p.UserClaim == "" {
|
||||
p.UserClaim = oidcUserClaim
|
||||
}
|
||||
}
|
||||
|
||||
// defaultURL will set return a default value if the given value is not set.
|
||||
@ -120,17 +129,6 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL {
|
||||
// OIDC compliant
|
||||
// ****************************************************************************
|
||||
|
||||
// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
|
||||
type OIDCClaims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"-"`
|
||||
Groups []string `json:"-"`
|
||||
Verified *bool `json:"email_verified"`
|
||||
Nonce string `json:"nonce"`
|
||||
|
||||
raw map[string]interface{}
|
||||
}
|
||||
|
||||
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
||||
rawIDToken := getIDToken(token)
|
||||
if strings.TrimSpace(rawIDToken) == "" {
|
||||
@ -144,110 +142,80 @@ func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (
|
||||
|
||||
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
|
||||
// with non-Token related fields.
|
||||
func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*sessions.SessionState, error) {
|
||||
ss := &sessions.SessionState{}
|
||||
|
||||
if idToken == nil {
|
||||
if rawIDToken == "" {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
claims, err := p.getClaims(idToken)
|
||||
extractor, err := p.getClaimExtractor(rawIDToken, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss.User = claims.Subject
|
||||
ss.Email = claims.Email
|
||||
ss.Groups = claims.Groups
|
||||
|
||||
// Allow specialized providers that embed OIDCProvider to control the User
|
||||
// claim. Not exposed as a configuration flag to generic OIDC provider
|
||||
// users (yet).
|
||||
if p.UserClaim != "" {
|
||||
user, ok := claims.raw[p.UserClaim].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim)
|
||||
// Use a slice of a struct (vs map) here in case the same claim is used twice
|
||||
for _, c := range []struct {
|
||||
claim string
|
||||
dst interface{}
|
||||
}{
|
||||
{p.UserClaim, &ss.User},
|
||||
{p.EmailClaim, &ss.Email},
|
||||
{p.GroupsClaim, &ss.Groups},
|
||||
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
|
||||
{"preferred_username", &ss.PreferredUsername},
|
||||
} {
|
||||
if _, err := extractor.GetClaimInto(c.claim, c.dst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ss.User = user
|
||||
}
|
||||
|
||||
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
|
||||
if pref, ok := claims.raw["preferred_username"].(string); ok {
|
||||
ss.PreferredUsername = pref
|
||||
}
|
||||
|
||||
// `email_verified` must be present and explicitly set to `false` to be
|
||||
// considered unverified.
|
||||
verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail
|
||||
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||
|
||||
var verified bool
|
||||
exists, err := extractor.GetClaimInto("email_verified", &verified)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if verifyEmail && exists && !verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", ss.Email)
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// getClaims extracts IDToken claims into an OIDCClaims
|
||||
func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
|
||||
claims := &OIDCClaims{}
|
||||
|
||||
// Extract default claims.
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
|
||||
}
|
||||
// Extract custom claims.
|
||||
if err := idToken.Claims(&claims.raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
||||
func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) {
|
||||
extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not initialise claim extractor: %v", err)
|
||||
}
|
||||
|
||||
email := claims.raw[p.EmailClaim]
|
||||
if email != nil {
|
||||
claims.Email = fmt.Sprint(email)
|
||||
}
|
||||
claims.Groups = p.extractGroups(claims.raw)
|
||||
|
||||
return claims, nil
|
||||
return extractor, nil
|
||||
}
|
||||
|
||||
// checkNonce compares the session's nonce with the IDToken's nonce claim
|
||||
func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error {
|
||||
claims, err := p.getClaims(idToken)
|
||||
func (p *ProviderData) checkNonce(s *sessions.SessionState) error {
|
||||
extractor, err := p.getClaimExtractor(s.IDToken, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("id_token claims extraction failed: %v", err)
|
||||
}
|
||||
if !s.CheckNonce(claims.Nonce) {
|
||||
var nonce string
|
||||
if _, err := extractor.GetClaimInto("nonce", &nonce); err != nil {
|
||||
return fmt.Errorf("could not extract nonce from ID Token: %v", err)
|
||||
}
|
||||
|
||||
if !s.CheckNonce(nonce) {
|
||||
return errors.New("id_token nonce claim does not match the session nonce")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractGroups extracts groups from a claim to a list in a type safe manner.
|
||||
// If the claim isn't present, `nil` is returned. If the groups claim is
|
||||
// present but empty, `[]string{}` is returned.
|
||||
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
|
||||
rawClaim, ok := claims[p.GroupsClaim]
|
||||
if !ok {
|
||||
return nil
|
||||
func (p *ProviderData) getAuthorizationHeader(accessToken string) http.Header {
|
||||
if p.getAuthorizationHeaderFunc != nil && accessToken != "" {
|
||||
return p.getAuthorizationHeaderFunc(accessToken)
|
||||
}
|
||||
|
||||
// Handle traditional list-based groups as well as non-standard singleton
|
||||
// based groups. Both variants support complex objects if needed.
|
||||
var claimGroups []interface{}
|
||||
switch raw := rawClaim.(type) {
|
||||
case []interface{}:
|
||||
claimGroups = raw
|
||||
case interface{}:
|
||||
claimGroups = []interface{}{raw}
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
for _, rawGroup := range claimGroups {
|
||||
formattedGroup, err := formatGroup(rawGroup)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
reflect.TypeOf(rawGroup), err)
|
||||
continue
|
||||
}
|
||||
groups = append(groups, formattedGroup)
|
||||
}
|
||||
return groups
|
||||
return nil
|
||||
}
|
||||
|
@ -60,16 +60,30 @@ var (
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
|
||||
numericGroupsIDToken = idTokenClaims{
|
||||
Name: "Jane Dobbs",
|
||||
Email: "janed@me.com",
|
||||
Phone: "+4798765432",
|
||||
Picture: "http://mugbook.com/janed/me.jpg",
|
||||
Groups: []interface{}{1, 2, 3},
|
||||
Roles: []string{"test:c", "test:d"},
|
||||
Verified: &verified,
|
||||
Nonce: encryption.HashNonce([]byte(oidcNonce)),
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
|
||||
complexGroupsIDToken = idTokenClaims{
|
||||
Name: "Complex Claim",
|
||||
Email: "complex@claims.com",
|
||||
Phone: "+5439871234",
|
||||
Picture: "http://mugbook.com/complex/claims.jpg",
|
||||
Groups: []map[string]interface{}{
|
||||
{
|
||||
Groups: []interface{}{
|
||||
map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
12345,
|
||||
"Just::A::String",
|
||||
},
|
||||
Roles: []string{"test:simple", "test:roles"},
|
||||
Verified: &verified,
|
||||
@ -228,6 +242,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
@ -247,6 +262,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "unverified@email.com",
|
||||
@ -259,10 +275,15 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "complex@claims.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
User: "123456789",
|
||||
Email: "complex@claims.com",
|
||||
Groups: []string{
|
||||
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||
"12345",
|
||||
"Just::A::String",
|
||||
},
|
||||
PreferredUsername: "Complex Claim",
|
||||
},
|
||||
},
|
||||
@ -279,19 +300,25 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"User Claim Invalid": {
|
||||
"User Claim switched to non string": {
|
||||
IDToken: defaultIDToken,
|
||||
AllowUnverified: true,
|
||||
UserClaim: "groups",
|
||||
UserClaim: "roles",
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedError: errors.New("unable to extract custom UserClaim (groups)"),
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "[\"test:c\",\"test:d\"]",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Email Claim Switched": {
|
||||
IDToken: unverifiedIDToken,
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "phone_number",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "+4025205729",
|
||||
@ -304,9 +331,10 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "roles",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "[test:c test:d]",
|
||||
Email: "[\"test:c\",\"test:d\"]",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Mystery Man",
|
||||
},
|
||||
@ -316,6 +344,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "aksjdfhjksadh",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "",
|
||||
@ -328,6 +357,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "roles",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
@ -340,6 +370,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "alskdjfsalkdjf",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
@ -347,6 +378,32 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Groups Claim Numeric values": {
|
||||
IDToken: numericGroupsIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"1", "2", "3"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Groups Claim string values": {
|
||||
IDToken: defaultIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "email",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"janed@me.com"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
@ -371,10 +428,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ss, err := provider.buildSessionFromClaims(idToken)
|
||||
ss, err := provider.buildSessionFromClaims(rawIDToken, "")
|
||||
if err != nil {
|
||||
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||
}
|
||||
@ -418,6 +472,12 @@ func TestProviderData_checkNonce(t *testing.T) {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
// Ensure that the ID token in the session is valid (signed and contains a nonce)
|
||||
// as the nonce claim is extracted to compare with the session nonce
|
||||
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
tc.Session.IDToken = rawIDToken
|
||||
|
||||
verificationOptions := &internaloidc.IDTokenVerificationOptions{
|
||||
AudienceClaims: []string{"aud"},
|
||||
ClientID: oidcClientID,
|
||||
@ -430,14 +490,7 @@ func TestProviderData_checkNonce(t *testing.T) {
|
||||
), verificationOptions),
|
||||
}
|
||||
|
||||
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = provider.checkNonce(tc.Session, idToken)
|
||||
if err != nil {
|
||||
if err := provider.checkNonce(tc.Session); err != nil {
|
||||
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||
} else {
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
@ -445,95 +498,3 @@ func TestProviderData_checkNonce(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderData_extractGroups(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
Claims map[string]interface{}
|
||||
GroupsClaim string
|
||||
ExpectedGroups []string
|
||||
}{
|
||||
"Standard String Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{"three", "string", "groups"},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"three", "string", "groups"},
|
||||
},
|
||||
"Different Claim Name": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"roles": []interface{}{"three", "string", "roles"},
|
||||
},
|
||||
GroupsClaim: "roles",
|
||||
ExpectedGroups: []string{"three", "string", "roles"},
|
||||
},
|
||||
"Numeric Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{1, 2, 3},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"1", "2", "3"},
|
||||
},
|
||||
"Complex Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{
|
||||
map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
12345,
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{
|
||||
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||
"12345",
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
"Missing Groups Claim Returns Nil": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: nil,
|
||||
},
|
||||
"Non List Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": "singleton",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"singleton"},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
verificationOptions := &internaloidc.IDTokenVerificationOptions{
|
||||
AudienceClaims: []string{"aud"},
|
||||
ClientID: oidcClientID,
|
||||
}
|
||||
provider := &ProviderData{
|
||||
Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
&oidc.Config{ClientID: oidcClientID},
|
||||
), verificationOptions),
|
||||
}
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
|
||||
groups := provider.extractGroups(tc.Claims)
|
||||
if tc.ExpectedGroups != nil {
|
||||
g.Expect(groups).To(Equal(tc.ExpectedGroups))
|
||||
} else {
|
||||
g.Expect(groups).To(BeNil())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -83,18 +82,3 @@ func formatGroup(rawGroup interface{}) (string, error) {
|
||||
}
|
||||
return string(jsonGroup), nil
|
||||
}
|
||||
|
||||
// coerceArray extracts a field from simplejson.Json that might be a
|
||||
// singleton or a list and coerces it into a list.
|
||||
func coerceArray(sj *simplejson.Json, key string) []interface{} {
|
||||
array, err := sj.Get(key).Array()
|
||||
if err == nil {
|
||||
return array
|
||||
}
|
||||
|
||||
single := sj.Get(key).Interface()
|
||||
if single == nil {
|
||||
return nil
|
||||
}
|
||||
return []interface{}{single}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user