1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-25 22:00:56 +02:00
oauth2-proxy/providers/logingov.go
2020-07-19 18:34:55 +01:00

264 lines
6.8 KiB
Go

package providers
import (
"bytes"
"context"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"gopkg.in/square/go-jose.v2"
)
// LoginGovProvider represents an OIDC based Identity Provider
type LoginGovProvider struct {
*ProviderData
// TODO (@timothy-spencer): Ideally, the nonce would be in the session state, but the session state
// is created only upon code redemption, not during the auth, when this must be supplied.
Nonce string
JWTKey *rsa.PrivateKey
PubJWKURL *url.URL
}
var _ Provider = (*LoginGovProvider)(nil)
// For generating a nonce
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randSeq(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}
const (
loginGovProviderName = "login.gov"
loginGovDefaultScope = "email openid"
)
var (
// Default Login URL for LoginGov.
// Pre-parsed URL of https://secure.login.gov/openid_connect/authorize.
loginGovDefaultLoginURL = &url.URL{
Scheme: "https",
Host: "secure.login.gov",
Path: "/openid_connect/authorize",
}
// Default Redeem URL for LoginGov.
// Pre-parsed URL of https://secure.login.gov/api/openid_connect/token.
loginGovDefaultRedeemURL = &url.URL{
Scheme: "https",
Host: "secure.login.gov",
Path: "/api/openid_connect/token",
}
// Default Profile URL for LoginGov.
// Pre-parsed URL of https://graph.loginGov.com/v2.5/me.
loginGovDefaultProfileURL = &url.URL{
Scheme: "https",
Host: "secure.login.gov",
Path: "/api/openid_connect/userinfo",
}
)
// NewLoginGovProvider initiates a new LoginGovProvider
func NewLoginGovProvider(p *ProviderData) *LoginGovProvider {
p.setProviderDefaults(providerDefaults{
name: loginGovProviderName,
loginURL: loginGovDefaultLoginURL,
redeemURL: loginGovDefaultRedeemURL,
profileURL: loginGovDefaultProfileURL,
validateURL: nil,
scope: loginGovDefaultScope,
})
return &LoginGovProvider{
ProviderData: p,
Nonce: randSeq(32),
}
}
type loginGovCustomClaims struct {
Acr string `json:"acr"`
Nonce string `json:"nonce"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Birthdate string `json:"birthdate"`
AtHash string `json:"at_hash"`
CHash string `json:"c_hash"`
jwt.StandardClaims
}
// checkNonce checks the nonce in the id_token
func checkNonce(idToken string, p *LoginGovProvider) (err error) {
token, err := jwt.ParseWithClaims(idToken, &loginGovCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
resp, myerr := http.Get(p.PubJWKURL.String())
if myerr != nil {
return nil, myerr
}
if resp.StatusCode != 200 {
myerr = fmt.Errorf("got %d from %q", resp.StatusCode, p.PubJWKURL.String())
return nil, myerr
}
body, myerr := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if myerr != nil {
return nil, myerr
}
var pubkeys jose.JSONWebKeySet
myerr = json.Unmarshal(body, &pubkeys)
if myerr != nil {
return nil, myerr
}
pubkey := pubkeys.Keys[0]
return pubkey.Key, nil
})
if err != nil {
return
}
claims := token.Claims.(*loginGovCustomClaims)
if claims.Nonce != p.Nonce {
err = fmt.Errorf("nonce validation failed")
return
}
return
}
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) {
// parse the user attributes from the data we got and make sure that
// the email address has been validated.
var emailData struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
// query the user info endpoint for user attributes
err := requests.New(userInfoEndpoint).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
Do().
UnmarshalInto(&emailData)
if err != nil {
return "", err
}
email := emailData.Email
if email == "" {
return "", fmt.Errorf("missing email")
}
if !emailData.EmailVerified {
return "", fmt.Errorf("email %s not listed as verified", email)
}
return email, nil
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
}
claims := &jwt.StandardClaims{
Issuer: p.ClientID,
Subject: p.ClientID,
Audience: p.RedeemURL.String(),
ExpiresAt: time.Now().Add(5 * time.Minute).Unix(),
Id: randSeq(32),
}
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
ss, err := token.SignedString(p.JWTKey)
if err != nil {
return
}
params := url.Values{}
params.Add("client_assertion", ss)
params.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
params.Add("code", code)
params.Add("grant_type", "authorization_code")
// Get the token from the body that we got from the token endpoint.
var jsonResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return nil, err
}
// check nonce here
err = checkNonce(jsonResponse.IDToken, p)
if err != nil {
return
}
// Get the email address
var email string
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
if err != nil {
return
}
created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
// Store the data that we found in the session state
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
Email: email,
}
return
}
// GetLoginURL overrides GetLoginURL to add login.gov parameters
func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string {
a := *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", redirectURI)
params.Set("approval_prompt", p.ApprovalPrompt)
params.Add("scope", p.Scope)
params.Set("client_id", p.ClientID)
params.Set("response_type", "code")
params.Add("state", state)
acr := p.AcrValues
if acr == "" {
acr = "http://idmanagement.gov/ns/assurance/loa/1"
}
params.Add("acr_values", acr)
params.Add("nonce", p.Nonce)
a.RawQuery = params.Encode()
return a.String()
}