mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
Streamline ErrMissingCode in provider Redeem methods
This commit is contained in:
parent
b92fd4b0bb
commit
1b3b00443a
@ -394,7 +394,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
|
||||
|
||||
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
return nil, providers.ErrMissingCode
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(host)
|
||||
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
||||
|
@ -1133,6 +1133,11 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
|
||||
Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||
err = test.SaveSession(startSession)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusOK, test.rw.Code)
|
||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||
assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes))
|
||||
}
|
||||
|
||||
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||
|
@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
||||
s = &sessions.SessionState{
|
||||
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
}
|
||||
return
|
||||
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
|
@ -126,8 +126,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err := errors.New("missing code")
|
||||
return nil, err
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
|
||||
claims := &jwt.StandardClaims{
|
||||
@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
|
||||
ss, err := token.SignedString(p.JWTKey)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
// check nonce here
|
||||
err = checkNonce(jsonResponse.IDToken, p)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the email address
|
||||
var email string
|
||||
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||
|
||||
// Store the data that we found in the session state
|
||||
s = &sessions.SessionState{
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
Email: email,
|
||||
}
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
||||
|
@ -19,18 +19,21 @@ var (
|
||||
// implementation method that doesn't have sensible defaults
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
|
||||
// ErrMissingCode is returned when a Redeem method is called with an empty
|
||||
// code
|
||||
ErrMissingCode = errors.New("missing code")
|
||||
|
||||
_ Provider = (*ProviderData)(nil)
|
||||
)
|
||||
|
||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
}
|
||||
err = result.UnmarshalInto(&jsonResponse)
|
||||
if err == nil {
|
||||
s = &sessions.SessionState{
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
}
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
var v url.Values
|
||||
v, err = url.ParseQuery(string(result.Body()))
|
||||
values, err := url.ParseQuery(string(result.Body()))
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
if a := v.Get("access_token"); a != "" {
|
||||
if token := values.Get("access_token"); token != "" {
|
||||
created := time.Now()
|
||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", result.Body())
|
||||
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
|
||||
}
|
||||
return
|
||||
|
||||
return nil, fmt.Errorf("no access token found %s", result.Body())
|
||||
}
|
||||
|
||||
// GetLoginURL with typical oauth parameters
|
||||
@ -100,7 +100,7 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session
|
||||
|
||||
// Authorize performs global authorization on an authenticated session.
|
||||
// This is not used for fine-grained per route authorization rules.
|
||||
func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if len(p.AllowedGroups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user