You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-07-15 01:44:22 +02:00
Azure token refresh (#754)
* Implement azure token refresh Based on original PR https://github.com/oauth2-proxy/oauth2-proxy/pull/278 * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> * Set CreatedAt to Now() on token refresh Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
This commit is contained in:
@ -23,9 +23,16 @@
|
|||||||
via the login url. If this option was used in the past, behavior will change with this release as it will
|
via the login url. If this option was used in the past, behavior will change with this release as it will
|
||||||
affect the tokens returned by Azure. In the past, the tokens were always for `https://graph.microsoft.com` (the default)
|
affect the tokens returned by Azure. In the past, the tokens were always for `https://graph.microsoft.com` (the default)
|
||||||
and will now be for the configured resource (if it exists, otherwise it will run into errors)
|
and will now be for the configured resource (if it exists, otherwise it will run into errors)
|
||||||
|
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) The Azure provider now has token refresh functionality implemented. This means that there won't
|
||||||
|
be any redirects in the browser anymore when tokens expire, but instead a token refresh is initiated
|
||||||
|
in the background, which leads to new tokens being returned in the cookies.
|
||||||
|
- Please note that `--cookie-refresh` must be 0 (the default) or equal to the token lifespan configured in Azure AD to make
|
||||||
|
Azure token refresh reliable. Setting this value to 0 means that it relies on the provider implementation
|
||||||
|
to decide if a refresh is required.
|
||||||
|
|
||||||
## Changes since v6.1.1
|
## Changes since v6.1.1
|
||||||
|
|
||||||
|
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock)
|
||||||
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
|
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
|
||||||
- [#796](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves)
|
- [#796](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves)
|
||||||
- [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed)
|
- [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
@ -74,6 +75,9 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
|
|||||||
if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
|
if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
|
||||||
p.ProtectedResource = azureDefaultProtectResourceURL
|
p.ProtectedResource = azureDefaultProtectResourceURL
|
||||||
}
|
}
|
||||||
|
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
|
||||||
|
p.ValidateURL = p.ProfileURL
|
||||||
|
}
|
||||||
|
|
||||||
return &AzureProvider{
|
return &AzureProvider{
|
||||||
ProviderData: p,
|
ProviderData: p,
|
||||||
@ -103,6 +107,7 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err = errors.New("missing code")
|
||||||
@ -123,6 +128,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
params.Add("resource", p.ProtectedResource.String())
|
params.Add("resource", p.ProtectedResource.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// blindly try json and x-www-form-urlencoded
|
||||||
var jsonResponse struct {
|
var jsonResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
@ -151,6 +157,61 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
|
// RefreshToken to fetch a new ID token if required
|
||||||
|
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
origExpiration := s.ExpiresOn
|
||||||
|
|
||||||
|
err := p.redeemRefreshToken(ctx, s)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("client_id", p.ClientID)
|
||||||
|
params.Add("client_secret", p.ClientSecret)
|
||||||
|
params.Add("refresh_token", s.RefreshToken)
|
||||||
|
params.Add("grant_type", "refresh_token")
|
||||||
|
|
||||||
|
var jsonResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresOn int64 `json:"expires_on,string"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = requests.New(p.RedeemURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithMethod("POST").
|
||||||
|
WithBody(bytes.NewBufferString(params.Encode())).
|
||||||
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&jsonResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
||||||
|
s.AccessToken = jsonResponse.AccessToken
|
||||||
|
s.IDToken = jsonResponse.IDToken
|
||||||
|
s.RefreshToken = jsonResponse.RefreshToken
|
||||||
|
s.CreatedAt = &now
|
||||||
|
s.ExpiresOn = &expires
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeAzureHeader(accessToken string) http.Header {
|
func makeAzureHeader(accessToken string) http.Header {
|
||||||
@ -219,3 +280,8 @@ func (p *AzureProvider) GetLoginURL(redirectURI, state string) string {
|
|||||||
a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams)
|
a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams)
|
||||||
return a.String()
|
return a.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateSessionState validates the AccessToken
|
||||||
|
func (p *AzureProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
|
return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken))
|
||||||
|
}
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@ -42,7 +44,7 @@ func TestNewAzureProvider(t *testing.T) {
|
|||||||
g.Expect(providerData.LoginURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/authorize"))
|
g.Expect(providerData.LoginURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/authorize"))
|
||||||
g.Expect(providerData.RedeemURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/token"))
|
g.Expect(providerData.RedeemURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/token"))
|
||||||
g.Expect(providerData.ProfileURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me"))
|
g.Expect(providerData.ProfileURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me"))
|
||||||
g.Expect(providerData.ValidateURL.String()).To(Equal(""))
|
g.Expect(providerData.ValidateURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me"))
|
||||||
g.Expect(providerData.Scope).To(Equal("openid"))
|
g.Expect(providerData.Scope).To(Equal("openid"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +99,7 @@ func TestAzureSetTenant(t *testing.T) {
|
|||||||
p.Data().ProfileURL.String())
|
p.Data().ProfileURL.String())
|
||||||
assert.Equal(t, "https://graph.microsoft.com",
|
assert.Equal(t, "https://graph.microsoft.com",
|
||||||
p.Data().ProtectedResource.String())
|
p.Data().ProtectedResource.String())
|
||||||
assert.Equal(t, "", p.Data().ValidateURL.String())
|
assert.Equal(t, "https://graph.microsoft.com/v1.0/me", p.Data().ValidateURL.String())
|
||||||
assert.Equal(t, "openid", p.Data().Scope)
|
assert.Equal(t, "openid", p.Data().Scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,3 +222,47 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
|
|||||||
result := p.GetLoginURL("https://my.test.app/oauth", "")
|
result := p.GetLoginURL("https://my.test.app/oauth", "")
|
||||||
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
|
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAzureProviderGetsTokensInRedeem(t *testing.T) {
|
||||||
|
b := testAzureBackend(`{ "access_token": "some_access_token", "refresh_token": "some_refresh_token", "expires_on": "1136239445", "id_token": "some_id_token" }`)
|
||||||
|
defer b.Close()
|
||||||
|
timestamp, _ := time.Parse(time.RFC3339, "2006-01-02T22:04:05Z")
|
||||||
|
bURL, _ := url.Parse(b.URL)
|
||||||
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
|
session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, session, nil)
|
||||||
|
assert.Equal(t, "some_access_token", session.AccessToken)
|
||||||
|
assert.Equal(t, "some_refresh_token", session.RefreshToken)
|
||||||
|
assert.Equal(t, "some_id_token", session.IDToken)
|
||||||
|
assert.Equal(t, timestamp, session.ExpiresOn.UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
||||||
|
p := testAzureProvider("")
|
||||||
|
|
||||||
|
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||||
|
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||||
|
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.False(t, refreshNeeded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
||||||
|
b := testAzureBackend(`{ "access_token": "new_some_access_token", "refresh_token": "new_some_refresh_token", "expires_on": "32693148245", "id_token": "new_some_id_token" }`)
|
||||||
|
defer b.Close()
|
||||||
|
timestamp, _ := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z")
|
||||||
|
bURL, _ := url.Parse(b.URL)
|
||||||
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
|
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||||
|
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||||
|
_, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, session, nil)
|
||||||
|
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||||
|
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
||||||
|
assert.Equal(t, "new_some_id_token", session.IDToken)
|
||||||
|
assert.Equal(t, timestamp, session.ExpiresOn.UTC())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user