mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
286 lines
7.6 KiB
Go
286 lines
7.6 KiB
Go
package providers
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/bitly/go-simplejson"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
|
)
|
|
|
|
// AzureProvider represents an Azure based Identity Provider
|
|
type AzureProvider struct {
|
|
*ProviderData
|
|
Tenant string
|
|
}
|
|
|
|
var _ Provider = (*AzureProvider)(nil)
|
|
|
|
const (
|
|
azureProviderName = "Azure"
|
|
azureDefaultScope = "openid"
|
|
)
|
|
|
|
var (
|
|
// Default Login URL for Azure.
|
|
// Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/authorize.
|
|
azureDefaultLoginURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "login.microsoftonline.com",
|
|
Path: "/common/oauth2/authorize",
|
|
}
|
|
|
|
// Default Redeem URL for Azure.
|
|
// Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/token.
|
|
azureDefaultRedeemURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "login.microsoftonline.com",
|
|
Path: "/common/oauth2/token",
|
|
}
|
|
|
|
// Default Profile URL for Azure.
|
|
// Pre-parsed URL of https://graph.microsoft.com/v1.0/me.
|
|
azureDefaultProfileURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "graph.microsoft.com",
|
|
Path: "/v1.0/me",
|
|
}
|
|
|
|
// Default ProtectedResource URL for Azure.
|
|
// Pre-parsed URL of https://graph.microsoft.com.
|
|
azureDefaultProtectResourceURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "graph.microsoft.com",
|
|
}
|
|
)
|
|
|
|
// NewAzureProvider initiates a new AzureProvider
|
|
func NewAzureProvider(p *ProviderData) *AzureProvider {
|
|
p.setProviderDefaults(providerDefaults{
|
|
name: azureProviderName,
|
|
loginURL: azureDefaultLoginURL,
|
|
redeemURL: azureDefaultRedeemURL,
|
|
profileURL: azureDefaultProfileURL,
|
|
validateURL: nil,
|
|
scope: azureDefaultScope,
|
|
})
|
|
|
|
if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
|
|
p.ProtectedResource = azureDefaultProtectResourceURL
|
|
}
|
|
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
|
|
p.ValidateURL = p.ProfileURL
|
|
}
|
|
|
|
return &AzureProvider{
|
|
ProviderData: p,
|
|
Tenant: "common",
|
|
}
|
|
}
|
|
|
|
// Configure defaults the AzureProvider configuration options
|
|
func (p *AzureProvider) Configure(tenant string) {
|
|
if tenant == "" || tenant == "common" {
|
|
// tenant is empty or default, remain on the default "common" tenant
|
|
return
|
|
}
|
|
|
|
// Specific tennant specified, override the Login and RedeemURLs
|
|
p.Tenant = tenant
|
|
overrideTenantURL(p.LoginURL, azureDefaultLoginURL, tenant, "authorize")
|
|
overrideTenantURL(p.RedeemURL, azureDefaultRedeemURL, tenant, "token")
|
|
}
|
|
|
|
func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
|
|
if current == nil || current.String() == "" || current.String() == defaultURL.String() {
|
|
*current = url.URL{
|
|
Scheme: "https",
|
|
Host: "login.microsoftonline.com",
|
|
Path: "/" + tenant + "/oauth2/" + path}
|
|
}
|
|
}
|
|
|
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
|
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
|
if code == "" {
|
|
return nil, ErrMissingCode
|
|
}
|
|
clientSecret, err := p.GetClientSecret()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
params := url.Values{}
|
|
params.Add("redirect_uri", redirectURL)
|
|
params.Add("client_id", p.ClientID)
|
|
params.Add("client_secret", clientSecret)
|
|
params.Add("code", code)
|
|
params.Add("grant_type", "authorization_code")
|
|
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
|
|
params.Add("resource", p.ProtectedResource.String())
|
|
}
|
|
|
|
// blindly try json and x-www-form-urlencoded
|
|
var jsonResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresOn int64 `json:"expires_on,string"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
|
|
err = requests.New(p.RedeemURL.String()).
|
|
WithContext(ctx).
|
|
WithMethod("POST").
|
|
WithBody(bytes.NewBufferString(params.Encode())).
|
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
|
Do().
|
|
UnmarshalInto(&jsonResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
created := time.Now()
|
|
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
|
|
|
return &sessions.SessionState{
|
|
AccessToken: jsonResponse.AccessToken,
|
|
IDToken: jsonResponse.IDToken,
|
|
CreatedAt: &created,
|
|
ExpiresOn: &expires,
|
|
RefreshToken: jsonResponse.RefreshToken,
|
|
}, nil
|
|
}
|
|
|
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
|
// RefreshToken to fetch a new ID token if required
|
|
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
|
return false, nil
|
|
}
|
|
|
|
origExpiration := s.ExpiresOn
|
|
|
|
err := p.redeemRefreshToken(ctx, s)
|
|
if err != nil {
|
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
|
}
|
|
|
|
fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
|
|
return true, nil
|
|
}
|
|
|
|
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
|
params := url.Values{}
|
|
params.Add("client_id", p.ClientID)
|
|
params.Add("client_secret", p.ClientSecret)
|
|
params.Add("refresh_token", s.RefreshToken)
|
|
params.Add("grant_type", "refresh_token")
|
|
|
|
var jsonResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresOn int64 `json:"expires_on,string"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
|
|
err = requests.New(p.RedeemURL.String()).
|
|
WithContext(ctx).
|
|
WithMethod("POST").
|
|
WithBody(bytes.NewBufferString(params.Encode())).
|
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
|
Do().
|
|
UnmarshalInto(&jsonResponse)
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
|
s.AccessToken = jsonResponse.AccessToken
|
|
s.IDToken = jsonResponse.IDToken
|
|
s.RefreshToken = jsonResponse.RefreshToken
|
|
s.CreatedAt = &now
|
|
s.ExpiresOn = &expires
|
|
return
|
|
}
|
|
|
|
func makeAzureHeader(accessToken string) http.Header {
|
|
return makeAuthorizationHeader(tokenTypeBearer, accessToken, nil)
|
|
}
|
|
|
|
func getEmailFromJSON(json *simplejson.Json) (string, error) {
|
|
var email string
|
|
var err error
|
|
|
|
email, err = json.Get("mail").String()
|
|
|
|
if err != nil || email == "" {
|
|
otherMails, otherMailsErr := json.Get("otherMails").Array()
|
|
if len(otherMails) > 0 {
|
|
email = otherMails[0].(string)
|
|
}
|
|
err = otherMailsErr
|
|
}
|
|
|
|
return email, err
|
|
}
|
|
|
|
// GetEmailAddress returns the Account email address
|
|
func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
|
var email string
|
|
var err error
|
|
|
|
if s.AccessToken == "" {
|
|
return "", errors.New("missing access token")
|
|
}
|
|
|
|
json, err := requests.New(p.ProfileURL.String()).
|
|
WithContext(ctx).
|
|
WithHeaders(makeAzureHeader(s.AccessToken)).
|
|
Do().
|
|
UnmarshalJSON()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
email, err = getEmailFromJSON(json)
|
|
if err == nil && email != "" {
|
|
return email, err
|
|
}
|
|
|
|
email, err = json.Get("userPrincipalName").String()
|
|
if err != nil {
|
|
logger.Errorf("failed making request %s", err)
|
|
return "", err
|
|
}
|
|
|
|
if email == "" {
|
|
logger.Errorf("failed to get email address")
|
|
return "", err
|
|
}
|
|
|
|
return email, err
|
|
}
|
|
|
|
func (p *AzureProvider) GetLoginURL(redirectURI, state string) string {
|
|
extraParams := url.Values{}
|
|
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
|
|
extraParams.Add("resource", p.ProtectedResource.String())
|
|
}
|
|
a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams)
|
|
return a.String()
|
|
}
|
|
|
|
// ValidateSession validates the AccessToken
|
|
func (p *AzureProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
|
return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken))
|
|
}
|