1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-21 21:47:11 +02:00

Merge pull request #1047 from oauth2-proxy/http-server

Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers
This commit is contained in:
Joel Speed 2021-03-07 21:12:02 +00:00 committed by GitHub
commit 6894738d97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1258 additions and 248 deletions

View File

@ -8,6 +8,7 @@
## Changes since v7.0.1
- [#1047](https://github.com/oauth2-proxy/oauth2-proxy/pull/1047) Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers (@JoelSpeed)
- [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves)
- [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich)
- [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed)

View File

@ -117,6 +117,8 @@ They may change between releases without notice.
| `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.<br/>Once a user is authenticated, requests to the server will be proxied to<br/>these upstream servers based on the path mappings defined in this list. |
| `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added<br/>to requests to upstream servers.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. |
| `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added<br/>to responses from the proxy.<br/>This is typically used when using the proxy as an external authentication<br/>provider in conjunction with another proxy such as NGINX and its<br/>auth_request module.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. |
| `server` | _[Server](#server)_ | Server is used to configure the HTTP(S) server for the proxy application.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. |
| `metricsServer` | _[Server](#server)_ | MetricsServer is used to configure the HTTP(S) server for metrics.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. |
### ClaimSource
@ -172,7 +174,7 @@ make up the header value
### SecretSource
(**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue))
(**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue), [TLS](#tls))
SecretSource references an individual secret value.
Only one source within the struct should be defined at any time.
@ -183,6 +185,29 @@ Only one source within the struct should be defined at any time.
| `fromEnv` | _string_ | FromEnv expects the name of an environment variable. |
| `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. |
### Server
(**Appears on:** [AlphaOptions](#alphaoptions))
Server represents the configuration for an HTTP(S) server
| Field | Type | Description |
| ----- | ---- | ----------- |
| `BindAddress` | _string_ | BindAddress is the the address on which to serve traffic.<br/>Leave blank or set to "-" to disable. |
| `SecureBindAddress` | _string_ | SecureBindAddress is the the address on which to serve secure traffic.<br/>Leave blank or set to "-" to disable. |
| `TLS` | _[TLS](#tls)_ | TLS contains the information for loading the certificate and key for the<br/>secure traffic. |
### TLS
(**Appears on:** [Server](#server))
TLS contains the information for loading a TLS certifcate and key.
| Field | Type | Description |
| ----- | ---- | ----------- |
| `Key` | _[SecretSource](#secretsource)_ | Key is the the TLS key data to use.<br/>Typically this will come from a file. |
| `Cert` | _[SecretSource](#secretsource)_ | Cert is the TLS certificate data to use.<br/>Typically this will come from a file. |
### Upstream
(**Appears on:** [Upstreams](#upstreams))

1
go.mod
View File

@ -30,6 +30,7 @@ require (
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/net v0.0.0-20200707034311-ab3426394381
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
google.golang.org/api v0.20.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/square/go-jose.v2 v2.4.1

3
go.sum
View File

@ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

136
http.go
View File

@ -1,136 +0,0 @@
package main
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
)
// Server represents an HTTP server
type Server struct {
Handler http.Handler
Opts *options.Options
stop chan struct{} // channel for waiting shutdown
}
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
func (s *Server) ListenAndServe() {
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
s.ServeHTTPS()
} else {
s.ServeHTTP()
}
}
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
func (s *Server) ServeHTTP() {
HTTPAddress := s.Opts.HTTPAddress
var scheme string
i := strings.Index(HTTPAddress, "://")
if i > -1 {
scheme = HTTPAddress[0:i]
}
var networkType string
switch scheme {
case "", "http":
networkType = "tcp"
default:
networkType = scheme
}
slice := strings.SplitN(HTTPAddress, "//", 2)
listenAddr := slice[len(slice)-1]
listener, err := net.Listen(networkType, listenAddr)
if err != nil {
logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
}
logger.Printf("HTTP: listening on %s", listenAddr)
s.serve(listener)
logger.Printf("HTTP: closing %s", listener.Addr())
}
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
func (s *Server) ServeHTTPS() {
addr := s.Opts.HTTPSAddress
config := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
}
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
var err error
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
if err != nil {
logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
}
ln, err := net.Listen("tcp", addr)
if err != nil {
logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
}
logger.Printf("HTTPS: listening on %s", ln.Addr())
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
s.serve(tlsListener)
logger.Printf("HTTPS: closing %s", tlsListener.Addr())
}
func (s *Server) serve(listener net.Listener) {
srv := &http.Server{Handler: s.Handler}
// See https://golang.org/pkg/net/http/#Server.Shutdown
idleConnsClosed := make(chan struct{})
go func() {
<-s.stop // wait notification for stopping server
// We received an interrupt signal, shut down.
if err := srv.Shutdown(context.Background()); err != nil {
// Error from closing listeners, or context timeout:
logger.Printf("HTTP server Shutdown: %v", err)
}
close(idleConnsClosed)
}()
err := srv.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Errorf("ERROR: http.Serve() - %s", err)
}
<-idleConnsClosed
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
err = tc.SetKeepAlive(true)
if err != nil {
logger.Printf("Error setting Keep-Alive: %v", err)
}
err = tc.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
logger.Printf("Error setting Keep-Alive period: %v", err)
}
return tc, nil
}

View File

@ -1,39 +0,0 @@
package main
import (
"net/http"
"sync"
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/stretchr/testify/assert"
)
func TestGracefulShutdown(t *testing.T) {
opts := options.NewOptions()
stop := make(chan struct{}, 1)
srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.ServeHTTP()
}()
stop <- struct{}{} // emulate catching signals
// An idiomatic for sync.WaitGroup with timeout
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
case <-time.After(1 * time.Second):
t.Fatal("Server should return gracefully but timeout has occurred")
}
assert.Len(t, stop, 0) // check if stop chan is empty
}

54
main.go
View File

@ -1,20 +1,15 @@
package main
import (
"context"
"fmt"
"math/rand"
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"github.com/ghodss/yaml"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
"github.com/spf13/pflag"
)
@ -67,54 +62,9 @@ func main() {
rand.Seed(time.Now().UnixNano())
oauthProxyStop := make(chan struct{}, 1)
metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop)
s := &Server{
Handler: oauthproxy,
Opts: opts,
stop: oauthProxyStop,
if err := oauthproxy.Start(); err != nil {
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
}
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
s.stop <- struct{}{} // notify having caught signal stop oauthproxy
close(metricsStop) // and the metrics endpoint
}()
s.ListenAndServe()
}
// startMetricsServer will start the metrics server on the specified address.
// It always return a channel to signal stop even when it does not run.
func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} {
stop := make(chan struct{}, 1)
// Attempt to setup the metrics endpoint if we have an address
if address != "" {
s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler}
go func() {
// ListenAndServe always returns a non-nil error. After Shutdown or
// Close, the returned error is ErrServerClosed
if err := s.ListenAndServe(); err != http.ErrServerClosed {
logger.Println(err)
// Stop the metrics shutdown go routine
close(stop)
// Stop the oauthproxy server, we have encounter an unexpected error
close(oauthProxyStop)
}
}()
go func() {
<-stop
if err := s.Shutdown(context.Background()); err != nil {
logger.Print(err)
}
}()
}
return stop
}
// loadConfiguration will load in the user's configuration.

View File

@ -15,6 +15,7 @@ import (
var _ = Describe("Configuration Loading Suite", func() {
const testLegacyConfig = `
http_address="127.0.0.1:4180"
upstreams="http://httpbin"
set_basic_auth="true"
basic_auth_password="super-secret-password"
@ -54,10 +55,11 @@ injectResponseHeaders:
prefix: "Basic "
basicAuthPassword:
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
server:
bindAddress: "127.0.0.1:4180"
`
const testCoreConfig = `
http_address="0.0.0.0:4180"
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
provider="oidc"
email_domains="example.com"
@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback"
opts, err := options.NewLegacyOptions().ToOptions()
Expect(err).ToNot(HaveOccurred())
opts.HTTPAddress = "0.0.0.0:4180"
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
opts.ProviderType = "oidc"
opts.EmailDomains = []string{"example.com"}
@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
configContent: testCoreConfig,
alphaConfigContent: testAlphaConfig + ":",
expectedOptions: func() *options.Options { return nil },
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"),
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"),
}),
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
configContent: testCoreConfig + "unknown_field=\"something\"",

View File

@ -8,8 +8,11 @@ import (
"net"
"net/http"
"net/url"
"os"
"os/signal"
"regexp"
"strings"
"syscall"
"time"
"github.com/justinas/alice"
@ -21,6 +24,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
@ -102,6 +106,7 @@ type OAuthProxy struct {
headersChain alice.Chain
preAuthChain alice.Chain
pageWriter pagewriter.Writer
server proxyhttp.Server
}
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, fmt.Errorf("could not build headers chain: %v", err)
}
return &OAuthProxy{
p := &OAuthProxy{
CookieName: opts.Cookie.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
CookieSeed: opts.Cookie.Secret,
@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
}, nil
}
if err := p.setupServer(opts); err != nil {
return nil, fmt.Errorf("error setting up server: %v", err)
}
return p, nil
}
func (p *OAuthProxy) Start() error {
if p.server == nil {
// We have to call setupServer before Start is called.
// If this doesn't happen it's a programming error.
panic("server has not been initialised")
}
ctx, cancel := context.WithCancel(context.Background())
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
cancel() // cancel the context
}()
return p.server.Start(ctx)
}
func (p *OAuthProxy) setupServer(opts *options.Options) error {
serverOpts := proxyhttp.Opts{
Handler: p,
BindAddress: opts.Server.BindAddress,
SecureBindAddress: opts.Server.SecureBindAddress,
TLS: opts.Server.TLS,
}
appServer, err := proxyhttp.NewServer(serverOpts)
if err != nil {
return fmt.Errorf("could not build app server: %v", err)
}
metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{
Handler: middleware.DefaultMetricsHandler,
BindAddress: opts.MetricsServer.BindAddress,
SecureBindAddress: opts.MetricsServer.BindAddress,
TLS: opts.MetricsServer.TLS,
})
if err != nil {
return fmt.Errorf("could not build metrics server: %v", err)
}
p.server = proxyhttp.NewServerGroup(appServer, metricsServer)
return nil
}
// buildPreAuthChain constructs a chain that should process every request before
@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
if opts.ForceHTTPS {
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
_, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress)
if err != nil {
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err)
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err)
}
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
}

View File

@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options {
},
},
}
return opts
}

View File

@ -28,6 +28,18 @@ type AlphaOptions struct {
// Headers may source values from either the authenticated user's session
// or from a static secret value.
InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"`
// Server is used to configure the HTTP(S) server for the proxy application.
// You may choose to run both HTTP and HTTPS servers simultaneously.
// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
// To use the secure server you must configure a TLS certificate and key.
Server Server `json:"server,omitempty"`
// MetricsServer is used to configure the HTTP(S) server for metrics.
// You may choose to run both HTTP and HTTPS servers simultaneously.
// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
// To use the secure server you must configure a TLS certificate and key.
MetricsServer Server `json:"metricsServer,omitempty"`
}
// MergeInto replaces alpha options in the Options struct with the values
@ -36,6 +48,8 @@ func (a *AlphaOptions) MergeInto(opts *Options) {
opts.UpstreamServers = a.Upstreams
opts.InjectRequestHeaders = a.InjectRequestHeaders
opts.InjectResponseHeaders = a.InjectResponseHeaders
opts.Server = a.Server
opts.MetricsServer = a.MetricsServer
}
// ExtractFrom populates the fields in the AlphaOptions with the values from
@ -44,4 +58,6 @@ func (a *AlphaOptions) ExtractFrom(opts *Options) {
a.Upstreams = opts.UpstreamServers
a.InjectRequestHeaders = opts.InjectRequestHeaders
a.InjectResponseHeaders = opts.InjectResponseHeaders
a.Server = opts.Server
a.MetricsServer = opts.MetricsServer
}

View File

@ -18,6 +18,9 @@ type LegacyOptions struct {
// Legacy options for injecting request/response headers
LegacyHeaders LegacyHeaders `cfg:",squash"`
// Legacy options for the server address and TLS
LegacyServer LegacyServer `cfg:",squash"`
Options Options `cfg:",squash"`
}
@ -35,6 +38,11 @@ func NewLegacyOptions() *LegacyOptions {
SkipAuthStripHeaders: true,
},
LegacyServer: LegacyServer{
HTTPAddress: "127.0.0.1:4180",
HTTPSAddress: ":443",
},
Options: *NewOptions(),
}
}
@ -44,6 +52,7 @@ func NewLegacyFlagSet() *pflag.FlagSet {
flagSet.AddFlagSet(legacyUpstreamsFlagSet())
flagSet.AddFlagSet(legacyHeadersFlagSet())
flagSet.AddFlagSet(legacyServerFlagset())
return flagSet
}
@ -56,6 +65,8 @@ func (l *LegacyOptions) ToOptions() (*Options, error) {
l.Options.UpstreamServers = upstreams
l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert()
l.Options.Server, l.Options.MetricsServer = l.LegacyServer.convert()
return &l.Options, nil
}
@ -403,3 +414,69 @@ func getXAuthRequestAccessTokenHeader() Header {
},
}
}
type LegacyServer struct {
MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_address"`
MetricsTLSCertFile string `flag:"metrics-tls-cert-file" cfg:"tls_cert_file"`
MetricsTLSKeyFile string `flag:"metrics-tls-key-file" cfg:"tls_key_file"`
HTTPAddress string `flag:"http-address" cfg:"http_address"`
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
}
func legacyServerFlagset() *pflag.FlagSet {
flagSet := pflag.NewFlagSet("server", pflag.ExitOnError)
flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
flagSet.String("metrics-secure-address", "", "the address /metrics will be served on for HTTPS clients (e.g. \":9100\")")
flagSet.String("metrics-tls-cert-file", "", "path to certificate file for secure metrics server")
flagSet.String("metrics-tls-key-file", "", "path to private key file for secure metrics server")
flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients")
flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients")
flagSet.String("tls-cert-file", "", "path to certificate file")
flagSet.String("tls-key-file", "", "path to private key file")
return flagSet
}
func (l LegacyServer) convert() (Server, Server) {
appServer := Server{
BindAddress: l.HTTPAddress,
SecureBindAddress: l.HTTPSAddress,
}
if l.TLSKeyFile != "" || l.TLSCertFile != "" {
appServer.TLS = &TLS{
Key: &SecretSource{
FromFile: l.TLSKeyFile,
},
Cert: &SecretSource{
FromFile: l.TLSCertFile,
},
}
// Preserve backwards compatibility, only run one server
appServer.BindAddress = ""
} else {
// Disable the HTTPS server if there's no certificates.
// This preserves backwards compatibility.
appServer.SecureBindAddress = ""
}
metricsServer := Server{
BindAddress: l.MetricsAddress,
SecureBindAddress: l.MetricsSecureAddress,
}
if l.MetricsTLSKeyFile != "" || l.MetricsTLSCertFile != "" {
metricsServer.TLS = &TLS{
Key: &SecretSource{
FromFile: l.MetricsTLSKeyFile,
},
Cert: &SecretSource{
FromFile: l.MetricsTLSCertFile,
},
}
}
return appServer, metricsServer
}

