diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7ba0dd4e..02fb5e5a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@
## Changes since v7.0.1
+- [#1047](https://github.com/oauth2-proxy/oauth2-proxy/pull/1047) Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers (@JoelSpeed)
- [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves)
- [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich)
- [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed)
diff --git a/docs/docs/configuration/alpha_config.md b/docs/docs/configuration/alpha_config.md
index a17d3515..1fe3f1f7 100644
--- a/docs/docs/configuration/alpha_config.md
+++ b/docs/docs/configuration/alpha_config.md
@@ -117,6 +117,8 @@ They may change between releases without notice.
| `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.
Once a user is authenticated, requests to the server will be proxied to
these upstream servers based on the path mappings defined in this list. |
| `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added
to requests to upstream servers.
Headers may source values from either the authenticated user's session
or from a static secret value. |
| `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added
to responses from the proxy.
This is typically used when using the proxy as an external authentication
provider in conjunction with another proxy such as NGINX and its
auth_request module.
Headers may source values from either the authenticated user's session
or from a static secret value. |
+| `server` | _[Server](#server)_ | Server is used to configure the HTTP(S) server for the proxy application.
You may choose to run both HTTP and HTTPS servers simultaneously.
This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
To use the secure server you must configure a TLS certificate and key. |
+| `metricsServer` | _[Server](#server)_ | MetricsServer is used to configure the HTTP(S) server for metrics.
You may choose to run both HTTP and HTTPS servers simultaneously.
This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
To use the secure server you must configure a TLS certificate and key. |
### ClaimSource
@@ -172,7 +174,7 @@ make up the header value
### SecretSource
-(**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue))
+(**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue), [TLS](#tls))
SecretSource references an individual secret value.
Only one source within the struct should be defined at any time.
@@ -183,6 +185,29 @@ Only one source within the struct should be defined at any time.
| `fromEnv` | _string_ | FromEnv expects the name of an environment variable. |
| `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. |
+### Server
+
+(**Appears on:** [AlphaOptions](#alphaoptions))
+
+Server represents the configuration for an HTTP(S) server
+
+| Field | Type | Description |
+| ----- | ---- | ----------- |
+| `BindAddress` | _string_ | BindAddress is the the address on which to serve traffic.
Leave blank or set to "-" to disable. |
+| `SecureBindAddress` | _string_ | SecureBindAddress is the the address on which to serve secure traffic.
Leave blank or set to "-" to disable. |
+| `TLS` | _[TLS](#tls)_ | TLS contains the information for loading the certificate and key for the
secure traffic. |
+
+### TLS
+
+(**Appears on:** [Server](#server))
+
+TLS contains the information for loading a TLS certifcate and key.
+
+| Field | Type | Description |
+| ----- | ---- | ----------- |
+| `Key` | _[SecretSource](#secretsource)_ | Key is the the TLS key data to use.
Typically this will come from a file. |
+| `Cert` | _[SecretSource](#secretsource)_ | Cert is the TLS certificate data to use.
Typically this will come from a file. |
+
### Upstream
(**Appears on:** [Upstreams](#upstreams))
diff --git a/go.mod b/go.mod
index 3b56cac8..b6ce05cd 100644
--- a/go.mod
+++ b/go.mod
@@ -30,6 +30,7 @@ require (
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/net v0.0.0-20200707034311-ab3426394381
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
+ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
google.golang.org/api v0.20.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/square/go-jose.v2 v2.4.1
diff --git a/go.sum b/go.sum
index 0c94e0ce..0b129b72 100644
--- a/go.sum
+++ b/go.sum
@@ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs=
+golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
diff --git a/http.go b/http.go
deleted file mode 100644
index 34700380..00000000
--- a/http.go
+++ /dev/null
@@ -1,136 +0,0 @@
-package main
-
-import (
- "context"
- "crypto/tls"
- "errors"
- "net"
- "net/http"
- "strings"
- "time"
-
- "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
- "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
-)
-
-// Server represents an HTTP server
-type Server struct {
- Handler http.Handler
- Opts *options.Options
- stop chan struct{} // channel for waiting shutdown
-}
-
-// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
-func (s *Server) ListenAndServe() {
- if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
- s.ServeHTTPS()
- } else {
- s.ServeHTTP()
- }
-}
-
-// ServeHTTP constructs a net.Listener and starts handling HTTP requests
-func (s *Server) ServeHTTP() {
- HTTPAddress := s.Opts.HTTPAddress
- var scheme string
-
- i := strings.Index(HTTPAddress, "://")
- if i > -1 {
- scheme = HTTPAddress[0:i]
- }
-
- var networkType string
- switch scheme {
- case "", "http":
- networkType = "tcp"
- default:
- networkType = scheme
- }
-
- slice := strings.SplitN(HTTPAddress, "//", 2)
- listenAddr := slice[len(slice)-1]
-
- listener, err := net.Listen(networkType, listenAddr)
- if err != nil {
- logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
- }
- logger.Printf("HTTP: listening on %s", listenAddr)
- s.serve(listener)
- logger.Printf("HTTP: closing %s", listener.Addr())
-}
-
-// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
-func (s *Server) ServeHTTPS() {
- addr := s.Opts.HTTPSAddress
- config := &tls.Config{
- MinVersion: tls.VersionTLS12,
- MaxVersion: tls.VersionTLS13,
- }
- if config.NextProtos == nil {
- config.NextProtos = []string{"http/1.1"}
- }
-
- var err error
- config.Certificates = make([]tls.Certificate, 1)
- config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
- if err != nil {
- logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
- }
-
- ln, err := net.Listen("tcp", addr)
- if err != nil {
- logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
- }
- logger.Printf("HTTPS: listening on %s", ln.Addr())
-
- tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
- s.serve(tlsListener)
- logger.Printf("HTTPS: closing %s", tlsListener.Addr())
-}
-
-func (s *Server) serve(listener net.Listener) {
- srv := &http.Server{Handler: s.Handler}
-
- // See https://golang.org/pkg/net/http/#Server.Shutdown
- idleConnsClosed := make(chan struct{})
- go func() {
- <-s.stop // wait notification for stopping server
-
- // We received an interrupt signal, shut down.
- if err := srv.Shutdown(context.Background()); err != nil {
- // Error from closing listeners, or context timeout:
- logger.Printf("HTTP server Shutdown: %v", err)
- }
- close(idleConnsClosed)
- }()
-
- err := srv.Serve(listener)
- if err != nil && !errors.Is(err, http.ErrServerClosed) {
- logger.Errorf("ERROR: http.Serve() - %s", err)
- }
- <-idleConnsClosed
-}
-
-// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
-// connections. It's used by ListenAndServe and ListenAndServeTLS so
-// dead TCP connections (e.g. closing laptop mid-download) eventually
-// go away.
-type tcpKeepAliveListener struct {
- *net.TCPListener
-}
-
-func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
- tc, err := ln.AcceptTCP()
- if err != nil {
- return nil, err
- }
- err = tc.SetKeepAlive(true)
- if err != nil {
- logger.Printf("Error setting Keep-Alive: %v", err)
- }
- err = tc.SetKeepAlivePeriod(3 * time.Minute)
- if err != nil {
- logger.Printf("Error setting Keep-Alive period: %v", err)
- }
- return tc, nil
-}
diff --git a/http_test.go b/http_test.go
deleted file mode 100644
index f4e12843..00000000
--- a/http_test.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package main
-
-import (
- "net/http"
- "sync"
- "testing"
- "time"
-
- "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
- "github.com/stretchr/testify/assert"
-)
-
-func TestGracefulShutdown(t *testing.T) {
- opts := options.NewOptions()
- stop := make(chan struct{}, 1)
- srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- defer wg.Done()
- srv.ServeHTTP()
- }()
-
- stop <- struct{}{} // emulate catching signals
-
- // An idiomatic for sync.WaitGroup with timeout
- c := make(chan struct{})
- go func() {
- defer close(c)
- wg.Wait()
- }()
- select {
- case <-c:
- case <-time.After(1 * time.Second):
- t.Fatal("Server should return gracefully but timeout has occurred")
- }
-
- assert.Len(t, stop, 0) // check if stop chan is empty
-}
diff --git a/main.go b/main.go
index 924e875e..97f2c5a0 100644
--- a/main.go
+++ b/main.go
@@ -1,20 +1,15 @@
package main
import (
- "context"
"fmt"
"math/rand"
- "net/http"
"os"
- "os/signal"
"runtime"
- "syscall"
"time"
"github.com/ghodss/yaml"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
- "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
"github.com/spf13/pflag"
)
@@ -67,54 +62,9 @@ func main() {
rand.Seed(time.Now().UnixNano())
- oauthProxyStop := make(chan struct{}, 1)
- metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop)
-
- s := &Server{
- Handler: oauthproxy,
- Opts: opts,
- stop: oauthProxyStop,
+ if err := oauthproxy.Start(); err != nil {
+ logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
}
- // Observe signals in background goroutine.
- go func() {
- sigint := make(chan os.Signal, 1)
- signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
- <-sigint
- s.stop <- struct{}{} // notify having caught signal stop oauthproxy
- close(metricsStop) // and the metrics endpoint
- }()
- s.ListenAndServe()
-}
-
-// startMetricsServer will start the metrics server on the specified address.
-// It always return a channel to signal stop even when it does not run.
-func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} {
- stop := make(chan struct{}, 1)
-
- // Attempt to setup the metrics endpoint if we have an address
- if address != "" {
- s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler}
- go func() {
- // ListenAndServe always returns a non-nil error. After Shutdown or
- // Close, the returned error is ErrServerClosed
- if err := s.ListenAndServe(); err != http.ErrServerClosed {
- logger.Println(err)
- // Stop the metrics shutdown go routine
- close(stop)
- // Stop the oauthproxy server, we have encounter an unexpected error
- close(oauthProxyStop)
- }
- }()
-
- go func() {
- <-stop
- if err := s.Shutdown(context.Background()); err != nil {
- logger.Print(err)
- }
- }()
- }
-
- return stop
}
// loadConfiguration will load in the user's configuration.
diff --git a/main_test.go b/main_test.go
index a91940e7..3273abfa 100644
--- a/main_test.go
+++ b/main_test.go
@@ -15,6 +15,7 @@ import (
var _ = Describe("Configuration Loading Suite", func() {
const testLegacyConfig = `
+http_address="127.0.0.1:4180"
upstreams="http://httpbin"
set_basic_auth="true"
basic_auth_password="super-secret-password"
@@ -54,10 +55,11 @@ injectResponseHeaders:
prefix: "Basic "
basicAuthPassword:
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
+server:
+ bindAddress: "127.0.0.1:4180"
`
const testCoreConfig = `
-http_address="0.0.0.0:4180"
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
provider="oidc"
email_domains="example.com"
@@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback"
opts, err := options.NewLegacyOptions().ToOptions()
Expect(err).ToNot(HaveOccurred())
- opts.HTTPAddress = "0.0.0.0:4180"
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
opts.ProviderType = "oidc"
opts.EmailDomains = []string{"example.com"}
@@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
configContent: testCoreConfig,
alphaConfigContent: testAlphaConfig + ":",
expectedOptions: func() *options.Options { return nil },
- expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"),
+ expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"),
}),
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
configContent: testCoreConfig + "unknown_field=\"something\"",
diff --git a/oauthproxy.go b/oauthproxy.go
index 43c64525..0a9669f3 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -8,8 +8,11 @@ import (
"net"
"net/http"
"net/url"
+ "os"
+ "os/signal"
"regexp"
"strings"
+ "syscall"
"time"
"github.com/justinas/alice"
@@ -21,6 +24,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
+ proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
@@ -102,6 +106,7 @@ type OAuthProxy struct {
headersChain alice.Chain
preAuthChain alice.Chain
pageWriter pagewriter.Writer
+ server proxyhttp.Server
}
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, fmt.Errorf("could not build headers chain: %v", err)
}
- return &OAuthProxy{
+ p := &OAuthProxy{
CookieName: opts.Cookie.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
CookieSeed: opts.Cookie.Secret,
@@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
- }, nil
+ }
+
+ if err := p.setupServer(opts); err != nil {
+ return nil, fmt.Errorf("error setting up server: %v", err)
+ }
+
+ return p, nil
+}
+
+func (p *OAuthProxy) Start() error {
+ if p.server == nil {
+ // We have to call setupServer before Start is called.
+ // If this doesn't happen it's a programming error.
+ panic("server has not been initialised")
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Observe signals in background goroutine.
+ go func() {
+ sigint := make(chan os.Signal, 1)
+ signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
+ <-sigint
+ cancel() // cancel the context
+ }()
+
+ return p.server.Start(ctx)
+}
+
+func (p *OAuthProxy) setupServer(opts *options.Options) error {
+ serverOpts := proxyhttp.Opts{
+ Handler: p,
+ BindAddress: opts.Server.BindAddress,
+ SecureBindAddress: opts.Server.SecureBindAddress,
+ TLS: opts.Server.TLS,
+ }
+
+ appServer, err := proxyhttp.NewServer(serverOpts)
+ if err != nil {
+ return fmt.Errorf("could not build app server: %v", err)
+ }
+
+ metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{
+ Handler: middleware.DefaultMetricsHandler,
+ BindAddress: opts.MetricsServer.BindAddress,
+ SecureBindAddress: opts.MetricsServer.BindAddress,
+ TLS: opts.MetricsServer.TLS,
+ })
+ if err != nil {
+ return fmt.Errorf("could not build metrics server: %v", err)
+ }
+
+ p.server = proxyhttp.NewServerGroup(appServer, metricsServer)
+ return nil
}
// buildPreAuthChain constructs a chain that should process every request before
@@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
if opts.ForceHTTPS {
- _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
+ _, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress)
if err != nil {
- return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err)
+ return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err)
}
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
}
diff --git a/oauthproxy_test.go b/oauthproxy_test.go
index 06014c9b..a2805c34 100644
--- a/oauthproxy_test.go
+++ b/oauthproxy_test.go
@@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options {
},
},
}
+
return opts
}
diff --git a/pkg/apis/options/alpha_options.go b/pkg/apis/options/alpha_options.go
index 6016bac0..c9a86310 100644
--- a/pkg/apis/options/alpha_options.go
+++ b/pkg/apis/options/alpha_options.go
@@ -28,6 +28,18 @@ type AlphaOptions struct {
// Headers may source values from either the authenticated user's session
// or from a static secret value.
InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"`
+
+ // Server is used to configure the HTTP(S) server for the proxy application.
+ // You may choose to run both HTTP and HTTPS servers simultaneously.
+ // This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
+ // To use the secure server you must configure a TLS certificate and key.
+ Server Server `json:"server,omitempty"`
+
+ // MetricsServer is used to configure the HTTP(S) server for metrics.
+ // You may choose to run both HTTP and HTTPS servers simultaneously.
+ // This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
+ // To use the secure server you must configure a TLS certificate and key.
+ MetricsServer Server `json:"metricsServer,omitempty"`
}
// MergeInto replaces alpha options in the Options struct with the values
@@ -36,6 +48,8 @@ func (a *AlphaOptions) MergeInto(opts *Options) {
opts.UpstreamServers = a.Upstreams
opts.InjectRequestHeaders = a.InjectRequestHeaders
opts.InjectResponseHeaders = a.InjectResponseHeaders
+ opts.Server = a.Server
+ opts.MetricsServer = a.MetricsServer
}
// ExtractFrom populates the fields in the AlphaOptions with the values from
@@ -44,4 +58,6 @@ func (a *AlphaOptions) ExtractFrom(opts *Options) {
a.Upstreams = opts.UpstreamServers
a.InjectRequestHeaders = opts.InjectRequestHeaders
a.InjectResponseHeaders = opts.InjectResponseHeaders
+ a.Server = opts.Server
+ a.MetricsServer = opts.MetricsServer
}
diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go
index d3fabd58..66205de0 100644
--- a/pkg/apis/options/legacy_options.go
+++ b/pkg/apis/options/legacy_options.go
@@ -18,6 +18,9 @@ type LegacyOptions struct {
// Legacy options for injecting request/response headers
LegacyHeaders LegacyHeaders `cfg:",squash"`
+ // Legacy options for the server address and TLS
+ LegacyServer LegacyServer `cfg:",squash"`
+
Options Options `cfg:",squash"`
}
@@ -35,6 +38,11 @@ func NewLegacyOptions() *LegacyOptions {
SkipAuthStripHeaders: true,
},
+ LegacyServer: LegacyServer{
+ HTTPAddress: "127.0.0.1:4180",
+ HTTPSAddress: ":443",
+ },
+
Options: *NewOptions(),
}
}
@@ -44,6 +52,7 @@ func NewLegacyFlagSet() *pflag.FlagSet {
flagSet.AddFlagSet(legacyUpstreamsFlagSet())
flagSet.AddFlagSet(legacyHeadersFlagSet())
+ flagSet.AddFlagSet(legacyServerFlagset())
return flagSet
}
@@ -56,6 +65,8 @@ func (l *LegacyOptions) ToOptions() (*Options, error) {
l.Options.UpstreamServers = upstreams
l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert()
+ l.Options.Server, l.Options.MetricsServer = l.LegacyServer.convert()
+
return &l.Options, nil
}
@@ -403,3 +414,69 @@ func getXAuthRequestAccessTokenHeader() Header {
},
}
}
+
+type LegacyServer struct {
+ MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
+ MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_address"`
+ MetricsTLSCertFile string `flag:"metrics-tls-cert-file" cfg:"tls_cert_file"`
+ MetricsTLSKeyFile string `flag:"metrics-tls-key-file" cfg:"tls_key_file"`
+ HTTPAddress string `flag:"http-address" cfg:"http_address"`
+ HTTPSAddress string `flag:"https-address" cfg:"https_address"`
+ TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
+ TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
+}
+
+func legacyServerFlagset() *pflag.FlagSet {
+ flagSet := pflag.NewFlagSet("server", pflag.ExitOnError)
+
+ flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
+ flagSet.String("metrics-secure-address", "", "the address /metrics will be served on for HTTPS clients (e.g. \":9100\")")
+ flagSet.String("metrics-tls-cert-file", "", "path to certificate file for secure metrics server")
+ flagSet.String("metrics-tls-key-file", "", "path to private key file for secure metrics server")
+ flagSet.String("http-address", "127.0.0.1:4180", "[http://]: or unix:// to listen on for HTTP clients")
+ flagSet.String("https-address", ":443", ": to listen on for HTTPS clients")
+ flagSet.String("tls-cert-file", "", "path to certificate file")
+ flagSet.String("tls-key-file", "", "path to private key file")
+
+ return flagSet
+}
+
+func (l LegacyServer) convert() (Server, Server) {
+ appServer := Server{
+ BindAddress: l.HTTPAddress,
+ SecureBindAddress: l.HTTPSAddress,
+ }
+ if l.TLSKeyFile != "" || l.TLSCertFile != "" {
+ appServer.TLS = &TLS{
+ Key: &SecretSource{
+ FromFile: l.TLSKeyFile,
+ },
+ Cert: &SecretSource{
+ FromFile: l.TLSCertFile,
+ },
+ }
+ // Preserve backwards compatibility, only run one server
+ appServer.BindAddress = ""
+ } else {
+ // Disable the HTTPS server if there's no certificates.
+ // This preserves backwards compatibility.
+ appServer.SecureBindAddress = ""
+ }
+
+ metricsServer := Server{
+ BindAddress: l.MetricsAddress,
+ SecureBindAddress: l.MetricsSecureAddress,
+ }
+ if l.MetricsTLSKeyFile != "" || l.MetricsTLSCertFile != "" {
+ metricsServer.TLS = &TLS{
+ Key: &SecretSource{
+ FromFile: l.MetricsTLSKeyFile,
+ },
+ Cert: &SecretSource{
+ FromFile: l.MetricsTLSCertFile,
+ },
+ }
+ }
+
+ return appServer, metricsServer
+}
diff --git a/pkg/apis/options/legacy_options_test.go b/pkg/apis/options/legacy_options_test.go
index dbac5793..9f397f6f 100644
--- a/pkg/apis/options/legacy_options_test.go
+++ b/pkg/apis/options/legacy_options_test.go
@@ -106,6 +106,10 @@ var _ = Describe("Legacy Options", func() {
opts.InjectResponseHeaders = []Header{}
+ opts.Server = Server{
+ BindAddress: "127.0.0.1:4180",
+ }
+
converted, err := legacyOpts.ToOptions()
Expect(err).ToNot(HaveOccurred())
Expect(converted).To(Equal(opts))
@@ -759,4 +763,93 @@ var _ = Describe("Legacy Options", func() {
}),
)
})
+
+ Context("Legacy Servers", func() {
+ type legacyServersTableInput struct {
+ legacyServer LegacyServer
+ expectedAppServer Server
+ expectedMetricsServer Server
+ }
+
+ const (
+ insecureAddr = "127.0.0.1:8080"
+ insecureMetricsAddr = ":9090"
+ secureAddr = ":443"
+ secureMetricsAddr = ":9443"
+ crtPath = "tls.crt"
+ keyPath = "tls.key"
+ )
+
+ var tlsConfig = &TLS{
+ Cert: &SecretSource{
+ FromFile: crtPath,
+ },
+ Key: &SecretSource{
+ FromFile: keyPath,
+ },
+ }
+
+ DescribeTable("should convert to app and metrics servers",
+ func(in legacyServersTableInput) {
+ appServer, metricsServer := in.legacyServer.convert()
+ Expect(appServer).To(Equal(in.expectedAppServer))
+ Expect(metricsServer).To(Equal(in.expectedMetricsServer))
+ },
+ Entry("with default options only starts app HTTP server", legacyServersTableInput{
+ legacyServer: LegacyServer{
+ HTTPAddress: insecureAddr,
+ HTTPSAddress: secureAddr,
+ },
+ expectedAppServer: Server{
+ BindAddress: insecureAddr,
+ },
+ }),
+ Entry("with TLS options specified only starts app HTTPS server", legacyServersTableInput{
+ legacyServer: LegacyServer{
+ HTTPAddress: insecureAddr,
+ HTTPSAddress: secureAddr,
+ TLSKeyFile: keyPath,
+ TLSCertFile: crtPath,
+ },
+ expectedAppServer: Server{
+ SecureBindAddress: secureAddr,
+ TLS: tlsConfig,
+ },
+ }),
+ Entry("with metrics HTTP and HTTPS addresses", legacyServersTableInput{
+ legacyServer: LegacyServer{
+ HTTPAddress: insecureAddr,
+ HTTPSAddress: secureAddr,
+ MetricsAddress: insecureMetricsAddr,
+ MetricsSecureAddress: secureMetricsAddr,
+ },
+ expectedAppServer: Server{
+ BindAddress: insecureAddr,
+ },
+ expectedMetricsServer: Server{
+ BindAddress: insecureMetricsAddr,
+ SecureBindAddress: secureMetricsAddr,
+ },
+ }),
+ Entry("with metrics HTTPS and tls cert/key", legacyServersTableInput{
+ legacyServer: LegacyServer{
+ HTTPAddress: insecureAddr,
+ HTTPSAddress: secureAddr,
+ MetricsAddress: insecureMetricsAddr,
+ MetricsSecureAddress: secureMetricsAddr,
+ MetricsTLSKeyFile: keyPath,
+ MetricsTLSCertFile: crtPath,
+ },
+ expectedAppServer: Server{
+ BindAddress: insecureAddr,
+ },
+ expectedMetricsServer: Server{
+ BindAddress: insecureMetricsAddr,
+ SecureBindAddress: secureMetricsAddr,
+ TLS: tlsConfig,
+ },
+ }),
+ )
+
+ })
})
diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go
index ef0d090f..34cb75d5 100644
--- a/pkg/apis/options/options.go
+++ b/pkg/apis/options/options.go
@@ -22,9 +22,6 @@ type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"`
PingPath string `flag:"ping-path" cfg:"ping_path"`
PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"`
- MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
- HTTPAddress string `flag:"http-address" cfg:"http_address"`
- HTTPSAddress string `flag:"https-address" cfg:"https_address"`
ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"`
RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"`
TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"`
@@ -33,8 +30,6 @@ type Options struct {
ClientID string `flag:"client-id" cfg:"client_id"`
ClientSecret string `flag:"client-secret" cfg:"client_secret"`
ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file"`
- TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
- TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"`
@@ -68,6 +63,9 @@ type Options struct {
InjectRequestHeaders []Header `cfg:",internal"`
InjectResponseHeaders []Header `cfg:",internal"`
+ Server Server `cfg:",internal"`
+ MetricsServer Server `cfg:",internal"`
+
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"`
SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"`
@@ -136,10 +134,7 @@ func NewOptions() *Options {
return &Options{
ProxyPrefix: "/oauth2",
ProviderType: "google",
- MetricsAddress: "",
PingPath: "/ping",
- HTTPAddress: "127.0.0.1:4180",
- HTTPSAddress: ":443",
RealClientIPHeader: "X-Real-IP",
ForceHTTPS: false,
Cookie: cookieDefaults(),
@@ -162,14 +157,10 @@ func NewOptions() *Options {
func NewFlagSet() *pflag.FlagSet {
flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError)
- flagSet.String("http-address", "127.0.0.1:4180", "[http://]: or unix:// to listen on for HTTP clients")
- flagSet.String("https-address", ":443", ": to listen on for HTTPS clients")
flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted")
flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)")
flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.")
flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests")
- flagSet.String("tls-cert-file", "", "path to certificate file")
- flagSet.String("tls-key-file", "", "path to private key file")
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)")
flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods")
@@ -204,7 +195,6 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)")
flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks")
flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks")
- flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
flagSet.String("session-store-type", "cookie", "the session storage provider to use")
flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)")
flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])")
diff --git a/pkg/apis/options/server.go b/pkg/apis/options/server.go
new file mode 100644
index 00000000..a100d9c5
--- /dev/null
+++ b/pkg/apis/options/server.go
@@ -0,0 +1,27 @@
+package options
+
+// Server represents the configuration for an HTTP(S) server
+type Server struct {
+ // BindAddress is the the address on which to serve traffic.
+ // Leave blank or set to "-" to disable.
+ BindAddress string
+
+ // SecureBindAddress is the the address on which to serve secure traffic.
+ // Leave blank or set to "-" to disable.
+ SecureBindAddress string
+
+ // TLS contains the information for loading the certificate and key for the
+ // secure traffic.
+ TLS *TLS
+}
+
+// TLS contains the information for loading a TLS certifcate and key.
+type TLS struct {
+ // Key is the the TLS key data to use.
+ // Typically this will come from a file.
+ Key *SecretSource
+
+ // Cert is the TLS certificate data to use.
+ // Typically this will come from a file.
+ Cert *SecretSource
+}
diff --git a/pkg/http/http_suite_test.go b/pkg/http/http_suite_test.go
new file mode 100644
index 00000000..13bd56e9
--- /dev/null
+++ b/pkg/http/http_suite_test.go
@@ -0,0 +1,88 @@
+package http
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "math/big"
+ "net"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var certData []byte
+var certDataSource, keyDataSource options.SecretSource
+var client *http.Client
+
+func TestHTTPSuite(t *testing.T) {
+ logger.SetOutput(GinkgoWriter)
+ logger.SetErrOutput(GinkgoWriter)
+
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "HTTP")
+}
+
+var _ = BeforeSuite(func() {
+ By("Generating a self-signed cert for TLS tests", func() {
+ priv, err := rsa.GenerateKey(rand.Reader, 2048)
+ Expect(err).ToNot(HaveOccurred())
+
+ keyOut := bytes.NewBuffer(nil)
+ privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed())
+ keyDataSource.Value = keyOut.Bytes()
+
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ Expect(err).ToNot(HaveOccurred())
+
+ template := x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ Organization: []string{"OAuth2 Proxy Test Suite"},
+ },
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(time.Hour),
+ IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ }
+
+ certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
+ Expect(err).ToNot(HaveOccurred())
+ certData = certBytes
+
+ certOut := bytes.NewBuffer(nil)
+ Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed())
+ certDataSource.Value = certOut.Bytes()
+ })
+
+ By("Setting up a http client", func() {
+ cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value)
+ Expect(err).ToNot(HaveOccurred())
+
+ certificate, err := x509.ParseCertificate(cert.Certificate[0])
+ Expect(err).ToNot(HaveOccurred())
+
+ certpool := x509.NewCertPool()
+ certpool.AddCert(certificate)
+
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+ transport.TLSClientConfig.RootCAs = certpool
+
+ client = &http.Client{
+ Transport: transport,
+ }
+ })
+})
diff --git a/pkg/http/server.go b/pkg/http/server.go
new file mode 100644
index 00000000..e9a1d248
--- /dev/null
+++ b/pkg/http/server.go
@@ -0,0 +1,245 @@
+package http
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
+ "golang.org/x/sync/errgroup"
+)
+
+// Server represents an HTTP or HTTPS server.
+type Server interface {
+ // Start blocks and runs the server.
+ Start(ctx context.Context) error
+}
+
+// Opts contains the information required to set up the server.
+type Opts struct {
+ // Handler is the http.Handler to be used to serve http pages by the server.
+ Handler http.Handler
+
+ // BindAddress is the address the HTTP server should listen on.
+ BindAddress string
+
+ // SecureBindAddress is the address the HTTPS server should listen on.
+ SecureBindAddress string
+
+ // TLS is the TLS configuration for the server.
+ TLS *options.TLS
+}
+
+// NewServer creates a new Server from the options given.
+func NewServer(opts Opts) (Server, error) {
+ s := &server{
+ handler: opts.Handler,
+ }
+ if err := s.setupListener(opts); err != nil {
+ return nil, fmt.Errorf("error setting up listener: %v", err)
+ }
+ if err := s.setupTLSListener(opts); err != nil {
+ return nil, fmt.Errorf("error setting up TLS listener: %v", err)
+ }
+
+ return s, nil
+}
+
+// server is an implementation of the Server interface.
+type server struct {
+ handler http.Handler
+
+ listener net.Listener
+ tlsListener net.Listener
+}
+
+// setupListener sets the server listener if the HTTP server is enabled.
+// The HTTP server can be disabled by setting the BindAddress to "-" or by
+// leaving it empty.
+func (s *server) setupListener(opts Opts) error {
+ if opts.BindAddress == "" || opts.BindAddress == "-" {
+ // No HTTP listener required
+ return nil
+ }
+
+ networkType := getNetworkScheme(opts.BindAddress)
+ listenAddr := getListenAddress(opts.BindAddress)
+
+ listener, err := net.Listen(networkType, listenAddr)
+ if err != nil {
+ return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err)
+ }
+ s.listener = listener
+
+ return nil
+}
+
+// setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
+// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
+// leaving it empty.
+func (s *server) setupTLSListener(opts Opts) error {
+ if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" {
+ // No HTTPS listener required
+ return nil
+ }
+
+ config := &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ MaxVersion: tls.VersionTLS13,
+ NextProtos: []string{"http/1.1"},
+ }
+ if opts.TLS == nil {
+ return errors.New("no TLS config provided")
+ }
+ cert, err := getCertificate(opts.TLS)
+ if err != nil {
+ return fmt.Errorf("could not load certificate: %v", err)
+ }
+ config.Certificates = []tls.Certificate{cert}
+
+ listenAddr := getListenAddress(opts.SecureBindAddress)
+
+ listener, err := net.Listen("tcp", listenAddr)
+ if err != nil {
+ return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
+ }
+
+ s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
+ return nil
+}
+
+// Start starts the HTTP and HTTPS server if applicable.
+// It will block until the context is cancelled.
+// If any errors occur, only the first error will be returned.
+func (s *server) Start(ctx context.Context) error {
+ g, groupCtx := errgroup.WithContext(ctx)
+
+ if s.listener != nil {
+ g.Go(func() error {
+ if err := s.startServer(groupCtx, s.listener); err != nil {
+ return fmt.Errorf("error starting insecure server: %v", err)
+ }
+ return nil
+ })
+ }
+
+ if s.tlsListener != nil {
+ g.Go(func() error {
+ if err := s.startServer(groupCtx, s.tlsListener); err != nil {
+ return fmt.Errorf("error starting secure server: %v", err)
+ }
+ return nil
+ })
+ }
+
+ return g.Wait()
+}
+
+// startServer creates and starts a new server with the given listener.
+// When the given context is cancelled the server will be shutdown.
+// If any errors occur, only the first error will be returned.
+func (s *server) startServer(ctx context.Context, listener net.Listener) error {
+ srv := &http.Server{Handler: s.handler}
+ g, groupCtx := errgroup.WithContext(ctx)
+
+ g.Go(func() error {
+ <-groupCtx.Done()
+
+ if err := srv.Shutdown(context.Background()); err != nil {
+ return fmt.Errorf("error shutting down server: %v", err)
+ }
+ return nil
+ })
+
+ g.Go(func() error {
+ if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ return fmt.Errorf("could not start server: %v", err)
+ }
+ return nil
+ })
+
+ return g.Wait()
+}
+
+// getNetworkScheme gets the scheme for the HTTP server.
+func getNetworkScheme(addr string) string {
+ var scheme string
+ i := strings.Index(addr, "://")
+ if i > -1 {
+ scheme = addr[0:i]
+ }
+
+ switch scheme {
+ case "", "http":
+ return "tcp"
+ default:
+ return scheme
+ }
+}
+
+// getListenAddress gets the address for the HTTP server.
+func getListenAddress(addr string) string {
+ slice := strings.SplitN(addr, "//", 2)
+ return slice[len(slice)-1]
+}
+
+// getCertificate loads the certificate data from the TLS config.
+func getCertificate(opts *options.TLS) (tls.Certificate, error) {
+ keyData, err := getSecretValue(opts.Key)
+ if err != nil {
+ return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
+ }
+
+ certData, err := getSecretValue(opts.Cert)
+ if err != nil {
+ return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
+ }
+
+ cert, err := tls.X509KeyPair(certData, keyData)
+ if err != nil {
+ return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
+ }
+
+ return cert, nil
+}
+
+// getSecretValue wraps util.GetSecretValue so that we can return an error if no
+// source is provided.
+func getSecretValue(src *options.SecretSource) ([]byte, error) {
+ if src == nil {
+ return nil, errors.New("no configuration provided")
+ }
+ return util.GetSecretValue(src)
+}
+
+// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
+// connections. It's used by so that dead TCP connections (e.g. closing laptop
+// mid-download) eventually go away.
+type tcpKeepAliveListener struct {
+ *net.TCPListener
+}
+
+// Accept implements the TCPListener interface.
+// It sets the keep alive period to 3 minutes for each connection.
+func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
+ tc, err := ln.AcceptTCP()
+ if err != nil {
+ return nil, err
+ }
+ err = tc.SetKeepAlive(true)
+ if err != nil {
+ logger.Errorf("Error setting Keep-Alive: %v", err)
+ }
+ err = tc.SetKeepAlivePeriod(3 * time.Minute)
+ if err != nil {
+ logger.Printf("Error setting Keep-Alive period: %v", err)
+ }
+ return tc, nil
+}
diff --git a/pkg/http/server_group.go b/pkg/http/server_group.go
new file mode 100644
index 00000000..261d9837
--- /dev/null
+++ b/pkg/http/server_group.go
@@ -0,0 +1,35 @@
+package http
+
+import (
+ "context"
+
+ "golang.org/x/sync/errgroup"
+)
+
+// NewServerGroup creates a new Server to start and gracefully stop a collection
+// of Servers.
+func NewServerGroup(servers ...Server) Server {
+ return &serverGroup{
+ servers: servers,
+ }
+}
+
+// serverGroup manages the starting and graceful shutdown of a collection of
+// servers.
+type serverGroup struct {
+ servers []Server
+}
+
+// Start runs the servers in the server group.
+func (s *serverGroup) Start(ctx context.Context) error {
+ g, groupCtx := errgroup.WithContext(ctx)
+
+ for _, server := range s.servers {
+ srv := server
+ g.Go(func() error {
+ return srv.Start(groupCtx)
+ })
+ }
+
+ return g.Wait()
+}
diff --git a/pkg/http/server_group_test.go b/pkg/http/server_group_test.go
new file mode 100644
index 00000000..97748b4d
--- /dev/null
+++ b/pkg/http/server_group_test.go
@@ -0,0 +1,102 @@
+package http
+
+import (
+ "context"
+ "errors"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Server Group", func() {
+ var m1, m2, m3 *mockServer
+ var ctx context.Context
+ var cancel context.CancelFunc
+ var group Server
+
+ BeforeEach(func() {
+ ctx, cancel = context.WithCancel(context.Background())
+
+ m1 = newMockServer()
+ m2 = newMockServer()
+ m3 = newMockServer()
+ group = NewServerGroup(m1, m2, m3)
+ })
+
+ AfterEach(func() {
+ cancel()
+ })
+
+ It("starts each server in the group", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(group.Start(ctx)).To(Succeed())
+ }()
+
+ Eventually(m1.started).Should(BeClosed(), "mock server 1 not started")
+ Eventually(m2.started).Should(BeClosed(), "mock server 2 not started")
+ Eventually(m3.started).Should(BeClosed(), "mock server 3 not started")
+ })
+
+ It("stop each server in the group when the context is cancelled", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(group.Start(ctx)).To(Succeed())
+ }()
+
+ cancel()
+ Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
+ Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
+ Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
+ })
+
+ It("stop each server in the group when the an error occurs", func() {
+ err := errors.New("server error")
+ go func() {
+ defer GinkgoRecover()
+ Expect(group.Start(ctx)).To(MatchError(err))
+ }()
+
+ m2.errors <- err
+ Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
+ Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
+ Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
+ })
+})
+
+// mockServer is used to test the server group can start
+// and stop multiple servers simultaneously.
+type mockServer struct {
+ started chan struct{}
+ startClosed bool
+ stopped chan struct{}
+ stopClosed bool
+ errors chan error
+}
+
+func newMockServer() *mockServer {
+ return &mockServer{
+ started: make(chan struct{}),
+ stopped: make(chan struct{}),
+ errors: make(chan error),
+ }
+}
+
+func (m *mockServer) Start(ctx context.Context) error {
+ if !m.startClosed {
+ close(m.started)
+ m.startClosed = true
+ }
+ defer func() {
+ if !m.stopClosed {
+ close(m.stopped)
+ m.stopClosed = true
+ }
+ }()
+ select {
+ case <-ctx.Done():
+ return nil
+ case err := <-m.errors:
+ return err
+ }
+}
diff --git a/pkg/http/server_test.go b/pkg/http/server_test.go
new file mode 100644
index 00000000..a2b995a2
--- /dev/null
+++ b/pkg/http/server_test.go
@@ -0,0 +1,472 @@
+package http
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+
+ "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/ginkgo/extensions/table"
+ . "github.com/onsi/gomega"
+)
+
+const hello = "Hello World!"
+
+var _ = Describe("Server", func() {
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Write([]byte(hello))
+ })
+
+ Context("NewServer", func() {
+ type newServerTableInput struct {
+ opts Opts
+ expectedErr error
+ expectHTTPListener bool
+ expectTLSListener bool
+ }
+
+ DescribeTable("When creating the new server from the options", func(in *newServerTableInput) {
+ srv, err := NewServer(in.opts)
+ if in.expectedErr != nil {
+ Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error())))
+ Expect(srv).To(BeNil())
+ return
+ }
+
+ Expect(err).ToNot(HaveOccurred())
+
+ s, ok := srv.(*server)
+ Expect(ok).To(BeTrue())
+
+ Expect(s.listener != nil).To(Equal(in.expectHTTPListener))
+ if in.expectHTTPListener {
+ Expect(s.listener.Close()).To(Succeed())
+ }
+ Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener))
+ if in.expectTLSListener {
+ Expect(s.tlsListener.Close()).To(Succeed())
+ }
+ },
+ Entry("with a valid http bind address", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ BindAddress: "127.0.0.1:0",
+ },
+ expectedErr: nil,
+ expectHTTPListener: true,
+ expectTLSListener: false,
+ }),
+ Entry("with a valid https bind address, with no TLS config", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:0",
+ },
+ expectedErr: errors.New("error setting up TLS listener: no TLS config provided"),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &certDataSource,
+ },
+ },
+ expectedErr: nil,
+ expectHTTPListener: false,
+ expectTLSListener: true,
+ }),
+ Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ BindAddress: "127.0.0.1:0",
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &certDataSource,
+ },
+ },
+ expectedErr: nil,
+ expectHTTPListener: true,
+ expectTLSListener: true,
+ }),
+ Entry("with a \"-\" for the bind address", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ BindAddress: "-",
+ },
+ expectedErr: nil,
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with a \"-\" for the secure bind address", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "-",
+ },
+ expectedErr: nil,
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with an invalid bind address scheme", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ BindAddress: "invalid://127.0.0.1:0",
+ },
+ expectedErr: errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with an invalid secure bind address scheme", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "invalid://127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &certDataSource,
+ },
+ },
+ expectedErr: nil,
+ expectHTTPListener: false,
+ expectTLSListener: true,
+ }),
+ Entry("with an invalid bind address port", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ BindAddress: "127.0.0.1:a",
+ },
+ expectedErr: errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with an invalid secure bind address port", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:a",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &certDataSource,
+ },
+ },
+ expectedErr: errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with an invalid TLS key", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &options.SecretSource{
+ Value: []byte("invalid"),
+ },
+ Cert: &certDataSource,
+ },
+ },
+ expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with an invalid TLS cert", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &options.SecretSource{
+ Value: []byte("invalid"),
+ },
+ },
+ },
+ expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with no TLS key", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Cert: &certDataSource,
+ },
+ },
+ expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("with no TLS cert", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ },
+ },
+ expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"),
+ expectHTTPListener: false,
+ expectTLSListener: false,
+ }),
+ Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ BindAddress: "http://127.0.0.1:0",
+ },
+ expectedErr: nil,
+ expectHTTPListener: true,
+ expectTLSListener: false,
+ }),
+ Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{
+ opts: Opts{
+ Handler: handler,
+ SecureBindAddress: "https://127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &certDataSource,
+ },
+ },
+ expectedErr: nil,
+ expectHTTPListener: false,
+ expectTLSListener: true,
+ }),
+ )
+ })
+
+ Context("Start", func() {
+ var srv Server
+ var ctx context.Context
+ var cancel context.CancelFunc
+
+ BeforeEach(func() {
+ ctx, cancel = context.WithCancel(context.Background())
+ })
+
+ AfterEach(func() {
+ cancel()
+ })
+
+ Context("with an http server", func() {
+ var listenAddr string
+
+ BeforeEach(func() {
+ var err error
+ srv, err = NewServer(Opts{
+ Handler: handler,
+ BindAddress: "127.0.0.1:0",
+ })
+ Expect(err).ToNot(HaveOccurred())
+
+ s, ok := srv.(*server)
+ Expect(ok).To(BeTrue())
+
+ listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
+ })
+
+ It("Starts the server and serves the handler", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ resp, err := client.Get(listenAddr)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+ body, err := ioutil.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(string(body)).To(Equal(hello))
+ })
+
+ It("Stops the server when the context is cancelled", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ _, err := client.Get(listenAddr)
+ Expect(err).ToNot(HaveOccurred())
+
+ cancel()
+
+ Eventually(func() error {
+ _, err := client.Get(listenAddr)
+ return err
+ }).Should(HaveOccurred())
+ })
+ })
+
+ Context("with an https server", func() {
+ var secureListenAddr string
+
+ BeforeEach(func() {
+ var err error
+ srv, err = NewServer(Opts{
+ Handler: handler,
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &certDataSource,
+ },
+ })
+ Expect(err).ToNot(HaveOccurred())
+
+ s, ok := srv.(*server)
+ Expect(ok).To(BeTrue())
+
+ secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
+ })
+
+ It("Starts the server and serves the handler", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ resp, err := client.Get(secureListenAddr)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+ body, err := ioutil.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(string(body)).To(Equal(hello))
+ })
+
+ It("Stops the server when the context is cancelled", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ _, err := client.Get(secureListenAddr)
+ Expect(err).ToNot(HaveOccurred())
+
+ cancel()
+
+ Eventually(func() error {
+ _, err := client.Get(secureListenAddr)
+ return err
+ }).Should(HaveOccurred())
+ })
+
+ It("Serves the certificate provided", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ resp, err := client.Get(secureListenAddr)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+ Expect(resp.TLS.VerifiedChains).Should(HaveLen(1))
+ Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1))
+ Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData))
+ })
+ })
+
+ Context("with both an http and an https server", func() {
+ var listenAddr, secureListenAddr string
+
+ BeforeEach(func() {
+ var err error
+ srv, err = NewServer(Opts{
+ Handler: handler,
+ BindAddress: "127.0.0.1:0",
+ SecureBindAddress: "127.0.0.1:0",
+ TLS: &options.TLS{
+ Key: &keyDataSource,
+ Cert: &certDataSource,
+ },
+ })
+ Expect(err).ToNot(HaveOccurred())
+
+ s, ok := srv.(*server)
+ Expect(ok).To(BeTrue())
+
+ listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
+ secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
+ })
+
+ It("Starts the server and serves the handler on http", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ resp, err := client.Get(listenAddr)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+ body, err := ioutil.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(string(body)).To(Equal(hello))
+ })
+
+ It("Starts the server and serves the handler on https", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ resp, err := client.Get(secureListenAddr)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(http.StatusOK))
+
+ body, err := ioutil.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(string(body)).To(Equal(hello))
+ })
+
+ It("Stops both servers when the context is cancelled", func() {
+ go func() {
+ defer GinkgoRecover()
+ Expect(srv.Start(ctx)).To(Succeed())
+ }()
+
+ _, err := client.Get(listenAddr)
+ Expect(err).ToNot(HaveOccurred())
+ _, err = client.Get(secureListenAddr)
+ Expect(err).ToNot(HaveOccurred())
+
+ cancel()
+
+ Eventually(func() error {
+ _, err := client.Get(listenAddr)
+ return err
+ }).Should(HaveOccurred())
+ Eventually(func() error {
+ _, err := client.Get(secureListenAddr)
+ return err
+ }).Should(HaveOccurred())
+ })
+ })
+ })
+
+ Context("getNetworkScheme", func() {
+ DescribeTable("should return the scheme", func(in, expected string) {
+ Expect(getNetworkScheme(in)).To(Equal(expected))
+ },
+ Entry("with no scheme", "127.0.0.1:0", "tcp"),
+ Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"),
+ Entry("with a http scheme", "http://192.168.0.1:1", "tcp"),
+ Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"),
+ Entry("with a random scheme", "random://10.10.10.10:10", "random"),
+ )
+ })
+
+ Context("getListenAddress", func() {
+ DescribeTable("should remove the scheme", func(in, expected string) {
+ Expect(getListenAddress(in)).To(Equal(expected))
+ },
+ Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"),
+ Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"),
+ Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"),
+ Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"),
+ Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"),
+ )
+ })
+})