1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-26 05:27:28 +02:00
oauth2-proxy/pkg/providers/util/claim_extractor_test.go

549 lines
17 KiB
Go

package util
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
const (
emptyJSON = "{}"
profilePath = "/userinfo"
authorizedAccessToken = "valid_access_token"
basicIDTokenPayload = `{
"user": "idTokenUser",
"email": "idTokenEmail",
"groups": [
"idTokenGroup1",
"idTokenGroup2"
]
}`
basicProfileURLPayload = `{
"user": "profileUser",
"email": "profileEmail",
"groups": [
"profileGroup1",
"profileGroup2"
]
}`
nestedClaimPayload = `{
"auth": {
"user": {
"username": "nestedUser"
}
}
}`
complexGroupsPayload = `{
"groups": [
{
"groupID": "group1",
"roles": ["admin"]
},
{
"groupID": "group2",
"roles": ["user", "employee"]
}
]
}`
)
var _ = Describe("Claim Extractor Suite", func() {
Context("Claim Extractor", func() {
type newClaimExtractorTableInput struct {
idToken string
expectedError error
}
DescribeTable("NewClaimExtractor",
func(in newClaimExtractorTableInput) {
_, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
} else {
Expect(err).ToNot(HaveOccurred())
}
},
Entry("with a valid JWT", newClaimExtractorTableInput{
idToken: createJWTFromPayload(basicIDTokenPayload),
expectedError: nil,
}),
Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{
idToken: createJWTFromPayload("this is not JSON"),
expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"),
}),
Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{
idToken: "eyJeyJ",
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"),
}),
Entry("with an non-base64 IDToken", newClaimExtractorTableInput{
idToken: "{metadata}.{payload}.{signature}",
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"),
}),
)
type getClaimTableInput struct {
testClaimExtractorOpts
claim string
expectedValue interface{}
expectExists bool
expectedError error
}
DescribeTable("GetClaim",
func(in getClaimTableInput) {
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
value, exists, err := claimExtractor.GetClaim(in.claim)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
if in.expectedValue != nil {
Expect(value).To(Equal(in.expectedValue))
} else {
Expect(value).To(BeNil())
}
Expect(exists).To(Equal(in.expectExists))
},
Entry("retrieves a string claim from ID Token when present", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
expectExists: true,
expectedValue: "idTokenUser",
expectedError: nil,
}),
Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
expectExists: true,
expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"},
expectedError: nil,
}),
Entry("when the requested claim is the empty string", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
},
claim: "",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
profileRequestHeaders: newAuthorizedHeader(),
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: nil,
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the profile URL is unauthorized", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: make(http.Header),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
expectExists: false,
expectedValue: nil,
expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
}),
Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
expectExists: true,
expectedValue: "profileUser",
expectedError: nil,
}),
Entry("retrieves a string claim from a nested path", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: nestedClaimPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "auth.user.username",
expectExists: true,
expectedValue: "nestedUser",
expectedError: nil,
}),
)
})
It("GetClaim should only call the profile URL once", func() {
var counter int32
countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&counter, 1)
rw.Write([]byte(basicProfileURLPayload))
}
claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{
idTokenPayload: "{}",
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: countRequestsHandler,
})
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
value, exists, err := claimExtractor.GetClaim("user")
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
Expect(value).To(Equal("profileUser"))
Expect(counter).To(BeEquivalentTo(1))
// Check a different claim, but expect the count not to increase
value, exists, err = claimExtractor.GetClaim("email")
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
Expect(value).To(Equal("profileEmail"))
Expect(counter).To(BeEquivalentTo(1))
})
It("GetClaim should not return an error with a non-nil empty ProfileURL", func() {
claims, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{
idTokenPayload: "{}",
profileRequestHeaders: newAuthorizedHeader(),
})
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
// Set the ProfileURL to be empty, but not nil
claims.(*claimExtractor).profileURL = &url.URL{}
value, exists, err := claims.GetClaim("user")
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeFalse())
Expect(value).To(BeNil())
})
type getClaimIntoTableInput struct {
testClaimExtractorOpts
into interface{}
claim string
expectedValue interface{}
expectExists bool
expectedError error
}
DescribeTable("GetClaimInto",
func(in getClaimIntoTableInput) {
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
exists, err := claimExtractor.GetClaimInto(in.claim, in.into)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
if in.expectedValue != nil {
Expect(in.into).To(Equal(in.expectedValue))
} else {
Expect(in.into).To(BeEmpty())
}
Expect(exists).To(Equal(in.expectExists))
},
Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: stringPointer(""),
expectExists: true,
expectedValue: stringPointer("idTokenUser"),
expectedError: nil,
}),
Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{"idTokenUser"}),
expectedError: nil,
}),
Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}),
expectedError: nil,
}),
Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringPointer(""),
expectExists: true,
expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"),
expectedError: nil,
}),
Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: "",
expectExists: false,
expectedValue: "",
expectedError: errors.New("could no coerce claim: unknown type for destination: string"),
}),
Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: complexGroupsPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{
"{\"groupID\":\"group1\",\"roles\":[\"admin\"]}",
"{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}",
}),
expectedError: nil,
}),
Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "not_found",
into: stringPointer(""),
expectExists: false,
expectedValue: stringPointer(""),
expectedError: nil,
}),
Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: make(http.Header),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
into: stringPointer(""),
expectExists: false,
expectedValue: stringPointer(""),
expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
}),
)
type coerceClaimTableInput struct {
value interface{}
dst interface{}
expectedDst interface{}
expectedError error
}
DescribeTable("coerceClaim",
func(in coerceClaimTableInput) {
err := coerceClaim(in.value, in.dst)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
Expect(in.dst).To(Equal(in.expectedDst))
},
Entry("coerces a string to a string", coerceClaimTableInput{
value: "some_string",
dst: stringPointer(""),
expectedDst: stringPointer("some_string"),
}),
Entry("coerces a slice to a string slice", coerceClaimTableInput{
value: []interface{}{"a", "b"},
dst: stringSlicePointer([]string{}),
expectedDst: stringSlicePointer([]string{"a", "b"}),
}),
Entry("coerces a bool to a bool", coerceClaimTableInput{
value: true,
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a string to a bool", coerceClaimTableInput{
value: "true",
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a map to a string", coerceClaimTableInput{
value: map[string]interface{}{
"foo": []interface{}{"bar", "baz"},
},
dst: stringPointer(""),
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
}),
)
})
// ******************************************
// Helpers for setting up the claim extractor
// ******************************************
type testClaimExtractorOpts struct {
idTokenPayload string
setProfileURL bool
profileRequestHeaders http.Header
profileRequestHandler http.HandlerFunc
}
func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) {
var profileURL *url.URL
var closeServer func()
if in.setProfileURL {
server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler))
closeServer = server.Close
var err error
profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath)
Expect(err).ToNot(HaveOccurred())
}
rawIDToken := createJWTFromPayload(in.idTokenPayload)
claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders)
return claimExtractor, closeServer, err
}
func createJWTFromPayload(payload string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON))
payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload))
return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header)
}
func newAuthorizedHeader() http.Header {
headers := make(http.Header)
headers.Add("Authorization", "Bearer "+authorizedAccessToken)
return headers
}
func hasAuthorizedHeader(headers http.Header) bool {
return headers.Get("Authorization") == "Bearer "+authorizedAccessToken
}
// ***********************
// Typed Pointer Functions
// ***********************
func stringPointer(in string) *string {
return &in
}
func stringSlicePointer(in []string) *[]string {
return &in
}
func boolPointer(in bool) *bool {
return &in
}
// ******************************
// Different profile URL handlers
// ******************************
func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) {
defer GinkgoRecover()
Expect(true).To(BeFalse(), "Unexpected request to profile URL")
}
func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) {
if !hasAuthorizedHeader(req.Header) {
rw.WriteHeader(403)
rw.Write([]byte("Unauthorized"))
return
}
rw.Write([]byte(basicProfileURLPayload))
}