View File

@ -106,6 +106,10 @@ var _ = Describe("Legacy Options", func() {
opts.InjectResponseHeaders = []Header{}
opts.Server = Server{
BindAddress: "127.0.0.1:4180",
}
converted, err := legacyOpts.ToOptions()
Expect(err).ToNot(HaveOccurred())
Expect(converted).To(Equal(opts))
@ -759,4 +763,93 @@ var _ = Describe("Legacy Options", func() {
}),
)
})
Context("Legacy Servers", func() {
type legacyServersTableInput struct {
legacyServer LegacyServer
expectedAppServer Server
expectedMetricsServer Server
}
const (
insecureAddr = "127.0.0.1:8080"
insecureMetricsAddr = ":9090"
secureAddr = ":443"
secureMetricsAddr = ":9443"
crtPath = "tls.crt"
keyPath = "tls.key"
)
var tlsConfig = &TLS{
Cert: &SecretSource{
FromFile: crtPath,
},
Key: &SecretSource{
FromFile: keyPath,
},
}
DescribeTable("should convert to app and metrics servers",
func(in legacyServersTableInput) {
appServer, metricsServer := in.legacyServer.convert()
Expect(appServer).To(Equal(in.expectedAppServer))
Expect(metricsServer).To(Equal(in.expectedMetricsServer))
},
Entry("with default options only starts app HTTP server", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
},
expectedAppServer: Server{
BindAddress: insecureAddr,
},
}),
Entry("with TLS options specified only starts app HTTPS server", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
TLSKeyFile: keyPath,
TLSCertFile: crtPath,
},
expectedAppServer: Server{
SecureBindAddress: secureAddr,
TLS: tlsConfig,
},
}),
Entry("with metrics HTTP and HTTPS addresses", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
MetricsAddress: insecureMetricsAddr,
MetricsSecureAddress: secureMetricsAddr,
},
expectedAppServer: Server{
BindAddress: insecureAddr,
},
expectedMetricsServer: Server{
BindAddress: insecureMetricsAddr,
SecureBindAddress: secureMetricsAddr,
},
}),
Entry("with metrics HTTPS and tls cert/key", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
MetricsAddress: insecureMetricsAddr,
MetricsSecureAddress: secureMetricsAddr,
MetricsTLSKeyFile: keyPath,
MetricsTLSCertFile: crtPath,
},
expectedAppServer: Server{
BindAddress: insecureAddr,
},
expectedMetricsServer: Server{
BindAddress: insecureMetricsAddr,
SecureBindAddress: secureMetricsAddr,
TLS: tlsConfig,
},
}),
)
})
})

