1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-05-29 23:17:38 +02:00

Merge pull request #1394 from oauth2-proxy/claim-extractor

Add generic claim extractor to get claims from ID Tokens
This commit is contained in:
Joel Speed 2022-02-16 10:37:20 +00:00 committed by GitHub
commit 9832844c8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 971 additions and 733 deletions

View File

@ -8,6 +8,7 @@
## Changes since v7.2.1
- [#1394](https://github.com/oauth2-proxy/oauth2-proxy/pull/1394) Add generic claim extractor to get claims from ID Tokens (@JoelSpeed)
- [#1468](https://github.com/oauth2-proxy/oauth2-proxy/pull/1468) Implement session locking with session state lock (@JoelSpeed, @Bibob7)
- [#1489](https://github.com/oauth2-proxy/oauth2-proxy/pull/1489) Fix Docker Buildx push to include build version (@JoelSpeed)
- [#1477](https://github.com/oauth2-proxy/oauth2-proxy/pull/1477) Remove provider documentation for `Microsoft Azure AD` (@omBratteng)

1
go.mod
View File

@ -23,6 +23,7 @@ require (
github.com/onsi/gomega v1.10.2
github.com/pierrec/lz4 v2.5.2+incompatible
github.com/prometheus/client_golang v1.9.0
github.com/spf13/cast v1.3.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.6.3
github.com/stretchr/testify v1.6.1

View File

@ -0,0 +1,210 @@
package util
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/bitly/go-simplejson"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"github.com/spf13/cast"
)
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
// present, from the profile URL.
type ClaimExtractor interface {
// GetClaim fetches a named claim and returns the value.
GetClaim(claim string) (interface{}, bool, error)
// GetClaimInto fetches a named claim and puts the value into the destination.
GetClaimInto(claim string, dst interface{}) (bool, error)
}
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
// If needed, it will use the profile URL to look up a claim if it isn't present
// within the ID Token.
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
payload, err := parseJWT(idToken)
if err != nil {
return nil, fmt.Errorf("failed to parse ID Token: %v", err)
}
tokenClaims, err := simplejson.NewJson(payload)
if err != nil {
return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
}
return &claimExtractor{
ctx: ctx,
profileURL: profileURL,
requestHeaders: profileRequestHeaders,
tokenClaims: tokenClaims,
}, nil
}
// claimExtractor implements the ClaimExtractor interface
type claimExtractor struct {
profileURL *url.URL
ctx context.Context
requestHeaders map[string][]string
tokenClaims *simplejson.Json
profileClaims *simplejson.Json
}
// GetClaim will return the value claim if it exists.
// It will only return an error if the profile URL needs to be fetched due to
// the claim not being present in the ID Token.
func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
if claim == "" {
return nil, false, nil
}
if value := getClaimFrom(claim, c.tokenClaims); value != nil {
return value, true, nil
}
if c.profileClaims == nil {
profileClaims, err := c.loadProfileClaims()
if err != nil {
return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
}
c.profileClaims = profileClaims
}
if value := getClaimFrom(claim, c.profileClaims); value != nil {
return value, true, nil
}
return nil, false, nil
}
// loadProfileClaims will fetch the profileURL using the provided headers as
// authentication.
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
if c.profileURL == nil || c.requestHeaders == nil {
// When no profileURL is set, we return a non-empty map so that
// we don't attempt to populate the profile claims again.
// If there are no headers, the request would be unauthorized so we also skip
// in this case too.
return simplejson.New(), nil
}
claims, err := requests.New(c.profileURL.String()).
WithContext(c.ctx).
WithHeaders(c.requestHeaders).
Do().
UnmarshalJSON()
if err != nil {
return nil, fmt.Errorf("error making request to profile URL: %v", err)
}
return claims, nil
}
// GetClaimInto loads a claim and places it into the destination interface.
// This will attempt to coerce the claim into the specified type.
// If it cannot be coerced, an error may be returned.
func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
value, exists, err := c.GetClaim(claim)
if err != nil {
return false, fmt.Errorf("could not get claim %q: %v", claim, err)
}
if !exists {
return false, nil
}
if err := coerceClaim(value, dst); err != nil {
return false, fmt.Errorf("could no coerce claim: %v", err)
}
return true, nil
}
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
}
// getClaimFrom gets a claim from a Json object.
// It can accept either a single claim name or a json path.
// Paths with indexes are not supported.
func getClaimFrom(claim string, src *simplejson.Json) interface{} {
claimParts := strings.Split(claim, ".")
return src.GetPath(claimParts...).Interface()
}
// coerceClaim tries to convert the value into the destination interface type.
// If it can convert the value, it will then store the value in the destination
// interface.
func coerceClaim(value, dst interface{}) error {
switch d := dst.(type) {
case *string:
str, err := toString(value)
if err != nil {
return fmt.Errorf("could not convert value to string: %v", err)
}
*d = str
case *[]string:
strSlice, err := toStringSlice(value)
if err != nil {
return fmt.Errorf("could not convert value to string slice: %v", err)
}
*d = strSlice
case *bool:
*d = cast.ToBool(value)
default:
return fmt.Errorf("unknown type for destination: %T", dst)
}
return nil
}
// toStringSlice converts an interface (either a slice or single value) into
// a slice of strings.
func toStringSlice(value interface{}) ([]string, error) {
var sliceValues []interface{}
switch v := value.(type) {
case []interface{}:
sliceValues = v
case interface{}:
sliceValues = []interface{}{v}
default:
sliceValues = cast.ToSlice(value)
}
out := []string{}
for _, v := range sliceValues {
str, err := toString(v)
if err != nil {
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
}
out = append(out, str)
}
return out, nil
}
// toString coerces a value into a string.
// If it is non-string, marshal it into JSON.
func toString(value interface{}) (string, error) {
if str, err := cast.ToStringE(value); err == nil {
return str, nil
}
jsonStr, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(jsonStr), nil
}

