1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-29 22:06:58 +02:00

Merge pull request from oauth2-proxy/https-redirect-middleware

Improve Redirect to HTTPs behaviour
This commit is contained in:
Joel Speed 2020-07-03 17:25:24 +01:00 committed by GitHub
commit c4cf15f3e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 222 additions and 70 deletions

@ -8,6 +8,7 @@
## Changes since v6.0.0
- [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed)
- [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed)
- [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed)
- [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed)

18
http.go

@ -9,7 +9,6 @@ import (
"strings"
"time"
"github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
)
@ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
func newRedirectToHTTPS(opts *options.Options) alice.Constructor {
return func(next http.Handler) http.Handler {
return redirectToHTTPS(opts, next)
}
}
func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proto := r.Header.Get("X-Forwarded-Proto")
if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) {
http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect)
}
h.ServeHTTP(w, r)
})
}

@ -2,7 +2,6 @@ package main
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@ -11,56 +10,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestRedirectToHTTPSTrue(t *testing.T) {
opts := options.NewOptions()
opts.ForceHTTPS = true
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
h.ServeHTTP(rw, r)
assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code)
}
func TestRedirectToHTTPSFalse(t *testing.T) {
opts := options.NewOptions()
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
h.ServeHTTP(rw, r)
assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code)
}
func TestRedirectNotWhenHTTPS(t *testing.T) {
opts := options.NewOptions()
opts.ForceHTTPS = true
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
s := httptest.NewTLSServer(h)
defer s.Close()
opts.HTTPSAddress = s.URL
client := s.Client()
res, err := client.Get(s.URL)
if err != nil {
t.Fatalf("request to test server failed with error: %v", err)
}
assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode)
}
func TestGracefulShutdown(t *testing.T) {
opts := options.NewOptions()
stop := make(chan struct{}, 1)

@ -3,6 +3,7 @@ package main
import (
"fmt"
"math/rand"
"net"
"os"
"os/signal"
"runtime"
@ -79,7 +80,11 @@ func main() {
chain := alice.New()
if opts.ForceHTTPS {
chain = chain.Append(newRedirectToHTTPS(opts))
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
if err != nil {
logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
}
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
}
healthCheckPaths := []string{opts.PingPath}

@ -1,6 +1,7 @@
package middleware
import (
"net/http"
"testing"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
@ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware")
}
func testHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte("test"))
})
}

@ -0,0 +1,50 @@
package middleware
import (
"net"
"net/http"
"net/url"
"strings"
"github.com/justinas/alice"
)
const httpsScheme = "https"
// NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect
// HTTP requests to HTTPS
func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
return func(next http.Handler) http.Handler {
return redirectToHTTPS(httpsPort, next)
}
}
// redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS
// if it is not already HTTPS.
// If the request is to a non standard port, the redirection request will be
// to the port from the httpsAddress given.
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
proto := req.Header.Get("X-Forwarded-Proto")
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
// Only care about the connection to us being HTTPS if the proto is empty,
// otherwise the proto is source of truth
next.ServeHTTP(rw, req)
return
}
// Copy the request URL
targetURL, _ := url.Parse(req.URL.String())
// Set the scheme to HTTPS
targetURL.Scheme = httpsScheme
// Overwrite the port if the original request was to a non-standard port
if targetURL.Port() != "" {
// If Port was not empty, this should be fine to ignore the error
host, _, _ := net.SplitHostPort(targetURL.Host)
targetURL.Host = net.JoinHostPort(host, httpsPort)
}
http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect)
})
}

@ -0,0 +1,158 @@
package middleware
import (
"crypto/tls"
"fmt"
"net/http/httptest"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
var _ = Describe("RedirectToHTTPS suite", func() {
const httpsPort = "8443"
var permanentRedirectBody = func(address string) string {
return fmt.Sprintf("<a href=\"%s\">Permanent Redirect</a>.\n\n", address)
}
type requestTableInput struct {
requestString string
useTLS bool
headers map[string]string
expectedStatus int
expectedBody string
expectedLocation string
}
DescribeTable("when serving a request",
func(in *requestTableInput) {
req := httptest.NewRequest("", in.requestString, nil)
for k, v := range in.headers {
req.Header.Add(k, v)
}
if in.useTLS {
req.TLS = &tls.ConnectionState{}
}
rw := httptest.NewRecorder()
handler := NewRedirectToHTTPS(httpsPort)(testHandler())
handler.ServeHTTP(rw, req)
Expect(rw.Code).To(Equal(in.expectedStatus))
Expect(rw.Body.String()).To(Equal(in.expectedBody))
if in.expectedLocation != "" {
Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation))
}
},
Entry("without TLS", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("with TLS", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "HTTPS",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "HTTPS",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "HTTP",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "HTTP",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{
requestString: "https://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "http",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "http",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("without TLS on a non-standard port", &requestTableInput{
requestString: "http://example.com:8080",
useTLS: false,
headers: map[string]string{},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com:8443"),
expectedLocation: "https://example.com:8443",
}),
Entry("with TLS on a non-standard port", &requestTableInput{
requestString: "https://example.com:8443",
useTLS: true,
headers: map[string]string{},
expectedStatus: 200,
expectedBody: "test",
}),
)
})