View File

@ -22,9 +22,6 @@ type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"`
PingPath string `flag:"ping-path" cfg:"ping_path"`
PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"`
MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
HTTPAddress string `flag:"http-address" cfg:"http_address"`
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"`
RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"`
TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"`
@ -33,8 +30,6 @@ type Options struct {
ClientID string `flag:"client-id" cfg:"client_id"`
ClientSecret string `flag:"client-secret" cfg:"client_secret"`
ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file"`
TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"`
@ -68,6 +63,9 @@ type Options struct {
InjectRequestHeaders []Header `cfg:",internal"`
InjectResponseHeaders []Header `cfg:",internal"`
Server Server `cfg:",internal"`
MetricsServer Server `cfg:",internal"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"`
SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"`
@ -136,10 +134,7 @@ func NewOptions() *Options {
return &Options{
ProxyPrefix: "/oauth2",
ProviderType: "google",
MetricsAddress: "",
PingPath: "/ping",
HTTPAddress: "127.0.0.1:4180",
HTTPSAddress: ":443",
RealClientIPHeader: "X-Real-IP",
ForceHTTPS: false,
Cookie: cookieDefaults(),
@ -162,14 +157,10 @@ func NewOptions() *Options {
func NewFlagSet() *pflag.FlagSet {
flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError)
flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients")
flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients")
flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted")
flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)")
flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.")
flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests")
flagSet.String("tls-cert-file", "", "path to certificate file")
flagSet.String("tls-key-file", "", "path to private key file")
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)")
flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods")
@ -204,7 +195,6 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)")
flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks")
flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks")
flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
flagSet.String("session-store-type", "cookie", "the session storage provider to use")
flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)")
flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])")

