2021-06-26 11:48:49 +01:00
|
|
|
package util
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/base64"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/bitly/go-simplejson"
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
|
|
|
"github.com/spf13/cast"
|
|
|
|
)
|
|
|
|
|
|
|
|
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
|
|
|
|
// present, from the profile URL.
|
|
|
|
type ClaimExtractor interface {
|
|
|
|
// GetClaim fetches a named claim and returns the value.
|
|
|
|
GetClaim(claim string) (interface{}, bool, error)
|
|
|
|
|
|
|
|
// GetClaimInto fetches a named claim and puts the value into the destination.
|
|
|
|
GetClaimInto(claim string, dst interface{}) (bool, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
|
|
|
|
// If needed, it will use the profile URL to look up a claim if it isn't present
|
|
|
|
// within the ID Token.
|
|
|
|
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
|
|
|
|
payload, err := parseJWT(idToken)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to parse ID Token: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
tokenClaims, err := simplejson.NewJson(payload)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &claimExtractor{
|
|
|
|
ctx: ctx,
|
|
|
|
profileURL: profileURL,
|
|
|
|
requestHeaders: profileRequestHeaders,
|
|
|
|
tokenClaims: tokenClaims,
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// claimExtractor implements the ClaimExtractor interface
|
|
|
|
type claimExtractor struct {
|
|
|
|
profileURL *url.URL
|
|
|
|
ctx context.Context
|
|
|
|
requestHeaders map[string][]string
|
|
|
|
tokenClaims *simplejson.Json
|
|
|
|
profileClaims *simplejson.Json
|
|
|
|
}
|
|
|
|
|
|
|
|
// GetClaim will return the value claim if it exists.
|
|
|
|
// It will only return an error if the profile URL needs to be fetched due to
|
|
|
|
// the claim not being present in the ID Token.
|
|
|
|
func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
|
|
|
|
if claim == "" {
|
|
|
|
return nil, false, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if value := getClaimFrom(claim, c.tokenClaims); value != nil {
|
|
|
|
return value, true, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.profileClaims == nil {
|
|
|
|
profileClaims, err := c.loadProfileClaims()
|
|
|
|
if err != nil {
|
|
|
|
return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
c.profileClaims = profileClaims
|
|
|
|
}
|
|
|
|
|
|
|
|
if value := getClaimFrom(claim, c.profileClaims); value != nil {
|
|
|
|
return value, true, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, false, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// loadProfileClaims will fetch the profileURL using the provided headers as
|
|
|
|
// authentication.
|
|
|
|
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
|
2022-02-18 14:09:07 +00:00
|
|
|
if c.profileURL == nil || c.profileURL.String() == "" || c.requestHeaders == nil {
|
2021-06-26 11:48:49 +01:00
|
|
|
// When no profileURL is set, we return a non-empty map so that
|
|
|
|
// we don't attempt to populate the profile claims again.
|
|
|
|
// If there are no headers, the request would be unauthorized so we also skip
|
|
|
|
// in this case too.
|
|
|
|
return simplejson.New(), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
claims, err := requests.New(c.profileURL.String()).
|
|
|
|
WithContext(c.ctx).
|
|
|
|
WithHeaders(c.requestHeaders).
|
|
|
|
Do().
|
|
|
|
UnmarshalJSON()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error making request to profile URL: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return claims, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// GetClaimInto loads a claim and places it into the destination interface.
|
|
|
|
// This will attempt to coerce the claim into the specified type.
|
|
|
|
// If it cannot be coerced, an error may be returned.
|
|
|
|
func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
|
|
|
|
value, exists, err := c.GetClaim(claim)
|
|
|
|
if err != nil {
|
|
|
|
return false, fmt.Errorf("could not get claim %q: %v", claim, err)
|
|
|
|
}
|
|
|
|
if !exists {
|
|
|
|
return false, nil
|
|
|
|
}
|
|
|
|
if err := coerceClaim(value, dst); err != nil {
|
|
|
|
return false, fmt.Errorf("could no coerce claim: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return true, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
|
|
|
|
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
|
|
|
|
func parseJWT(p string) ([]byte, error) {
|
|
|
|
parts := strings.Split(p, ".")
|
|
|
|
if len(parts) < 2 {
|
|
|
|
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
|
|
|
|
}
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
|
|
|
}
|
|
|
|
return payload, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// getClaimFrom gets a claim from a Json object.
|
|
|
|
// It can accept either a single claim name or a json path.
|
|
|
|
// Paths with indexes are not supported.
|
|
|
|
func getClaimFrom(claim string, src *simplejson.Json) interface{} {
|
|
|
|
claimParts := strings.Split(claim, ".")
|
|
|
|
return src.GetPath(claimParts...).Interface()
|
|
|
|
}
|
|
|
|
|
|
|
|
// coerceClaim tries to convert the value into the destination interface type.
|
|
|
|
// If it can convert the value, it will then store the value in the destination
|
|
|
|
// interface.
|
|
|
|
func coerceClaim(value, dst interface{}) error {
|
|
|
|
switch d := dst.(type) {
|
|
|
|
case *string:
|
|
|
|
str, err := toString(value)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("could not convert value to string: %v", err)
|
|
|
|
}
|
|
|
|
*d = str
|
|
|
|
case *[]string:
|
|
|
|
strSlice, err := toStringSlice(value)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("could not convert value to string slice: %v", err)
|
|
|
|
}
|
|
|
|
*d = strSlice
|
|
|
|
case *bool:
|
|
|
|
*d = cast.ToBool(value)
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("unknown type for destination: %T", dst)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// toStringSlice converts an interface (either a slice or single value) into
|
|
|
|
// a slice of strings.
|
|
|
|
func toStringSlice(value interface{}) ([]string, error) {
|
|
|
|
var sliceValues []interface{}
|
|
|
|
switch v := value.(type) {
|
|
|
|
case []interface{}:
|
|
|
|
sliceValues = v
|
|
|
|
case interface{}:
|
|
|
|
sliceValues = []interface{}{v}
|
|
|
|
default:
|
|
|
|
sliceValues = cast.ToSlice(value)
|
|
|
|
}
|
|
|
|
|
|
|
|
out := []string{}
|
|
|
|
for _, v := range sliceValues {
|
|
|
|
str, err := toString(v)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
|
|
|
|
}
|
|
|
|
out = append(out, str)
|
|
|
|
}
|
|
|
|
return out, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// toString coerces a value into a string.
|
|
|
|
// If it is non-string, marshal it into JSON.
|
|
|
|
func toString(value interface{}) (string, error) {
|
|
|
|
if str, err := cast.ToStringE(value); err == nil {
|
|
|
|
return str, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
jsonStr, err := json.Marshal(value)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
return string(jsonStr), nil
|
|
|
|
}
|