mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-21 21:47:11 +02:00
Integrate sessions middlewares
This commit is contained in:
parent
034f057b60
commit
eb234011eb
222
oauthproxy.go
222
oauthproxy.go
@ -15,7 +15,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
|
"github.com/justinas/alice"
|
||||||
ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip"
|
ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic"
|
||||||
@ -23,6 +25,7 @@ import (
|
|||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/ip"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/ip"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/providers"
|
"github.com/oauth2-proxy/oauth2-proxy/providers"
|
||||||
@ -98,6 +101,8 @@ type OAuthProxy struct {
|
|||||||
trustedIPs *ip.NetSet
|
trustedIPs *ip.NetSet
|
||||||
Banner string
|
Banner string
|
||||||
Footer string
|
Footer string
|
||||||
|
|
||||||
|
sessionChain alice.Chain
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||||
@ -156,6 +161,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator)
|
||||||
|
|
||||||
return &OAuthProxy{
|
return &OAuthProxy{
|
||||||
CookieName: opts.Cookie.Name,
|
CookieName: opts.Cookie.Name,
|
||||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
|
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
|
||||||
@ -209,9 +216,45 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
|
|
||||||
basicAuthValidator: basicAuthValidator,
|
basicAuthValidator: basicAuthValidator,
|
||||||
displayHtpasswdForm: basicAuthValidator != nil,
|
displayHtpasswdForm: basicAuthValidator != nil,
|
||||||
|
sessionChain: sessionChain,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain {
|
||||||
|
chain := alice.New(middleware.NewScope())
|
||||||
|
|
||||||
|
if opts.SkipJwtBearerTokens {
|
||||||
|
sessionLoaders := []middlewareapi.TokenToSessionLoader{}
|
||||||
|
if opts.GetOIDCVerifier() != nil {
|
||||||
|
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
|
||||||
|
Verifier: opts.GetOIDCVerifier(),
|
||||||
|
TokenToSession: opts.GetProvider().CreateSessionStateFromBearerToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, verifier := range opts.GetJWTBearerVerifiers() {
|
||||||
|
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
|
||||||
|
Verifier: verifier,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders))
|
||||||
|
}
|
||||||
|
|
||||||
|
if validator != nil {
|
||||||
|
chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator))
|
||||||
|
}
|
||||||
|
|
||||||
|
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||||
|
SessionStore: sessionStore,
|
||||||
|
RefreshPeriod: opts.Cookie.Refresh,
|
||||||
|
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
|
||||||
|
ValidateSessionState: opts.GetProvider().ValidateSessionState,
|
||||||
|
}))
|
||||||
|
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
||||||
// redirect clients to once authenticated
|
// redirect clients to once authenticated
|
||||||
func (p *OAuthProxy) GetRedirectURI(host string) string {
|
func (p *OAuthProxy) GetRedirectURI(host string) string {
|
||||||
@ -780,86 +823,20 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
|||||||
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
|
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
|
||||||
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
var session *sessionsapi.SessionState
|
var session *sessionsapi.SessionState
|
||||||
var err error
|
|
||||||
var saveSession, clearSession, revalidated bool
|
|
||||||
|
|
||||||
if p.skipJwtBearerTokens && req.Header.Get("Authorization") != "" {
|
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
session, err = p.GetJwtSession(req)
|
session = middleware.GetRequestScope(req).Session
|
||||||
if err != nil {
|
}))
|
||||||
logger.Printf("Error retrieving session from token in Authorization header: %s", err)
|
getSession.ServeHTTP(rw, req)
|
||||||
}
|
|
||||||
if session != nil {
|
|
||||||
saveSession = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAddr := ip.GetClientString(p.realClientIPParser, req, true)
|
|
||||||
if session == nil {
|
if session == nil {
|
||||||
session, err = p.LoadCookiedSession(req)
|
return nil, ErrNeedsLogin
|
||||||
if err != nil {
|
|
||||||
logger.Printf("Error loading cookied session: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if session != nil {
|
|
||||||
if session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
|
|
||||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh)
|
|
||||||
saveSession = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil {
|
|
||||||
logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
|
|
||||||
clearSession = true
|
|
||||||
session = nil
|
|
||||||
} else if ok {
|
|
||||||
saveSession = true
|
|
||||||
revalidated = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if session != nil && session.IsExpired() {
|
|
||||||
logger.Printf("Removing session: token expired %s", session)
|
|
||||||
session = nil
|
|
||||||
saveSession = false
|
|
||||||
clearSession = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if saveSession && !revalidated && session != nil && session.AccessToken != "" {
|
|
||||||
if !p.provider.ValidateSessionState(req.Context(), session) {
|
|
||||||
logger.Printf("Removing session: error validating %s", session)
|
|
||||||
saveSession = false
|
|
||||||
session = nil
|
|
||||||
clearSession = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if session != nil && session.Email != "" && !p.Validator(session.Email) {
|
if session != nil && session.Email != "" && !p.Validator(session.Email) {
|
||||||
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
||||||
session = nil
|
// Invalid session, clear it
|
||||||
saveSession = false
|
|
||||||
clearSession = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if saveSession && session != nil {
|
|
||||||
err = p.SaveSession(rw, req, session)
|
|
||||||
if err != nil {
|
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "Save session error %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if clearSession {
|
|
||||||
p.ClearSessionCookie(rw, req)
|
p.ClearSessionCookie(rw, req)
|
||||||
}
|
|
||||||
|
|
||||||
if session == nil {
|
|
||||||
session, err = p.CheckBasicAuth(req)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("Error during basic auth validation: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if session == nil {
|
|
||||||
return nil, ErrNeedsLogin
|
return nil, ErrNeedsLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -997,36 +974,6 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
|
||||||
// credentials and authenticates these against the proxies HtpasswdFile
|
|
||||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) {
|
|
||||||
if p.basicAuthValidator == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
auth := req.Header.Get("Authorization")
|
|
||||||
if auth == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
s := strings.SplitN(auth, " ", 2)
|
|
||||||
if len(s) != 2 || s[0] != "Basic" {
|
|
||||||
return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization"))
|
|
||||||
}
|
|
||||||
b, err := b64.StdEncoding.DecodeString(s[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pair := strings.SplitN(string(b), ":", 2)
|
|
||||||
if len(pair) != 2 {
|
|
||||||
return nil, fmt.Errorf("invalid format %s", b)
|
|
||||||
}
|
|
||||||
if p.basicAuthValidator.Validate(pair[0], pair[1]) {
|
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
|
||||||
return &sessionsapi.SessionState{User: pair[0]}, nil
|
|
||||||
}
|
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isAjax checks if a request is an ajax request
|
// isAjax checks if a request is an ajax request
|
||||||
func isAjax(req *http.Request) bool {
|
func isAjax(req *http.Request) bool {
|
||||||
acceptValues := req.Header.Values("Accept")
|
acceptValues := req.Header.Values("Accept")
|
||||||
@ -1044,74 +991,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
|||||||
rw.Header().Set("Content-Type", applicationJSON)
|
rw.Header().Set("Content-Type", applicationJSON)
|
||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetJwtSession loads a session based on a JWT token in the authorization header.
|
|
||||||
// (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers)
|
|
||||||
func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState, error) {
|
|
||||||
rawBearerToken, err := p.findBearerToken(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we are using an oidc provider, go ahead and try that provider first with its Verifier
|
|
||||||
// and Bearer Token -> Session converter
|
|
||||||
if p.mainJwtBearerVerifier != nil {
|
|
||||||
bearerToken, err := p.mainJwtBearerVerifier.Verify(req.Context(), rawBearerToken)
|
|
||||||
if err == nil {
|
|
||||||
return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, attempt to verify against the extra JWT issuers and use a more generic
|
|
||||||
// Bearer Token -> Session converter
|
|
||||||
for _, verifier := range p.extraJwtBearerVerifiers {
|
|
||||||
bearerToken, err := verifier.Verify(req.Context(), rawBearerToken)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return (*providers.ProviderData)(nil).CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// findBearerToken finds a valid JWT token from the Authorization header of a given request.
|
|
||||||
func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) {
|
|
||||||
auth := req.Header.Get("Authorization")
|
|
||||||
s := strings.SplitN(auth, " ", 2)
|
|
||||||
if len(s) != 2 {
|
|
||||||
return "", fmt.Errorf("invalid authorization header %s", auth)
|
|
||||||
}
|
|
||||||
jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`)
|
|
||||||
var rawBearerToken string
|
|
||||||
if s[0] == "Bearer" && jwtRegex.MatchString(s[1]) {
|
|
||||||
rawBearerToken = s[1]
|
|
||||||
} else if s[0] == "Basic" {
|
|
||||||
// Check if we have a Bearer token masquerading in Basic
|
|
||||||
b, err := b64.StdEncoding.DecodeString(s[1])
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
pair := strings.SplitN(string(b), ":", 2)
|
|
||||||
if len(pair) != 2 {
|
|
||||||
return "", fmt.Errorf("invalid format %s", b)
|
|
||||||
}
|
|
||||||
user, password := pair[0], pair[1]
|
|
||||||
|
|
||||||
// check user, user+password, or just password for a token
|
|
||||||
if jwtRegex.MatchString(user) {
|
|
||||||
// Support blank passwords or magic `x-oauth-basic` passwords - nothing else
|
|
||||||
if password == "" || password == "x-oauth-basic" {
|
|
||||||
rawBearerToken = user
|
|
||||||
}
|
|
||||||
} else if jwtRegex.MatchString(password) {
|
|
||||||
// support passwords and ignore user
|
|
||||||
rawBearerToken = password
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if rawBearerToken == "" {
|
|
||||||
return "", fmt.Errorf("no valid bearer token found in authorization header")
|
|
||||||
}
|
|
||||||
|
|
||||||
return rawBearerToken, nil
|
|
||||||
}
|
|
||||||
|
@ -1889,7 +1889,7 @@ func TestGetJwtSession(t *testing.T) {
|
|||||||
|
|
||||||
// Bearer
|
// Bearer
|
||||||
expires := time.Unix(1912151821, 0)
|
expires := time.Unix(1912151821, 0)
|
||||||
session, err := test.proxy.GetJwtSession(test.req)
|
session, err := test.proxy.getAuthenticatedSession(test.rw, test.req)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, session.User, "1234567890")
|
assert.Equal(t, session.User, "1234567890")
|
||||||
assert.Equal(t, session.Email, "john@example.com")
|
assert.Equal(t, session.Email, "john@example.com")
|
||||||
@ -1912,70 +1912,6 @@ func TestGetJwtSession(t *testing.T) {
|
|||||||
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
|
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindJwtBearerToken(t *testing.T) {
|
|
||||||
p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}}
|
|
||||||
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}
|
|
||||||
|
|
||||||
validToken := "eyJfoobar.eyJfoobar.12345asdf"
|
|
||||||
var token string
|
|
||||||
|
|
||||||
// Bearer
|
|
||||||
getReq.Header = map[string][]string{
|
|
||||||
"Authorization": {fmt.Sprintf("Bearer %s", validToken)},
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := p.findBearerToken(getReq)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, validToken, token)
|
|
||||||
|
|
||||||
// Basic - no password
|
|
||||||
getReq.SetBasicAuth(token, "")
|
|
||||||
token, err = p.findBearerToken(getReq)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, validToken, token)
|
|
||||||
|
|
||||||
// Basic - sentinel password
|
|
||||||
getReq.SetBasicAuth(token, "x-oauth-basic")
|
|
||||||
token, err = p.findBearerToken(getReq)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, validToken, token)
|
|
||||||
|
|
||||||
// Basic - any username, password matching jwt pattern
|
|
||||||
getReq.SetBasicAuth("any-username-you-could-wish-for", token)
|
|
||||||
token, err = p.findBearerToken(getReq)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, validToken, token)
|
|
||||||
|
|
||||||
failures := []string{
|
|
||||||
// Too many parts
|
|
||||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
||||||
// Not enough parts
|
|
||||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA",
|
|
||||||
// Invalid encrypted key
|
|
||||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA",
|
|
||||||
// Invalid IV
|
|
||||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA",
|
|
||||||
// Invalid ciphertext
|
|
||||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA",
|
|
||||||
// Invalid tag
|
|
||||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////",
|
|
||||||
// Invalid header
|
|
||||||
"W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
||||||
// Invalid header
|
|
||||||
"######.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
||||||
// Missing alc/enc params
|
|
||||||
"e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, failure := range failures {
|
|
||||||
getReq.Header = map[string][]string{
|
|
||||||
"Authorization": {fmt.Sprintf("Bearer %s", failure)},
|
|
||||||
}
|
|
||||||
_, err := p.findBearerToken(getReq)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_prepareNoCache(t *testing.T) {
|
func Test_prepareNoCache(t *testing.T) {
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
prepareNoCache(w)
|
prepareNoCache(w)
|
||||||
|
@ -126,35 +126,8 @@ func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateSessionStateFromBearerToken should be implemented to allow providers
|
||||||
|
// to convert ID tokens into sessions
|
||||||
func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||||
var claims struct {
|
return nil, errors.New("not implemented")
|
||||||
Subject string `json:"sub"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
Verified *bool `json:"email_verified"`
|
|
||||||
PreferredUsername string `json:"preferred_username"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims.Email == "" {
|
|
||||||
claims.Email = claims.Subject
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims.Verified != nil && !*claims.Verified {
|
|
||||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
|
||||||
}
|
|
||||||
|
|
||||||
newSession := &sessions.SessionState{
|
|
||||||
Email: claims.Email,
|
|
||||||
User: claims.Subject,
|
|
||||||
PreferredUsername: claims.PreferredUsername,
|
|
||||||
AccessToken: rawIDToken,
|
|
||||||
IDToken: rawIDToken,
|
|
||||||
RefreshToken: "",
|
|
||||||
ExpiresOn: &idToken.Expiry,
|
|
||||||
}
|
|
||||||
|
|
||||||
return newSession, nil
|
|
||||||
}
|
}
|
||||||
|
@ -2,15 +2,10 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
|
||||||
"github.com/dgrijalva/jwt-go"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@ -52,39 +47,3 @@ func TestAcrValuesConfigured(t *testing.T) {
|
|||||||
result := p.GetLoginURL("https://my.test.app/oauth", "")
|
result := p.GetLoginURL("https://my.test.app/oauth", "")
|
||||||
assert.Contains(t, result, "acr_values=testValue")
|
assert.Contains(t, result, "acr_values=testValue")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateSessionStateFromBearerToken(t *testing.T) {
|
|
||||||
minimalIDToken := jwt.StandardClaims{
|
|
||||||
Audience: "asdf1234",
|
|
||||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
|
||||||
Id: "id-some-id",
|
|
||||||
IssuedAt: time.Now().Unix(),
|
|
||||||
Issuer: "https://issuer.example.com",
|
|
||||||
NotBefore: 0,
|
|
||||||
Subject: "123456789",
|
|
||||||
}
|
|
||||||
// From oidc_test.go
|
|
||||||
verifier := oidc.NewVerifier(
|
|
||||||
"https://issuer.example.com",
|
|
||||||
fakeKeySetStub{},
|
|
||||||
&oidc.Config{ClientID: "asdf1234"},
|
|
||||||
)
|
|
||||||
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, minimalIDToken).SignedString(key)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
// Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below
|
|
||||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
session, err := (*ProviderData)(nil).CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, rawIDToken, session.AccessToken)
|
|
||||||
assert.Equal(t, rawIDToken, session.IDToken)
|
|
||||||
assert.Equal(t, "123456789", session.Email)
|
|
||||||
assert.Equal(t, "123456789", session.User)
|
|
||||||
assert.Empty(t, session.RefreshToken)
|
|
||||||
assert.Empty(t, session.PreferredUsername)
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user