View File

@ -0,0 +1,27 @@
package options
// Server represents the configuration for an HTTP(S) server
type Server struct {
// BindAddress is the the address on which to serve traffic.
// Leave blank or set to "-" to disable.
BindAddress string
// SecureBindAddress is the the address on which to serve secure traffic.
// Leave blank or set to "-" to disable.
SecureBindAddress string
// TLS contains the information for loading the certificate and key for the
// secure traffic.
TLS *TLS
}
// TLS contains the information for loading a TLS certifcate and key.
type TLS struct {
// Key is the the TLS key data to use.
// Typically this will come from a file.
Key *SecretSource
// Cert is the TLS certificate data to use.
// Typically this will come from a file.
Cert *SecretSource
}

View File

@ -0,0 +1,88 @@
package http
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/http"
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var certData []byte
var certDataSource, keyDataSource options.SecretSource
var client *http.Client
func TestHTTPSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "HTTP")
}
var _ = BeforeSuite(func() {
By("Generating a self-signed cert for TLS tests", func() {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
keyOut := bytes.NewBuffer(nil)
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
Expect(err).ToNot(HaveOccurred())
Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed())
keyDataSource.Value = keyOut.Bytes()
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
Expect(err).ToNot(HaveOccurred())
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"OAuth2 Proxy Test Suite"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
Expect(err).ToNot(HaveOccurred())
certData = certBytes
certOut := bytes.NewBuffer(nil)
Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed())
certDataSource.Value = certOut.Bytes()
})
By("Setting up a http client", func() {
cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value)
Expect(err).ToNot(HaveOccurred())
certificate, err := x509.ParseCertificate(cert.Certificate[0])
Expect(err).ToNot(HaveOccurred())
certpool := x509.NewCertPool()
certpool.AddCert(certificate)
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig.RootCAs = certpool
client = &http.Client{
Transport: transport,
}
})
})

