diff --git a/oauthproxy.go b/oauthproxy.go index d11040c3..a1339415 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -919,7 +919,8 @@ func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.Sess // and optional authorization). func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { session, err := p.getAuthenticatedSession(rw, req) - if err != nil { + if err != nil || session == nil { + // If there's no session, or an error retrieving it, we need to return 401 to trigger the OAuth2 flow. http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 0d8bc91a..0eac9dab 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1401,6 +1401,24 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries") } +func TestAuthOnlyEndpointRejectPreflighRequests(t *testing.T) { + skipPreflight := func(opts *options.Options) { + opts.SkipAuthPreflight = true + } + + test, err := NewAuthOnlyEndpointTest("", skipPreflight) + if err != nil { + t.Fatal(err) + } + + test.req.Method = http.MethodOptions + + test.proxy.ServeHTTP(test.rw, test.req) + assert.Equal(t, http.StatusUnauthorized, test.rw.Code) + bodyBytes, _ := io.ReadAll(test.rw.Body) + assert.Equal(t, "Unauthorized\n", string(bodyBytes)) +} + func TestAuthSkippedForPreflightRequests(t *testing.T) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -3057,14 +3075,14 @@ func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) { method: "OPTIONS", ip: "1.2.3.5:43670", withSession: false, - expectedStatusCode: http.StatusAccepted, + expectedStatusCode: http.StatusUnauthorized, }, { name: "UserWithoutSessionTrustedIp", method: "GET", ip: "1.2.3.4:43670", withSession: false, - expectedStatusCode: http.StatusAccepted, + expectedStatusCode: http.StatusUnauthorized, }, }