mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-13 00:07:26 +02:00
Merge pull request #619 from oauth2-proxy/https-redirect-middleware
Improve Redirect to HTTPs behaviour
This commit is contained in:
commit
c4cf15f3e1
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
## Changes since v6.0.0
|
## Changes since v6.0.0
|
||||||
|
|
||||||
|
- [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed)
|
||||||
- [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed)
|
- [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed)
|
||||||
- [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed)
|
- [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed)
|
||||||
- [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed)
|
- [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed)
|
||||||
|
18
http.go
18
http.go
@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
)
|
)
|
||||||
@ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
|||||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||||
return tc, nil
|
return tc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRedirectToHTTPS(opts *options.Options) alice.Constructor {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return redirectToHTTPS(opts, next)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
proto := r.Header.Get("X-Forwarded-Proto")
|
|
||||||
if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) {
|
|
||||||
http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
51
http_test.go
51
http_test.go
@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -11,56 +10,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRedirectToHTTPSTrue(t *testing.T) {
|
|
||||||
opts := options.NewOptions()
|
|
||||||
opts.ForceHTTPS = true
|
|
||||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
w.Write([]byte("test"))
|
|
||||||
}
|
|
||||||
|
|
||||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
r, _ := http.NewRequest("GET", "/", nil)
|
|
||||||
h.ServeHTTP(rw, r)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRedirectToHTTPSFalse(t *testing.T) {
|
|
||||||
opts := options.NewOptions()
|
|
||||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
w.Write([]byte("test"))
|
|
||||||
}
|
|
||||||
|
|
||||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
r, _ := http.NewRequest("GET", "/", nil)
|
|
||||||
h.ServeHTTP(rw, r)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRedirectNotWhenHTTPS(t *testing.T) {
|
|
||||||
opts := options.NewOptions()
|
|
||||||
opts.ForceHTTPS = true
|
|
||||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
w.Write([]byte("test"))
|
|
||||||
}
|
|
||||||
|
|
||||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
|
||||||
s := httptest.NewTLSServer(h)
|
|
||||||
defer s.Close()
|
|
||||||
|
|
||||||
opts.HTTPSAddress = s.URL
|
|
||||||
client := s.Client()
|
|
||||||
res, err := client.Get(s.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("request to test server failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGracefulShutdown(t *testing.T) {
|
func TestGracefulShutdown(t *testing.T) {
|
||||||
opts := options.NewOptions()
|
opts := options.NewOptions()
|
||||||
stop := make(chan struct{}, 1)
|
stop := make(chan struct{}, 1)
|
||||||
|
7
main.go
7
main.go
@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -79,7 +80,11 @@ func main() {
|
|||||||
chain := alice.New()
|
chain := alice.New()
|
||||||
|
|
||||||
if opts.ForceHTTPS {
|
if opts.ForceHTTPS {
|
||||||
chain = chain.Append(newRedirectToHTTPS(opts))
|
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
||||||
|
}
|
||||||
|
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||||
}
|
}
|
||||||
|
|
||||||
healthCheckPaths := []string{opts.PingPath}
|
healthCheckPaths := []string{opts.PingPath}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
@ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) {
|
|||||||
RegisterFailHandler(Fail)
|
RegisterFailHandler(Fail)
|
||||||
RunSpecs(t, "Middleware")
|
RunSpecs(t, "Middleware")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Write([]byte("test"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
50
pkg/middleware/redirect_to_https.go
Normal file
50
pkg/middleware/redirect_to_https.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
|
)
|
||||||
|
|
||||||
|
const httpsScheme = "https"
|
||||||
|
|
||||||
|
// NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect
|
||||||
|
// HTTP requests to HTTPS
|
||||||
|
func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return redirectToHTTPS(httpsPort, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS
|
||||||
|
// if it is not already HTTPS.
|
||||||
|
// If the request is to a non standard port, the redirection request will be
|
||||||
|
// to the port from the httpsAddress given.
|
||||||
|
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
proto := req.Header.Get("X-Forwarded-Proto")
|
||||||
|
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
|
||||||
|
// Only care about the connection to us being HTTPS if the proto is empty,
|
||||||
|
// otherwise the proto is source of truth
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the request URL
|
||||||
|
targetURL, _ := url.Parse(req.URL.String())
|
||||||
|
// Set the scheme to HTTPS
|
||||||
|
targetURL.Scheme = httpsScheme
|
||||||
|
|
||||||
|
// Overwrite the port if the original request was to a non-standard port
|
||||||
|
if targetURL.Port() != "" {
|
||||||
|
// If Port was not empty, this should be fine to ignore the error
|
||||||
|
host, _, _ := net.SplitHostPort(targetURL.Host)
|
||||||
|
targetURL.Host = net.JoinHostPort(host, httpsPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect)
|
||||||
|
})
|
||||||
|
}
|
158
pkg/middleware/redirect_to_https_test.go
Normal file
158
pkg/middleware/redirect_to_https_test.go
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net/http/httptest"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
|
const httpsPort = "8443"
|
||||||
|
|
||||||
|
var permanentRedirectBody = func(address string) string {
|
||||||
|
return fmt.Sprintf("<a href=\"%s\">Permanent Redirect</a>.\n\n", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestTableInput struct {
|
||||||
|
requestString string
|
||||||
|
useTLS bool
|
||||||
|
headers map[string]string
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody string
|
||||||
|
expectedLocation string
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("when serving a request",
|
||||||
|
func(in *requestTableInput) {
|
||||||
|
req := httptest.NewRequest("", in.requestString, nil)
|
||||||
|
for k, v := range in.headers {
|
||||||
|
req.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
if in.useTLS {
|
||||||
|
req.TLS = &tls.ConnectionState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewRedirectToHTTPS(httpsPort)(testHandler())
|
||||||
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
Expect(rw.Code).To(Equal(in.expectedStatus))
|
||||||
|
Expect(rw.Body.String()).To(Equal(in.expectedBody))
|
||||||
|
|
||||||
|
if in.expectedLocation != "" {
|
||||||
|
Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Entry("without TLS", &requestTableInput{
|
||||||
|
requestString: "http://example.com",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{},
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
|
expectedLocation: "https://example.com",
|
||||||
|
}),
|
||||||
|
Entry("with TLS", &requestTableInput{
|
||||||
|
requestString: "https://example.com",
|
||||||
|
useTLS: true,
|
||||||
|
headers: map[string]string{},
|
||||||
|
expectedStatus: 200,
|
||||||
|
expectedBody: "test",
|
||||||
|
}),
|
||||||
|
Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||||
|
requestString: "http://example.com",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "HTTPS",
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
expectedBody: "test",
|
||||||
|
}),
|
||||||
|
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||||
|
requestString: "https://example.com",
|
||||||
|
useTLS: true,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "HTTPS",
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
expectedBody: "test",
|
||||||
|
}),
|
||||||
|
Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{
|
||||||
|
requestString: "http://example.com",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
expectedBody: "test",
|
||||||
|
}),
|
||||||
|
Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{
|
||||||
|
requestString: "https://example.com",
|
||||||
|
useTLS: true,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
expectedBody: "test",
|
||||||
|
}),
|
||||||
|
Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
|
||||||
|
requestString: "http://example.com",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "HTTP",
|
||||||
|
},
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
|
expectedLocation: "https://example.com",
|
||||||
|
}),
|
||||||
|
Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
|
||||||
|
requestString: "https://example.com",
|
||||||
|
useTLS: true,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "HTTP",
|
||||||
|
},
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
|
expectedLocation: "https://example.com",
|
||||||
|
}),
|
||||||
|
Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{
|
||||||
|
requestString: "https://example.com",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "http",
|
||||||
|
},
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
|
expectedLocation: "https://example.com",
|
||||||
|
}),
|
||||||
|
Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{
|
||||||
|
requestString: "https://example.com",
|
||||||
|
useTLS: true,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "http",
|
||||||
|
},
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
|
expectedLocation: "https://example.com",
|
||||||
|
}),
|
||||||
|
Entry("without TLS on a non-standard port", &requestTableInput{
|
||||||
|
requestString: "http://example.com:8080",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{},
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://example.com:8443"),
|
||||||
|
expectedLocation: "https://example.com:8443",
|
||||||
|
}),
|
||||||
|
Entry("with TLS on a non-standard port", &requestTableInput{
|
||||||
|
requestString: "https://example.com:8443",
|
||||||
|
useTLS: true,
|
||||||
|
headers: map[string]string{},
|
||||||
|
expectedStatus: 200,
|
||||||
|
expectedBody: "test",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
Loading…
x
Reference in New Issue
Block a user