245
pkg/http/server.go Normal file
View File

@ -0,0 +1,245 @@
package http
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"golang.org/x/sync/errgroup"
)
// Server represents an HTTP or HTTPS server.
type Server interface {
// Start blocks and runs the server.
Start(ctx context.Context) error
}
// Opts contains the information required to set up the server.
type Opts struct {
// Handler is the http.Handler to be used to serve http pages by the server.
Handler http.Handler
// BindAddress is the address the HTTP server should listen on.
BindAddress string
// SecureBindAddress is the address the HTTPS server should listen on.
SecureBindAddress string
// TLS is the TLS configuration for the server.
TLS *options.TLS
}
// NewServer creates a new Server from the options given.
func NewServer(opts Opts) (Server, error) {
s := &server{
handler: opts.Handler,
}
if err := s.setupListener(opts); err != nil {
return nil, fmt.Errorf("error setting up listener: %v", err)
}
if err := s.setupTLSListener(opts); err != nil {
return nil, fmt.Errorf("error setting up TLS listener: %v", err)
}
return s, nil
}
// server is an implementation of the Server interface.
type server struct {
handler http.Handler
listener net.Listener
tlsListener net.Listener
}
// setupListener sets the server listener if the HTTP server is enabled.
// The HTTP server can be disabled by setting the BindAddress to "-" or by
// leaving it empty.
func (s *server) setupListener(opts Opts) error {
if opts.BindAddress == "" || opts.BindAddress == "-" {
// No HTTP listener required
return nil
}
networkType := getNetworkScheme(opts.BindAddress)
listenAddr := getListenAddress(opts.BindAddress)
listener, err := net.Listen(networkType, listenAddr)
if err != nil {
return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err)
}
s.listener = listener
return nil
}
// setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
// leaving it empty.
func (s *server) setupTLSListener(opts Opts) error {
if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" {
// No HTTPS listener required
return nil
}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
NextProtos: []string{"http/1.1"},
}
if opts.TLS == nil {
return errors.New("no TLS config provided")
}
cert, err := getCertificate(opts.TLS)
if err != nil {
return fmt.Errorf("could not load certificate: %v", err)
}
config.Certificates = []tls.Certificate{cert}
listenAddr := getListenAddress(opts.SecureBindAddress)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
}
s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
return nil
}
// Start starts the HTTP and HTTPS server if applicable.
// It will block until the context is cancelled.
// If any errors occur, only the first error will be returned.
func (s *server) Start(ctx context.Context) error {
g, groupCtx := errgroup.WithContext(ctx)
if s.listener != nil {
g.Go(func() error {
if err := s.startServer(groupCtx, s.listener); err != nil {
return fmt.Errorf("error starting insecure server: %v", err)
}
return nil
})
}
if s.tlsListener != nil {
g.Go(func() error {
if err := s.startServer(groupCtx, s.tlsListener); err != nil {
return fmt.Errorf("error starting secure server: %v", err)
}
return nil
})
}
return g.Wait()
}
// startServer creates and starts a new server with the given listener.
// When the given context is cancelled the server will be shutdown.
// If any errors occur, only the first error will be returned.
func (s *server) startServer(ctx context.Context, listener net.Listener) error {
srv := &http.Server{Handler: s.handler}
g, groupCtx := errgroup.WithContext(ctx)
g.Go(func() error {
<-groupCtx.Done()
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("error shutting down server: %v", err)
}
return nil
})
g.Go(func() error {
if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("could not start server: %v", err)
}
return nil
})
return g.Wait()
}
// getNetworkScheme gets the scheme for the HTTP server.
func getNetworkScheme(addr string) string {
var scheme string
i := strings.Index(addr, "://")
if i > -1 {
scheme = addr[0:i]
}
switch scheme {
case "", "http":
return "tcp"
default:
return scheme
}
}
// getListenAddress gets the address for the HTTP server.
func getListenAddress(addr string) string {
slice := strings.SplitN(addr, "//", 2)
return slice[len(slice)-1]
}
// getCertificate loads the certificate data from the TLS config.
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
keyData, err := getSecretValue(opts.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
}
certData, err := getSecretValue(opts.Cert)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
}
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
}
return cert, nil
}
// getSecretValue wraps util.GetSecretValue so that we can return an error if no
// source is provided.
func getSecretValue(src *options.SecretSource) ([]byte, error) {
if src == nil {
return nil, errors.New("no configuration provided")
}
return util.GetSecretValue(src)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by so that dead TCP connections (e.g. closing laptop
// mid-download) eventually go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
// Accept implements the TCPListener interface.
// It sets the keep alive period to 3 minutes for each connection.
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
err = tc.SetKeepAlive(true)
if err != nil {
logger.Errorf("Error setting Keep-Alive: %v", err)
}
err = tc.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
logger.Printf("Error setting Keep-Alive period: %v", err)
}
return tc, nil
}

