1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-24 08:52:25 +02:00

Streamline ErrMissingCode in provider Redeem methods

This commit is contained in:
Nick Meves 2020-10-23 19:35:15 -07:00
parent b92fd4b0bb
commit 1b3b00443a
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
6 changed files with 36 additions and 37 deletions

View File

@ -394,7 +394,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
if code == "" {
return nil, errors.New("missing code")
return nil, providers.ErrMissingCode
}
redirectURI := p.GetRedirectURI(host)
s, err := p.provider.Redeem(ctx, redirectURI, code)

View File

@ -1133,6 +1133,11 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
Email: "john.doe@example.com", AccessToken: "my_access_token"}
err = test.SaveSession(startSession)
assert.NoError(t, err)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusOK, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes))
}
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {

View File

@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err = errors.New("missing code")
return
return nil, ErrMissingCode
}
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return nil, err
}
params := url.Values{}
@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
created := time.Now()
expires := time.Unix(jsonResponse.ExpiresOn, 0)
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken,
}
return
}, nil
}
// RefreshSessionIfNeeded checks if the session has expired and uses the

View File

@ -126,8 +126,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err := errors.New("missing code")
return nil, err
return nil, ErrMissingCode
}
clientSecret, err := p.GetClientSecret()
if err != nil {

View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/rsa"
"errors"
"fmt"
"math/rand"
"net/url"
@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err = errors.New("missing code")
return
return nil, ErrMissingCode
}
claims := &jwt.StandardClaims{
@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
ss, err := token.SignedString(p.JWTKey)
if err != nil {
return
return nil, err
}
params := url.Values{}
@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
// check nonce here
err = checkNonce(jsonResponse.IDToken, p)
if err != nil {
return
return nil, err
}
// Get the email address
var email string
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
if err != nil {
return
return nil, err
}
created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
// Store the data that we found in the session state
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
Email: email,
}
return
}, nil
}
// GetLoginURL overrides GetLoginURL to add login.gov parameters

View File

@ -19,18 +19,21 @@ var (
// implementation method that doesn't have sensible defaults
ErrNotImplemented = errors.New("not implemented")
// ErrMissingCode is returned when a Redeem method is called with an empty
// code
ErrMissingCode = errors.New("missing code")
_ Provider = (*ProviderData)(nil)
)
// Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err = errors.New("missing code")
return
return nil, ErrMissingCode
}
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return nil, err
}
params := url.Values{}
@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
}
err = result.UnmarshalInto(&jsonResponse)
if err == nil {
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
}
return
}, nil
}
var v url.Values
v, err = url.ParseQuery(string(result.Body()))
values, err := url.ParseQuery(string(result.Body()))
if err != nil {
return
return nil, err
}
if a := v.Get("access_token"); a != "" {
if token := values.Get("access_token"); token != "" {
created := time.Now()
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
} else {
err = fmt.Errorf("no access token found %s", result.Body())
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
}
return
return nil, fmt.Errorf("no access token found %s", result.Body())
}
// GetLoginURL with typical oauth parameters
@ -100,7 +100,7 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session
// Authorize performs global authorization on an authenticated session.
// This is not used for fine-grained per route authorization rules.
func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) {
if len(p.AllowedGroups) == 0 {
return true, nil
}