View File

@ -0,0 +1,530 @@
package util
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
const (
emptyJSON = "{}"
profilePath = "/userinfo"
authorizedAccessToken = "valid_access_token"
basicIDTokenPayload = `{
"user": "idTokenUser",
"email": "idTokenEmail",
"groups": [
"idTokenGroup1",
"idTokenGroup2"
]
}`
basicProfileURLPayload = `{
"user": "profileUser",
"email": "profileEmail",
"groups": [
"profileGroup1",
"profileGroup2"
]
}`
nestedClaimPayload = `{
"auth": {
"user": {
"username": "nestedUser"
}
}
}`
complexGroupsPayload = `{
"groups": [
{
"groupID": "group1",
"roles": ["admin"]
},
{
"groupID": "group2",
"roles": ["user", "employee"]
}
]
}`
)
var _ = Describe("Claim Extractor Suite", func() {
Context("Claim Extractor", func() {
type newClaimExtractorTableInput struct {
idToken string
expectedError error
}
DescribeTable("NewClaimExtractor",
func(in newClaimExtractorTableInput) {
_, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
} else {
Expect(err).ToNot(HaveOccurred())
}
},
Entry("with a valid JWT", newClaimExtractorTableInput{
idToken: createJWTFromPayload(basicIDTokenPayload),
expectedError: nil,
}),
Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{
idToken: createJWTFromPayload("this is not JSON"),
expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"),
}),
Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{
idToken: "eyJeyJ",
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"),
}),
Entry("with an non-base64 IDToken", newClaimExtractorTableInput{
idToken: "{metadata}.{payload}.{signature}",
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"),
}),
)
type getClaimTableInput struct {
testClaimExtractorOpts
claim string
expectedValue interface{}
expectExists bool
expectedError error
}
DescribeTable("GetClaim",
func(in getClaimTableInput) {
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
value, exists, err := claimExtractor.GetClaim(in.claim)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
if in.expectedValue != nil {
Expect(value).To(Equal(in.expectedValue))
} else {
Expect(value).To(BeNil())
}
Expect(exists).To(Equal(in.expectExists))
},
Entry("retrieves a string claim from ID Token when present", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
expectExists: true,
expectedValue: "idTokenUser",
expectedError: nil,
}),
Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
expectExists: true,
expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"},
expectedError: nil,
}),
Entry("when the requested claim is the empty string", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
},
claim: "",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
profileRequestHeaders: newAuthorizedHeader(),
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: nil,
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the profile URL is unauthorized", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: make(http.Header),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
expectExists: false,
expectedValue: nil,
expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
}),
Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
expectExists: true,
expectedValue: "profileUser",
expectedError: nil,
}),
Entry("retrieves a string claim from a nested path", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: nestedClaimPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "auth.user.username",
expectExists: true,
expectedValue: "nestedUser",
expectedError: nil,
}),
)
})
It("GetClaim should only call the profile URL once", func() {
var counter int32
countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&counter, 1)
rw.Write([]byte(basicProfileURLPayload))
}
claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{
idTokenPayload: "{}",
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: countRequestsHandler,
})
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
value, exists, err := claimExtractor.GetClaim("user")
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
Expect(value).To(Equal("profileUser"))
Expect(counter).To(BeEquivalentTo(1))
// Check a different claim, but expect the count not to increase
value, exists, err = claimExtractor.GetClaim("email")
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
Expect(value).To(Equal("profileEmail"))
Expect(counter).To(BeEquivalentTo(1))
})
type getClaimIntoTableInput struct {
testClaimExtractorOpts
into interface{}
claim string
expectedValue interface{}
expectExists bool
expectedError error
}
DescribeTable("GetClaimInto",
func(in getClaimIntoTableInput) {
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
exists, err := claimExtractor.GetClaimInto(in.claim, in.into)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
if in.expectedValue != nil {
Expect(in.into).To(Equal(in.expectedValue))
} else {
Expect(in.into).To(BeEmpty())
}
Expect(exists).To(Equal(in.expectExists))
},
Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: stringPointer(""),
expectExists: true,
expectedValue: stringPointer("idTokenUser"),
expectedError: nil,
}),
Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{"idTokenUser"}),
expectedError: nil,
}),
Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}),
expectedError: nil,
}),
Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringPointer(""),
expectExists: true,
expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"),
expectedError: nil,
}),
Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: "",
expectExists: false,
expectedValue: "",
expectedError: errors.New("could no coerce claim: unknown type for destination: string"),
}),
Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: complexGroupsPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{
"{\"groupID\":\"group1\",\"roles\":[\"admin\"]}",
"{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}",
}),
expectedError: nil,
}),
Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "not_found",
into: stringPointer(""),
expectExists: false,
expectedValue: stringPointer(""),
expectedError: nil,
}),
Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: make(http.Header),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
into: stringPointer(""),
expectExists: false,
expectedValue: stringPointer(""),
expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
}),
)
type coerceClaimTableInput struct {
value interface{}
dst interface{}
expectedDst interface{}
expectedError error
}
DescribeTable("coerceClaim",
func(in coerceClaimTableInput) {
err := coerceClaim(in.value, in.dst)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
Expect(in.dst).To(Equal(in.expectedDst))
},
Entry("coerces a string to a string", coerceClaimTableInput{
value: "some_string",
dst: stringPointer(""),
expectedDst: stringPointer("some_string"),
}),
Entry("coerces a slice to a string slice", coerceClaimTableInput{
value: []interface{}{"a", "b"},
dst: stringSlicePointer([]string{}),
expectedDst: stringSlicePointer([]string{"a", "b"}),
}),
Entry("coerces a bool to a bool", coerceClaimTableInput{
value: true,
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a string to a bool", coerceClaimTableInput{
value: "true",
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a map to a string", coerceClaimTableInput{
value: map[string]interface{}{
"foo": []interface{}{"bar", "baz"},
},
dst: stringPointer(""),
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
}),
)
})
// ******************************************
// Helpers for setting up the claim extractor
// ******************************************
type testClaimExtractorOpts struct {
idTokenPayload string
setProfileURL bool
profileRequestHeaders http.Header
profileRequestHandler http.HandlerFunc
}
func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) {
var profileURL *url.URL
var closeServer func()
if in.setProfileURL {
server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler))
closeServer = server.Close
var err error
profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath)
Expect(err).ToNot(HaveOccurred())
}
rawIDToken := createJWTFromPayload(in.idTokenPayload)
claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders)
return claimExtractor, closeServer, err
}
func createJWTFromPayload(payload string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON))
payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload))
return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header)
}
func newAuthorizedHeader() http.Header {
headers := make(http.Header)
headers.Add("Authorization", "Bearer "+authorizedAccessToken)
return headers
}
func hasAuthorizedHeader(headers http.Header) bool {
return headers.Get("Authorization") == "Bearer "+authorizedAccessToken
}
// ***********************
// Typed Pointer Functions
// ***********************
func stringPointer(in string) *string {
return &in
}
func stringSlicePointer(in []string) *[]string {
return &in
}
func boolPointer(in bool) *bool {
return &in
}
// ******************************
// Different profile URL handlers
// ******************************
func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) {
defer GinkgoRecover()
Expect(true).To(BeFalse(), "Unexpected request to profile URL")
}
func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) {
if !hasAuthorizedHeader(req.Header) {
rw.WriteHeader(403)
rw.Write([]byte("Unauthorized"))
return
}
rw.Write([]byte(basicProfileURLPayload))
}