35
pkg/http/server_group.go Normal file
View File

@ -0,0 +1,35 @@
package http
import (
"context"
"golang.org/x/sync/errgroup"
)
// NewServerGroup creates a new Server to start and gracefully stop a collection
// of Servers.
func NewServerGroup(servers ...Server) Server {
return &serverGroup{
servers: servers,
}
}
// serverGroup manages the starting and graceful shutdown of a collection of
// servers.
type serverGroup struct {
servers []Server
}
// Start runs the servers in the server group.
func (s *serverGroup) Start(ctx context.Context) error {
g, groupCtx := errgroup.WithContext(ctx)
for _, server := range s.servers {
srv := server
g.Go(func() error {
return srv.Start(groupCtx)
})
}
return g.Wait()
}

View File

@ -0,0 +1,102 @@
package http
import (
"context"
"errors"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Server Group", func() {
var m1, m2, m3 *mockServer
var ctx context.Context
var cancel context.CancelFunc
var group Server
BeforeEach(func() {
ctx, cancel = context.WithCancel(context.Background())
m1 = newMockServer()
m2 = newMockServer()
m3 = newMockServer()
group = NewServerGroup(m1, m2, m3)
})
AfterEach(func() {
cancel()
})
It("starts each server in the group", func() {
go func() {
defer GinkgoRecover()
Expect(group.Start(ctx)).To(Succeed())
}()
Eventually(m1.started).Should(BeClosed(), "mock server 1 not started")
Eventually(m2.started).Should(BeClosed(), "mock server 2 not started")
Eventually(m3.started).Should(BeClosed(), "mock server 3 not started")
})
It("stop each server in the group when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(group.Start(ctx)).To(Succeed())
}()
cancel()
Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
})
It("stop each server in the group when the an error occurs", func() {
err := errors.New("server error")
go func() {
defer GinkgoRecover()
Expect(group.Start(ctx)).To(MatchError(err))
}()
m2.errors <- err
Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
})
})
// mockServer is used to test the server group can start
// and stop multiple servers simultaneously.
type mockServer struct {
started chan struct{}
startClosed bool
stopped chan struct{}
stopClosed bool
errors chan error
}
func newMockServer() *mockServer {
return &mockServer{
started: make(chan struct{}),
stopped: make(chan struct{}),
errors: make(chan error),
}
}
func (m *mockServer) Start(ctx context.Context) error {
if !m.startClosed {
close(m.started)
m.startClosed = true
}
defer func() {
if !m.stopClosed {
close(m.stopped)
m.stopClosed = true
}
}()
select {
case <-ctx.Done():
return nil
case err := <-m.errors:
return err
}
}

