mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-08 04:03:58 +02:00
Merge pull request #746 from oauth2-proxy/fix-static
Fix conversion of static responses in upstreams
This commit is contained in:
commit
841bf77f7f
@ -9,6 +9,7 @@
|
||||
## Changes since v6.1.0
|
||||
|
||||
- [#729](https://github.com/oauth2-proxy/oauth2-proxy/pull/729) Use X-Forwarded-Host consistently when set (@NickMeves)
|
||||
- [#746](https://github.com/oauth2-proxy/oauth2-proxy/pull/746) Fix conversion of static responses in upstreams (@JoelSpeed)
|
||||
|
||||
# v6.1.0
|
||||
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
@ -87,6 +88,8 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) {
|
||||
if u.Fragment != "" {
|
||||
upstream.ID = u.Fragment
|
||||
upstream.Path = u.Fragment
|
||||
// Trim the fragment from the end of the URI
|
||||
upstream.URI = strings.SplitN(upstreamString, "#", 2)[0]
|
||||
}
|
||||
case "static":
|
||||
responseCode, err := strconv.Atoi(u.Host)
|
||||
@ -97,17 +100,18 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) {
|
||||
upstream.Static = true
|
||||
upstream.StaticCode = &responseCode
|
||||
|
||||
// These are not allowed to be empty and must be unique
|
||||
// This is not allowed to be empty and must be unique
|
||||
upstream.ID = upstreamString
|
||||
upstream.Path = upstreamString
|
||||
|
||||
// We only support the root path in the legacy config
|
||||
upstream.Path = "/"
|
||||
|
||||
// Force defaults compatible with static responses
|
||||
upstream.URI = ""
|
||||
upstream.InsecureSkipTLSVerify = false
|
||||
upstream.PassHostHeader = nil
|
||||
upstream.ProxyWebSockets = nil
|
||||
flush := 1 * time.Second
|
||||
upstream.FlushInterval = &flush
|
||||
upstream.FlushInterval = nil
|
||||
}
|
||||
|
||||
upstreams = append(upstreams, upstream)
|
||||
|
@ -21,9 +21,10 @@ var _ = Describe("Legacy Options", func() {
|
||||
legacyOpts.LegacyUpstreams.PassHostHeader = true
|
||||
legacyOpts.LegacyUpstreams.ProxyWebSockets = true
|
||||
legacyOpts.LegacyUpstreams.SSLUpstreamInsecureSkipVerify = true
|
||||
legacyOpts.LegacyUpstreams.Upstreams = []string{"http://foo.bar/baz", "file://var/lib/website#/bar"}
|
||||
legacyOpts.LegacyUpstreams.Upstreams = []string{"http://foo.bar/baz", "file:///var/lib/website#/bar", "static://204"}
|
||||
|
||||
truth := true
|
||||
staticCode := 204
|
||||
opts.UpstreamServers = Upstreams{
|
||||
{
|
||||
ID: "/baz",
|
||||
@ -37,12 +38,23 @@ var _ = Describe("Legacy Options", func() {
|
||||
{
|
||||
ID: "/bar",
|
||||
Path: "/bar",
|
||||
URI: "file://var/lib/website#/bar",
|
||||
URI: "file:///var/lib/website",
|
||||
FlushInterval: &flushInterval,
|
||||
InsecureSkipTLSVerify: true,
|
||||
PassHostHeader: &truth,
|
||||
ProxyWebSockets: &truth,
|
||||
},
|
||||
{
|
||||
ID: "static://204",
|
||||
Path: "/",
|
||||
URI: "",
|
||||
Static: true,
|
||||
StaticCode: &staticCode,
|
||||
FlushInterval: nil,
|
||||
InsecureSkipTLSVerify: false,
|
||||
PassHostHeader: nil,
|
||||
ProxyWebSockets: nil,
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := legacyOpts.ToOptions()
|
||||
@ -58,8 +70,6 @@ var _ = Describe("Legacy Options", func() {
|
||||
errMsg string
|
||||
}
|
||||
|
||||
defaultFlushInterval := 1 * time.Second
|
||||
|
||||
// Non defaults for these options
|
||||
skipVerify := true
|
||||
passHostHeader := false
|
||||
@ -90,11 +100,11 @@ var _ = Describe("Legacy Options", func() {
|
||||
FlushInterval: &flushInterval,
|
||||
}
|
||||
|
||||
validFileWithFragment := "file://var/lib/website#/bar"
|
||||
validFileWithFragment := "file:///var/lib/website#/bar"
|
||||
validFileWithFragmentUpstream := Upstream{
|
||||
ID: "/bar",
|
||||
Path: "/bar",
|
||||
URI: validFileWithFragment,
|
||||
URI: "file:///var/lib/website",
|
||||
InsecureSkipTLSVerify: skipVerify,
|
||||
PassHostHeader: &passHostHeader,
|
||||
ProxyWebSockets: &proxyWebSockets,
|
||||
@ -105,28 +115,28 @@ var _ = Describe("Legacy Options", func() {
|
||||
validStaticCode := 204
|
||||
validStaticUpstream := Upstream{
|
||||
ID: validStatic,
|
||||
Path: validStatic,
|
||||
Path: "/",
|
||||
URI: "",
|
||||
Static: true,
|
||||
StaticCode: &validStaticCode,
|
||||
InsecureSkipTLSVerify: false,
|
||||
PassHostHeader: nil,
|
||||
ProxyWebSockets: nil,
|
||||
FlushInterval: &defaultFlushInterval,
|
||||
FlushInterval: nil,
|
||||
}
|
||||
|
||||
invalidStatic := "static://abc"
|
||||
invalidStaticCode := 200
|
||||
invalidStaticUpstream := Upstream{
|
||||
ID: invalidStatic,
|
||||
Path: invalidStatic,
|
||||
Path: "/",
|
||||
URI: "",
|
||||
Static: true,
|
||||
StaticCode: &invalidStaticCode,
|
||||
InsecureSkipTLSVerify: false,
|
||||
PassHostHeader: nil,
|
||||
ProxyWebSockets: nil,
|
||||
FlushInterval: &defaultFlushInterval,
|
||||
FlushInterval: nil,
|
||||
}
|
||||
|
||||
invalidHTTP := ":foo"
|
||||
|
@ -56,6 +56,7 @@ func (m *multiUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request
|
||||
|
||||
// registerStaticResponseHandler registers a static response handler with at the given path.
|
||||
func (m *multiUpstreamProxy) registerStaticResponseHandler(upstream options.Upstream) {
|
||||
logger.Printf("mapping path %q => static response %d", upstream.Path, derefStaticCode(upstream.StaticCode))
|
||||
m.serveMux.Handle(upstream.Path, newStaticResponseHandler(upstream.ID, upstream.StaticCode))
|
||||
}
|
||||
|
||||
|
@ -10,12 +10,8 @@ const defaultStaticResponseCode = 200
|
||||
// newStaticResponseHandler creates a new staticResponseHandler that serves a
|
||||
// a static response code.
|
||||
func newStaticResponseHandler(upstream string, code *int) http.Handler {
|
||||
if code == nil {
|
||||
c := defaultStaticResponseCode
|
||||
code = &c
|
||||
}
|
||||
return &staticResponseHandler{
|
||||
code: *code,
|
||||
code: derefStaticCode(code),
|
||||
upstream: upstream,
|
||||
}
|
||||
}
|
||||
@ -32,3 +28,11 @@ func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Requ
|
||||
rw.WriteHeader(s.code)
|
||||
fmt.Fprintf(rw, "Authenticated")
|
||||
}
|
||||
|
||||
// derefStaticCode returns the derefenced value, or the default if the value is nil
|
||||
func derefStaticCode(code *int) int {
|
||||
if code != nil {
|
||||
return *code
|
||||
}
|
||||
return defaultStaticResponseCode
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user