diff --git a/CHANGELOG.md b/CHANGELOG.md index 690b4831..b5655e40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.2.1 +- [#1394](https://github.com/oauth2-proxy/oauth2-proxy/pull/1394) Add generic claim extractor to get claims from ID Tokens (@JoelSpeed) - [#1468](https://github.com/oauth2-proxy/oauth2-proxy/pull/1468) Implement session locking with session state lock (@JoelSpeed, @Bibob7) - [#1489](https://github.com/oauth2-proxy/oauth2-proxy/pull/1489) Fix Docker Buildx push to include build version (@JoelSpeed) - [#1477](https://github.com/oauth2-proxy/oauth2-proxy/pull/1477) Remove provider documentation for `Microsoft Azure AD` (@omBratteng) diff --git a/go.mod b/go.mod index ed1229d5..08a2e390 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/onsi/gomega v1.10.2 github.com/pierrec/lz4 v2.5.2+incompatible github.com/prometheus/client_golang v1.9.0 + github.com/spf13/cast v1.3.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.6.3 github.com/stretchr/testify v1.6.1 diff --git a/pkg/providers/util/claim_extractor.go b/pkg/providers/util/claim_extractor.go new file mode 100644 index 00000000..f0fe320e --- /dev/null +++ b/pkg/providers/util/claim_extractor.go @@ -0,0 +1,210 @@ +package util + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/bitly/go-simplejson" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" + "github.com/spf13/cast" +) + +// ClaimExtractor is used to extract claim values from an ID Token, or, if not +// present, from the profile URL. +type ClaimExtractor interface { + // GetClaim fetches a named claim and returns the value. + GetClaim(claim string) (interface{}, bool, error) + + // GetClaimInto fetches a named claim and puts the value into the destination. + GetClaimInto(claim string, dst interface{}) (bool, error) +} + +// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token. +// If needed, it will use the profile URL to look up a claim if it isn't present +// within the ID Token. +func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) { + payload, err := parseJWT(idToken) + if err != nil { + return nil, fmt.Errorf("failed to parse ID Token: %v", err) + } + + tokenClaims, err := simplejson.NewJson(payload) + if err != nil { + return nil, fmt.Errorf("failed to parse ID Token payload: %v", err) + } + + return &claimExtractor{ + ctx: ctx, + profileURL: profileURL, + requestHeaders: profileRequestHeaders, + tokenClaims: tokenClaims, + }, nil +} + +// claimExtractor implements the ClaimExtractor interface +type claimExtractor struct { + profileURL *url.URL + ctx context.Context + requestHeaders map[string][]string + tokenClaims *simplejson.Json + profileClaims *simplejson.Json +} + +// GetClaim will return the value claim if it exists. +// It will only return an error if the profile URL needs to be fetched due to +// the claim not being present in the ID Token. +func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) { + if claim == "" { + return nil, false, nil + } + + if value := getClaimFrom(claim, c.tokenClaims); value != nil { + return value, true, nil + } + + if c.profileClaims == nil { + profileClaims, err := c.loadProfileClaims() + if err != nil { + return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err) + } + + c.profileClaims = profileClaims + } + + if value := getClaimFrom(claim, c.profileClaims); value != nil { + return value, true, nil + } + + return nil, false, nil +} + +// loadProfileClaims will fetch the profileURL using the provided headers as +// authentication. +func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) { + if c.profileURL == nil || c.requestHeaders == nil { + // When no profileURL is set, we return a non-empty map so that + // we don't attempt to populate the profile claims again. + // If there are no headers, the request would be unauthorized so we also skip + // in this case too. + return simplejson.New(), nil + } + + claims, err := requests.New(c.profileURL.String()). + WithContext(c.ctx). + WithHeaders(c.requestHeaders). + Do(). + UnmarshalJSON() + if err != nil { + return nil, fmt.Errorf("error making request to profile URL: %v", err) + } + + return claims, nil +} + +// GetClaimInto loads a claim and places it into the destination interface. +// This will attempt to coerce the claim into the specified type. +// If it cannot be coerced, an error may be returned. +func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) { + value, exists, err := c.GetClaim(claim) + if err != nil { + return false, fmt.Errorf("could not get claim %q: %v", claim, err) + } + if !exists { + return false, nil + } + if err := coerceClaim(value, dst); err != nil { + return false, fmt.Errorf("could no coerce claim: %v", err) + } + + return true, nil +} + +// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130 +// We use it to grab the raw ID Token payload so that we can parse it into the JSON library. +func parseJWT(p string) ([]byte, error) { + parts := strings.Split(p, ".") + if len(parts) < 2 { + return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err) + } + return payload, nil +} + +// getClaimFrom gets a claim from a Json object. +// It can accept either a single claim name or a json path. +// Paths with indexes are not supported. +func getClaimFrom(claim string, src *simplejson.Json) interface{} { + claimParts := strings.Split(claim, ".") + return src.GetPath(claimParts...).Interface() +} + +// coerceClaim tries to convert the value into the destination interface type. +// If it can convert the value, it will then store the value in the destination +// interface. +func coerceClaim(value, dst interface{}) error { + switch d := dst.(type) { + case *string: + str, err := toString(value) + if err != nil { + return fmt.Errorf("could not convert value to string: %v", err) + } + *d = str + case *[]string: + strSlice, err := toStringSlice(value) + if err != nil { + return fmt.Errorf("could not convert value to string slice: %v", err) + } + *d = strSlice + case *bool: + *d = cast.ToBool(value) + default: + return fmt.Errorf("unknown type for destination: %T", dst) + } + return nil +} + +// toStringSlice converts an interface (either a slice or single value) into +// a slice of strings. +func toStringSlice(value interface{}) ([]string, error) { + var sliceValues []interface{} + switch v := value.(type) { + case []interface{}: + sliceValues = v + case interface{}: + sliceValues = []interface{}{v} + default: + sliceValues = cast.ToSlice(value) + } + + out := []string{} + for _, v := range sliceValues { + str, err := toString(v) + if err != nil { + return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err) + } + out = append(out, str) + } + return out, nil +} + +// toString coerces a value into a string. +// If it is non-string, marshal it into JSON. +func toString(value interface{}) (string, error) { + if str, err := cast.ToStringE(value); err == nil { + return str, nil + } + + jsonStr, err := json.Marshal(value) + if err != nil { + return "", err + } + return string(jsonStr), nil +} diff --git a/pkg/providers/util/claim_extractor_test.go b/pkg/providers/util/claim_extractor_test.go new file mode 100644 index 00000000..fb6220fe --- /dev/null +++ b/pkg/providers/util/claim_extractor_test.go @@ -0,0 +1,530 @@ +package util + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +const ( + emptyJSON = "{}" + profilePath = "/userinfo" + authorizedAccessToken = "valid_access_token" + basicIDTokenPayload = `{ + "user": "idTokenUser", + "email": "idTokenEmail", + "groups": [ + "idTokenGroup1", + "idTokenGroup2" + ] + }` + basicProfileURLPayload = `{ + "user": "profileUser", + "email": "profileEmail", + "groups": [ + "profileGroup1", + "profileGroup2" + ] + }` + nestedClaimPayload = `{ + "auth": { + "user": { + "username": "nestedUser" + } + } + }` + complexGroupsPayload = `{ + "groups": [ + { + "groupID": "group1", + "roles": ["admin"] + }, + { + "groupID": "group2", + "roles": ["user", "employee"] + } + ] + }` +) + +var _ = Describe("Claim Extractor Suite", func() { + Context("Claim Extractor", func() { + type newClaimExtractorTableInput struct { + idToken string + expectedError error + } + + DescribeTable("NewClaimExtractor", + func(in newClaimExtractorTableInput) { + _, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + }, + Entry("with a valid JWT", newClaimExtractorTableInput{ + idToken: createJWTFromPayload(basicIDTokenPayload), + expectedError: nil, + }), + Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{ + idToken: createJWTFromPayload("this is not JSON"), + expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"), + }), + Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{ + idToken: "eyJeyJ", + expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"), + }), + Entry("with an non-base64 IDToken", newClaimExtractorTableInput{ + idToken: "{metadata}.{payload}.{signature}", + expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"), + }), + ) + + type getClaimTableInput struct { + testClaimExtractorOpts + claim string + expectedValue interface{} + expectExists bool + expectedError error + } + + DescribeTable("GetClaim", + func(in getClaimTableInput) { + claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts) + Expect(err).ToNot(HaveOccurred()) + if serverClose != nil { + defer serverClose() + } + + value, exists, err := claimExtractor.GetClaim(in.claim) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + return + } + + Expect(err).ToNot(HaveOccurred()) + if in.expectedValue != nil { + Expect(value).To(Equal(in.expectedValue)) + } else { + Expect(value).To(BeNil()) + } + + Expect(exists).To(Equal(in.expectExists)) + }, + Entry("retrieves a string claim from ID Token when present", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + expectExists: true, + expectedValue: "idTokenUser", + expectedError: nil, + }), + Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + expectExists: true, + expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"}, + expectedError: nil, + }), + Entry("when the requested claim is the empty string", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + }, + claim: "", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + profileRequestHeaders: newAuthorizedHeader(), + }, + claim: "not_found", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "not_found", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: nil, + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "not_found", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the profile URL is unauthorized", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: emptyJSON, + setProfileURL: true, + profileRequestHeaders: make(http.Header), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "user", + expectExists: false, + expectedValue: nil, + expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"), + }), + Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: emptyJSON, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "user", + expectExists: true, + expectedValue: "profileUser", + expectedError: nil, + }), + Entry("retrieves a string claim from a nested path", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: nestedClaimPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "auth.user.username", + expectExists: true, + expectedValue: "nestedUser", + expectedError: nil, + }), + ) + }) + + It("GetClaim should only call the profile URL once", func() { + var counter int32 + countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&counter, 1) + rw.Write([]byte(basicProfileURLPayload)) + } + + claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{ + idTokenPayload: "{}", + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: countRequestsHandler, + }) + Expect(err).ToNot(HaveOccurred()) + if serverClose != nil { + defer serverClose() + } + + value, exists, err := claimExtractor.GetClaim("user") + Expect(err).ToNot(HaveOccurred()) + Expect(exists).To(BeTrue()) + Expect(value).To(Equal("profileUser")) + Expect(counter).To(BeEquivalentTo(1)) + + // Check a different claim, but expect the count not to increase + value, exists, err = claimExtractor.GetClaim("email") + Expect(err).ToNot(HaveOccurred()) + Expect(exists).To(BeTrue()) + Expect(value).To(Equal("profileEmail")) + Expect(counter).To(BeEquivalentTo(1)) + }) + + type getClaimIntoTableInput struct { + testClaimExtractorOpts + into interface{} + claim string + expectedValue interface{} + expectExists bool + expectedError error + } + + DescribeTable("GetClaimInto", + func(in getClaimIntoTableInput) { + claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts) + Expect(err).ToNot(HaveOccurred()) + if serverClose != nil { + defer serverClose() + } + + exists, err := claimExtractor.GetClaimInto(in.claim, in.into) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + return + } + + Expect(err).ToNot(HaveOccurred()) + if in.expectedValue != nil { + Expect(in.into).To(Equal(in.expectedValue)) + } else { + Expect(in.into).To(BeEmpty()) + } + + Expect(exists).To(Equal(in.expectExists)) + }, + Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + into: stringPointer(""), + expectExists: true, + expectedValue: stringPointer("idTokenUser"), + expectedError: nil, + }), + Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + into: stringSlicePointer([]string{}), + expectExists: true, + expectedValue: stringSlicePointer([]string{"idTokenUser"}), + expectedError: nil, + }), + Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + into: stringSlicePointer([]string{}), + expectExists: true, + expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}), + expectedError: nil, + }), + Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + into: stringPointer(""), + expectExists: true, + expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"), + expectedError: nil, + }), + Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + into: "", + expectExists: false, + expectedValue: "", + expectedError: errors.New("could no coerce claim: unknown type for destination: string"), + }), + Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: complexGroupsPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + into: stringSlicePointer([]string{}), + expectExists: true, + expectedValue: stringSlicePointer([]string{ + "{\"groupID\":\"group1\",\"roles\":[\"admin\"]}", + "{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}", + }), + expectedError: nil, + }), + Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "not_found", + into: stringPointer(""), + expectExists: false, + expectedValue: stringPointer(""), + expectedError: nil, + }), + Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: emptyJSON, + setProfileURL: true, + profileRequestHeaders: make(http.Header), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "user", + into: stringPointer(""), + expectExists: false, + expectedValue: stringPointer(""), + expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"), + }), + ) + + type coerceClaimTableInput struct { + value interface{} + dst interface{} + expectedDst interface{} + expectedError error + } + + DescribeTable("coerceClaim", + func(in coerceClaimTableInput) { + err := coerceClaim(in.value, in.dst) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + return + } + + Expect(err).ToNot(HaveOccurred()) + Expect(in.dst).To(Equal(in.expectedDst)) + }, + Entry("coerces a string to a string", coerceClaimTableInput{ + value: "some_string", + dst: stringPointer(""), + expectedDst: stringPointer("some_string"), + }), + Entry("coerces a slice to a string slice", coerceClaimTableInput{ + value: []interface{}{"a", "b"}, + dst: stringSlicePointer([]string{}), + expectedDst: stringSlicePointer([]string{"a", "b"}), + }), + Entry("coerces a bool to a bool", coerceClaimTableInput{ + value: true, + dst: boolPointer(false), + expectedDst: boolPointer(true), + }), + Entry("coerces a string to a bool", coerceClaimTableInput{ + value: "true", + dst: boolPointer(false), + expectedDst: boolPointer(true), + }), + Entry("coerces a map to a string", coerceClaimTableInput{ + value: map[string]interface{}{ + "foo": []interface{}{"bar", "baz"}, + }, + dst: stringPointer(""), + expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), + }), + ) +}) + +// ****************************************** +// Helpers for setting up the claim extractor +// ****************************************** + +type testClaimExtractorOpts struct { + idTokenPayload string + setProfileURL bool + profileRequestHeaders http.Header + profileRequestHandler http.HandlerFunc +} + +func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) { + var profileURL *url.URL + var closeServer func() + if in.setProfileURL { + server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler)) + closeServer = server.Close + + var err error + profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath) + Expect(err).ToNot(HaveOccurred()) + } + + rawIDToken := createJWTFromPayload(in.idTokenPayload) + + claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders) + return claimExtractor, closeServer, err +} + +func createJWTFromPayload(payload string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON)) + payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload)) + + return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header) +} + +func newAuthorizedHeader() http.Header { + headers := make(http.Header) + headers.Add("Authorization", "Bearer "+authorizedAccessToken) + return headers +} + +func hasAuthorizedHeader(headers http.Header) bool { + return headers.Get("Authorization") == "Bearer "+authorizedAccessToken +} + +// *********************** +// Typed Pointer Functions +// *********************** + +func stringPointer(in string) *string { + return &in +} + +func stringSlicePointer(in []string) *[]string { + return &in +} + +func boolPointer(in bool) *bool { + return &in +} + +// ****************************** +// Different profile URL handlers +// ****************************** + +func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) { + defer GinkgoRecover() + Expect(true).To(BeFalse(), "Unexpected request to profile URL") +} + +func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) { + if !hasAuthorizedHeader(req.Header) { + rw.WriteHeader(403) + rw.Write([]byte("Unauthorized")) + return + } + + rw.Write([]byte(basicProfileURLPayload)) +} diff --git a/pkg/providers/util/util_suite_test.go b/pkg/providers/util/util_suite_test.go new file mode 100644 index 00000000..c76d9813 --- /dev/null +++ b/pkg/providers/util/util_suite_test.go @@ -0,0 +1,17 @@ +package util + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestProviderUtilSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + logger.SetErrOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Provider Utils") +} diff --git a/providers/adfs.go b/providers/adfs.go index 797c8566..f5cbbfcc 100644 --- a/providers/adfs.go +++ b/providers/adfs.go @@ -103,16 +103,17 @@ func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt } func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { - idToken, err := p.Verifier.Verify(ctx, s.IDToken) + claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken) if err != nil { - return err + return fmt.Errorf("could not extract claims: %v", err) } - claims, err := p.getClaims(idToken) + + upn, found, err := claims.GetClaim(adfsUPNClaim) if err != nil { - return fmt.Errorf("couldn't extract claims from id_token (%v)", err) + return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err) } - upn := claims.raw[adfsUPNClaim] - if upn != nil { + + if found && fmt.Sprint(upn) != "" { s.Email = fmt.Sprint(upn) } return nil diff --git a/providers/adfs_test.go b/providers/adfs_test.go index 7eb1c487..93e61ea5 100755 --- a/providers/adfs_test.go +++ b/providers/adfs_test.go @@ -79,7 +79,7 @@ func testADFSBackend() *httptest.Server { { "access_token": "my_access_token", "id_token": "my_id_token", - "refresh_token": "my_refresh_token" + "refresh_token": "my_refresh_token" } ` userInfo := ` @@ -150,9 +150,7 @@ var _ = Describe("ADFS Provider Tests", func() { Context("with valid token", func() { It("should not throw an error", func() { rawIDToken, _ := newSignedTestIDToken(defaultIDToken) - idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) - Expect(err).To(BeNil()) - session, err := p.buildSessionFromClaims(idToken) + session, err := p.buildSessionFromClaims(rawIDToken, "") Expect(err).To(BeNil()) session.IDToken = rawIDToken err = p.EnrichSession(context.Background(), session) diff --git a/providers/auth_test.go b/providers/auth_test.go index 2ece923e..bda93b90 100644 --- a/providers/auth_test.go +++ b/providers/auth_test.go @@ -15,9 +15,20 @@ func CreateAuthorizedSession() *sessions.SessionState { } func IsAuthorizedInHeader(reqHeader http.Header) bool { - return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", authorizedAccessToken) + return IsAuthorizedInHeaderWithToken(reqHeader, authorizedAccessToken) +} + +func IsAuthorizedInHeaderWithToken(reqHeader http.Header, token string) bool { + return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", token) } func IsAuthorizedInURL(reqURL *url.URL) bool { return reqURL.Query().Get("access_token") == authorizedAccessToken } + +func isAuthorizedRefreshInURLWithToken(reqURL *url.URL, token string) bool { + if token == "" { + return false + } + return reqURL.Query().Get("refresh_token") == token +} diff --git a/providers/azure.go b/providers/azure.go index 39beb836..10ea701e 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -78,6 +78,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider { if p.ValidateURL == nil || p.ValidateURL.String() == "" { p.ValidateURL = p.ProfileURL } + p.getAuthorizationHeaderFunc = makeAzureHeader return &AzureProvider{ ProviderData: p, @@ -150,7 +151,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* session.CreatedAtNow() session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) - email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) + email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken, session.AccessToken) // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 @@ -163,7 +164,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* } if session.Email == "" { - email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken) + email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken, session.AccessToken) if err == nil && email != "" { session.Email = email } else { @@ -215,16 +216,16 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err // verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token // when oidc verifier is configured -func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) { +func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, rawIDToken string, accessToken string) (string, error) { email := "" - if token != "" && p.Verifier != nil { - token, err := p.Verifier.Verify(ctx, token) + if rawIDToken != "" && p.Verifier != nil { + _, err := p.Verifier.Verify(ctx, rawIDToken) // due to issues mentioned above, id_token may not be signed by AAD if err == nil { - claims, err := p.getClaims(token) + s, err := p.buildSessionFromClaims(rawIDToken, accessToken) if err == nil { - email = claims.Email + email = s.Email } else { logger.Printf("unable to get claims from token: %v", err) } @@ -287,7 +288,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess s.CreatedAtNow() s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) - email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) + email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken, s.AccessToken) // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 @@ -300,7 +301,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess } if s.Email == "" { - email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken) + email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken, s.AccessToken) if err == nil && email != "" { s.Email = email } else { diff --git a/providers/azure_test.go b/providers/azure_test.go index 3bb16d84..25b8c202 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -13,9 +13,8 @@ import ( "testing" "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt" - - oidc "github.com/coreos/go-oidc/v3/oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" @@ -145,11 +144,11 @@ func TestAzureSetTenant(t *testing.T) { assert.Equal(t, "openid", p.Data().Scope) } -func testAzureBackend(payload string) *httptest.Server { - return testAzureBackendWithError(payload, false) +func testAzureBackend(payload string, accessToken, refreshToken string) *httptest.Server { + return testAzureBackendWithError(payload, accessToken, refreshToken, false) } -func testAzureBackendWithError(payload string, injectError bool) *httptest.Server { +func testAzureBackendWithError(payload string, accessToken, refreshToken string, injectError bool) *httptest.Server { path := "/v1.0/me" return httptest.NewServer(http.HandlerFunc( @@ -163,7 +162,8 @@ func testAzureBackendWithError(payload string, injectError bool) *httptest.Serve w.WriteHeader(200) } w.Write([]byte(payload)) - } else if !IsAuthorizedInHeader(r.Header) { + } else if !IsAuthorizedInHeaderWithToken(r.Header, accessToken) && + !isAuthorizedRefreshInURLWithToken(r.URL, refreshToken) { w.WriteHeader(403) } else { w.WriteHeader(200) @@ -224,7 +224,7 @@ func TestAzureProviderEnrichSession(t *testing.T) { host string ) if testCase.PayloadFromAzureBackend != "" { - b = testAzureBackend(testCase.PayloadFromAzureBackend) + b = testAzureBackend(testCase.PayloadFromAzureBackend, authorizedAccessToken, "") defer b.Close() bURL, _ := url.Parse(b.URL) @@ -319,7 +319,7 @@ func TestAzureProviderRedeem(t *testing.T) { payloadBytes, err := json.Marshal(payload) assert.NoError(t, err) - b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError) + b := testAzureBackendWithError(string(payloadBytes), accessTokenString, testCase.RefreshToken, testCase.InjectRedeemURLError) defer b.Close() bURL, _ := url.Parse(b.URL) @@ -353,35 +353,44 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { func TestAzureProviderRefresh(t *testing.T) { email := "foo@example.com" + subject := "foo" idToken := idTokenClaims{ - StandardClaims: jwt.StandardClaims{Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532"}, - Email: email} + Email: email, + StandardClaims: jwt.StandardClaims{ + Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532", + Subject: subject, + }, + } idTokenString, err := newSignedTestIDToken(idToken) assert.NoError(t, err) + timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") assert.NoError(t, err) + + newAccessToken := "new_some_access_token" payload := azureOAuthPayload{ IDToken: idTokenString, RefreshToken: "new_some_refresh_token", - AccessToken: "new_some_access_token", + AccessToken: newAccessToken, ExpiresOn: timestamp.Unix(), } - payloadBytes, err := json.Marshal(payload) assert.NoError(t, err) - b := testAzureBackend(string(payloadBytes)) + + refreshToken := "some_refresh_token" + b := testAzureBackend(string(payloadBytes), newAccessToken, refreshToken) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) expires := time.Now().Add(time.Duration(-1) * time.Hour) - session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} + session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: refreshToken, IDToken: "some_id_token", ExpiresOn: &expires} refreshed, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) assert.True(t, refreshed) assert.NotEqual(t, session, nil) - assert.Equal(t, "new_some_access_token", session.AccessToken) + assert.Equal(t, newAccessToken, session.AccessToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken) assert.Equal(t, idTokenString, session.IDToken) assert.Equal(t, email, session.Email) diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 4c1196df..acbd6f70 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -57,6 +57,8 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { validateURL: digitalOceanDefaultProfileURL, scope: digitalOceanDefaultScope, }) + p.getAuthorizationHeaderFunc = makeOIDCHeader + return &DigitalOceanProvider{ProviderData: p} } diff --git a/providers/facebook.go b/providers/facebook.go index 6db9c38d..cfa836d7 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -58,6 +58,7 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { validateURL: facebookDefaultProfileURL, scope: facebookDefaultScope, }) + p.getAuthorizationHeaderFunc = makeOIDCHeader return &FacebookProvider{ProviderData: p} } diff --git a/providers/linkedin.go b/providers/linkedin.go index cac80222..3904d840 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -65,6 +65,8 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { validateURL: linkedinDefaultValidateURL, scope: linkedinDefaultScope, }) + p.getAuthorizationHeaderFunc = makeLinkedInHeader + return &LinkedInProvider{ProviderData: p} } diff --git a/providers/nextcloud.go b/providers/nextcloud.go index 4a074d6a..e9156016 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -1,13 +1,5 @@ package providers -import ( - "context" - "fmt" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" -) - // NextcloudProvider represents an Nextcloud based Identity Provider type NextcloudProvider struct { *ProviderData @@ -20,20 +12,11 @@ const nextCloudProviderName = "Nextcloud" // NewNextcloudProvider initiates a new NextcloudProvider func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { p.ProviderName = nextCloudProviderName + p.getAuthorizationHeaderFunc = makeOIDCHeader + if p.EmailClaim == OIDCEmailClaim { + // This implies the email claim has not been overridden, we should set a default + // for this provider + p.EmailClaim = "ocs.data.email" + } return &NextcloudProvider{ProviderData: p} } - -// GetEmailAddress returns the Account email address -func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - json, err := requests.New(p.ValidateURL.String()). - WithContext(ctx). - WithHeaders(makeOIDCHeader(s.AccessToken)). - Do(). - UnmarshalJSON() - if err != nil { - return "", fmt.Errorf("error making request: %v", err) - } - - email, err := json.Get("ocs").Get("data").Get("email").String() - return email, err -} diff --git a/providers/nextcloud_test.go b/providers/nextcloud_test.go index cd26885f..92f5030c 100644 --- a/providers/nextcloud_test.go +++ b/providers/nextcloud_test.go @@ -1,18 +1,13 @@ package providers import ( - "context" - "net/http" - "net/http/httptest" "net/url" "testing" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) const formatJSON = "format=json" -const userPath = "/ocs/v2.php/cloud/user" func testNextcloudProvider(hostname string) *NextcloudProvider { p := NewNextcloudProvider( @@ -32,23 +27,6 @@ func testNextcloudProvider(hostname string) *NextcloudProvider { return p } -func testNextcloudBackend(payload string) *httptest.Server { - path := userPath - query := formatJSON - - return httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != path || r.URL.RawQuery != query { - w.WriteHeader(404) - } else if !IsAuthorizedInHeader(r.Header) { - w.WriteHeader(403) - } else { - w.WriteHeader(200) - w.Write([]byte(payload)) - } - })) -} - func TestNextcloudProviderDefaults(t *testing.T) { p := testNextcloudProvider("") assert.NotEqual(t, nil, p) @@ -87,53 +65,3 @@ func TestNextcloudProviderOverrides(t *testing.T) { assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON, p.Data().ValidateURL.String()) } - -func TestNextcloudProviderGetEmailAddress(t *testing.T) { - b := testNextcloudBackend("{\"ocs\": {\"data\": { \"email\": \"michael.bland@gsa.gov\"}}}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testNextcloudProvider(bURL.Host) - p.ValidateURL.Path = userPath - p.ValidateURL.RawQuery = formatJSON - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) -} - -// Note that trying to trigger the "failed building request" case is not -// practical, since the only way it can fail is if the URL fails to parse. -func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) { - b := testNextcloudBackend("unused payload") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testNextcloudProvider(bURL.Host) - p.ValidateURL.Path = userPath - p.ValidateURL.RawQuery = formatJSON - - // We'll trigger a request failure by using an unexpected access - // token. Alternatively, we could allow the parsing of the payload as - // JSON to fail. - session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -} - -func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { - b := testNextcloudBackend("{\"foo\": \"bar\"}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testNextcloudProvider(bURL.Host) - p.ValidateURL.Path = userPath - p.ValidateURL.RawQuery = formatJSON - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -} diff --git a/providers/oidc.go b/providers/oidc.go index b1711d54..cccb8d70 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -5,12 +5,10 @@ import ( "errors" "fmt" "net/url" - "reflect" "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "golang.org/x/oauth2" ) @@ -24,6 +22,8 @@ type OIDCProvider struct { // NewOIDCProvider initiates a new OIDCProvider func NewOIDCProvider(p *ProviderData) *OIDCProvider { p.ProviderName = "OpenID Connect" + p.getAuthorizationHeaderFunc = makeOIDCHeader + return &OIDCProvider{ ProviderData: p, SkipNonce: true, @@ -68,21 +68,6 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s // EnrichSession is called after Redeem to allow providers to enrich session fields // such as User, Email, Groups with provider specific API calls. func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { - if p.ProfileURL.String() == "" { - if s.Email == "" { - return errors.New("id_token did not contain an email and profileURL is not defined") - } - return nil - } - - // Try to get missing emails or groups from a profileURL - if s.Email == "" || s.Groups == nil { - err := p.enrichFromProfileURL(ctx, s) - if err != nil { - logger.Errorf("Warning: Profile URL request failed: %v", err) - } - } - // If a mandatory email wasn't set, error at this point. if s.Email == "" { return errors.New("neither the id_token nor the profileURL set an email") @@ -90,42 +75,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta return nil } -// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of -// an OIDC profile URL -func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error { - respJSON, err := requests.New(p.ProfileURL.String()). - WithContext(ctx). - WithHeaders(makeOIDCHeader(s.AccessToken)). - Do(). - UnmarshalJSON() - if err != nil { - return err - } - - email, err := respJSON.Get(p.EmailClaim).String() - if err == nil && s.Email == "" { - s.Email = email - } - - if len(s.Groups) > 0 { - return nil - } - for _, group := range coerceArray(respJSON, p.GroupsClaim) { - formatted, err := formatGroup(group) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(group), err) - continue - } - s.Groups = append(s.Groups, formatted) - } - - return nil -} - // ValidateSession checks that the session's IDToken is still valid func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { - idToken, err := p.Verifier.Verify(ctx, s.IDToken) + _, err := p.Verifier.Verify(ctx, s.IDToken) if err != nil { logger.Errorf("id_token verification failed: %v", err) return false @@ -134,7 +86,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS if p.SkipNonce { return true } - err = p.checkNonce(s, idToken) + err = p.checkNonce(s) if err != nil { logger.Errorf("nonce verification failed: %v", err) return false @@ -212,7 +164,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) return nil, err } - ss, err := p.buildSessionFromClaims(idToken) + ss, err := p.buildSessionFromClaims(token, "") if err != nil { return nil, err } @@ -235,7 +187,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) // createSession takes an oauth2.Token and creates a SessionState from it. // It alters behavior if called from Redeem vs Refresh func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { - idToken, err := p.verifyIDToken(ctx, token) + _, err := p.verifyIDToken(ctx, token) if err != nil { switch err { case ErrMissingIDToken: @@ -248,14 +200,15 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r } } - ss, err := p.buildSessionFromClaims(idToken) + rawIDToken := getIDToken(token) + ss, err := p.buildSessionFromClaims(rawIDToken, token.AccessToken) if err != nil { return nil, err } ss.AccessToken = token.AccessToken ss.RefreshToken = token.RefreshToken - ss.IDToken = getIDToken(token) + ss.IDToken = rawIDToken ss.CreatedAtNow() ss.SetExpiresOn(token.Expiry) diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 3c87da00..1a98f462 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/http/httptest" @@ -54,6 +53,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { Scope: "openid profile offline_access", EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", Verifier: internaloidc.NewVerifier(oidc.NewVerifier( oidcIssuer, mockJWKS{}, @@ -142,333 +142,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { assert.Equal(t, defaultIDToken.Phone, session.Email) } -func TestOIDCProvider_EnrichSession(t *testing.T) { - testCases := map[string]struct { - ExistingSession *sessions.SessionState - EmailClaim string - GroupsClaim string - ProfileJSON map[string]interface{} - ExpectedError error - ExpectedSession *sessions.SessionState - }{ - "Already Populated": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Email": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "found@email.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Email: "found@email.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - - "Missing Email Only in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "found@email.com", - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Email: "found@email.com", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Email with Custom Claim": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "weird", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "weird": "weird@claim.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Email: "weird@claim.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Email not in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "groups": []string{"new", "thing"}, - }, - ExpectedError: errors.New("neither the id_token nor the profileURL set an email"), - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"new", "thing"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups with Complex Groups in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []map[string]interface{}{ - { - "groupId": "Admin Group Id", - "roles": []string{"Admin"}, - }, - }, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups with Singleton Complex Group in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": map[string]interface{}{ - "groupId": "Admin Group Id", - "roles": []string{"Admin"}, - }, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Empty Groups Claims": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups with Custom Claim": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "roles", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "roles": []string{"new", "thing", "roles"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"new", "thing", "roles"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups String Profile URL Response": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": "singleton", - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"singleton"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups in both Claims and Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - } - for testName, tc := range testCases { - t.Run(testName, func(t *testing.T) { - jsonResp, err := json.Marshal(tc.ProfileJSON) - assert.NoError(t, err) - - server, provider := newTestOIDCSetup(jsonResp) - provider.ProfileURL, err = url.Parse(server.URL) - assert.NoError(t, err) - - provider.EmailClaim = tc.EmailClaim - provider.GroupsClaim = tc.GroupsClaim - defer server.Close() - - err = provider.EnrichSession(context.Background(), tc.ExistingSession) - assert.Equal(t, tc.ExpectedError, err) - assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession) - }) - } -} - func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { idToken, _ := newSignedTestIDToken(defaultIDToken) @@ -565,11 +238,15 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { ExpectedGroups: []string{"test:c", "test:d"}, }, "Complex Groups Claim": { - IDToken: complexGroupsIDToken, - GroupsClaim: "groups", - ExpectedUser: "123456789", - ExpectedEmail: "complex@claims.com", - ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + IDToken: complexGroupsIDToken, + GroupsClaim: "groups", + ExpectedUser: "123456789", + ExpectedEmail: "complex@claims.com", + ExpectedGroups: []string{ + "{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", + "12345", + "Just::A::String", + }, }, } for testName, tc := range testCases { diff --git a/providers/provider_data.go b/providers/provider_data.go index 38f17405..13241ee7 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -5,20 +5,23 @@ import ( "errors" "fmt" "io/ioutil" + "net/http" "net/url" - "reflect" "strings" "github.com/coreos/go-oidc/v3/oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/util" "golang.org/x/oauth2" ) const ( OIDCEmailClaim = "email" OIDCGroupsClaim = "groups" + // This is not exported as it's not currently user configurable + oidcUserClaim = "sub" ) var OIDCAudienceClaims = []string{"aud"} @@ -52,6 +55,8 @@ type ProviderData struct { // Universal Group authorization data structure // any provider can set to consume AllowedGroups map[string]struct{} + + getAuthorizationHeaderFunc func(string) http.Header } // Data returns the ProviderData @@ -99,6 +104,10 @@ func (p *ProviderData) setProviderDefaults(defaults providerDefaults) { if p.Scope == "" { p.Scope = defaults.scope } + + if p.UserClaim == "" { + p.UserClaim = oidcUserClaim + } } // defaultURL will set return a default value if the given value is not set. @@ -120,17 +129,6 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL { // OIDC compliant // **************************************************************************** -// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload -type OIDCClaims struct { - Subject string `json:"sub"` - Email string `json:"-"` - Groups []string `json:"-"` - Verified *bool `json:"email_verified"` - Nonce string `json:"nonce"` - - raw map[string]interface{} -} - func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { rawIDToken := getIDToken(token) if strings.TrimSpace(rawIDToken) == "" { @@ -144,110 +142,80 @@ func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) ( // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState // with non-Token related fields. -func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { +func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*sessions.SessionState, error) { ss := &sessions.SessionState{} - if idToken == nil { + if rawIDToken == "" { return ss, nil } - claims, err := p.getClaims(idToken) + extractor, err := p.getClaimExtractor(rawIDToken, accessToken) if err != nil { - return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) + return nil, err } - ss.User = claims.Subject - ss.Email = claims.Email - ss.Groups = claims.Groups - - // Allow specialized providers that embed OIDCProvider to control the User - // claim. Not exposed as a configuration flag to generic OIDC provider - // users (yet). - if p.UserClaim != "" { - user, ok := claims.raw[p.UserClaim].(string) - if !ok { - return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim) + // Use a slice of a struct (vs map) here in case the same claim is used twice + for _, c := range []struct { + claim string + dst interface{} + }{ + {p.UserClaim, &ss.User}, + {p.EmailClaim, &ss.Email}, + {p.GroupsClaim, &ss.Groups}, + // TODO (@NickMeves) Deprecate for dynamic claim to session mapping + {"preferred_username", &ss.PreferredUsername}, + } { + if _, err := extractor.GetClaimInto(c.claim, c.dst); err != nil { + return nil, err } - ss.User = user - } - - // TODO (@NickMeves) Deprecate for dynamic claim to session mapping - if pref, ok := claims.raw["preferred_username"].(string); ok { - ss.PreferredUsername = pref } // `email_verified` must be present and explicitly set to `false` to be // considered unverified. verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail - if verifyEmail && claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + + var verified bool + exists, err := extractor.GetClaimInto("email_verified", &verified) + if err != nil { + return nil, err + } + + if verifyEmail && exists && !verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", ss.Email) } return ss, nil } -// getClaims extracts IDToken claims into an OIDCClaims -func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { - claims := &OIDCClaims{} - - // Extract default claims. - if err := idToken.Claims(&claims); err != nil { - return nil, fmt.Errorf("failed to parse default id_token claims: %v", err) - } - // Extract custom claims. - if err := idToken.Claims(&claims.raw); err != nil { - return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) +func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) { + extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken)) + if err != nil { + return nil, fmt.Errorf("could not initialise claim extractor: %v", err) } - email := claims.raw[p.EmailClaim] - if email != nil { - claims.Email = fmt.Sprint(email) - } - claims.Groups = p.extractGroups(claims.raw) - - return claims, nil + return extractor, nil } // checkNonce compares the session's nonce with the IDToken's nonce claim -func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error { - claims, err := p.getClaims(idToken) +func (p *ProviderData) checkNonce(s *sessions.SessionState) error { + extractor, err := p.getClaimExtractor(s.IDToken, "") if err != nil { return fmt.Errorf("id_token claims extraction failed: %v", err) } - if !s.CheckNonce(claims.Nonce) { + var nonce string + if _, err := extractor.GetClaimInto("nonce", &nonce); err != nil { + return fmt.Errorf("could not extract nonce from ID Token: %v", err) + } + + if !s.CheckNonce(nonce) { return errors.New("id_token nonce claim does not match the session nonce") } return nil } -// extractGroups extracts groups from a claim to a list in a type safe manner. -// If the claim isn't present, `nil` is returned. If the groups claim is -// present but empty, `[]string{}` is returned. -func (p *ProviderData) extractGroups(claims map[string]interface{}) []string { - rawClaim, ok := claims[p.GroupsClaim] - if !ok { - return nil +func (p *ProviderData) getAuthorizationHeader(accessToken string) http.Header { + if p.getAuthorizationHeaderFunc != nil && accessToken != "" { + return p.getAuthorizationHeaderFunc(accessToken) } - - // Handle traditional list-based groups as well as non-standard singleton - // based groups. Both variants support complex objects if needed. - var claimGroups []interface{} - switch raw := rawClaim.(type) { - case []interface{}: - claimGroups = raw - case interface{}: - claimGroups = []interface{}{raw} - } - - groups := []string{} - for _, rawGroup := range claimGroups { - formattedGroup, err := formatGroup(rawGroup) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(rawGroup), err) - continue - } - groups = append(groups, formattedGroup) - } - return groups + return nil } diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index 8e6d12c4..64c8326d 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -60,16 +60,30 @@ var ( StandardClaims: standardClaims, } + numericGroupsIDToken = idTokenClaims{ + Name: "Jane Dobbs", + Email: "janed@me.com", + Phone: "+4798765432", + Picture: "http://mugbook.com/janed/me.jpg", + Groups: []interface{}{1, 2, 3}, + Roles: []string{"test:c", "test:d"}, + Verified: &verified, + Nonce: encryption.HashNonce([]byte(oidcNonce)), + StandardClaims: standardClaims, + } + complexGroupsIDToken = idTokenClaims{ Name: "Complex Claim", Email: "complex@claims.com", Phone: "+5439871234", Picture: "http://mugbook.com/complex/claims.jpg", - Groups: []map[string]interface{}{ - { + Groups: []interface{}{ + map[string]interface{}{ "groupId": "Admin Group Id", "roles": []string{"Admin"}, }, + 12345, + "Just::A::String", }, Roles: []string{"test:simple", "test:roles"}, Verified: &verified, @@ -228,6 +242,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: false, EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", @@ -247,6 +262,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "unverified@email.com", @@ -259,10 +275,15 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ - User: "123456789", - Email: "complex@claims.com", - Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + User: "123456789", + Email: "complex@claims.com", + Groups: []string{ + "{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", + "12345", + "Just::A::String", + }, PreferredUsername: "Complex Claim", }, }, @@ -279,19 +300,25 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { PreferredUsername: "Jane Dobbs", }, }, - "User Claim Invalid": { + "User Claim switched to non string": { IDToken: defaultIDToken, AllowUnverified: true, - UserClaim: "groups", + UserClaim: "roles", EmailClaim: "email", GroupsClaim: "groups", - ExpectedError: errors.New("unable to extract custom UserClaim (groups)"), + ExpectedSession: &sessions.SessionState{ + User: "[\"test:c\",\"test:d\"]", + Email: "janed@me.com", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Jane Dobbs", + }, }, "Email Claim Switched": { IDToken: unverifiedIDToken, AllowUnverified: true, EmailClaim: "phone_number", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "+4025205729", @@ -304,9 +331,10 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "roles", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", - Email: "[test:c test:d]", + Email: "[\"test:c\",\"test:d\"]", Groups: []string{"test:a", "test:b"}, PreferredUsername: "Mystery Man", }, @@ -316,6 +344,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "aksjdfhjksadh", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "", @@ -328,6 +357,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: false, EmailClaim: "email", GroupsClaim: "roles", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", @@ -340,6 +370,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: false, EmailClaim: "email", GroupsClaim: "alskdjfsalkdjf", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", @@ -347,6 +378,32 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { PreferredUsername: "Jane Dobbs", }, }, + "Groups Claim Numeric values": { + IDToken: numericGroupsIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "groups", + UserClaim: "sub", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "janed@me.com", + Groups: []string{"1", "2", "3"}, + PreferredUsername: "Jane Dobbs", + }, + }, + "Groups Claim string values": { + IDToken: defaultIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "email", + UserClaim: "sub", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "janed@me.com", + Groups: []string{"janed@me.com"}, + PreferredUsername: "Jane Dobbs", + }, + }, } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { @@ -371,10 +428,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { rawIDToken, err := newSignedTestIDToken(tc.IDToken) g.Expect(err).ToNot(HaveOccurred()) - idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) - g.Expect(err).ToNot(HaveOccurred()) - - ss, err := provider.buildSessionFromClaims(idToken) + ss, err := provider.buildSessionFromClaims(rawIDToken, "") if err != nil { g.Expect(err).To(Equal(tc.ExpectedError)) } @@ -418,6 +472,12 @@ func TestProviderData_checkNonce(t *testing.T) { t.Run(testName, func(t *testing.T) { g := NewWithT(t) + // Ensure that the ID token in the session is valid (signed and contains a nonce) + // as the nonce claim is extracted to compare with the session nonce + rawIDToken, err := newSignedTestIDToken(tc.IDToken) + g.Expect(err).ToNot(HaveOccurred()) + tc.Session.IDToken = rawIDToken + verificationOptions := &internaloidc.IDTokenVerificationOptions{ AudienceClaims: []string{"aud"}, ClientID: oidcClientID, @@ -430,14 +490,7 @@ func TestProviderData_checkNonce(t *testing.T) { ), verificationOptions), } - rawIDToken, err := newSignedTestIDToken(tc.IDToken) - g.Expect(err).ToNot(HaveOccurred()) - - idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) - g.Expect(err).ToNot(HaveOccurred()) - - err = provider.checkNonce(tc.Session, idToken) - if err != nil { + if err := provider.checkNonce(tc.Session); err != nil { g.Expect(err).To(Equal(tc.ExpectedError)) } else { g.Expect(err).ToNot(HaveOccurred()) @@ -445,95 +498,3 @@ func TestProviderData_checkNonce(t *testing.T) { }) } } - -func TestProviderData_extractGroups(t *testing.T) { - testCases := map[string]struct { - Claims map[string]interface{} - GroupsClaim string - ExpectedGroups []string - }{ - "Standard String Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": []interface{}{"three", "string", "groups"}, - }, - GroupsClaim: "groups", - ExpectedGroups: []string{"three", "string", "groups"}, - }, - "Different Claim Name": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "roles": []interface{}{"three", "string", "roles"}, - }, - GroupsClaim: "roles", - ExpectedGroups: []string{"three", "string", "roles"}, - }, - "Numeric Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": []interface{}{1, 2, 3}, - }, - GroupsClaim: "groups", - ExpectedGroups: []string{"1", "2", "3"}, - }, - "Complex Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": []interface{}{ - map[string]interface{}{ - "groupId": "Admin Group Id", - "roles": []string{"Admin"}, - }, - 12345, - "Just::A::String", - }, - }, - GroupsClaim: "groups", - ExpectedGroups: []string{ - "{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", - "12345", - "Just::A::String", - }, - }, - "Missing Groups Claim Returns Nil": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - }, - GroupsClaim: "groups", - ExpectedGroups: nil, - }, - "Non List Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": "singleton", - }, - GroupsClaim: "groups", - ExpectedGroups: []string{"singleton"}, - }, - } - for testName, tc := range testCases { - t.Run(testName, func(t *testing.T) { - g := NewWithT(t) - - verificationOptions := &internaloidc.IDTokenVerificationOptions{ - AudienceClaims: []string{"aud"}, - ClientID: oidcClientID, - } - provider := &ProviderData{ - Verifier: internaloidc.NewVerifier(oidc.NewVerifier( - oidcIssuer, - mockJWKS{}, - &oidc.Config{ClientID: oidcClientID}, - ), verificationOptions), - } - provider.GroupsClaim = tc.GroupsClaim - - groups := provider.extractGroups(tc.Claims) - if tc.ExpectedGroups != nil { - g.Expect(groups).To(Equal(tc.ExpectedGroups)) - } else { - g.Expect(groups).To(BeNil()) - } - }) - } -} diff --git a/providers/util.go b/providers/util.go index e6fdc344..0507dde0 100644 --- a/providers/util.go +++ b/providers/util.go @@ -6,7 +6,6 @@ import ( "net/http" "net/url" - "github.com/bitly/go-simplejson" "golang.org/x/oauth2" ) @@ -83,18 +82,3 @@ func formatGroup(rawGroup interface{}) (string, error) { } return string(jsonGroup), nil } - -// coerceArray extracts a field from simplejson.Json that might be a -// singleton or a list and coerces it into a list. -func coerceArray(sj *simplejson.Json, key string) []interface{} { - array, err := sj.Get(key).Array() - if err == nil { - return array - } - - single := sj.Get(key).Interface() - if single == nil { - return nil - } - return []interface{}{single} -}