mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-05-29 23:17:38 +02:00
Add claim extractor provider util
This commit is contained in:
parent
44dc3cad77
commit
537e596904
1
go.mod
1
go.mod
@ -23,6 +23,7 @@ require (
|
|||||||
github.com/onsi/gomega v1.10.2
|
github.com/onsi/gomega v1.10.2
|
||||||
github.com/pierrec/lz4 v2.5.2+incompatible
|
github.com/pierrec/lz4 v2.5.2+incompatible
|
||||||
github.com/prometheus/client_golang v1.9.0
|
github.com/prometheus/client_golang v1.9.0
|
||||||
|
github.com/spf13/cast v1.3.0
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/spf13/viper v1.6.3
|
github.com/spf13/viper v1.6.3
|
||||||
github.com/stretchr/testify v1.6.1
|
github.com/stretchr/testify v1.6.1
|
||||||
|
210
pkg/providers/util/claim_extractor.go
Normal file
210
pkg/providers/util/claim_extractor.go
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitly/go-simplejson"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||||
|
"github.com/spf13/cast"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
|
||||||
|
// present, from the profile URL.
|
||||||
|
type ClaimExtractor interface {
|
||||||
|
// GetClaim fetches a named claim and returns the value.
|
||||||
|
GetClaim(claim string) (interface{}, bool, error)
|
||||||
|
|
||||||
|
// GetClaimInto fetches a named claim and puts the value into the destination.
|
||||||
|
GetClaimInto(claim string, dst interface{}) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
|
||||||
|
// If needed, it will use the profile URL to look up a claim if it isn't present
|
||||||
|
// within the ID Token.
|
||||||
|
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
|
||||||
|
payload, err := parseJWT(idToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse ID Token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenClaims, err := simplejson.NewJson(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &claimExtractor{
|
||||||
|
ctx: ctx,
|
||||||
|
profileURL: profileURL,
|
||||||
|
requestHeaders: profileRequestHeaders,
|
||||||
|
tokenClaims: tokenClaims,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// claimExtractor implements the ClaimExtractor interface
|
||||||
|
type claimExtractor struct {
|
||||||
|
profileURL *url.URL
|
||||||
|
ctx context.Context
|
||||||
|
requestHeaders map[string][]string
|
||||||
|
tokenClaims *simplejson.Json
|
||||||
|
profileClaims *simplejson.Json
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaim will return the value claim if it exists.
|
||||||
|
// It will only return an error if the profile URL needs to be fetched due to
|
||||||
|
// the claim not being present in the ID Token.
|
||||||
|
func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
|
||||||
|
if claim == "" {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := getClaimFrom(claim, c.tokenClaims); value != nil {
|
||||||
|
return value, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.profileClaims == nil {
|
||||||
|
profileClaims, err := c.loadProfileClaims()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.profileClaims = profileClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := getClaimFrom(claim, c.profileClaims); value != nil {
|
||||||
|
return value, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadProfileClaims will fetch the profileURL using the provided headers as
|
||||||
|
// authentication.
|
||||||
|
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
|
||||||
|
if c.profileURL == nil || c.requestHeaders == nil {
|
||||||
|
// When no profileURL is set, we return a non-empty map so that
|
||||||
|
// we don't attempt to populate the profile claims again.
|
||||||
|
// If there are no headers, the request would be unauthorized so we also skip
|
||||||
|
// in this case too.
|
||||||
|
return simplejson.New(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := requests.New(c.profileURL.String()).
|
||||||
|
WithContext(c.ctx).
|
||||||
|
WithHeaders(c.requestHeaders).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error making request to profile URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaimInto loads a claim and places it into the destination interface.
|
||||||
|
// This will attempt to coerce the claim into the specified type.
|
||||||
|
// If it cannot be coerced, an error may be returned.
|
||||||
|
func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
|
||||||
|
value, exists, err := c.GetClaim(claim)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("could not get claim %q: %v", claim, err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err := coerceClaim(value, dst); err != nil {
|
||||||
|
return false, fmt.Errorf("could no coerce claim: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
|
||||||
|
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
|
||||||
|
func parseJWT(p string) ([]byte, error) {
|
||||||
|
parts := strings.Split(p, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
|
||||||
|
}
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||||
|
}
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClaimFrom gets a claim from a Json object.
|
||||||
|
// It can accept either a single claim name or a json path.
|
||||||
|
// Paths with indexes are not supported.
|
||||||
|
func getClaimFrom(claim string, src *simplejson.Json) interface{} {
|
||||||
|
claimParts := strings.Split(claim, ".")
|
||||||
|
return src.GetPath(claimParts...).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// coerceClaim tries to convert the value into the destination interface type.
|
||||||
|
// If it can convert the value, it will then store the value in the destination
|
||||||
|
// interface.
|
||||||
|
func coerceClaim(value, dst interface{}) error {
|
||||||
|
switch d := dst.(type) {
|
||||||
|
case *string:
|
||||||
|
str, err := toString(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not convert value to string: %v", err)
|
||||||
|
}
|
||||||
|
*d = str
|
||||||
|
case *[]string:
|
||||||
|
strSlice, err := toStringSlice(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not convert value to string slice: %v", err)
|
||||||
|
}
|
||||||
|
*d = strSlice
|
||||||
|
case *bool:
|
||||||
|
*d = cast.ToBool(value)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown type for destination: %T", dst)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toStringSlice converts an interface (either a slice or single value) into
|
||||||
|
// a slice of strings.
|
||||||
|
func toStringSlice(value interface{}) ([]string, error) {
|
||||||
|
var sliceValues []interface{}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []interface{}:
|
||||||
|
sliceValues = v
|
||||||
|
case interface{}:
|
||||||
|
sliceValues = []interface{}{v}
|
||||||
|
default:
|
||||||
|
sliceValues = cast.ToSlice(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := []string{}
|
||||||
|
for _, v := range sliceValues {
|
||||||
|
str, err := toString(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
|
||||||
|
}
|
||||||
|
out = append(out, str)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toString coerces a value into a string.
|
||||||
|
// If it is non-string, marshal it into JSON.
|
||||||
|
func toString(value interface{}) (string, error) {
|
||||||
|
if str, err := cast.ToStringE(value); err == nil {
|
||||||
|
return str, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonStr, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(jsonStr), nil
|
||||||
|
}
|
530
pkg/providers/util/claim_extractor_test.go
Normal file
530
pkg/providers/util/claim_extractor_test.go
Normal file
@ -0,0 +1,530 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
emptyJSON = "{}"
|
||||||
|
profilePath = "/userinfo"
|
||||||
|
authorizedAccessToken = "valid_access_token"
|
||||||
|
basicIDTokenPayload = `{
|
||||||
|
"user": "idTokenUser",
|
||||||
|
"email": "idTokenEmail",
|
||||||
|
"groups": [
|
||||||
|
"idTokenGroup1",
|
||||||
|
"idTokenGroup2"
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
basicProfileURLPayload = `{
|
||||||
|
"user": "profileUser",
|
||||||
|
"email": "profileEmail",
|
||||||
|
"groups": [
|
||||||
|
"profileGroup1",
|
||||||
|
"profileGroup2"
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
nestedClaimPayload = `{
|
||||||
|
"auth": {
|
||||||
|
"user": {
|
||||||
|
"username": "nestedUser"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
complexGroupsPayload = `{
|
||||||
|
"groups": [
|
||||||
|
{
|
||||||
|
"groupID": "group1",
|
||||||
|
"roles": ["admin"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"groupID": "group2",
|
||||||
|
"roles": ["user", "employee"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Claim Extractor Suite", func() {
|
||||||
|
Context("Claim Extractor", func() {
|
||||||
|
type newClaimExtractorTableInput struct {
|
||||||
|
idToken string
|
||||||
|
expectedError error
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("NewClaimExtractor",
|
||||||
|
func(in newClaimExtractorTableInput) {
|
||||||
|
_, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil)
|
||||||
|
if in.expectedError != nil {
|
||||||
|
Expect(err).To(MatchError(in.expectedError))
|
||||||
|
} else {
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Entry("with a valid JWT", newClaimExtractorTableInput{
|
||||||
|
idToken: createJWTFromPayload(basicIDTokenPayload),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{
|
||||||
|
idToken: createJWTFromPayload("this is not JSON"),
|
||||||
|
expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"),
|
||||||
|
}),
|
||||||
|
Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{
|
||||||
|
idToken: "eyJeyJ",
|
||||||
|
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"),
|
||||||
|
}),
|
||||||
|
Entry("with an non-base64 IDToken", newClaimExtractorTableInput{
|
||||||
|
idToken: "{metadata}.{payload}.{signature}",
|
||||||
|
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
type getClaimTableInput struct {
|
||||||
|
testClaimExtractorOpts
|
||||||
|
claim string
|
||||||
|
expectedValue interface{}
|
||||||
|
expectExists bool
|
||||||
|
expectedError error
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("GetClaim",
|
||||||
|
func(in getClaimTableInput) {
|
||||||
|
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if serverClose != nil {
|
||||||
|
defer serverClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists, err := claimExtractor.GetClaim(in.claim)
|
||||||
|
if in.expectedError != nil {
|
||||||
|
Expect(err).To(MatchError(in.expectedError))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if in.expectedValue != nil {
|
||||||
|
Expect(value).To(Equal(in.expectedValue))
|
||||||
|
} else {
|
||||||
|
Expect(value).To(BeNil())
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(exists).To(Equal(in.expectExists))
|
||||||
|
},
|
||||||
|
Entry("retrieves a string claim from ID Token when present", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "user",
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: "idTokenUser",
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "groups",
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"},
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("when the requested claim is the empty string", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
},
|
||||||
|
claim: "",
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: nil,
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
},
|
||||||
|
claim: "not_found",
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: nil,
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: requiresAuthProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "not_found",
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: nil,
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: nil,
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "not_found",
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: nil,
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("when the profile URL is unauthorized", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: emptyJSON,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: make(http.Header),
|
||||||
|
profileRequestHandler: requiresAuthProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "user",
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: nil,
|
||||||
|
expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
|
||||||
|
}),
|
||||||
|
Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: emptyJSON,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: requiresAuthProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "user",
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: "profileUser",
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("retrieves a string claim from a nested path", getClaimTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: nestedClaimPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "auth.user.username",
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: "nestedUser",
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("GetClaim should only call the profile URL once", func() {
|
||||||
|
var counter int32
|
||||||
|
countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) {
|
||||||
|
atomic.AddInt32(&counter, 1)
|
||||||
|
rw.Write([]byte(basicProfileURLPayload))
|
||||||
|
}
|
||||||
|
|
||||||
|
claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{
|
||||||
|
idTokenPayload: "{}",
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: countRequestsHandler,
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if serverClose != nil {
|
||||||
|
defer serverClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists, err := claimExtractor.GetClaim("user")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(exists).To(BeTrue())
|
||||||
|
Expect(value).To(Equal("profileUser"))
|
||||||
|
Expect(counter).To(BeEquivalentTo(1))
|
||||||
|
|
||||||
|
// Check a different claim, but expect the count not to increase
|
||||||
|
value, exists, err = claimExtractor.GetClaim("email")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(exists).To(BeTrue())
|
||||||
|
Expect(value).To(Equal("profileEmail"))
|
||||||
|
Expect(counter).To(BeEquivalentTo(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
type getClaimIntoTableInput struct {
|
||||||
|
testClaimExtractorOpts
|
||||||
|
into interface{}
|
||||||
|
claim string
|
||||||
|
expectedValue interface{}
|
||||||
|
expectExists bool
|
||||||
|
expectedError error
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("GetClaimInto",
|
||||||
|
func(in getClaimIntoTableInput) {
|
||||||
|
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if serverClose != nil {
|
||||||
|
defer serverClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := claimExtractor.GetClaimInto(in.claim, in.into)
|
||||||
|
if in.expectedError != nil {
|
||||||
|
Expect(err).To(MatchError(in.expectedError))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if in.expectedValue != nil {
|
||||||
|
Expect(in.into).To(Equal(in.expectedValue))
|
||||||
|
} else {
|
||||||
|
Expect(in.into).To(BeEmpty())
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(exists).To(Equal(in.expectExists))
|
||||||
|
},
|
||||||
|
Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "user",
|
||||||
|
into: stringPointer(""),
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: stringPointer("idTokenUser"),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "user",
|
||||||
|
into: stringSlicePointer([]string{}),
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: stringSlicePointer([]string{"idTokenUser"}),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "groups",
|
||||||
|
into: stringSlicePointer([]string{}),
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "groups",
|
||||||
|
into: stringPointer(""),
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "user",
|
||||||
|
into: "",
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: "",
|
||||||
|
expectedError: errors.New("could no coerce claim: unknown type for destination: string"),
|
||||||
|
}),
|
||||||
|
Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: complexGroupsPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "groups",
|
||||||
|
into: stringSlicePointer([]string{}),
|
||||||
|
expectExists: true,
|
||||||
|
expectedValue: stringSlicePointer([]string{
|
||||||
|
"{\"groupID\":\"group1\",\"roles\":[\"admin\"]}",
|
||||||
|
"{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}",
|
||||||
|
}),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: basicIDTokenPayload,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: newAuthorizedHeader(),
|
||||||
|
profileRequestHandler: requiresAuthProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "not_found",
|
||||||
|
into: stringPointer(""),
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: stringPointer(""),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{
|
||||||
|
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||||
|
idTokenPayload: emptyJSON,
|
||||||
|
setProfileURL: true,
|
||||||
|
profileRequestHeaders: make(http.Header),
|
||||||
|
profileRequestHandler: requiresAuthProfileHandler,
|
||||||
|
},
|
||||||
|
claim: "user",
|
||||||
|
into: stringPointer(""),
|
||||||
|
expectExists: false,
|
||||||
|
expectedValue: stringPointer(""),
|
||||||
|
expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
type coerceClaimTableInput struct {
|
||||||
|
value interface{}
|
||||||
|
dst interface{}
|
||||||
|
expectedDst interface{}
|
||||||
|
expectedError error
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("coerceClaim",
|
||||||
|
func(in coerceClaimTableInput) {
|
||||||
|
err := coerceClaim(in.value, in.dst)
|
||||||
|
if in.expectedError != nil {
|
||||||
|
Expect(err).To(MatchError(in.expectedError))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(in.dst).To(Equal(in.expectedDst))
|
||||||
|
},
|
||||||
|
Entry("coerces a string to a string", coerceClaimTableInput{
|
||||||
|
value: "some_string",
|
||||||
|
dst: stringPointer(""),
|
||||||
|
expectedDst: stringPointer("some_string"),
|
||||||
|
}),
|
||||||
|
Entry("coerces a slice to a string slice", coerceClaimTableInput{
|
||||||
|
value: []interface{}{"a", "b"},
|
||||||
|
dst: stringSlicePointer([]string{}),
|
||||||
|
expectedDst: stringSlicePointer([]string{"a", "b"}),
|
||||||
|
}),
|
||||||
|
Entry("coerces a bool to a bool", coerceClaimTableInput{
|
||||||
|
value: true,
|
||||||
|
dst: boolPointer(false),
|
||||||
|
expectedDst: boolPointer(true),
|
||||||
|
}),
|
||||||
|
Entry("coerces a string to a bool", coerceClaimTableInput{
|
||||||
|
value: "true",
|
||||||
|
dst: boolPointer(false),
|
||||||
|
expectedDst: boolPointer(true),
|
||||||
|
}),
|
||||||
|
Entry("coerces a map to a string", coerceClaimTableInput{
|
||||||
|
value: map[string]interface{}{
|
||||||
|
"foo": []interface{}{"bar", "baz"},
|
||||||
|
},
|
||||||
|
dst: stringPointer(""),
|
||||||
|
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// ******************************************
|
||||||
|
// Helpers for setting up the claim extractor
|
||||||
|
// ******************************************
|
||||||
|
|
||||||
|
type testClaimExtractorOpts struct {
|
||||||
|
idTokenPayload string
|
||||||
|
setProfileURL bool
|
||||||
|
profileRequestHeaders http.Header
|
||||||
|
profileRequestHandler http.HandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) {
|
||||||
|
var profileURL *url.URL
|
||||||
|
var closeServer func()
|
||||||
|
if in.setProfileURL {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler))
|
||||||
|
closeServer = server.Close
|
||||||
|
|
||||||
|
var err error
|
||||||
|
profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
|
||||||
|
rawIDToken := createJWTFromPayload(in.idTokenPayload)
|
||||||
|
|
||||||
|
claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders)
|
||||||
|
return claimExtractor, closeServer, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func createJWTFromPayload(payload string) string {
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON))
|
||||||
|
payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthorizedHeader() http.Header {
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Add("Authorization", "Bearer "+authorizedAccessToken)
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasAuthorizedHeader(headers http.Header) bool {
|
||||||
|
return headers.Get("Authorization") == "Bearer "+authorizedAccessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// ***********************
|
||||||
|
// Typed Pointer Functions
|
||||||
|
// ***********************
|
||||||
|
|
||||||
|
func stringPointer(in string) *string {
|
||||||
|
return &in
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSlicePointer(in []string) *[]string {
|
||||||
|
return &in
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolPointer(in bool) *bool {
|
||||||
|
return &in
|
||||||
|
}
|
||||||
|
|
||||||
|
// ******************************
|
||||||
|
// Different profile URL handlers
|
||||||
|
// ******************************
|
||||||
|
|
||||||
|
func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(true).To(BeFalse(), "Unexpected request to profile URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if !hasAuthorizedHeader(req.Header) {
|
||||||
|
rw.WriteHeader(403)
|
||||||
|
rw.Write([]byte("Unauthorized"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.Write([]byte(basicProfileURLPayload))
|
||||||
|
}
|
17
pkg/providers/util/util_suite_test.go
Normal file
17
pkg/providers/util/util_suite_test.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProviderUtilSuite(t *testing.T) {
|
||||||
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
logger.SetErrOutput(GinkgoWriter)
|
||||||
|
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Provider Utils")
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user