mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-21 21:47:11 +02:00
Integrate sessions middlewares
This commit is contained in:
parent
034f057b60
commit
eb234011eb
222
oauthproxy.go
222
oauthproxy.go
@ -15,7 +15,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/justinas/alice"
|
||||
ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic"
|
||||
@ -23,6 +25,7 @@ import (
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/ip"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/providers"
|
||||
@ -98,6 +101,8 @@ type OAuthProxy struct {
|
||||
trustedIPs *ip.NetSet
|
||||
Banner string
|
||||
Footer string
|
||||
|
||||
sessionChain alice.Chain
|
||||
}
|
||||
|
||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||
@ -156,6 +161,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
}
|
||||
}
|
||||
|
||||
sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator)
|
||||
|
||||
return &OAuthProxy{
|
||||
CookieName: opts.Cookie.Name,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
|
||||
@ -209,9 +216,45 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
|
||||
basicAuthValidator: basicAuthValidator,
|
||||
displayHtpasswdForm: basicAuthValidator != nil,
|
||||
sessionChain: sessionChain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain {
|
||||
chain := alice.New(middleware.NewScope())
|
||||
|
||||
if opts.SkipJwtBearerTokens {
|
||||
sessionLoaders := []middlewareapi.TokenToSessionLoader{}
|
||||
if opts.GetOIDCVerifier() != nil {
|
||||
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
|
||||
Verifier: opts.GetOIDCVerifier(),
|
||||
TokenToSession: opts.GetProvider().CreateSessionStateFromBearerToken,
|
||||
})
|
||||
}
|
||||
|
||||
for _, verifier := range opts.GetJWTBearerVerifiers() {
|
||||
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
|
||||
Verifier: verifier,
|
||||
})
|
||||
}
|
||||
|
||||
chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders))
|
||||
}
|
||||
|
||||
if validator != nil {
|
||||
chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator))
|
||||
}
|
||||
|
||||
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||
SessionStore: sessionStore,
|
||||
RefreshPeriod: opts.Cookie.Refresh,
|
||||
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
|
||||
ValidateSessionState: opts.GetProvider().ValidateSessionState,
|
||||
}))
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
||||
// redirect clients to once authenticated
|
||||
func (p *OAuthProxy) GetRedirectURI(host string) string {
|
||||
@ -780,86 +823,20 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
|
||||
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
var session *sessionsapi.SessionState
|
||||
var err error
|
||||
var saveSession, clearSession, revalidated bool
|
||||
|
||||
if p.skipJwtBearerTokens && req.Header.Get("Authorization") != "" {
|
||||
session, err = p.GetJwtSession(req)
|
||||
if err != nil {
|
||||
logger.Printf("Error retrieving session from token in Authorization header: %s", err)
|
||||
}
|
||||
if session != nil {
|
||||
saveSession = false
|
||||
}
|
||||
}
|
||||
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
session = middleware.GetRequestScope(req).Session
|
||||
}))
|
||||
getSession.ServeHTTP(rw, req)
|
||||
|
||||
remoteAddr := ip.GetClientString(p.realClientIPParser, req, true)
|
||||
if session == nil {
|
||||
session, err = p.LoadCookiedSession(req)
|
||||
if err != nil {
|
||||
logger.Printf("Error loading cookied session: %s", err)
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
if session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
|
||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh)
|
||||
saveSession = true
|
||||
}
|
||||
|
||||
if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil {
|
||||
logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
|
||||
clearSession = true
|
||||
session = nil
|
||||
} else if ok {
|
||||
saveSession = true
|
||||
revalidated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if session != nil && session.IsExpired() {
|
||||
logger.Printf("Removing session: token expired %s", session)
|
||||
session = nil
|
||||
saveSession = false
|
||||
clearSession = true
|
||||
}
|
||||
|
||||
if saveSession && !revalidated && session != nil && session.AccessToken != "" {
|
||||
if !p.provider.ValidateSessionState(req.Context(), session) {
|
||||
logger.Printf("Removing session: error validating %s", session)
|
||||
saveSession = false
|
||||
session = nil
|
||||
clearSession = true
|
||||
}
|
||||
return nil, ErrNeedsLogin
|
||||
}
|
||||
|
||||
if session != nil && session.Email != "" && !p.Validator(session.Email) {
|
||||
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
||||
session = nil
|
||||
saveSession = false
|
||||
clearSession = true
|
||||
}
|
||||
|
||||
if saveSession && session != nil {
|
||||
err = p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "Save session error %s", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if clearSession {
|
||||
// Invalid session, clear it
|
||||
p.ClearSessionCookie(rw, req)
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
session, err = p.CheckBasicAuth(req)
|
||||
if err != nil {
|
||||
logger.Printf("Error during basic auth validation: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
return nil, ErrNeedsLogin
|
||||
}
|
||||
|
||||
@ -997,36 +974,6 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
||||
// credentials and authenticates these against the proxies HtpasswdFile
|
||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
if p.basicAuthValidator == nil {
|
||||
return nil, nil
|
||||
}
|
||||
auth := req.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
return nil, nil
|
||||
}
|
||||
s := strings.SplitN(auth, " ", 2)
|
||||
if len(s) != 2 || s[0] != "Basic" {
|
||||
return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization"))
|
||||
}
|
||||
b, err := b64.StdEncoding.DecodeString(s[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pair := strings.SplitN(string(b), ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return nil, fmt.Errorf("invalid format %s", b)
|
||||
}
|
||||
if p.basicAuthValidator.Validate(pair[0], pair[1]) {
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
||||
return &sessionsapi.SessionState{User: pair[0]}, nil
|
||||
}
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// isAjax checks if a request is an ajax request
|
||||
func isAjax(req *http.Request) bool {
|
||||
acceptValues := req.Header.Values("Accept")
|
||||
@ -1044,74 +991,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
||||
rw.Header().Set("Content-Type", applicationJSON)
|
||||
rw.WriteHeader(code)
|
||||
}
|
||||
|
||||
// GetJwtSession loads a session based on a JWT token in the authorization header.
|
||||
// (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers)
|
||||
func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
rawBearerToken, err := p.findBearerToken(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we are using an oidc provider, go ahead and try that provider first with its Verifier
|
||||
// and Bearer Token -> Session converter
|
||||
if p.mainJwtBearerVerifier != nil {
|
||||
bearerToken, err := p.mainJwtBearerVerifier.Verify(req.Context(), rawBearerToken)
|
||||
if err == nil {
|
||||
return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, attempt to verify against the extra JWT issuers and use a more generic
|
||||
// Bearer Token -> Session converter
|
||||
for _, verifier := range p.extraJwtBearerVerifiers {
|
||||
bearerToken, err := verifier.Verify(req.Context(), rawBearerToken)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return (*providers.ProviderData)(nil).CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
|
||||
}
|
||||
return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
// findBearerToken finds a valid JWT token from the Authorization header of a given request.
|
||||
func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) {
|
||||
auth := req.Header.Get("Authorization")
|
||||
s := strings.SplitN(auth, " ", 2)
|
||||
if len(s) != 2 {
|
||||
return "", fmt.Errorf("invalid authorization header %s", auth)
|
||||
}
|
||||
jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`)
|
||||
var rawBearerToken string
|
||||
if s[0] == "Bearer" && jwtRegex.MatchString(s[1]) {
|
||||
rawBearerToken = s[1]
|
||||
} else if s[0] == "Basic" {
|
||||
// Check if we have a Bearer token masquerading in Basic
|
||||
b, err := b64.StdEncoding.DecodeString(s[1])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pair := strings.SplitN(string(b), ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return "", fmt.Errorf("invalid format %s", b)
|
||||
}
|
||||
user, password := pair[0], pair[1]
|
||||
|
||||
// check user, user+password, or just password for a token
|
||||
if jwtRegex.MatchString(user) {
|
||||
// Support blank passwords or magic `x-oauth-basic` passwords - nothing else
|
||||
if password == "" || password == "x-oauth-basic" {
|
||||
rawBearerToken = user
|
||||
}
|
||||
} else if jwtRegex.MatchString(password) {
|
||||
// support passwords and ignore user
|
||||
rawBearerToken = password
|
||||
}
|
||||
}
|
||||
if rawBearerToken == "" {
|
||||
return "", fmt.Errorf("no valid bearer token found in authorization header")
|
||||
}
|
||||
|
||||
return rawBearerToken, nil
|
||||
}
|
||||
|
@ -1889,7 +1889,7 @@ func TestGetJwtSession(t *testing.T) {
|
||||
|
||||
// Bearer
|
||||
expires := time.Unix(1912151821, 0)
|
||||
session, err := test.proxy.GetJwtSession(test.req)
|
||||
session, err := test.proxy.getAuthenticatedSession(test.rw, test.req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, session.User, "1234567890")
|
||||
assert.Equal(t, session.Email, "john@example.com")
|
||||
@ -1912,70 +1912,6 @@ func TestGetJwtSession(t *testing.T) {
|
||||
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
|
||||
}
|
||||
|
||||
func TestFindJwtBearerToken(t *testing.T) {
|
||||
p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}}
|
||||
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}
|
||||
|
||||
validToken := "eyJfoobar.eyJfoobar.12345asdf"
|
||||
var token string
|
||||
|
||||
// Bearer
|
||||
getReq.Header = map[string][]string{
|
||||
"Authorization": {fmt.Sprintf("Bearer %s", validToken)},
|
||||
}
|
||||
|
||||
token, err := p.findBearerToken(getReq)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, validToken, token)
|
||||
|
||||
// Basic - no password
|
||||
getReq.SetBasicAuth(token, "")
|
||||
token, err = p.findBearerToken(getReq)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, validToken, token)
|
||||
|
||||
// Basic - sentinel password
|
||||
getReq.SetBasicAuth(token, "x-oauth-basic")
|
||||
token, err = p.findBearerToken(getReq)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, validToken, token)
|
||||
|
||||
// Basic - any username, password matching jwt pattern
|
||||
getReq.SetBasicAuth("any-username-you-could-wish-for", token)
|
||||
token, err = p.findBearerToken(getReq)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, validToken, token)
|
||||
|
||||
failures := []string{
|
||||
// Too many parts
|
||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
||||
// Not enough parts
|
||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA",
|
||||
// Invalid encrypted key
|
||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA",
|
||||
// Invalid IV
|
||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA",
|
||||
// Invalid ciphertext
|
||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA",
|
||||
// Invalid tag
|
||||
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////",
|
||||
// Invalid header
|
||||
"W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
||||
// Invalid header
|
||||
"######.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
||||
// Missing alc/enc params
|
||||
"e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
||||
}
|
||||
|
||||
for _, failure := range failures {
|
||||
getReq.Header = map[string][]string{
|
||||
"Authorization": {fmt.Sprintf("Bearer %s", failure)},
|
||||
}
|
||||
_, err := p.findBearerToken(getReq)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prepareNoCache(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
prepareNoCache(w)
|
||||
|
@ -126,35 +126,8 @@ func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// CreateSessionStateFromBearerToken should be implemented to allow providers
|
||||
// to convert ID tokens into sessions
|
||||
func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
var claims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Verified *bool `json:"email_verified"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
}
|
||||
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
|
||||
}
|
||||
|
||||
if claims.Email == "" {
|
||||
claims.Email = claims.Subject
|
||||
}
|
||||
|
||||
if claims.Verified != nil && !*claims.Verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||
}
|
||||
|
||||
newSession := &sessions.SessionState{
|
||||
Email: claims.Email,
|
||||
User: claims.Subject,
|
||||
PreferredUsername: claims.PreferredUsername,
|
||||
AccessToken: rawIDToken,
|
||||
IDToken: rawIDToken,
|
||||
RefreshToken: "",
|
||||
ExpiresOn: &idToken.Expiry,
|
||||
}
|
||||
|
||||
return newSession, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
@ -2,15 +2,10 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@ -52,39 +47,3 @@ func TestAcrValuesConfigured(t *testing.T) {
|
||||
result := p.GetLoginURL("https://my.test.app/oauth", "")
|
||||
assert.Contains(t, result, "acr_values=testValue")
|
||||
}
|
||||
|
||||
func TestCreateSessionStateFromBearerToken(t *testing.T) {
|
||||
minimalIDToken := jwt.StandardClaims{
|
||||
Audience: "asdf1234",
|
||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
||||
Id: "id-some-id",
|
||||
IssuedAt: time.Now().Unix(),
|
||||
Issuer: "https://issuer.example.com",
|
||||
NotBefore: 0,
|
||||
Subject: "123456789",
|
||||
}
|
||||
// From oidc_test.go
|
||||
verifier := oidc.NewVerifier(
|
||||
"https://issuer.example.com",
|
||||
fakeKeySetStub{},
|
||||
&oidc.Config{ClientID: "asdf1234"},
|
||||
)
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, minimalIDToken).SignedString(key)
|
||||
assert.NoError(t, err)
|
||||
// Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
assert.NoError(t, err)
|
||||
|
||||
session, err := (*ProviderData)(nil).CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, rawIDToken, session.AccessToken)
|
||||
assert.Equal(t, rawIDToken, session.IDToken)
|
||||
assert.Equal(t, "123456789", session.Email)
|
||||
assert.Equal(t, "123456789", session.User)
|
||||
assert.Empty(t, session.RefreshToken)
|
||||
assert.Empty(t, session.PreferredUsername)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user