mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-02 23:27:22 +02:00
Streamline ErrMissingCode in provider Redeem methods
This commit is contained in:
parent
b92fd4b0bb
commit
1b3b00443a
@ -394,7 +394,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
|
|||||||
|
|
||||||
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, errors.New("missing code")
|
return nil, providers.ErrMissingCode
|
||||||
}
|
}
|
||||||
redirectURI := p.GetRedirectURI(host)
|
redirectURI := p.GetRedirectURI(host)
|
||||||
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
||||||
|
@ -1133,6 +1133,11 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
|
|||||||
Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||||
err = test.SaveSession(startSession)
|
err = test.SaveSession(startSession)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
|
assert.Equal(t, http.StatusOK, test.rw.Code)
|
||||||
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||||
|
assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||||
|
@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return
|
|
||||||
}
|
}
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
||||||
s = &sessions.SessionState{
|
|
||||||
|
return &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
ExpiresOn: &expires,
|
ExpiresOn: &expires,
|
||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
}
|
}, nil
|
||||||
return
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
|
@ -126,8 +126,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
|||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err := errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
claims := &jwt.StandardClaims{
|
claims := &jwt.StandardClaims{
|
||||||
@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
|||||||
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
|
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
|
||||||
ss, err := token.SignedString(p.JWTKey)
|
ss, err := token.SignedString(p.JWTKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
|||||||
// check nonce here
|
// check nonce here
|
||||||
err = checkNonce(jsonResponse.IDToken, p)
|
err = checkNonce(jsonResponse.IDToken, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the email address
|
// Get the email address
|
||||||
var email string
|
var email string
|
||||||
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
|
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||||
|
|
||||||
// Store the data that we found in the session state
|
// Store the data that we found in the session state
|
||||||
s = &sessions.SessionState{
|
return &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
ExpiresOn: &expires,
|
ExpiresOn: &expires,
|
||||||
Email: email,
|
Email: email,
|
||||||
}
|
}, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
||||||
|
@ -19,18 +19,21 @@ var (
|
|||||||
// implementation method that doesn't have sensible defaults
|
// implementation method that doesn't have sensible defaults
|
||||||
ErrNotImplemented = errors.New("not implemented")
|
ErrNotImplemented = errors.New("not implemented")
|
||||||
|
|
||||||
|
// ErrMissingCode is returned when a Redeem method is called with an empty
|
||||||
|
// code
|
||||||
|
ErrMissingCode = errors.New("missing code")
|
||||||
|
|
||||||
_ Provider = (*ProviderData)(nil)
|
_ Provider = (*ProviderData)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||||
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return
|
|
||||||
}
|
}
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
}
|
}
|
||||||
err = result.UnmarshalInto(&jsonResponse)
|
err = result.UnmarshalInto(&jsonResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s = &sessions.SessionState{
|
return &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
}
|
}, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var v url.Values
|
values, err := url.ParseQuery(string(result.Body()))
|
||||||
v, err = url.ParseQuery(string(result.Body()))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
if a := v.Get("access_token"); a != "" {
|
if token := values.Get("access_token"); token != "" {
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("no access token found %s", result.Body())
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
return nil, fmt.Errorf("no access token found %s", result.Body())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginURL with typical oauth parameters
|
// GetLoginURL with typical oauth parameters
|
||||||
@ -100,7 +100,7 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session
|
|||||||
|
|
||||||
// Authorize performs global authorization on an authenticated session.
|
// Authorize performs global authorization on an authenticated session.
|
||||||
// This is not used for fine-grained per route authorization rules.
|
// This is not used for fine-grained per route authorization rules.
|
||||||
func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
if len(p.AllowedGroups) == 0 {
|
if len(p.AllowedGroups) == 0 {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user