472
pkg/http/server_test.go Normal file
View File

@ -0,0 +1,472 @@
package http
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
const hello = "Hello World!"
var _ = Describe("Server", func() {
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte(hello))
})
Context("NewServer", func() {
type newServerTableInput struct {
opts Opts
expectedErr error
expectHTTPListener bool
expectTLSListener bool
}
DescribeTable("When creating the new server from the options", func(in *newServerTableInput) {
srv, err := NewServer(in.opts)
if in.expectedErr != nil {
Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error())))
Expect(srv).To(BeNil())
return
}
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
Expect(s.listener != nil).To(Equal(in.expectHTTPListener))
if in.expectHTTPListener {
Expect(s.listener.Close()).To(Succeed())
}
Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener))
if in.expectTLSListener {
Expect(s.tlsListener.Close()).To(Succeed())
}
},
Entry("with a valid http bind address", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
},
expectedErr: nil,
expectHTTPListener: true,
expectTLSListener: false,
}),
Entry("with a valid https bind address, with no TLS config", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
},
expectedErr: errors.New("error setting up TLS listener: no TLS config provided"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: true,
}),
Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: true,
expectTLSListener: true,
}),
Entry("with a \"-\" for the bind address", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "-",
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with a \"-\" for the secure bind address", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "-",
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid bind address scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "invalid://127.0.0.1:0",
},
expectedErr: errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid secure bind address scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "invalid://127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: true,
}),
Entry("with an invalid bind address port", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "127.0.0.1:a",
},
expectedErr: errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid secure bind address port", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:a",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid TLS key", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &options.SecretSource{
Value: []byte("invalid"),
},
Cert: &certDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid TLS cert", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &options.SecretSource{
Value: []byte("invalid"),
},
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with no TLS key", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Cert: &certDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with no TLS cert", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "http://127.0.0.1:0",
},
expectedErr: nil,
expectHTTPListener: true,
expectTLSListener: false,
}),
Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "https://127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: true,
}),
)
})
Context("Start", func() {
var srv Server
var ctx context.Context
var cancel context.CancelFunc
BeforeEach(func() {
ctx, cancel = context.WithCancel(context.Background())
})
AfterEach(func() {
cancel()
})
Context("with an http server", func() {
var listenAddr string
BeforeEach(func() {
var err error
srv, err = NewServer(Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
})
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
})
It("Starts the server and serves the handler", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Stops the server when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
_, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(func() error {
_, err := client.Get(listenAddr)
return err
}).Should(HaveOccurred())
})
})
Context("with an https server", func() {
var secureListenAddr string
BeforeEach(func() {
var err error
srv, err = NewServer(Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
})
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
})
It("Starts the server and serves the handler", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Stops the server when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
_, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(func() error {
_, err := client.Get(secureListenAddr)
return err
}).Should(HaveOccurred())
})
It("Serves the certificate provided", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expect(resp.TLS.VerifiedChains).Should(HaveLen(1))
Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1))
Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData))
})
})
Context("with both an http and an https server", func() {
var listenAddr, secureListenAddr string
BeforeEach(func() {
var err error
srv, err = NewServer(Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
})
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
})
It("Starts the server and serves the handler on http", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Starts the server and serves the handler on https", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Stops both servers when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
_, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
_, err = client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(func() error {
_, err := client.Get(listenAddr)
return err
}).Should(HaveOccurred())
Eventually(func() error {
_, err := client.Get(secureListenAddr)
return err
}).Should(HaveOccurred())
})
})
})
Context("getNetworkScheme", func() {
DescribeTable("should return the scheme", func(in, expected string) {
Expect(getNetworkScheme(in)).To(Equal(expected))
},
Entry("with no scheme", "127.0.0.1:0", "tcp"),
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"),
Entry("with a http scheme", "http://192.168.0.1:1", "tcp"),
Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"),
Entry("with a random scheme", "random://10.10.10.10:10", "random"),
)
})
Context("getListenAddress", func() {
DescribeTable("should remove the scheme", func(in, expected string) {
Expect(getListenAddress(in)).To(Equal(expected))
},
Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"),
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"),
Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"),
Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"),
Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"),
)
})
})