diff --git a/pkg/middleware/basic_session.go b/pkg/middleware/basic_session.go new file mode 100644 index 00000000..734f4d83 --- /dev/null +++ b/pkg/middleware/basic_session.go @@ -0,0 +1,88 @@ +package middleware + +import ( + "fmt" + "net/http" + + "github.com/justinas/alice" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor { + return func(next http.Handler) http.Handler { + return loadBasicAuthSession(validator, next) + } +} + +// loadBasicAuthSession attmepts to load a session from basic auth credentials +// stored in an Authorization header within the request. +// If no authorization header is found, or the header is invalid, no session +// will be loaded and the request will be passed to the next handler. +// If a session was loaded by a previous handler, it will not be replaced. +func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + if scope.Session != nil { + // The session was already loaded, pass to the next handler + next.ServeHTTP(rw, req) + return + } + + session, err := getBasicSession(validator, req) + if err != nil { + logger.Printf("Error retrieving session from token in Authorization header: %v", err) + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +// getBasicSession attempts to load a basic session from the request. +// If the credentials in the request exist within the htpasswdMap, +// a new session will be created. +func getBasicSession(validator basic.Validator, req *http.Request) (*sessionsapi.SessionState, error) { + auth := req.Header.Get("Authorization") + if auth == "" { + // No auth header provided, so don't attempt to load a session + return nil, nil + } + + user, password, err := findBasicCredentialsFromHeader(auth) + if err != nil { + return nil, err + } + + if validator.Validate(user, password) { + logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") + return &sessionsapi.SessionState{User: user}, nil + } + + logger.PrintAuthf(user, req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") + return nil, nil +} + +// findBasicCredentialsFromHeader finds basic auth credneitals from the +// Authorization header of a given request. +func findBasicCredentialsFromHeader(header string) (string, string, error) { + tokenType, token, err := splitAuthHeader(header) + if err != nil { + return "", "", err + } + + if tokenType != "Basic" { + return "", "", fmt.Errorf("invalid Authorization header: %q", header) + } + + user, password, err := getBasicAuthCredentials(token) + if err != nil { + return "", "", fmt.Errorf("error decoding basic auth credentials: %v", err) + } + + return user, password, nil +} diff --git a/pkg/middleware/basic_session_test.go b/pkg/middleware/basic_session_test.go new file mode 100644 index 00000000..109b09a5 --- /dev/null +++ b/pkg/middleware/basic_session_test.go @@ -0,0 +1,132 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +const ( + adminUser = "admin" + adminPassword = "Adm1n1str$t0r" + user1 = "user1" + user1Password = "UsErOn3P455" + user2 = "user2" + user2Password = "us3r2P455W0Rd!" +) + +var _ = Describe("Basic Auth Session Suite", func() { + Context("BasicAuthSessionLoader", func() { + + type basicAuthSessionLoaderTableInput struct { + authorizationHeader string + existingSession *sessionsapi.SessionState + expectedSession *sessionsapi.SessionState + } + + DescribeTable("with an authorization header", + func(in basicAuthSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the authorization header and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header.Set("Authorization", in.authorizationHeader) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + + validator := fakeBasicValidator{ + users: map[string]string{ + adminUser: adminPassword, + user1: user1Password, + user2: user2Password, + }, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + })) + handler.ServeHTTP(rw, req) + + Expect(gotSession).To(Equal(in.expectedSession)) + }, + Entry("", basicAuthSessionLoaderTableInput{ + authorizationHeader: "", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef", basicAuthSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef (with existing session)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Bearer ", basicAuthSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", adminPassword), + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic ", basicAuthSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Basic %s", adminPassword), + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic Base64(:) (with existing session)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic OlVzRXJPbjNQNDU1", + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Basic Base64(user1:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic dXNlcjE6VXNFck9uM1A0NTU=", + existingSession: nil, + expectedSession: &sessionsapi.SessionState{User: "user1"}, + }), + Entry("Basic Base64(user2:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic dXNlcjI6VXNFck9uM1A0NTU=", + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic Base64(user2:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic dXNlcjI6dXMzcjJQNDU1VzBSZCE=", + existingSession: nil, + expectedSession: &sessionsapi.SessionState{User: "user2"}, + }), + Entry("Basic Base64(admin:)", basicAuthSessionLoaderTableInput{ + authorizationHeader: "Basic YWRtaW46QWRtMW4xc3RyJHQwcg==", + existingSession: nil, + expectedSession: &sessionsapi.SessionState{User: "admin"}, + }), + ) + }) +}) + +type fakeBasicValidator struct { + users map[string]string +} + +func (f fakeBasicValidator) Validate(user, password string) bool { + if f.users == nil { + return false + } + if realPassword, ok := f.users[user]; ok { + return realPassword == password + } + return false +} diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go index 4c3087c7..3a0b65fc 100644 --- a/pkg/middleware/jwt_session.go +++ b/pkg/middleware/jwt_session.go @@ -2,11 +2,9 @@ package middleware import ( "context" - "encoding/base64" "fmt" "net/http" "regexp" - "strings" "github.com/coreos/go-oidc" "github.com/justinas/alice" @@ -115,17 +113,11 @@ func (j *jwtSessionLoader) findBearerTokenFromHeader(header string) (string, err // getBasicToken tries to extract a token from the basic value provided. func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { - b, err := base64.StdEncoding.DecodeString(token) + user, password, err := getBasicAuthCredentials(token) if err != nil { - return "", fmt.Errorf("invalid basic auth token: %v", err) + return "", err } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return "", fmt.Errorf("invalid format: %q", b) - } - user, password := pair[0], pair[1] - // check user, user+password, or just password for a token if j.jwtRegex.MatchString(user) { // Support blank passwords or magic `x-oauth-basic` passwords - nothing else @@ -140,16 +132,6 @@ func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { return "", fmt.Errorf("invalid basic auth token found in authorization header") } -// splitAuthHeader takes the auth header value and splits it into the token type -// and the token value. -func splitAuthHeader(header string) (string, string, error) { - s := strings.Split(header, " ") - if len(s) != 2 { - return "", "", fmt.Errorf("invalid authorization header: %q", header) - } - return s[0], s[1], nil -} - // createSessionStateFromBearerToken is a default implementation for converting // a JWT into a session state. func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) { diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go index f8adf2d3..5148ad28 100644 --- a/pkg/middleware/jwt_session_test.go +++ b/pkg/middleware/jwt_session_test.go @@ -381,46 +381,6 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` ) }) - Context("splitAuthHeader", func() { - type splitAuthTableInput struct { - header string - expectedErr error - expectedTokenType string - expectedTokenValue string - } - - DescribeTable("with a header value", - func(in splitAuthTableInput) { - tt, tv, err := splitAuthHeader(in.header) - if in.expectedErr != nil { - Expect(err).To(MatchError(in.expectedErr)) - } else { - Expect(err).ToNot(HaveOccurred()) - } - Expect(tt).To(Equal(in.expectedTokenType)) - Expect(tv).To(Equal(in.expectedTokenValue)) - }, - Entry("Bearer abcdef", splitAuthTableInput{ - header: "Bearer abcdef", - expectedErr: nil, - expectedTokenType: "Bearer", - expectedTokenValue: "abcdef", - }), - Entry("Bearer", splitAuthTableInput{ - header: "Bearer", - expectedErr: errors.New("invalid authorization header: \"Bearer\""), - expectedTokenType: "", - expectedTokenValue: "", - }), - Entry("Bearer abc def", splitAuthTableInput{ - header: "Bearer abc def", - expectedErr: errors.New("invalid authorization header: \"Bearer abc def\""), - expectedTokenType: "", - expectedTokenValue: "", - }), - ) - }) - Context("createSessionStateFromBearerToken", func() { ctx := context.Background() expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) diff --git a/pkg/middleware/session_utils.go b/pkg/middleware/session_utils.go new file mode 100644 index 00000000..00277d4f --- /dev/null +++ b/pkg/middleware/session_utils.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "encoding/base64" + "fmt" + "strings" +) + +// splitAuthHeader takes the auth header value and splits it into the token type +// and the token value. +func splitAuthHeader(header string) (string, string, error) { + s := strings.Split(header, " ") + if len(s) != 2 { + return "", "", fmt.Errorf("invalid authorization header: %q", header) + } + return s[0], s[1], nil +} + +// getBasicAuthCredentials decodes a basic auth token and extracts the user +// and password pair. +func getBasicAuthCredentials(token string) (string, string, error) { + b, err := base64.StdEncoding.DecodeString(token) + if err != nil { + return "", "", fmt.Errorf("invalid basic auth token: %v", err) + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return "", "", fmt.Errorf("invalid format: %q", b) + } + // user, password + return pair[0], pair[1], nil +} diff --git a/pkg/middleware/session_utils_test.go b/pkg/middleware/session_utils_test.go new file mode 100644 index 00000000..26eb84bb --- /dev/null +++ b/pkg/middleware/session_utils_test.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "errors" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Session utilities suite", func() { + Context("splitAuthHeader", func() { + type splitAuthTableInput struct { + header string + expectedErr error + expectedTokenType string + expectedTokenValue string + } + + DescribeTable("with a header value", + func(in splitAuthTableInput) { + tt, tv, err := splitAuthHeader(in.header) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(tt).To(Equal(in.expectedTokenType)) + Expect(tv).To(Equal(in.expectedTokenValue)) + }, + Entry("Bearer abcdef", splitAuthTableInput{ + header: "Bearer abcdef", + expectedErr: nil, + expectedTokenType: "Bearer", + expectedTokenValue: "abcdef", + }), + Entry("Bearer", splitAuthTableInput{ + header: "Bearer", + expectedErr: errors.New("invalid authorization header: \"Bearer\""), + expectedTokenType: "", + expectedTokenValue: "", + }), + Entry("Bearer abc def", splitAuthTableInput{ + header: "Bearer abc def", + expectedErr: errors.New("invalid authorization header: \"Bearer abc def\""), + expectedTokenType: "", + expectedTokenValue: "", + }), + ) + }) + + Context("getBasicAuthCredentials", func() { + type getBasicAuthCredentialsTableInput struct { + token string + expectedErr error + expectedUser string + expectedPassword string + } + + DescribeTable("from token", + func(in getBasicAuthCredentialsTableInput) { + user, password, err := getBasicAuthCredentials(in.token) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(user).To(Equal(in.expectedUser)) + Expect(password).To(Equal(in.expectedPassword)) + }, + Entry("", getBasicAuthCredentialsTableInput{ + token: "", + expectedErr: errors.New("invalid format: \"\""), + expectedUser: "", + expectedPassword: "", + }), + Entry("invalid-base64", getBasicAuthCredentialsTableInput{ + token: "invalid-base64", + expectedErr: errors.New("invalid basic auth token: illegal base64 data at input byte 7"), + expectedUser: "", + expectedPassword: "", + }), + Entry("Base64(some-user:some-password)", getBasicAuthCredentialsTableInput{ + token: "c29tZS11c2VyOnNvbWUtcGFzc3dvcmQ=", + expectedErr: nil, + expectedUser: "some-user", + expectedPassword: "some-password", + }), + Entry("Base64(no-password:)", getBasicAuthCredentialsTableInput{ + token: "bm8tcGFzc3dvcmQ6", + expectedErr: nil, + expectedUser: "no-password", + expectedPassword: "", + }), + Entry("Base64(:no-user)", getBasicAuthCredentialsTableInput{ + token: "Om5vLXVzZXI=", + expectedErr: nil, + expectedUser: "", + expectedPassword: "no-user", + }), + ) + }) +})