mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-21 21:47:11 +02:00
Merge pull request #1047 from oauth2-proxy/http-server
Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers
This commit is contained in:
commit
6894738d97
@ -8,6 +8,7 @@
|
||||
|
||||
## Changes since v7.0.1
|
||||
|
||||
- [#1047](https://github.com/oauth2-proxy/oauth2-proxy/pull/1047) Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers (@JoelSpeed)
|
||||
- [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves)
|
||||
- [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich)
|
||||
- [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed)
|
||||
|
@ -117,6 +117,8 @@ They may change between releases without notice.
|
||||
| `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.<br/>Once a user is authenticated, requests to the server will be proxied to<br/>these upstream servers based on the path mappings defined in this list. |
|
||||
| `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added<br/>to requests to upstream servers.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. |
|
||||
| `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added<br/>to responses from the proxy.<br/>This is typically used when using the proxy as an external authentication<br/>provider in conjunction with another proxy such as NGINX and its<br/>auth_request module.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. |
|
||||
| `server` | _[Server](#server)_ | Server is used to configure the HTTP(S) server for the proxy application.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. |
|
||||
| `metricsServer` | _[Server](#server)_ | MetricsServer is used to configure the HTTP(S) server for metrics.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. |
|
||||
|
||||
### ClaimSource
|
||||
|
||||
@ -172,7 +174,7 @@ make up the header value
|
||||
|
||||
### SecretSource
|
||||
|
||||
(**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue))
|
||||
(**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue), [TLS](#tls))
|
||||
|
||||
SecretSource references an individual secret value.
|
||||
Only one source within the struct should be defined at any time.
|
||||
@ -183,6 +185,29 @@ Only one source within the struct should be defined at any time.
|
||||
| `fromEnv` | _string_ | FromEnv expects the name of an environment variable. |
|
||||
| `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. |
|
||||
|
||||
### Server
|
||||
|
||||
(**Appears on:** [AlphaOptions](#alphaoptions))
|
||||
|
||||
Server represents the configuration for an HTTP(S) server
|
||||
|
||||
| Field | Type | Description |
|
||||
| ----- | ---- | ----------- |
|
||||
| `BindAddress` | _string_ | BindAddress is the the address on which to serve traffic.<br/>Leave blank or set to "-" to disable. |
|
||||
| `SecureBindAddress` | _string_ | SecureBindAddress is the the address on which to serve secure traffic.<br/>Leave blank or set to "-" to disable. |
|
||||
| `TLS` | _[TLS](#tls)_ | TLS contains the information for loading the certificate and key for the<br/>secure traffic. |
|
||||
|
||||
### TLS
|
||||
|
||||
(**Appears on:** [Server](#server))
|
||||
|
||||
TLS contains the information for loading a TLS certifcate and key.
|
||||
|
||||
| Field | Type | Description |
|
||||
| ----- | ---- | ----------- |
|
||||
| `Key` | _[SecretSource](#secretsource)_ | Key is the the TLS key data to use.<br/>Typically this will come from a file. |
|
||||
| `Cert` | _[SecretSource](#secretsource)_ | Cert is the TLS certificate data to use.<br/>Typically this will come from a file. |
|
||||
|
||||
### Upstream
|
||||
|
||||
(**Appears on:** [Upstreams](#upstreams))
|
||||
|
1
go.mod
1
go.mod
@ -30,6 +30,7 @@ require (
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
|
||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
|
||||
google.golang.org/api v0.20.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
gopkg.in/square/go-jose.v2 v2.4.1
|
||||
|
3
go.sum
3
go.sum
@ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
136
http.go
136
http.go
@ -1,136 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
)
|
||||
|
||||
// Server represents an HTTP server
|
||||
type Server struct {
|
||||
Handler http.Handler
|
||||
Opts *options.Options
|
||||
stop chan struct{} // channel for waiting shutdown
|
||||
}
|
||||
|
||||
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
|
||||
func (s *Server) ListenAndServe() {
|
||||
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
|
||||
s.ServeHTTPS()
|
||||
} else {
|
||||
s.ServeHTTP()
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
|
||||
func (s *Server) ServeHTTP() {
|
||||
HTTPAddress := s.Opts.HTTPAddress
|
||||
var scheme string
|
||||
|
||||
i := strings.Index(HTTPAddress, "://")
|
||||
if i > -1 {
|
||||
scheme = HTTPAddress[0:i]
|
||||
}
|
||||
|
||||
var networkType string
|
||||
switch scheme {
|
||||
case "", "http":
|
||||
networkType = "tcp"
|
||||
default:
|
||||
networkType = scheme
|
||||
}
|
||||
|
||||
slice := strings.SplitN(HTTPAddress, "//", 2)
|
||||
listenAddr := slice[len(slice)-1]
|
||||
|
||||
listener, err := net.Listen(networkType, listenAddr)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
|
||||
}
|
||||
logger.Printf("HTTP: listening on %s", listenAddr)
|
||||
s.serve(listener)
|
||||
logger.Printf("HTTP: closing %s", listener.Addr())
|
||||
}
|
||||
|
||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
|
||||
func (s *Server) ServeHTTPS() {
|
||||
addr := s.Opts.HTTPSAddress
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
if config.NextProtos == nil {
|
||||
config.NextProtos = []string{"http/1.1"}
|
||||
}
|
||||
|
||||
var err error
|
||||
config.Certificates = make([]tls.Certificate, 1)
|
||||
config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
|
||||
}
|
||||
logger.Printf("HTTPS: listening on %s", ln.Addr())
|
||||
|
||||
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
|
||||
s.serve(tlsListener)
|
||||
logger.Printf("HTTPS: closing %s", tlsListener.Addr())
|
||||
}
|
||||
|
||||
func (s *Server) serve(listener net.Listener) {
|
||||
srv := &http.Server{Handler: s.Handler}
|
||||
|
||||
// See https://golang.org/pkg/net/http/#Server.Shutdown
|
||||
idleConnsClosed := make(chan struct{})
|
||||
go func() {
|
||||
<-s.stop // wait notification for stopping server
|
||||
|
||||
// We received an interrupt signal, shut down.
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
// Error from closing listeners, or context timeout:
|
||||
logger.Printf("HTTP server Shutdown: %v", err)
|
||||
}
|
||||
close(idleConnsClosed)
|
||||
}()
|
||||
|
||||
err := srv.Serve(listener)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Errorf("ERROR: http.Serve() - %s", err)
|
||||
}
|
||||
<-idleConnsClosed
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
// go away.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tc.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
logger.Printf("Error setting Keep-Alive: %v", err)
|
||||
}
|
||||
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
if err != nil {
|
||||
logger.Printf("Error setting Keep-Alive period: %v", err)
|
||||
}
|
||||
return tc, nil
|
||||
}
|
39
http_test.go
39
http_test.go
@ -1,39 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
stop := make(chan struct{}, 1)
|
||||
srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srv.ServeHTTP()
|
||||
}()
|
||||
|
||||
stop <- struct{}{} // emulate catching signals
|
||||
|
||||
// An idiomatic for sync.WaitGroup with timeout
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
defer close(c)
|
||||
wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Server should return gracefully but timeout has occurred")
|
||||
}
|
||||
|
||||
assert.Len(t, stop, 0) // check if stop chan is empty
|
||||
}
|
54
main.go
54
main.go
@ -1,20 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ghodss/yaml"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
@ -67,54 +62,9 @@ func main() {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
oauthProxyStop := make(chan struct{}, 1)
|
||||
metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop)
|
||||
|
||||
s := &Server{
|
||||
Handler: oauthproxy,
|
||||
Opts: opts,
|
||||
stop: oauthProxyStop,
|
||||
if err := oauthproxy.Start(); err != nil {
|
||||
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
|
||||
}
|
||||
// Observe signals in background goroutine.
|
||||
go func() {
|
||||
sigint := make(chan os.Signal, 1)
|
||||
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
||||
<-sigint
|
||||
s.stop <- struct{}{} // notify having caught signal stop oauthproxy
|
||||
close(metricsStop) // and the metrics endpoint
|
||||
}()
|
||||
s.ListenAndServe()
|
||||
}
|
||||
|
||||
// startMetricsServer will start the metrics server on the specified address.
|
||||
// It always return a channel to signal stop even when it does not run.
|
||||
func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} {
|
||||
stop := make(chan struct{}, 1)
|
||||
|
||||
// Attempt to setup the metrics endpoint if we have an address
|
||||
if address != "" {
|
||||
s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler}
|
||||
go func() {
|
||||
// ListenAndServe always returns a non-nil error. After Shutdown or
|
||||
// Close, the returned error is ErrServerClosed
|
||||
if err := s.ListenAndServe(); err != http.ErrServerClosed {
|
||||
logger.Println(err)
|
||||
// Stop the metrics shutdown go routine
|
||||
close(stop)
|
||||
// Stop the oauthproxy server, we have encounter an unexpected error
|
||||
close(oauthProxyStop)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-stop
|
||||
if err := s.Shutdown(context.Background()); err != nil {
|
||||
logger.Print(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return stop
|
||||
}
|
||||
|
||||
// loadConfiguration will load in the user's configuration.
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
var _ = Describe("Configuration Loading Suite", func() {
|
||||
const testLegacyConfig = `
|
||||
http_address="127.0.0.1:4180"
|
||||
upstreams="http://httpbin"
|
||||
set_basic_auth="true"
|
||||
basic_auth_password="super-secret-password"
|
||||
@ -54,10 +55,11 @@ injectResponseHeaders:
|
||||
prefix: "Basic "
|
||||
basicAuthPassword:
|
||||
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
|
||||
server:
|
||||
bindAddress: "127.0.0.1:4180"
|
||||
`
|
||||
|
||||
const testCoreConfig = `
|
||||
http_address="0.0.0.0:4180"
|
||||
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
||||
provider="oidc"
|
||||
email_domains="example.com"
|
||||
@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback"
|
||||
opts, err := options.NewLegacyOptions().ToOptions()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
opts.HTTPAddress = "0.0.0.0:4180"
|
||||
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
||||
opts.ProviderType = "oidc"
|
||||
opts.EmailDomains = []string{"example.com"}
|
||||
@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
|
||||
configContent: testCoreConfig,
|
||||
alphaConfigContent: testAlphaConfig + ":",
|
||||
expectedOptions: func() *options.Options { return nil },
|
||||
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"),
|
||||
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"),
|
||||
}),
|
||||
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
|
||||
configContent: testCoreConfig + "unknown_field=\"something\"",
|
||||
|
@ -8,8 +8,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
@ -21,6 +24,7 @@ import (
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||
@ -102,6 +106,7 @@ type OAuthProxy struct {
|
||||
headersChain alice.Chain
|
||||
preAuthChain alice.Chain
|
||||
pageWriter pagewriter.Writer
|
||||
server proxyhttp.Server
|
||||
}
|
||||
|
||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||
@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
return nil, fmt.Errorf("could not build headers chain: %v", err)
|
||||
}
|
||||
|
||||
return &OAuthProxy{
|
||||
p := &OAuthProxy{
|
||||
CookieName: opts.Cookie.Name,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
|
||||
CookieSeed: opts.Cookie.Secret,
|
||||
@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
headersChain: headersChain,
|
||||
preAuthChain: preAuthChain,
|
||||
pageWriter: pageWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := p.setupServer(opts); err != nil {
|
||||
return nil, fmt.Errorf("error setting up server: %v", err)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) Start() error {
|
||||
if p.server == nil {
|
||||
// We have to call setupServer before Start is called.
|
||||
// If this doesn't happen it's a programming error.
|
||||
panic("server has not been initialised")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Observe signals in background goroutine.
|
||||
go func() {
|
||||
sigint := make(chan os.Signal, 1)
|
||||
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
||||
<-sigint
|
||||
cancel() // cancel the context
|
||||
}()
|
||||
|
||||
return p.server.Start(ctx)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) setupServer(opts *options.Options) error {
|
||||
serverOpts := proxyhttp.Opts{
|
||||
Handler: p,
|
||||
BindAddress: opts.Server.BindAddress,
|
||||
SecureBindAddress: opts.Server.SecureBindAddress,
|
||||
TLS: opts.Server.TLS,
|
||||
}
|
||||
|
||||
appServer, err := proxyhttp.NewServer(serverOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not build app server: %v", err)
|
||||
}
|
||||
|
||||
metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{
|
||||
Handler: middleware.DefaultMetricsHandler,
|
||||
BindAddress: opts.MetricsServer.BindAddress,
|
||||
SecureBindAddress: opts.MetricsServer.BindAddress,
|
||||
TLS: opts.MetricsServer.TLS,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not build metrics server: %v", err)
|
||||
}
|
||||
|
||||
p.server = proxyhttp.NewServerGroup(appServer, metricsServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPreAuthChain constructs a chain that should process every request before
|
||||
@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
||||
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
_, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress)
|
||||
if err != nil {
|
||||
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
||||
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err)
|
||||
}
|
||||
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||
}
|
||||
|
@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,18 @@ type AlphaOptions struct {
|
||||
// Headers may source values from either the authenticated user's session
|
||||
// or from a static secret value.
|
||||
InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"`
|
||||
|
||||
// Server is used to configure the HTTP(S) server for the proxy application.
|
||||
// You may choose to run both HTTP and HTTPS servers simultaneously.
|
||||
// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
|
||||
// To use the secure server you must configure a TLS certificate and key.
|
||||
Server Server `json:"server,omitempty"`
|
||||
|
||||
// MetricsServer is used to configure the HTTP(S) server for metrics.
|
||||
// You may choose to run both HTTP and HTTPS servers simultaneously.
|
||||
// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
|
||||
// To use the secure server you must configure a TLS certificate and key.
|
||||
MetricsServer Server `json:"metricsServer,omitempty"`
|
||||
}
|
||||
|
||||
// MergeInto replaces alpha options in the Options struct with the values
|
||||
@ -36,6 +48,8 @@ func (a *AlphaOptions) MergeInto(opts *Options) {
|
||||
opts.UpstreamServers = a.Upstreams
|
||||
opts.InjectRequestHeaders = a.InjectRequestHeaders
|
||||
opts.InjectResponseHeaders = a.InjectResponseHeaders
|
||||
opts.Server = a.Server
|
||||
opts.MetricsServer = a.MetricsServer
|
||||
}
|
||||
|
||||
// ExtractFrom populates the fields in the AlphaOptions with the values from
|
||||
@ -44,4 +58,6 @@ func (a *AlphaOptions) ExtractFrom(opts *Options) {
|
||||
a.Upstreams = opts.UpstreamServers
|
||||
a.InjectRequestHeaders = opts.InjectRequestHeaders
|
||||
a.InjectResponseHeaders = opts.InjectResponseHeaders
|
||||
a.Server = opts.Server
|
||||
a.MetricsServer = opts.MetricsServer
|
||||
}
|
||||
|
@ -18,6 +18,9 @@ type LegacyOptions struct {
|
||||
// Legacy options for injecting request/response headers
|
||||
LegacyHeaders LegacyHeaders `cfg:",squash"`
|
||||
|
||||
// Legacy options for the server address and TLS
|
||||
LegacyServer LegacyServer `cfg:",squash"`
|
||||
|
||||
Options Options `cfg:",squash"`
|
||||
}
|
||||
|
||||
@ -35,6 +38,11 @@ func NewLegacyOptions() *LegacyOptions {
|
||||
SkipAuthStripHeaders: true,
|
||||
},
|
||||
|
||||
LegacyServer: LegacyServer{
|
||||
HTTPAddress: "127.0.0.1:4180",
|
||||
HTTPSAddress: ":443",
|
||||
},
|
||||
|
||||
Options: *NewOptions(),
|
||||
}
|
||||
}
|
||||
@ -44,6 +52,7 @@ func NewLegacyFlagSet() *pflag.FlagSet {
|
||||
|
||||
flagSet.AddFlagSet(legacyUpstreamsFlagSet())
|
||||
flagSet.AddFlagSet(legacyHeadersFlagSet())
|
||||
flagSet.AddFlagSet(legacyServerFlagset())
|
||||
|
||||
return flagSet
|
||||
}
|
||||
@ -56,6 +65,8 @@ func (l *LegacyOptions) ToOptions() (*Options, error) {
|
||||
l.Options.UpstreamServers = upstreams
|
||||
|
||||
l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert()
|
||||
l.Options.Server, l.Options.MetricsServer = l.LegacyServer.convert()
|
||||
|
||||
return &l.Options, nil
|
||||
}
|
||||
|
||||
@ -403,3 +414,69 @@ func getXAuthRequestAccessTokenHeader() Header {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type LegacyServer struct {
|
||||
MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
|
||||
MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_address"`
|
||||
MetricsTLSCertFile string `flag:"metrics-tls-cert-file" cfg:"tls_cert_file"`
|
||||
MetricsTLSKeyFile string `flag:"metrics-tls-key-file" cfg:"tls_key_file"`
|
||||
HTTPAddress string `flag:"http-address" cfg:"http_address"`
|
||||
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
|
||||
TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
|
||||
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
|
||||
}
|
||||
|
||||
func legacyServerFlagset() *pflag.FlagSet {
|
||||
flagSet := pflag.NewFlagSet("server", pflag.ExitOnError)
|
||||
|
||||
flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
|
||||
flagSet.String("metrics-secure-address", "", "the address /metrics will be served on for HTTPS clients (e.g. \":9100\")")
|
||||
flagSet.String("metrics-tls-cert-file", "", "path to certificate file for secure metrics server")
|
||||
flagSet.String("metrics-tls-key-file", "", "path to private key file for secure metrics server")
|
||||
flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients")
|
||||
flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients")
|
||||
flagSet.String("tls-cert-file", "", "path to certificate file")
|
||||
flagSet.String("tls-key-file", "", "path to private key file")
|
||||
|
||||
return flagSet
|
||||
}
|
||||
|
||||
func (l LegacyServer) convert() (Server, Server) {
|
||||
appServer := Server{
|
||||
BindAddress: l.HTTPAddress,
|
||||
SecureBindAddress: l.HTTPSAddress,
|
||||
}
|
||||
if l.TLSKeyFile != "" || l.TLSCertFile != "" {
|
||||
appServer.TLS = &TLS{
|
||||
Key: &SecretSource{
|
||||
FromFile: l.TLSKeyFile,
|
||||
},
|
||||
Cert: &SecretSource{
|
||||
FromFile: l.TLSCertFile,
|
||||
},
|
||||
}
|
||||
// Preserve backwards compatibility, only run one server
|
||||
appServer.BindAddress = ""
|
||||
} else {
|
||||
// Disable the HTTPS server if there's no certificates.
|
||||
// This preserves backwards compatibility.
|
||||
appServer.SecureBindAddress = ""
|
||||
}
|
||||
|
||||
metricsServer := Server{
|
||||
BindAddress: l.MetricsAddress,
|
||||
SecureBindAddress: l.MetricsSecureAddress,
|
||||
}
|
||||
if l.MetricsTLSKeyFile != "" || l.MetricsTLSCertFile != "" {
|
||||
metricsServer.TLS = &TLS{
|
||||
Key: &SecretSource{
|
||||
FromFile: l.MetricsTLSKeyFile,
|
||||
},
|
||||
Cert: &SecretSource{
|
||||
FromFile: l.MetricsTLSCertFile,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return appServer, metricsServer
|
||||
}
|
||||
|
@ -106,6 +106,10 @@ var _ = Describe("Legacy Options", func() {
|
||||
|
||||
opts.InjectResponseHeaders = []Header{}
|
||||
|
||||
opts.Server = Server{
|
||||
BindAddress: "127.0.0.1:4180",
|
||||
}
|
||||
|
||||
converted, err := legacyOpts.ToOptions()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(converted).To(Equal(opts))
|
||||
@ -759,4 +763,93 @@ var _ = Describe("Legacy Options", func() {
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
Context("Legacy Servers", func() {
|
||||
type legacyServersTableInput struct {
|
||||
legacyServer LegacyServer
|
||||
expectedAppServer Server
|
||||
expectedMetricsServer Server
|
||||
}
|
||||
|
||||
const (
|
||||
insecureAddr = "127.0.0.1:8080"
|
||||
insecureMetricsAddr = ":9090"
|
||||
secureAddr = ":443"
|
||||
secureMetricsAddr = ":9443"
|
||||
crtPath = "tls.crt"
|
||||
keyPath = "tls.key"
|
||||
)
|
||||
|
||||
var tlsConfig = &TLS{
|
||||
Cert: &SecretSource{
|
||||
FromFile: crtPath,
|
||||
},
|
||||
Key: &SecretSource{
|
||||
FromFile: keyPath,
|
||||
},
|
||||
}
|
||||
|
||||
DescribeTable("should convert to app and metrics servers",
|
||||
func(in legacyServersTableInput) {
|
||||
appServer, metricsServer := in.legacyServer.convert()
|
||||
Expect(appServer).To(Equal(in.expectedAppServer))
|
||||
Expect(metricsServer).To(Equal(in.expectedMetricsServer))
|
||||
},
|
||||
Entry("with default options only starts app HTTP server", legacyServersTableInput{
|
||||
legacyServer: LegacyServer{
|
||||
HTTPAddress: insecureAddr,
|
||||
HTTPSAddress: secureAddr,
|
||||
},
|
||||
expectedAppServer: Server{
|
||||
BindAddress: insecureAddr,
|
||||
},
|
||||
}),
|
||||
Entry("with TLS options specified only starts app HTTPS server", legacyServersTableInput{
|
||||
legacyServer: LegacyServer{
|
||||
HTTPAddress: insecureAddr,
|
||||
HTTPSAddress: secureAddr,
|
||||
TLSKeyFile: keyPath,
|
||||
TLSCertFile: crtPath,
|
||||
},
|
||||
expectedAppServer: Server{
|
||||
SecureBindAddress: secureAddr,
|
||||
TLS: tlsConfig,
|
||||
},
|
||||
}),
|
||||
Entry("with metrics HTTP and HTTPS addresses", legacyServersTableInput{
|
||||
legacyServer: LegacyServer{
|
||||
HTTPAddress: insecureAddr,
|
||||
HTTPSAddress: secureAddr,
|
||||
MetricsAddress: insecureMetricsAddr,
|
||||
MetricsSecureAddress: secureMetricsAddr,
|
||||
},
|
||||
expectedAppServer: Server{
|
||||
BindAddress: insecureAddr,
|
||||
},
|
||||
expectedMetricsServer: Server{
|
||||
BindAddress: insecureMetricsAddr,
|
||||
SecureBindAddress: secureMetricsAddr,
|
||||
},
|
||||
}),
|
||||
Entry("with metrics HTTPS and tls cert/key", legacyServersTableInput{
|
||||
legacyServer: LegacyServer{
|
||||
HTTPAddress: insecureAddr,
|
||||
HTTPSAddress: secureAddr,
|
||||
MetricsAddress: insecureMetricsAddr,
|
||||
MetricsSecureAddress: secureMetricsAddr,
|
||||
MetricsTLSKeyFile: keyPath,
|
||||
MetricsTLSCertFile: crtPath,
|
||||
},
|
||||
expectedAppServer: Server{
|
||||
BindAddress: insecureAddr,
|
||||
},
|
||||
expectedMetricsServer: Server{
|
||||
BindAddress: insecureMetricsAddr,
|
||||
SecureBindAddress: secureMetricsAddr,
|
||||
TLS: tlsConfig,
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
})
|
||||
})
|
||||
|
@ -22,9 +22,6 @@ type Options struct {
|
||||
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"`
|
||||
PingPath string `flag:"ping-path" cfg:"ping_path"`
|
||||
PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"`
|
||||
MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
|
||||
HTTPAddress string `flag:"http-address" cfg:"http_address"`
|
||||
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
|
||||
ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"`
|
||||
RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"`
|
||||
TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"`
|
||||
@ -33,8 +30,6 @@ type Options struct {
|
||||
ClientID string `flag:"client-id" cfg:"client_id"`
|
||||
ClientSecret string `flag:"client-secret" cfg:"client_secret"`
|
||||
ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file"`
|
||||
TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
|
||||
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
|
||||
|
||||
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
||||
KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"`
|
||||
@ -68,6 +63,9 @@ type Options struct {
|
||||
InjectRequestHeaders []Header `cfg:",internal"`
|
||||
InjectResponseHeaders []Header `cfg:",internal"`
|
||||
|
||||
Server Server `cfg:",internal"`
|
||||
MetricsServer Server `cfg:",internal"`
|
||||
|
||||
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
|
||||
SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"`
|
||||
SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"`
|
||||
@ -136,10 +134,7 @@ func NewOptions() *Options {
|
||||
return &Options{
|
||||
ProxyPrefix: "/oauth2",
|
||||
ProviderType: "google",
|
||||
MetricsAddress: "",
|
||||
PingPath: "/ping",
|
||||
HTTPAddress: "127.0.0.1:4180",
|
||||
HTTPSAddress: ":443",
|
||||
RealClientIPHeader: "X-Real-IP",
|
||||
ForceHTTPS: false,
|
||||
Cookie: cookieDefaults(),
|
||||
@ -162,14 +157,10 @@ func NewOptions() *Options {
|
||||
func NewFlagSet() *pflag.FlagSet {
|
||||
flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError)
|
||||
|
||||
flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients")
|
||||
flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients")
|
||||
flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted")
|
||||
flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)")
|
||||
flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.")
|
||||
flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests")
|
||||
flagSet.String("tls-cert-file", "", "path to certificate file")
|
||||
flagSet.String("tls-key-file", "", "path to private key file")
|
||||
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
|
||||
flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)")
|
||||
flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods")
|
||||
@ -204,7 +195,6 @@ func NewFlagSet() *pflag.FlagSet {
|
||||
flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)")
|
||||
flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks")
|
||||
flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks")
|
||||
flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
|
||||
flagSet.String("session-store-type", "cookie", "the session storage provider to use")
|
||||
flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)")
|
||||
flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])")
|
||||
|
27
pkg/apis/options/server.go
Normal file
27
pkg/apis/options/server.go
Normal file
@ -0,0 +1,27 @@
|
||||
package options
|
||||
|
||||
// Server represents the configuration for an HTTP(S) server
|
||||
type Server struct {
|
||||
// BindAddress is the the address on which to serve traffic.
|
||||
// Leave blank or set to "-" to disable.
|
||||
BindAddress string
|
||||
|
||||
// SecureBindAddress is the the address on which to serve secure traffic.
|
||||
// Leave blank or set to "-" to disable.
|
||||
SecureBindAddress string
|
||||
|
||||
// TLS contains the information for loading the certificate and key for the
|
||||
// secure traffic.
|
||||
TLS *TLS
|
||||
}
|
||||
|
||||
// TLS contains the information for loading a TLS certifcate and key.
|
||||
type TLS struct {
|
||||
// Key is the the TLS key data to use.
|
||||
// Typically this will come from a file.
|
||||
Key *SecretSource
|
||||
|
||||
// Cert is the TLS certificate data to use.
|
||||
// Typically this will come from a file.
|
||||
Cert *SecretSource
|
||||
}
|
88
pkg/http/http_suite_test.go
Normal file
88
pkg/http/http_suite_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var certData []byte
|
||||
var certDataSource, keyDataSource options.SecretSource
|
||||
var client *http.Client
|
||||
|
||||
func TestHTTPSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
logger.SetErrOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "HTTP")
|
||||
}
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
By("Generating a self-signed cert for TLS tests", func() {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
keyOut := bytes.NewBuffer(nil)
|
||||
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed())
|
||||
keyDataSource.Value = keyOut.Bytes()
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"OAuth2 Proxy Test Suite"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
certData = certBytes
|
||||
|
||||
certOut := bytes.NewBuffer(nil)
|
||||
Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed())
|
||||
certDataSource.Value = certOut.Bytes()
|
||||
})
|
||||
|
||||
By("Setting up a http client", func() {
|
||||
cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
certificate, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
certpool := x509.NewCertPool()
|
||||
certpool.AddCert(certificate)
|
||||
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig.RootCAs = certpool
|
||||
|
||||
client = &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
})
|
||||
})
|
245
pkg/http/server.go
Normal file
245
pkg/http/server.go
Normal file
@ -0,0 +1,245 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Server represents an HTTP or HTTPS server.
|
||||
type Server interface {
|
||||
// Start blocks and runs the server.
|
||||
Start(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Opts contains the information required to set up the server.
|
||||
type Opts struct {
|
||||
// Handler is the http.Handler to be used to serve http pages by the server.
|
||||
Handler http.Handler
|
||||
|
||||
// BindAddress is the address the HTTP server should listen on.
|
||||
BindAddress string
|
||||
|
||||
// SecureBindAddress is the address the HTTPS server should listen on.
|
||||
SecureBindAddress string
|
||||
|
||||
// TLS is the TLS configuration for the server.
|
||||
TLS *options.TLS
|
||||
}
|
||||
|
||||
// NewServer creates a new Server from the options given.
|
||||
func NewServer(opts Opts) (Server, error) {
|
||||
s := &server{
|
||||
handler: opts.Handler,
|
||||
}
|
||||
if err := s.setupListener(opts); err != nil {
|
||||
return nil, fmt.Errorf("error setting up listener: %v", err)
|
||||
}
|
||||
if err := s.setupTLSListener(opts); err != nil {
|
||||
return nil, fmt.Errorf("error setting up TLS listener: %v", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// server is an implementation of the Server interface.
|
||||
type server struct {
|
||||
handler http.Handler
|
||||
|
||||
listener net.Listener
|
||||
tlsListener net.Listener
|
||||
}
|
||||
|
||||
// setupListener sets the server listener if the HTTP server is enabled.
|
||||
// The HTTP server can be disabled by setting the BindAddress to "-" or by
|
||||
// leaving it empty.
|
||||
func (s *server) setupListener(opts Opts) error {
|
||||
if opts.BindAddress == "" || opts.BindAddress == "-" {
|
||||
// No HTTP listener required
|
||||
return nil
|
||||
}
|
||||
|
||||
networkType := getNetworkScheme(opts.BindAddress)
|
||||
listenAddr := getListenAddress(opts.BindAddress)
|
||||
|
||||
listener, err := net.Listen(networkType, listenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err)
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
|
||||
// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
|
||||
// leaving it empty.
|
||||
func (s *server) setupTLSListener(opts Opts) error {
|
||||
if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" {
|
||||
// No HTTPS listener required
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
if opts.TLS == nil {
|
||||
return errors.New("no TLS config provided")
|
||||
}
|
||||
cert, err := getCertificate(opts.TLS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not load certificate: %v", err)
|
||||
}
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
|
||||
listenAddr := getListenAddress(opts.SecureBindAddress)
|
||||
|
||||
listener, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
|
||||
}
|
||||
|
||||
s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts the HTTP and HTTPS server if applicable.
|
||||
// It will block until the context is cancelled.
|
||||
// If any errors occur, only the first error will be returned.
|
||||
func (s *server) Start(ctx context.Context) error {
|
||||
g, groupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
if s.listener != nil {
|
||||
g.Go(func() error {
|
||||
if err := s.startServer(groupCtx, s.listener); err != nil {
|
||||
return fmt.Errorf("error starting insecure server: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if s.tlsListener != nil {
|
||||
g.Go(func() error {
|
||||
if err := s.startServer(groupCtx, s.tlsListener); err != nil {
|
||||
return fmt.Errorf("error starting secure server: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// startServer creates and starts a new server with the given listener.
|
||||
// When the given context is cancelled the server will be shutdown.
|
||||
// If any errors occur, only the first error will be returned.
|
||||
func (s *server) startServer(ctx context.Context, listener net.Listener) error {
|
||||
srv := &http.Server{Handler: s.handler}
|
||||
g, groupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
<-groupCtx.Done()
|
||||
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
return fmt.Errorf("error shutting down server: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("could not start server: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// getNetworkScheme gets the scheme for the HTTP server.
|
||||
func getNetworkScheme(addr string) string {
|
||||
var scheme string
|
||||
i := strings.Index(addr, "://")
|
||||
if i > -1 {
|
||||
scheme = addr[0:i]
|
||||
}
|
||||
|
||||
switch scheme {
|
||||
case "", "http":
|
||||
return "tcp"
|
||||
default:
|
||||
return scheme
|
||||
}
|
||||
}
|
||||
|
||||
// getListenAddress gets the address for the HTTP server.
|
||||
func getListenAddress(addr string) string {
|
||||
slice := strings.SplitN(addr, "//", 2)
|
||||
return slice[len(slice)-1]
|
||||
}
|
||||
|
||||
// getCertificate loads the certificate data from the TLS config.
|
||||
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
|
||||
keyData, err := getSecretValue(opts.Key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
|
||||
}
|
||||
|
||||
certData, err := getSecretValue(opts.Cert)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(certData, keyData)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// getSecretValue wraps util.GetSecretValue so that we can return an error if no
|
||||
// source is provided.
|
||||
func getSecretValue(src *options.SecretSource) ([]byte, error) {
|
||||
if src == nil {
|
||||
return nil, errors.New("no configuration provided")
|
||||
}
|
||||
return util.GetSecretValue(src)
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by so that dead TCP connections (e.g. closing laptop
|
||||
// mid-download) eventually go away.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
// Accept implements the TCPListener interface.
|
||||
// It sets the keep alive period to 3 minutes for each connection.
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tc.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
logger.Errorf("Error setting Keep-Alive: %v", err)
|
||||
}
|
||||
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
if err != nil {
|
||||
logger.Printf("Error setting Keep-Alive period: %v", err)
|
||||
}
|
||||
return tc, nil
|
||||
}
|
35
pkg/http/server_group.go
Normal file
35
pkg/http/server_group.go
Normal file
@ -0,0 +1,35 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// NewServerGroup creates a new Server to start and gracefully stop a collection
|
||||
// of Servers.
|
||||
func NewServerGroup(servers ...Server) Server {
|
||||
return &serverGroup{
|
||||
servers: servers,
|
||||
}
|
||||
}
|
||||
|
||||
// serverGroup manages the starting and graceful shutdown of a collection of
|
||||
// servers.
|
||||
type serverGroup struct {
|
||||
servers []Server
|
||||
}
|
||||
|
||||
// Start runs the servers in the server group.
|
||||
func (s *serverGroup) Start(ctx context.Context) error {
|
||||
g, groupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
for _, server := range s.servers {
|
||||
srv := server
|
||||
g.Go(func() error {
|
||||
return srv.Start(groupCtx)
|
||||
})
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
}
|
102
pkg/http/server_group_test.go
Normal file
102
pkg/http/server_group_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Server Group", func() {
|
||||
var m1, m2, m3 *mockServer
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
var group Server
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
m1 = newMockServer()
|
||||
m2 = newMockServer()
|
||||
m3 = newMockServer()
|
||||
group = NewServerGroup(m1, m2, m3)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
It("starts each server in the group", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(group.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
Eventually(m1.started).Should(BeClosed(), "mock server 1 not started")
|
||||
Eventually(m2.started).Should(BeClosed(), "mock server 2 not started")
|
||||
Eventually(m3.started).Should(BeClosed(), "mock server 3 not started")
|
||||
})
|
||||
|
||||
It("stop each server in the group when the context is cancelled", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(group.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
cancel()
|
||||
Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
|
||||
Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
|
||||
Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
|
||||
})
|
||||
|
||||
It("stop each server in the group when the an error occurs", func() {
|
||||
err := errors.New("server error")
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(group.Start(ctx)).To(MatchError(err))
|
||||
}()
|
||||
|
||||
m2.errors <- err
|
||||
Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
|
||||
Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
|
||||
Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
|
||||
})
|
||||
})
|
||||
|
||||
// mockServer is used to test the server group can start
|
||||
// and stop multiple servers simultaneously.
|
||||
type mockServer struct {
|
||||
started chan struct{}
|
||||
startClosed bool
|
||||
stopped chan struct{}
|
||||
stopClosed bool
|
||||
errors chan error
|
||||
}
|
||||
|
||||
func newMockServer() *mockServer {
|
||||
return &mockServer{
|
||||
started: make(chan struct{}),
|
||||
stopped: make(chan struct{}),
|
||||
errors: make(chan error),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockServer) Start(ctx context.Context) error {
|
||||
if !m.startClosed {
|
||||
close(m.started)
|
||||
m.startClosed = true
|
||||
}
|
||||
defer func() {
|
||||
if !m.stopClosed {
|
||||
close(m.stopped)
|
||||
m.stopClosed = true
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case err := <-m.errors:
|
||||
return err
|
||||
}
|
||||
}
|
472
pkg/http/server_test.go
Normal file
472
pkg/http/server_test.go
Normal file
@ -0,0 +1,472 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const hello = "Hello World!"
|
||||
|
||||
var _ = Describe("Server", func() {
|
||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte(hello))
|
||||
})
|
||||
|
||||
Context("NewServer", func() {
|
||||
type newServerTableInput struct {
|
||||
opts Opts
|
||||
expectedErr error
|
||||
expectHTTPListener bool
|
||||
expectTLSListener bool
|
||||
}
|
||||
|
||||
DescribeTable("When creating the new server from the options", func(in *newServerTableInput) {
|
||||
srv, err := NewServer(in.opts)
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error())))
|
||||
Expect(srv).To(BeNil())
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
s, ok := srv.(*server)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
||||
Expect(s.listener != nil).To(Equal(in.expectHTTPListener))
|
||||
if in.expectHTTPListener {
|
||||
Expect(s.listener.Close()).To(Succeed())
|
||||
}
|
||||
Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener))
|
||||
if in.expectTLSListener {
|
||||
Expect(s.tlsListener.Close()).To(Succeed())
|
||||
}
|
||||
},
|
||||
Entry("with a valid http bind address", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "127.0.0.1:0",
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: true,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with a valid https bind address, with no TLS config", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
},
|
||||
expectedErr: errors.New("error setting up TLS listener: no TLS config provided"),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: true,
|
||||
}),
|
||||
Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "127.0.0.1:0",
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: true,
|
||||
expectTLSListener: true,
|
||||
}),
|
||||
Entry("with a \"-\" for the bind address", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "-",
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with a \"-\" for the secure bind address", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "-",
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with an invalid bind address scheme", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "invalid://127.0.0.1:0",
|
||||
},
|
||||
expectedErr: errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with an invalid secure bind address scheme", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "invalid://127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: true,
|
||||
}),
|
||||
Entry("with an invalid bind address port", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "127.0.0.1:a",
|
||||
},
|
||||
expectedErr: errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with an invalid secure bind address port", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:a",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with an invalid TLS key", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &options.SecretSource{
|
||||
Value: []byte("invalid"),
|
||||
},
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with an invalid TLS cert", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &options.SecretSource{
|
||||
Value: []byte("invalid"),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with no TLS key", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("with no TLS cert", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"),
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "http://127.0.0.1:0",
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: true,
|
||||
expectTLSListener: false,
|
||||
}),
|
||||
Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{
|
||||
opts: Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "https://127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectHTTPListener: false,
|
||||
expectTLSListener: true,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
Context("Start", func() {
|
||||
var srv Server
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
Context("with an http server", func() {
|
||||
var listenAddr string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
srv, err = NewServer(Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "127.0.0.1:0",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
s, ok := srv.(*server)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
||||
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
|
||||
})
|
||||
|
||||
It("Starts the server and serves the handler", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
resp, err := client.Get(listenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(hello))
|
||||
})
|
||||
|
||||
It("Stops the server when the context is cancelled", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
_, err := client.Get(listenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cancel()
|
||||
|
||||
Eventually(func() error {
|
||||
_, err := client.Get(listenAddr)
|
||||
return err
|
||||
}).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with an https server", func() {
|
||||
var secureListenAddr string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
srv, err = NewServer(Opts{
|
||||
Handler: handler,
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
s, ok := srv.(*server)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
||||
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
|
||||
})
|
||||
|
||||
It("Starts the server and serves the handler", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
resp, err := client.Get(secureListenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(hello))
|
||||
})
|
||||
|
||||
It("Stops the server when the context is cancelled", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
_, err := client.Get(secureListenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cancel()
|
||||
|
||||
Eventually(func() error {
|
||||
_, err := client.Get(secureListenAddr)
|
||||
return err
|
||||
}).Should(HaveOccurred())
|
||||
})
|
||||
|
||||
It("Serves the certificate provided", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
resp, err := client.Get(secureListenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
Expect(resp.TLS.VerifiedChains).Should(HaveLen(1))
|
||||
Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1))
|
||||
Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with both an http and an https server", func() {
|
||||
var listenAddr, secureListenAddr string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
srv, err = NewServer(Opts{
|
||||
Handler: handler,
|
||||
BindAddress: "127.0.0.1:0",
|
||||
SecureBindAddress: "127.0.0.1:0",
|
||||
TLS: &options.TLS{
|
||||
Key: &keyDataSource,
|
||||
Cert: &certDataSource,
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
s, ok := srv.(*server)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
||||
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
|
||||
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
|
||||
})
|
||||
|
||||
It("Starts the server and serves the handler on http", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
resp, err := client.Get(listenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(hello))
|
||||
})
|
||||
|
||||
It("Starts the server and serves the handler on https", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
resp, err := client.Get(secureListenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(hello))
|
||||
})
|
||||
|
||||
It("Stops both servers when the context is cancelled", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Expect(srv.Start(ctx)).To(Succeed())
|
||||
}()
|
||||
|
||||
_, err := client.Get(listenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.Get(secureListenAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cancel()
|
||||
|
||||
Eventually(func() error {
|
||||
_, err := client.Get(listenAddr)
|
||||
return err
|
||||
}).Should(HaveOccurred())
|
||||
Eventually(func() error {
|
||||
_, err := client.Get(secureListenAddr)
|
||||
return err
|
||||
}).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("getNetworkScheme", func() {
|
||||
DescribeTable("should return the scheme", func(in, expected string) {
|
||||
Expect(getNetworkScheme(in)).To(Equal(expected))
|
||||
},
|
||||
Entry("with no scheme", "127.0.0.1:0", "tcp"),
|
||||
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"),
|
||||
Entry("with a http scheme", "http://192.168.0.1:1", "tcp"),
|
||||
Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"),
|
||||
Entry("with a random scheme", "random://10.10.10.10:10", "random"),
|
||||
)
|
||||
})
|
||||
|
||||
Context("getListenAddress", func() {
|
||||
DescribeTable("should remove the scheme", func(in, expected string) {
|
||||
Expect(getListenAddress(in)).To(Equal(expected))
|
||||
},
|
||||
Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"),
|
||||
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"),
|
||||
Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"),
|
||||
Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"),
|
||||
Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"),
|
||||
)
|
||||
})
|
||||
})
|
Loading…
x
Reference in New Issue
Block a user