mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-29 22:06:58 +02:00
Merge pull request #619 from oauth2-proxy/https-redirect-middleware
Improve Redirect to HTTPs behaviour
This commit is contained in:
commit
c4cf15f3e1
@ -8,6 +8,7 @@
|
||||
|
||||
## Changes since v6.0.0
|
||||
|
||||
- [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed)
|
||||
- [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed)
|
||||
- [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed)
|
||||
- [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed)
|
||||
|
18
http.go
18
http.go
@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
)
|
||||
@ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func newRedirectToHTTPS(opts *options.Options) alice.Constructor {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return redirectToHTTPS(opts, next)
|
||||
}
|
||||
}
|
||||
|
||||
func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) {
|
||||
http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
51
http_test.go
51
http_test.go
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -11,56 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRedirectToHTTPSTrue(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
opts.ForceHTTPS = true
|
||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
|
||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
||||
rw := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(rw, r)
|
||||
|
||||
assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code)
|
||||
}
|
||||
|
||||
func TestRedirectToHTTPSFalse(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
|
||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
||||
rw := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(rw, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code)
|
||||
}
|
||||
|
||||
func TestRedirectNotWhenHTTPS(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
opts.ForceHTTPS = true
|
||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
|
||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
||||
s := httptest.NewTLSServer(h)
|
||||
defer s.Close()
|
||||
|
||||
opts.HTTPSAddress = s.URL
|
||||
client := s.Client()
|
||||
res, err := client.Get(s.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("request to test server failed with error: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
stop := make(chan struct{}, 1)
|
||||
|
7
main.go
7
main.go
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
@ -79,7 +80,11 @@ func main() {
|
||||
chain := alice.New()
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
chain = chain.Append(newRedirectToHTTPS(opts))
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
||||
}
|
||||
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||
}
|
||||
|
||||
healthCheckPaths := []string{opts.PingPath}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
@ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware")
|
||||
}
|
||||
|
||||
func testHandler() http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte("test"))
|
||||
})
|
||||
}
|
||||
|
50
pkg/middleware/redirect_to_https.go
Normal file
50
pkg/middleware/redirect_to_https.go
Normal file
@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
)
|
||||
|
||||
const httpsScheme = "https"
|
||||
|
||||
// NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect
|
||||
// HTTP requests to HTTPS
|
||||
func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return redirectToHTTPS(httpsPort, next)
|
||||
}
|
||||
}
|
||||
|
||||
// redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS
|
||||
// if it is not already HTTPS.
|
||||
// If the request is to a non standard port, the redirection request will be
|
||||
// to the port from the httpsAddress given.
|
||||
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
|
||||
// Only care about the connection to us being HTTPS if the proto is empty,
|
||||
// otherwise the proto is source of truth
|
||||
next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the request URL
|
||||
targetURL, _ := url.Parse(req.URL.String())
|
||||
// Set the scheme to HTTPS
|
||||
targetURL.Scheme = httpsScheme
|
||||
|
||||
// Overwrite the port if the original request was to a non-standard port
|
||||
if targetURL.Port() != "" {
|
||||
// If Port was not empty, this should be fine to ignore the error
|
||||
host, _, _ := net.SplitHostPort(targetURL.Host)
|
||||
targetURL.Host = net.JoinHostPort(host, httpsPort)
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect)
|
||||
})
|
||||
}
|
158
pkg/middleware/redirect_to_https_test.go
Normal file
158
pkg/middleware/redirect_to_https_test.go
Normal file
@ -0,0 +1,158 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
const httpsPort = "8443"
|
||||
|
||||
var permanentRedirectBody = func(address string) string {
|
||||
return fmt.Sprintf("<a href=\"%s\">Permanent Redirect</a>.\n\n", address)
|
||||
}
|
||||
|
||||
type requestTableInput struct {
|
||||
requestString string
|
||||
useTLS bool
|
||||
headers map[string]string
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
expectedLocation string
|
||||
}
|
||||
|
||||
DescribeTable("when serving a request",
|
||||
func(in *requestTableInput) {
|
||||
req := httptest.NewRequest("", in.requestString, nil)
|
||||
for k, v := range in.headers {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
if in.useTLS {
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler := NewRedirectToHTTPS(httpsPort)(testHandler())
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
Expect(rw.Code).To(Equal(in.expectedStatus))
|
||||
Expect(rw.Body.String()).To(Equal(in.expectedBody))
|
||||
|
||||
if in.expectedLocation != "" {
|
||||
Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation))
|
||||
}
|
||||
},
|
||||
Entry("without TLS", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("without TLS on a non-standard port", &requestTableInput{
|
||||
requestString: "http://example.com:8080",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com:8443"),
|
||||
expectedLocation: "https://example.com:8443",
|
||||
}),
|
||||
Entry("with TLS on a non-standard port", &requestTableInput{
|
||||
requestString: "https://example.com:8443",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
)
|
||||
})
|
Loading…
x
Reference in New Issue
Block a user