diff --git a/main.go b/main.go index b42c2c05..24d48072 100644 --- a/main.go +++ b/main.go @@ -32,13 +32,19 @@ func main() { return } - opts := options.NewOptions() - err := options.Load(*config, flagSet, opts) + legacyOpts := options.NewLegacyOptions() + err := options.Load(*config, flagSet, legacyOpts) if err != nil { logger.Printf("ERROR: Failed to load config: %v", err) os.Exit(1) } + opts, err := legacyOpts.ToOptions() + if err != nil { + logger.Printf("ERROR: Failed to convert config: %v", err) + os.Exit(1) + } + err = validation.Validate(opts) if err != nil { logger.Printf("%s", err) diff --git a/oauthproxy.go b/oauthproxy.go index 4b310b9b..034fe6a3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/tls" b64 "encoding/base64" "encoding/json" "errors" @@ -10,15 +9,12 @@ import ( "html/template" "net" "net/http" - "net/http/httputil" "net/url" "regexp" - "strconv" "strings" "time" "github.com/coreos/go-oidc" - "github.com/mbland/hmacauth" ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" @@ -28,37 +24,17 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/providers" - "github.com/yhat/wsutil" ) const ( - // SignatureHeader is the name of the request header containing the GAP Signature - // Part of hmacauth - SignatureHeader = "GAP-Signature" - httpScheme = "http" httpsScheme = "https" applicationJSON = "application/json" ) -// SignatureHeaders contains the headers to be signed by the hmac algorithm -// Part of hmacauth -var SignatureHeaders = []string{ - "Content-Length", - "Content-Md5", - "Content-Type", - "Date", - "Authorization", - "X-Forwarded-User", - "X-Forwarded-Email", - "X-Forwarded-Preferred-User", - "X-Forwarded-Access-Token", - "Cookie", - "Gap-Auth", -} - var ( // ErrNeedsLogin means the user should be redirected to the login page ErrNeedsLogin = errors.New("redirect to login page") @@ -124,116 +100,6 @@ type OAuthProxy struct { Footer string } -// UpstreamProxy represents an upstream server to proxy to -type UpstreamProxy struct { - upstream string - handler http.Handler - wsHandler http.Handler - auth hmacauth.HmacAuth -} - -// ServeHTTP proxies requests to the upstream provider while signing the -// request headers -func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Header().Set("GAP-Upstream-Address", u.upstream) - if u.auth != nil { - r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) - u.auth.SignRequest(r) - } - if u.wsHandler != nil && strings.EqualFold(r.Header.Get("Connection"), "upgrade") && r.Header.Get("Upgrade") == "websocket" { - u.wsHandler.ServeHTTP(w, r) - } else { - u.handler.ServeHTTP(w, r) - } - -} - -// NewReverseProxy creates a new reverse proxy for proxying requests to upstream -// servers -func NewReverseProxy(target *url.URL, opts *options.Options) (proxy *httputil.ReverseProxy) { - proxy = httputil.NewSingleHostReverseProxy(target) - proxy.FlushInterval = opts.FlushInterval - if opts.SSLUpstreamInsecureSkipVerify { - proxy.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - } - setProxyErrorHandler(proxy, opts) - return proxy -} - -func setProxyErrorHandler(proxy *httputil.ReverseProxy, opts *options.Options) { - templates := loadTemplates(opts.CustomTemplatesDir) - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, proxyErr error) { - logger.Printf("Error proxying to upstream server: %v", proxyErr) - w.WriteHeader(http.StatusBadGateway) - data := struct { - Title string - Message string - ProxyPrefix string - }{ - Title: "Bad Gateway", - Message: "Error proxying to upstream server", - ProxyPrefix: opts.ProxyPrefix, - } - templates.ExecuteTemplate(w, "error.html", data) - } -} - -func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { - director := proxy.Director - proxy.Director = func(req *http.Request) { - director(req) - // use RequestURI so that we aren't unescaping encoded slashes in the request path - req.Host = target.Host - req.URL.Opaque = req.RequestURI - req.URL.RawQuery = "" - } -} - -func setProxyDirector(proxy *httputil.ReverseProxy) { - director := proxy.Director - proxy.Director = func(req *http.Request) { - director(req) - // use RequestURI so that we aren't unescaping encoded slashes in the request path - req.URL.Opaque = req.RequestURI - req.URL.RawQuery = "" - } -} - -// NewFileServer creates a http.Handler to serve files from the filesystem -func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { - return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) -} - -// NewWebSocketOrRestReverseProxy creates a reverse proxy for REST or websocket based on url -func NewWebSocketOrRestReverseProxy(u *url.URL, opts *options.Options, auth hmacauth.HmacAuth) http.Handler { - u.Path = "" - proxy := NewReverseProxy(u, opts) - if !opts.PassHostHeader { - setProxyUpstreamHostHeader(proxy, u) - } else { - setProxyDirector(proxy) - } - - // this should give us a wss:// scheme if the url is https:// based. - var wsProxy *wsutil.ReverseProxy - if opts.ProxyWebSockets { - wsScheme := "ws" + strings.TrimPrefix(u.Scheme, "http") - wsURL := &url.URL{Scheme: wsScheme, Host: u.Host} - wsProxy = wsutil.NewSingleHostReverseProxy(wsURL) - if opts.SSLUpstreamInsecureSkipVerify { - wsProxy.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - } - return &UpstreamProxy{ - upstream: u.Host, - handler: proxy, - wsHandler: wsProxy, - auth: auth, - } -} - // NewOAuthProxy creates a new instance of OAuthProxy from the options provided func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) @@ -241,48 +107,13 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("error initialising session store: %v", err) } - serveMux := http.NewServeMux() - var auth hmacauth.HmacAuth - if sigData := opts.GetSignatureData(); sigData != nil { - auth = hmacauth.NewHmacAuth(sigData.Hash, []byte(sigData.Key), - SignatureHeader, SignatureHeaders) + templates := loadTemplates(opts.CustomTemplatesDir) + proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) + upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) + if err != nil { + return nil, fmt.Errorf("error initialising upstream proxy: %v", err) } - for _, u := range opts.GetProxyURLs() { - path := u.Path - host := u.Host - switch u.Scheme { - case httpScheme, httpsScheme: - logger.Printf("mapping path %q => upstream %q", path, u) - proxy := NewWebSocketOrRestReverseProxy(u, opts, auth) - serveMux.Handle(path, proxy) - case "static": - responseCode, err := strconv.Atoi(host) - if err != nil { - logger.Printf("unable to convert %q to int, use default \"200\"", host) - responseCode = 200 - } - serveMux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(responseCode) - fmt.Fprintf(rw, "Authenticated") - }) - case "file": - if u.Fragment != "" { - path = u.Fragment - } - logger.Printf("mapping path %q => file system %q", path, u.Path) - proxy := NewFileServer(path, u.Path) - uProxy := UpstreamProxy{ - upstream: path, - handler: proxy, - wsHandler: nil, - auth: nil, - } - serveMux.Handle(path, &uProxy) - default: - panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) - } - } for _, u := range opts.GetCompiledRegex() { logger.Printf("compiled skip-auth-regex => %q", u) } @@ -350,7 +181,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr provider: opts.GetProvider(), providerNameOverride: opts.ProviderName, sessionStore: sessionStore, - serveMux: serveMux, + serveMux: upstreamProxy, redirectURL: redirectURL, whitelistDomains: opts.WhitelistDomains, skipAuthRegex: opts.SkipAuthRegex, @@ -371,7 +202,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr PassAuthorization: opts.PassAuthorization, PreferEmailToUser: opts.PreferEmailToUser, SkipProviderButton: opts.SkipProviderButton, - templates: loadTemplates(opts.CustomTemplatesDir), + templates: templates, trustedIPs: trustedIPs, Banner: opts.Banner, Footer: opts.Footer, diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 40b2d9d1..00c33ff6 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "io/ioutil" - "net" "net/http" "net/http/httptest" "net/url" @@ -24,11 +23,11 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" + "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/net/websocket" ) const ( @@ -44,143 +43,6 @@ func init() { logger.SetFlags(logger.Lshortfile) } -type WebSocketOrRestHandler struct { - restHandler http.Handler - wsHandler http.Handler -} - -func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - h.wsHandler.ServeHTTP(w, r) - } else { - h.restHandler.ServeHTTP(w, r) - } -} - -func TestWebSocketProxy(t *testing.T) { - handler := WebSocketOrRestHandler{ - restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - hostname, _, _ := net.SplitHostPort(r.Host) - _, err := w.Write([]byte(hostname)) - if err != nil { - t.Fatal(err) - } - }), - wsHandler: websocket.Handler(func(ws *websocket.Conn) { - defer func(t *testing.T) { - if err := ws.Close(); err != nil { - t.Fatal(err) - } - }(t) - var data []byte - err := websocket.Message.Receive(ws, &data) - if err != nil { - t.Fatal(err) - } - err = websocket.Message.Send(ws, data) - if err != nil { - t.Fatal(err) - } - }), - } - backend := httptest.NewServer(&handler) - t.Cleanup(backend.Close) - - backendURL, _ := url.Parse(backend.URL) - - opts := baseTestOptions() - var auth hmacauth.HmacAuth - opts.PassHostHeader = true - proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth) - frontend := httptest.NewServer(proxyHandler) - t.Cleanup(frontend.Close) - - frontendURL, _ := url.Parse(frontend.URL) - frontendWSURL := "ws://" + frontendURL.Host + "/" - - ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/") - if err != nil { - t.Fatal(err) - } - request := []byte("hello, world!") - err = websocket.Message.Send(ws, request) - if err != nil { - t.Fatal(err) - } - var response = make([]byte, 1024) - err = websocket.Message.Receive(ws, &response) - if err != nil { - t.Fatal(err) - } - if g, e := string(request), string(response); g != e { - t.Errorf("got body %q; expected %q", g, e) - } - - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - res, _ := http.DefaultClient.Do(getReq) - bodyBytes, _ := ioutil.ReadAll(res.Body) - backendHostname, _, _ := net.SplitHostPort(backendURL.Host) - if g, e := string(bodyBytes), backendHostname; g != e { - t.Errorf("got body %q; expected %q", g, e) - } -} - -func TestNewReverseProxy(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - hostname, _, _ := net.SplitHostPort(r.Host) - _, err := w.Write([]byte(hostname)) - if err != nil { - t.Fatal(err) - } - })) - t.Cleanup(backend.Close) - - backendURL, _ := url.Parse(backend.URL) - backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) - backendHost := net.JoinHostPort(backendHostname, backendPort) - proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") - - proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second}) - setProxyUpstreamHostHeader(proxyHandler, proxyURL) - frontend := httptest.NewServer(proxyHandler) - t.Cleanup(frontend.Close) - - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - res, _ := http.DefaultClient.Do(getReq) - bodyBytes, _ := ioutil.ReadAll(res.Body) - if g, e := string(bodyBytes), backendHostname; g != e { - t.Errorf("got body %q; expected %q", g, e) - } -} - -func TestEncodedSlashes(t *testing.T) { - var seen string - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - seen = r.RequestURI - })) - t.Cleanup(backend.Close) - - b, _ := url.Parse(backend.URL) - proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second}) - setProxyDirector(proxyHandler) - frontend := httptest.NewServer(proxyHandler) - t.Cleanup(frontend.Close) - - f, _ := url.Parse(frontend.URL) - encodedPath := "/a%2Fb/?c=1" - getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} - _, err := http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - if seen != encodedPath { - t.Errorf("got bad request %q expected %q", seen, encodedPath) - } -} - func TestRobotsTxt(t *testing.T) { opts := baseTestOptions() err := validation.Validate(opts) @@ -562,7 +424,14 @@ func TestBasicAuthPassword(t *testing.T) { } })) opts := baseTestOptions() - opts.Upstreams = append(opts.Upstreams, providerServer.URL) + opts.UpstreamServers = options.Upstreams{ + { + ID: providerServer.URL, + Path: "/", + URI: providerServer.URL, + }, + } + opts.Cookie.Secure = false opts.PassBasicAuth = true opts.SetBasicAuth = true @@ -867,7 +736,7 @@ type PassAccessTokenTest struct { type PassAccessTokenTestOptions struct { PassAccessToken bool - ProxyUpstream string + ProxyUpstream options.Upstream } func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) { @@ -893,10 +762,17 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe })) patt.opts = baseTestOptions() - patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL) - if opts.ProxyUpstream != "" { - patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream) + patt.opts.UpstreamServers = options.Upstreams{ + { + ID: patt.providerServer.URL, + Path: "/", + URI: patt.providerServer.URL, + }, } + if opts.ProxyUpstream.ID != "" { + patt.opts.UpstreamServers = append(patt.opts.UpstreamServers, opts.ProxyUpstream) + } + patt.opts.Cookie.Secure = false patt.opts.PassAccessToken = opts.PassAccessToken err := validation.Validate(patt.opts) @@ -999,7 +875,11 @@ func TestForwardAccessTokenUpstream(t *testing.T) { func TestStaticProxyUpstream(t *testing.T) { patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, - ProxyUpstream: "static://200/static-proxy", + ProxyUpstream: options.Upstream{ + ID: "static-proxy", + Path: "/static-proxy", + Static: true, + }, }) if err != nil { t.Fatal(err) @@ -1572,7 +1452,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { t.Cleanup(upstream.Close) opts := baseTestOptions() - opts.Upstreams = append(opts.Upstreams, upstream.URL) + opts.UpstreamServers = options.Upstreams{ + { + ID: upstream.URL, + Path: "/", + URI: upstream.URL, + }, + } opts.SkipAuthPreflight = true err := validation.Validate(opts) assert.NoError(t, err) @@ -1641,7 +1527,13 @@ func NewSignatureTest() (*SignatureTest, error) { if err != nil { return nil, err } - opts.Upstreams = append(opts.Upstreams, upstream.URL) + opts.UpstreamServers = options.Upstreams{ + { + ID: upstream.URL, + Path: "/", + URI: upstream.URL, + }, + } providerHandler := func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte(`{"access_token": "my_auth_token"}`)) @@ -1716,7 +1608,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er } // This is used by the upstream to validate the signature. st.authenticator.auth = hmacauth.NewHmacAuth( - crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) + crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders) proxy.ServeHTTP(st.rw, req) return nil @@ -2110,7 +2002,13 @@ func Test_noCacheHeaders(t *testing.T) { t.Cleanup(upstream.Close) opts := baseTestOptions() - opts.Upstreams = []string{upstream.URL} + opts.UpstreamServers = options.Upstreams{ + { + ID: upstream.URL, + Path: "/", + URI: upstream.URL, + }, + } opts.SkipAuthRegex = []string{".*"} err := validation.Validate(opts) assert.NoError(t, err) @@ -2335,7 +2233,13 @@ func TestTrustedIPs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := baseTestOptions() - opts.Upstreams = []string{"static://200"} + opts.UpstreamServers = options.Upstreams{ + { + ID: "static", + Path: "/", + Static: true, + }, + } opts.TrustedIPs = tt.trustedIPs opts.ReverseProxy = tt.reverseProxy opts.RealClientIPHeader = tt.realClientIPHeader