View File

@ -0,0 +1,17 @@
package util
import (
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func TestProviderUtilSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Provider Utils")
}

View File

@ -103,16 +103,17 @@ func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt
}
func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error {
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken)
if err != nil {
return err
return fmt.Errorf("could not extract claims: %v", err)
}
claims, err := p.getClaims(idToken)
upn, found, err := claims.GetClaim(adfsUPNClaim)
if err != nil {
return fmt.Errorf("couldn't extract claims from id_token (%v)", err)
return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err)
}
upn := claims.raw[adfsUPNClaim]
if upn != nil {
if found && fmt.Sprint(upn) != "" {
s.Email = fmt.Sprint(upn)
}
return nil

View File

@ -79,7 +79,7 @@ func testADFSBackend() *httptest.Server {
{
"access_token": "my_access_token",
"id_token": "my_id_token",
"refresh_token": "my_refresh_token"
"refresh_token": "my_refresh_token"
}
`
userInfo := `
@ -150,9 +150,7 @@ var _ = Describe("ADFS Provider Tests", func() {
Context("with valid token", func() {
It("should not throw an error", func() {
rawIDToken, _ := newSignedTestIDToken(defaultIDToken)
idToken, err := p.Verifier.Verify(context.Background(), rawIDToken)
Expect(err).To(BeNil())
session, err := p.buildSessionFromClaims(idToken)
session, err := p.buildSessionFromClaims(rawIDToken, "")
Expect(err).To(BeNil())
session.IDToken = rawIDToken
err = p.EnrichSession(context.Background(), session)

View File

@ -15,9 +15,20 @@ func CreateAuthorizedSession() *sessions.SessionState {
}
func IsAuthorizedInHeader(reqHeader http.Header) bool {
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", authorizedAccessToken)
return IsAuthorizedInHeaderWithToken(reqHeader, authorizedAccessToken)
}
func IsAuthorizedInHeaderWithToken(reqHeader http.Header, token string) bool {
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", token)
}
func IsAuthorizedInURL(reqURL *url.URL) bool {
return reqURL.Query().Get("access_token") == authorizedAccessToken
}
func isAuthorizedRefreshInURLWithToken(reqURL *url.URL, token string) bool {
if token == "" {
return false
}
return reqURL.Query().Get("refresh_token") == token
}

View File

@ -78,6 +78,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
p.ValidateURL = p.ProfileURL
}
p.getAuthorizationHeaderFunc = makeAzureHeader
return &AzureProvider{
ProviderData: p,
@ -150,7 +151,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
session.CreatedAtNow()
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken, session.AccessToken)
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
@ -163,7 +164,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
}
if session.Email == "" {
email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken)
email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken, session.AccessToken)
if err == nil && email != "" {
session.Email = email
} else {
@ -215,16 +216,16 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err
// verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token
// when oidc verifier is configured
func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) {
func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, rawIDToken string, accessToken string) (string, error) {
email := ""
if token != "" && p.Verifier != nil {
token, err := p.Verifier.Verify(ctx, token)
if rawIDToken != "" && p.Verifier != nil {
_, err := p.Verifier.Verify(ctx, rawIDToken)
// due to issues mentioned above, id_token may not be signed by AAD
if err == nil {
claims, err := p.getClaims(token)
s, err := p.buildSessionFromClaims(rawIDToken, accessToken)
if err == nil {
email = claims.Email
email = s.Email
} else {
logger.Printf("unable to get claims from token: %v", err)
}
@ -287,7 +288,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
s.CreatedAtNow()
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken, s.AccessToken)
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
@ -300,7 +301,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
}
if s.Email == "" {
email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken)
email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken, s.AccessToken)
if err == nil && email != "" {
s.Email = email
} else {

View File

@ -13,9 +13,8 @@ import (
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
oidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
@ -145,11 +144,11 @@ func TestAzureSetTenant(t *testing.T) {
assert.Equal(t, "openid", p.Data().Scope)
}
func testAzureBackend(payload string) *httptest.Server {
return testAzureBackendWithError(payload, false)
func testAzureBackend(payload string, accessToken, refreshToken string) *httptest.Server {
return testAzureBackendWithError(payload, accessToken, refreshToken, false)
}
func testAzureBackendWithError(payload string, injectError bool) *httptest.Server {
func testAzureBackendWithError(payload string, accessToken, refreshToken string, injectError bool) *httptest.Server {
path := "/v1.0/me"
return httptest.NewServer(http.HandlerFunc(
@ -163,7 +162,8 @@ func testAzureBackendWithError(payload string, injectError bool) *httptest.Serve
w.WriteHeader(200)
}
w.Write([]byte(payload))
} else if !IsAuthorizedInHeader(r.Header) {
} else if !IsAuthorizedInHeaderWithToken(r.Header, accessToken) &&
!isAuthorizedRefreshInURLWithToken(r.URL, refreshToken) {
w.WriteHeader(403)
} else {
w.WriteHeader(200)
@ -224,7 +224,7 @@ func TestAzureProviderEnrichSession(t *testing.T) {
host string
)
if testCase.PayloadFromAzureBackend != "" {
b = testAzureBackend(testCase.PayloadFromAzureBackend)
b = testAzureBackend(testCase.PayloadFromAzureBackend, authorizedAccessToken, "")
defer b.Close()
bURL, _ := url.Parse(b.URL)
@ -319,7 +319,7 @@ func TestAzureProviderRedeem(t *testing.T) {
payloadBytes, err := json.Marshal(payload)
assert.NoError(t, err)
b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError)
b := testAzureBackendWithError(string(payloadBytes), accessTokenString, testCase.RefreshToken, testCase.InjectRedeemURLError)
defer b.Close()
bURL, _ := url.Parse(b.URL)
@ -353,35 +353,44 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
func TestAzureProviderRefresh(t *testing.T) {
email := "foo@example.com"
subject := "foo"
idToken := idTokenClaims{
StandardClaims: jwt.StandardClaims{Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532"},
Email: email}
Email: email,
StandardClaims: jwt.StandardClaims{
Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532",
Subject: subject,
},
}
idTokenString, err := newSignedTestIDToken(idToken)
assert.NoError(t, err)
timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z")
assert.NoError(t, err)
newAccessToken := "new_some_access_token"
payload := azureOAuthPayload{
IDToken: idTokenString,
RefreshToken: "new_some_refresh_token",
AccessToken: "new_some_access_token",
AccessToken: newAccessToken,
ExpiresOn: timestamp.Unix(),
}
payloadBytes, err := json.Marshal(payload)
assert.NoError(t, err)
b := testAzureBackend(string(payloadBytes))
refreshToken := "some_refresh_token"
b := testAzureBackend(string(payloadBytes), newAccessToken, refreshToken)
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
expires := time.Now().Add(time.Duration(-1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: refreshToken, IDToken: "some_id_token", ExpiresOn: &expires}
refreshed, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err)
assert.True(t, refreshed)
assert.NotEqual(t, session, nil)
assert.Equal(t, "new_some_access_token", session.AccessToken)
assert.Equal(t, newAccessToken, session.AccessToken)
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
assert.Equal(t, idTokenString, session.IDToken)
assert.Equal(t, email, session.Email)

View File

@ -57,6 +57,8 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider {
validateURL: digitalOceanDefaultProfileURL,
scope: digitalOceanDefaultScope,
})
p.getAuthorizationHeaderFunc = makeOIDCHeader
return &DigitalOceanProvider{ProviderData: p}
}

View File

@ -58,6 +58,7 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
validateURL: facebookDefaultProfileURL,
scope: facebookDefaultScope,
})
p.getAuthorizationHeaderFunc = makeOIDCHeader
return &FacebookProvider{ProviderData: p}
}

View File

@ -65,6 +65,8 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
validateURL: linkedinDefaultValidateURL,
scope: linkedinDefaultScope,
})
p.getAuthorizationHeaderFunc = makeLinkedInHeader
return &LinkedInProvider{ProviderData: p}
}

View File

@ -1,13 +1,5 @@
package providers
import (
"context"
"fmt"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
)
// NextcloudProvider represents an Nextcloud based Identity Provider
type NextcloudProvider struct {
*ProviderData
@ -20,20 +12,11 @@ const nextCloudProviderName = "Nextcloud"
// NewNextcloudProvider initiates a new NextcloudProvider
func NewNextcloudProvider(p *ProviderData) *NextcloudProvider {
p.ProviderName = nextCloudProviderName
p.getAuthorizationHeaderFunc = makeOIDCHeader
if p.EmailClaim == OIDCEmailClaim {
// This implies the email claim has not been overridden, we should set a default
// for this provider
p.EmailClaim = "ocs.data.email"
}
return &NextcloudProvider{ProviderData: p}
}
// GetEmailAddress returns the Account email address
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
WithHeaders(makeOIDCHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", fmt.Errorf("error making request: %v", err)
}
email, err := json.Get("ocs").Get("data").Get("email").String()
return email, err
}

View File

@ -1,18 +1,13 @@
package providers
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
const formatJSON = "format=json"
const userPath = "/ocs/v2.php/cloud/user"
func testNextcloudProvider(hostname string) *NextcloudProvider {
p := NewNextcloudProvider(
@ -32,23 +27,6 @@ func testNextcloudProvider(hostname string) *NextcloudProvider {
return p
}
func testNextcloudBackend(payload string) *httptest.Server {
path := userPath
query := formatJSON
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != path || r.URL.RawQuery != query {
w.WriteHeader(404)
} else if !IsAuthorizedInHeader(r.Header) {
w.WriteHeader(403)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}
func TestNextcloudProviderDefaults(t *testing.T) {
p := testNextcloudProvider("")
assert.NotEqual(t, nil, p)
@ -87,53 +65,3 @@ func TestNextcloudProviderOverrides(t *testing.T) {
assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON,
p.Data().ValidateURL.String())
}
func TestNextcloudProviderGetEmailAddress(t *testing.T) {
b := testNextcloudBackend("{\"ocs\": {\"data\": { \"email\": \"michael.bland@gsa.gov\"}}}")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testNextcloudProvider(bURL.Host)
p.ValidateURL.Path = userPath
p.ValidateURL.RawQuery = formatJSON
session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}
// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testNextcloudBackend("unused payload")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testNextcloudProvider(bURL.Host)
p.ValidateURL.Path = userPath
p.ValidateURL.RawQuery = formatJSON
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testNextcloudBackend("{\"foo\": \"bar\"}")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testNextcloudProvider(bURL.Host)
p.ValidateURL.Path = userPath
p.ValidateURL.RawQuery = formatJSON
session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}

View File

@ -5,12 +5,10 @@ import (
"errors"
"fmt"
"net/url"
"reflect"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"golang.org/x/oauth2"
)
@ -24,6 +22,8 @@ type OIDCProvider struct {
// NewOIDCProvider initiates a new OIDCProvider
func NewOIDCProvider(p *ProviderData) *OIDCProvider {
p.ProviderName = "OpenID Connect"
p.getAuthorizationHeaderFunc = makeOIDCHeader
return &OIDCProvider{
ProviderData: p,
SkipNonce: true,
@ -68,21 +68,6 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s
// EnrichSession is called after Redeem to allow providers to enrich session fields
// such as User, Email, Groups with provider specific API calls.
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
if p.ProfileURL.String() == "" {
if s.Email == "" {
return errors.New("id_token did not contain an email and profileURL is not defined")
}
return nil
}
// Try to get missing emails or groups from a profileURL
if s.Email == "" || s.Groups == nil {
err := p.enrichFromProfileURL(ctx, s)
if err != nil {
logger.Errorf("Warning: Profile URL request failed: %v", err)
}
}
// If a mandatory email wasn't set, error at this point.
if s.Email == "" {
return errors.New("neither the id_token nor the profileURL set an email")
@ -90,42 +75,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
return nil
}
// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
// an OIDC profile URL
func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
respJSON, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(makeOIDCHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return err
}
email, err := respJSON.Get(p.EmailClaim).String()
if err == nil && s.Email == "" {
s.Email = email
}
if len(s.Groups) > 0 {
return nil
}
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
formatted, err := formatGroup(group)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
reflect.TypeOf(group), err)
continue
}
s.Groups = append(s.Groups, formatted)
}
return nil
}
// ValidateSession checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
_, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil {
logger.Errorf("id_token verification failed: %v", err)
return false
@ -134,7 +86,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
if p.SkipNonce {
return true
}
err = p.checkNonce(s, idToken)
err = p.checkNonce(s)
if err != nil {
logger.Errorf("nonce verification failed: %v", err)
return false
@ -212,7 +164,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
return nil, err
}
ss, err := p.buildSessionFromClaims(idToken)
ss, err := p.buildSessionFromClaims(token, "")
if err != nil {
return nil, err
}
@ -235,7 +187,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
// createSession takes an oauth2.Token and creates a SessionState from it.
// It alters behavior if called from Redeem vs Refresh
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
idToken, err := p.verifyIDToken(ctx, token)
_, err := p.verifyIDToken(ctx, token)
if err != nil {
switch err {
case ErrMissingIDToken:
@ -248,14 +200,15 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
}
}
ss, err := p.buildSessionFromClaims(idToken)
rawIDToken := getIDToken(token)
ss, err := p.buildSessionFromClaims(rawIDToken, token.AccessToken)
if err != nil {
return nil, err
}
ss.AccessToken = token.AccessToken
ss.RefreshToken = token.RefreshToken
ss.IDToken = getIDToken(token)
ss.IDToken = rawIDToken
ss.CreatedAtNow()
ss.SetExpiresOn(token.Expiry)

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -54,6 +53,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
Scope: "openid profile offline_access",
EmailClaim: "email",
GroupsClaim: "groups",
UserClaim: "sub",
Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
oidcIssuer,
mockJWKS{},
@ -142,333 +142,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
assert.Equal(t, defaultIDToken.Phone, session.Email)
}
func TestOIDCProvider_EnrichSession(t *testing.T) {
testCases := map[string]struct {
ExistingSession *sessions.SessionState
EmailClaim string
GroupsClaim string
ProfileJSON map[string]interface{}
ExpectedError error
ExpectedSession *sessions.SessionState
}{
"Already Populated": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "found@email.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "found@email.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email Only in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "found@email.com",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "found@email.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email with Custom Claim": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "weird",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"weird": "weird@claim.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "weird@claim.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email not in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"groups": []string{"new", "thing"},
},
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"new", "thing"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Complex Groups in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []map[string]interface{}{
{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Singleton Complex Group in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Empty Groups Claims": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Custom Claim": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "roles",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"roles": []string{"new", "thing", "roles"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"new", "thing", "roles"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups String Profile URL Response": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": "singleton",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"singleton"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups in both Claims and Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
jsonResp, err := json.Marshal(tc.ProfileJSON)
assert.NoError(t, err)
server, provider := newTestOIDCSetup(jsonResp)
provider.ProfileURL, err = url.Parse(server.URL)
assert.NoError(t, err)
provider.EmailClaim = tc.EmailClaim
provider.GroupsClaim = tc.GroupsClaim
defer server.Close()
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
assert.Equal(t, tc.ExpectedError, err)
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
})
}
}
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
idToken, _ := newSignedTestIDToken(defaultIDToken)
@ -565,11 +238,15 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
ExpectedGroups: []string{"test:c", "test:d"},
},
"Complex Groups Claim": {
IDToken: complexGroupsIDToken,
GroupsClaim: "groups",
ExpectedUser: "123456789",
ExpectedEmail: "complex@claims.com",
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: complexGroupsIDToken,
GroupsClaim: "groups",
ExpectedUser: "123456789",
ExpectedEmail: "complex@claims.com",
ExpectedGroups: []string{
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
"12345",
"Just::A::String",
},
},
}
for testName, tc := range testCases {

View File

@ -5,20 +5,23 @@ import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/util"
"golang.org/x/oauth2"
)
const (
OIDCEmailClaim = "email"
OIDCGroupsClaim = "groups"
// This is not exported as it's not currently user configurable
oidcUserClaim = "sub"
)
var OIDCAudienceClaims = []string{"aud"}
@ -52,6 +55,8 @@ type ProviderData struct {
// Universal Group authorization data structure
// any provider can set to consume
AllowedGroups map[string]struct{}
getAuthorizationHeaderFunc func(string) http.Header
}
// Data returns the ProviderData
@ -99,6 +104,10 @@ func (p *ProviderData) setProviderDefaults(defaults providerDefaults) {
if p.Scope == "" {
p.Scope = defaults.scope
}
if p.UserClaim == "" {
p.UserClaim = oidcUserClaim
}
}
// defaultURL will set return a default value if the given value is not set.
@ -120,17 +129,6 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL {
// OIDC compliant
// ****************************************************************************
// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
type OIDCClaims struct {
Subject string `json:"sub"`
Email string `json:"-"`
Groups []string `json:"-"`
Verified *bool `json:"email_verified"`
Nonce string `json:"nonce"`
raw map[string]interface{}
}
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken := getIDToken(token)
if strings.TrimSpace(rawIDToken) == "" {
@ -144,110 +142,80 @@ func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
// with non-Token related fields.
func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) {
func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*sessions.SessionState, error) {
ss := &sessions.SessionState{}
if idToken == nil {
if rawIDToken == "" {
return ss, nil
}
claims, err := p.getClaims(idToken)
extractor, err := p.getClaimExtractor(rawIDToken, accessToken)
if err != nil {
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
return nil, err
}
ss.User = claims.Subject
ss.Email = claims.Email
ss.Groups = claims.Groups
// Allow specialized providers that embed OIDCProvider to control the User
// claim. Not exposed as a configuration flag to generic OIDC provider
// users (yet).
if p.UserClaim != "" {
user, ok := claims.raw[p.UserClaim].(string)
if !ok {
return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim)
// Use a slice of a struct (vs map) here in case the same claim is used twice
for _, c := range []struct {
claim string
dst interface{}
}{
{p.UserClaim, &ss.User},
{p.EmailClaim, &ss.Email},
{p.GroupsClaim, &ss.Groups},
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
{"preferred_username", &ss.PreferredUsername},
} {
if _, err := extractor.GetClaimInto(c.claim, c.dst); err != nil {
return nil, err
}
ss.User = user
}
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
if pref, ok := claims.raw["preferred_username"].(string); ok {
ss.PreferredUsername = pref
}
// `email_verified` must be present and explicitly set to `false` to be
// considered unverified.
verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail
if verifyEmail && claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
var verified bool
exists, err := extractor.GetClaimInto("email_verified", &verified)
if err != nil {
return nil, err
}
if verifyEmail && exists && !verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", ss.Email)
}
return ss, nil
}
// getClaims extracts IDToken claims into an OIDCClaims
func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
claims := &OIDCClaims{}
// Extract default claims.
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
}
// Extract custom claims.
if err := idToken.Claims(&claims.raw); err != nil {
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) {
extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken))
if err != nil {
return nil, fmt.Errorf("could not initialise claim extractor: %v", err)
}
email := claims.raw[p.EmailClaim]
if email != nil {
claims.Email = fmt.Sprint(email)
}
claims.Groups = p.extractGroups(claims.raw)
return claims, nil
return extractor, nil
}
// checkNonce compares the session's nonce with the IDToken's nonce claim
func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error {
claims, err := p.getClaims(idToken)
func (p *ProviderData) checkNonce(s *sessions.SessionState) error {
extractor, err := p.getClaimExtractor(s.IDToken, "")
if err != nil {
return fmt.Errorf("id_token claims extraction failed: %v", err)
}
if !s.CheckNonce(claims.Nonce) {
var nonce string
if _, err := extractor.GetClaimInto("nonce", &nonce); err != nil {
return fmt.Errorf("could not extract nonce from ID Token: %v", err)
}
if !s.CheckNonce(nonce) {
return errors.New("id_token nonce claim does not match the session nonce")
}
return nil
}
// extractGroups extracts groups from a claim to a list in a type safe manner.
// If the claim isn't present, `nil` is returned. If the groups claim is
// present but empty, `[]string{}` is returned.
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
rawClaim, ok := claims[p.GroupsClaim]
if !ok {
return nil
func (p *ProviderData) getAuthorizationHeader(accessToken string) http.Header {
if p.getAuthorizationHeaderFunc != nil && accessToken != "" {
return p.getAuthorizationHeaderFunc(accessToken)
}
// Handle traditional list-based groups as well as non-standard singleton
// based groups. Both variants support complex objects if needed.
var claimGroups []interface{}
switch raw := rawClaim.(type) {
case []interface{}:
claimGroups = raw
case interface{}:
claimGroups = []interface{}{raw}
}
groups := []string{}
for _, rawGroup := range claimGroups {
formattedGroup, err := formatGroup(rawGroup)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
reflect.TypeOf(rawGroup), err)
continue
}
groups = append(groups, formattedGroup)
}
return groups
return nil
}

View File

@ -60,16 +60,30 @@ var (
StandardClaims: standardClaims,
}
numericGroupsIDToken = idTokenClaims{
Name: "Jane Dobbs",
Email: "janed@me.com",
Phone: "+4798765432",
Picture: "http://mugbook.com/janed/me.jpg",
Groups: []interface{}{1, 2, 3},
Roles: []string{"test:c", "test:d"},
Verified: &verified,
Nonce: encryption.HashNonce([]byte(oidcNonce)),
StandardClaims: standardClaims,
}
complexGroupsIDToken = idTokenClaims{
Name: "Complex Claim",
Email: "complex@claims.com",
Phone: "+5439871234",
Picture: "http://mugbook.com/complex/claims.jpg",
Groups: []map[string]interface{}{
{
Groups: []interface{}{
map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
12345,
"Just::A::String",
},
Roles: []string{"test:simple", "test:roles"},
Verified: &verified,
@ -228,6 +242,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
@ -247,6 +262,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true,
EmailClaim: "email",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "unverified@email.com",
@ -259,10 +275,15 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true,
EmailClaim: "email",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "complex@claims.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
User: "123456789",
Email: "complex@claims.com",
Groups: []string{
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
"12345",
"Just::A::String",
},
PreferredUsername: "Complex Claim",
},
},
@ -279,19 +300,25 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
PreferredUsername: "Jane Dobbs",
},
},
"User Claim Invalid": {
"User Claim switched to non string": {
IDToken: defaultIDToken,
AllowUnverified: true,
UserClaim: "groups",
UserClaim: "roles",
EmailClaim: "email",
GroupsClaim: "groups",
ExpectedError: errors.New("unable to extract custom UserClaim (groups)"),
ExpectedSession: &sessions.SessionState{
User: "[\"test:c\",\"test:d\"]",
Email: "janed@me.com",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Jane Dobbs",
},
},
"Email Claim Switched": {
IDToken: unverifiedIDToken,
AllowUnverified: true,
EmailClaim: "phone_number",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "+4025205729",
@ -304,9 +331,10 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true,
EmailClaim: "roles",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "[test:c test:d]",
Email: "[\"test:c\",\"test:d\"]",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Mystery Man",
},
@ -316,6 +344,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true,
EmailClaim: "aksjdfhjksadh",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "",
@ -328,6 +357,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "roles",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
@ -340,6 +370,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "alskdjfsalkdjf",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
@ -347,6 +378,32 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
PreferredUsername: "Jane Dobbs",
},
},
"Groups Claim Numeric values": {
IDToken: numericGroupsIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: []string{"1", "2", "3"},
PreferredUsername: "Jane Dobbs",
},
},
"Groups Claim string values": {
IDToken: defaultIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "email",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: []string{"janed@me.com"},
PreferredUsername: "Jane Dobbs",
},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
@ -371,10 +428,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
g.Expect(err).ToNot(HaveOccurred())
ss, err := provider.buildSessionFromClaims(idToken)
ss, err := provider.buildSessionFromClaims(rawIDToken, "")
if err != nil {
g.Expect(err).To(Equal(tc.ExpectedError))
}
@ -418,6 +472,12 @@ func TestProviderData_checkNonce(t *testing.T) {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
// Ensure that the ID token in the session is valid (signed and contains a nonce)
// as the nonce claim is extracted to compare with the session nonce
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())
tc.Session.IDToken = rawIDToken
verificationOptions := &internaloidc.IDTokenVerificationOptions{
AudienceClaims: []string{"aud"},
ClientID: oidcClientID,
@ -430,14 +490,7 @@ func TestProviderData_checkNonce(t *testing.T) {
), verificationOptions),
}
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
g.Expect(err).ToNot(HaveOccurred())
err = provider.checkNonce(tc.Session, idToken)
if err != nil {
if err := provider.checkNonce(tc.Session); err != nil {
g.Expect(err).To(Equal(tc.ExpectedError))
} else {
g.Expect(err).ToNot(HaveOccurred())
@ -445,95 +498,3 @@ func TestProviderData_checkNonce(t *testing.T) {
})
}
}
func TestProviderData_extractGroups(t *testing.T) {
testCases := map[string]struct {
Claims map[string]interface{}
GroupsClaim string
ExpectedGroups []string
}{
"Standard String Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{"three", "string", "groups"},
},
GroupsClaim: "groups",
ExpectedGroups: []string{"three", "string", "groups"},
},
"Different Claim Name": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"roles": []interface{}{"three", "string", "roles"},
},
GroupsClaim: "roles",
ExpectedGroups: []string{"three", "string", "roles"},
},
"Numeric Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{1, 2, 3},
},
GroupsClaim: "groups",
ExpectedGroups: []string{"1", "2", "3"},
},
"Complex Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{
map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
12345,
"Just::A::String",
},
},
GroupsClaim: "groups",
ExpectedGroups: []string{
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
"12345",
"Just::A::String",
},
},
"Missing Groups Claim Returns Nil": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
},
GroupsClaim: "groups",
ExpectedGroups: nil,
},
"Non List Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": "singleton",
},
GroupsClaim: "groups",
ExpectedGroups: []string{"singleton"},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
verificationOptions := &internaloidc.IDTokenVerificationOptions{
AudienceClaims: []string{"aud"},
ClientID: oidcClientID,
}
provider := &ProviderData{
Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
oidcIssuer,
mockJWKS{},
&oidc.Config{ClientID: oidcClientID},
), verificationOptions),
}
provider.GroupsClaim = tc.GroupsClaim
groups := provider.extractGroups(tc.Claims)
if tc.ExpectedGroups != nil {
g.Expect(groups).To(Equal(tc.ExpectedGroups))
} else {
g.Expect(groups).To(BeNil())
}
})
}
}

View File

@ -6,7 +6,6 @@ import (
"net/http"
"net/url"
"github.com/bitly/go-simplejson"
"golang.org/x/oauth2"
)
@ -83,18 +82,3 @@ func formatGroup(rawGroup interface{}) (string, error) {
}
return string(jsonGroup), nil
}
// coerceArray extracts a field from simplejson.Json that might be a
// singleton or a list and coerces it into a list.
func coerceArray(sj *simplejson.Json, key string) []interface{} {
array, err := sj.Get(key).Array()
if err == nil {
return array
}
single := sj.Get(key).Interface()
if single == nil {
return nil
}
return []interface{}{single}
}