1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-23 21:50:48 +02:00

Azure token refresh ()

* Implement azure token refresh

Based on original PR https://github.com/oauth2-proxy/oauth2-proxy/pull/278

* Update CHANGELOG.md

* Apply suggestions from code review

Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>

* Set CreatedAt to Now() on token refresh

Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
This commit is contained in:
Alexander Block 2020-11-04 20:25:59 +01:00 committed by GitHub
parent 65016c8da1
commit 0e119d7c84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 121 additions and 2 deletions

@ -23,9 +23,16 @@
via the login url. If this option was used in the past, behavior will change with this release as it will
affect the tokens returned by Azure. In the past, the tokens were always for `https://graph.microsoft.com` (the default)
and will now be for the configured resource (if it exists, otherwise it will run into errors)
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) The Azure provider now has token refresh functionality implemented. This means that there won't
be any redirects in the browser anymore when tokens expire, but instead a token refresh is initiated
in the background, which leads to new tokens being returned in the cookies.
- Please note that `--cookie-refresh` must be 0 (the default) or equal to the token lifespan configured in Azure AD to make
Azure token refresh reliable. Setting this value to 0 means that it relies on the provider implementation
to decide if a refresh is required.
## Changes since v6.1.1
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock)
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
- [#796](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves)
- [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed)

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"
@ -74,6 +75,9 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
p.ProtectedResource = azureDefaultProtectResourceURL
}
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
p.ValidateURL = p.ProfileURL
}
return &AzureProvider{
ProviderData: p,
@ -103,6 +107,7 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
}
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" {
err = errors.New("missing code")
@ -123,6 +128,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String())
}
// blindly try json and x-www-form-urlencoded
var jsonResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
@ -151,6 +157,61 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
RefreshToken: jsonResponse.RefreshToken,
}
return
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
origExpiration := s.ExpiresOn
err := p.redeemRefreshToken(ctx, s)
if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}
fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
return true, nil
}
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("refresh_token", s.RefreshToken)
params.Add("grant_type", "refresh_token")
var jsonResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresOn int64 `json:"expires_on,string"`
IDToken string `json:"id_token"`
}
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return
}
now := time.Now()
expires := time.Unix(jsonResponse.ExpiresOn, 0)
s.AccessToken = jsonResponse.AccessToken
s.IDToken = jsonResponse.IDToken
s.RefreshToken = jsonResponse.RefreshToken
s.CreatedAt = &now
s.ExpiresOn = &expires
return
}
func makeAzureHeader(accessToken string) http.Header {
@ -219,3 +280,8 @@ func (p *AzureProvider) GetLoginURL(redirectURI, state string) string {
a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams)
return a.String()
}
// ValidateSessionState validates the AccessToken
func (p *AzureProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken))
}

@ -8,6 +8,8 @@ import (
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
)
@ -42,7 +44,7 @@ func TestNewAzureProvider(t *testing.T) {
g.Expect(providerData.LoginURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/authorize"))
g.Expect(providerData.RedeemURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/token"))
g.Expect(providerData.ProfileURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me"))
g.Expect(providerData.ValidateURL.String()).To(Equal(""))
g.Expect(providerData.ValidateURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me"))
g.Expect(providerData.Scope).To(Equal("openid"))
}
@ -97,7 +99,7 @@ func TestAzureSetTenant(t *testing.T) {
p.Data().ProfileURL.String())
assert.Equal(t, "https://graph.microsoft.com",
p.Data().ProtectedResource.String())
assert.Equal(t, "", p.Data().ValidateURL.String())
assert.Equal(t, "https://graph.microsoft.com/v1.0/me", p.Data().ValidateURL.String())
assert.Equal(t, "openid", p.Data().Scope)
}
@ -220,3 +222,47 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
result := p.GetLoginURL("https://my.test.app/oauth", "")
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
}
func TestAzureProviderGetsTokensInRedeem(t *testing.T) {
b := testAzureBackend(`{ "access_token": "some_access_token", "refresh_token": "some_refresh_token", "expires_on": "1136239445", "id_token": "some_id_token" }`)
defer b.Close()
timestamp, _ := time.Parse(time.RFC3339, "2006-01-02T22:04:05Z")
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil)
assert.Equal(t, "some_access_token", session.AccessToken)
assert.Equal(t, "some_refresh_token", session.RefreshToken)
assert.Equal(t, "some_id_token", session.IDToken)
assert.Equal(t, timestamp, session.ExpiresOn.UTC())
}
func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
p := testAzureProvider("")
expires := time.Now().Add(time.Duration(1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
assert.Equal(t, nil, err)
assert.False(t, refreshNeeded)
}
func TestAzureProviderRefreshWhenExpired(t *testing.T) {
b := testAzureBackend(`{ "access_token": "new_some_access_token", "refresh_token": "new_some_refresh_token", "expires_on": "32693148245", "id_token": "new_some_id_token" }`)
defer b.Close()
timestamp, _ := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z")
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
expires := time.Now().Add(time.Duration(-1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
_, err := p.RefreshSessionIfNeeded(context.Background(), session)
assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil)
assert.Equal(t, "new_some_access_token", session.AccessToken)
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
assert.Equal(t, "new_some_id_token", session.IDToken)
assert.Equal(t, timestamp, session.ExpiresOn.UTC())
}