diff --git a/CHANGELOG.md b/CHANGELOG.md index 001dbe6c..2c2fde8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - [#3001](https://github.com/oauth2-proxy/oauth2-proxy/pull/3001) Allow to set non-default authorization request response mode (@stieler-it) - [#3041](https://github.com/oauth2-proxy/oauth2-proxy/pull/3041) chore(deps): upgrade to latest golang v1.23.x release (@TheImplementer) - [#1916](https://github.com/oauth2-proxy/oauth2-proxy/pull/1916) fix: role extraction from access token in keycloak oidc (@Elektordi / @tuunit) +- [#3014](https://github.com/oauth2-proxy/oauth2-proxy/pull/3014) feat: ability to parse JWT encoded profile claims (@ikarius) # V7.8.2 diff --git a/pkg/providers/util/claim_extractor.go b/pkg/providers/util/claim_extractor.go index ec2fac90..9ab7a8c8 100644 --- a/pkg/providers/util/claim_extractor.go +++ b/pkg/providers/util/claim_extractor.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "mime" "net/http" "net/url" "strings" @@ -94,11 +95,25 @@ func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) { return simplejson.New(), nil } - claims, err := requests.New(c.profileURL.String()). + builder := requests.New(c.profileURL.String()). WithContext(c.ctx). WithHeaders(c.requestHeaders). - Do(). - UnmarshalSimpleJSON() + Do() + + // We first check if the result is a JWT token + // https://openid.net/specs/openid-connect-core-1_0-final.html#UserInfoResponse + mediaType, _, parseErr := mime.ParseMediaType(builder.Headers().Get("Content-Type")) + + if parseErr == nil && mediaType == "application/jwt" { + // Decode and use JWT payload as profile claims + if pl, err := parseJWT(string(builder.Body())); err == nil { + return simplejson.NewJson(pl) + } + } + + // Otherwise, process as normal JSON payload + claims, err := builder.UnmarshalSimpleJSON() + if err != nil { return nil, fmt.Errorf("error making request to profile URL: %v", err) } diff --git a/pkg/providers/util/claim_extractor_test.go b/pkg/providers/util/claim_extractor_test.go index b6d0b513..4ce4606f 100644 --- a/pkg/providers/util/claim_extractor_test.go +++ b/pkg/providers/util/claim_extractor_test.go @@ -497,6 +497,54 @@ var _ = Describe("Claim Extractor Suite", func() { expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), }), ) + + It("should extract claims from a JWT response", func() { + jwtResponsePayload := `{ + "user": "jwtUser", + "email": "jwtEmail", + "groups": [ + "jwtGroup1", + "jwtGroup2" + ] + }` + + jwtResponseHandler := func(rw http.ResponseWriter, req *http.Request) { + if !hasAuthorizedHeader(req.Header) { + rw.WriteHeader(403) + rw.Write([]byte("Unauthorized")) + return + } + + rw.Header().Set("Content-Type", "application/jwt; charset=utf-8") + rw.Write([]byte(createJWTFromPayload(jwtResponsePayload))) + } + + claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{ + idTokenPayload: emptyJSON, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: jwtResponseHandler, + }) + Expect(err).ToNot(HaveOccurred()) + if serverClose != nil { + defer serverClose() + } + + value, exists, err := claimExtractor.GetClaim("user") + Expect(err).ToNot(HaveOccurred()) + Expect(exists).To(BeTrue()) + Expect(value).To(Equal("jwtUser")) + + value, exists, err = claimExtractor.GetClaim("email") + Expect(err).ToNot(HaveOccurred()) + Expect(exists).To(BeTrue()) + Expect(value).To(Equal("jwtEmail")) + + value, exists, err = claimExtractor.GetClaim("groups") + Expect(err).ToNot(HaveOccurred()) + Expect(exists).To(BeTrue()) + Expect(value).To(Equal([]interface{}{"jwtGroup1", "jwtGroup2"})) + }) }) // ******************************************