From 0e119d7c84b165f04b31bb74fbfc2c96f8269a02 Mon Sep 17 00:00:00 2001 From: Alexander Block Date: Wed, 4 Nov 2020 20:25:59 +0100 Subject: [PATCH] Azure token refresh (#754) * Implement azure token refresh Based on original PR https://github.com/oauth2-proxy/oauth2-proxy/pull/278 * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Joel Speed * Set CreatedAt to Now() on token refresh Co-authored-by: Joel Speed --- CHANGELOG.md | 7 +++++ providers/azure.go | 66 +++++++++++++++++++++++++++++++++++++++++ providers/azure_test.go | 50 +++++++++++++++++++++++++++++-- 3 files changed, 121 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9f3f1f1..67eb73ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,9 +23,16 @@ via the login url. If this option was used in the past, behavior will change with this release as it will affect the tokens returned by Azure. In the past, the tokens were always for `https://graph.microsoft.com` (the default) and will now be for the configured resource (if it exists, otherwise it will run into errors) +- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) The Azure provider now has token refresh functionality implemented. This means that there won't + be any redirects in the browser anymore when tokens expire, but instead a token refresh is initiated + in the background, which leads to new tokens being returned in the cookies. + - Please note that `--cookie-refresh` must be 0 (the default) or equal to the token lifespan configured in Azure AD to make + Azure token refresh reliable. Setting this value to 0 means that it relies on the provider implementation + to decide if a refresh is required. ## Changes since v6.1.1 +- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock) - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) - [#796](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves) - [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed) diff --git a/providers/azure.go b/providers/azure.go index 234aaff2..d65b11f4 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "net/http" "net/url" "time" @@ -74,6 +75,9 @@ func NewAzureProvider(p *ProviderData) *AzureProvider { if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { p.ProtectedResource = azureDefaultProtectResourceURL } + if p.ValidateURL == nil || p.ValidateURL.String() == "" { + p.ValidateURL = p.ProfileURL + } return &AzureProvider{ ProviderData: p, @@ -103,6 +107,7 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) { } } +// Redeem exchanges the OAuth2 authentication token for an ID token func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") @@ -123,6 +128,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } + // blindly try json and x-www-form-urlencoded var jsonResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` @@ -151,6 +157,61 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s RefreshToken: jsonResponse.RefreshToken, } return + +} + +// RefreshSessionIfNeeded checks if the session has expired and uses the +// RefreshToken to fetch a new ID token if required +func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + return false, nil + } + + origExpiration := s.ExpiresOn + + err := p.redeemRefreshToken(ctx, s) + if err != nil { + return false, fmt.Errorf("unable to redeem refresh token: %v", err) + } + + fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) + return true, nil +} + +func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { + params := url.Values{} + params.Add("client_id", p.ClientID) + params.Add("client_secret", p.ClientSecret) + params.Add("refresh_token", s.RefreshToken) + params.Add("grant_type", "refresh_token") + + var jsonResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresOn int64 `json:"expires_on,string"` + IDToken string `json:"id_token"` + } + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). + UnmarshalInto(&jsonResponse) + + if err != nil { + return + } + + now := time.Now() + expires := time.Unix(jsonResponse.ExpiresOn, 0) + s.AccessToken = jsonResponse.AccessToken + s.IDToken = jsonResponse.IDToken + s.RefreshToken = jsonResponse.RefreshToken + s.CreatedAt = &now + s.ExpiresOn = &expires + return } func makeAzureHeader(accessToken string) http.Header { @@ -219,3 +280,8 @@ func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) return a.String() } + +// ValidateSessionState validates the AccessToken +func (p *AzureProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { + return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken)) +} diff --git a/providers/azure_test.go b/providers/azure_test.go index 6e2e4e97..9e3cabf7 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) @@ -42,7 +44,7 @@ func TestNewAzureProvider(t *testing.T) { g.Expect(providerData.LoginURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/authorize")) g.Expect(providerData.RedeemURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/token")) g.Expect(providerData.ProfileURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me")) - g.Expect(providerData.ValidateURL.String()).To(Equal("")) + g.Expect(providerData.ValidateURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me")) g.Expect(providerData.Scope).To(Equal("openid")) } @@ -97,7 +99,7 @@ func TestAzureSetTenant(t *testing.T) { p.Data().ProfileURL.String()) assert.Equal(t, "https://graph.microsoft.com", p.Data().ProtectedResource.String()) - assert.Equal(t, "", p.Data().ValidateURL.String()) + assert.Equal(t, "https://graph.microsoft.com/v1.0/me", p.Data().ValidateURL.String()) assert.Equal(t, "openid", p.Data().Scope) } @@ -220,3 +222,47 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { result := p.GetLoginURL("https://my.test.app/oauth", "") assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) } + +func TestAzureProviderGetsTokensInRedeem(t *testing.T) { + b := testAzureBackend(`{ "access_token": "some_access_token", "refresh_token": "some_refresh_token", "expires_on": "1136239445", "id_token": "some_id_token" }`) + defer b.Close() + timestamp, _ := time.Parse(time.RFC3339, "2006-01-02T22:04:05Z") + bURL, _ := url.Parse(b.URL) + p := testAzureProvider(bURL.Host) + + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") + assert.Equal(t, nil, err) + assert.NotEqual(t, session, nil) + assert.Equal(t, "some_access_token", session.AccessToken) + assert.Equal(t, "some_refresh_token", session.RefreshToken) + assert.Equal(t, "some_id_token", session.IDToken) + assert.Equal(t, timestamp, session.ExpiresOn.UTC()) +} + +func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { + p := testAzureProvider("") + + expires := time.Now().Add(time.Duration(1) * time.Hour) + session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} + refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) + assert.Equal(t, nil, err) + assert.False(t, refreshNeeded) +} + +func TestAzureProviderRefreshWhenExpired(t *testing.T) { + b := testAzureBackend(`{ "access_token": "new_some_access_token", "refresh_token": "new_some_refresh_token", "expires_on": "32693148245", "id_token": "new_some_id_token" }`) + defer b.Close() + timestamp, _ := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") + bURL, _ := url.Parse(b.URL) + p := testAzureProvider(bURL.Host) + + expires := time.Now().Add(time.Duration(-1) * time.Hour) + session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} + _, err := p.RefreshSessionIfNeeded(context.Background(), session) + assert.Equal(t, nil, err) + assert.NotEqual(t, session, nil) + assert.Equal(t, "new_some_access_token", session.AccessToken) + assert.Equal(t, "new_some_refresh_token", session.RefreshToken) + assert.Equal(t, "new_some_id_token", session.IDToken) + assert.Equal(t, timestamp, session.ExpiresOn.UTC()) +}