1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-27 22:01:28 +02:00

Merge pull request from oauth2-proxy/http-server

Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers
This commit is contained in:
Joel Speed 2021-03-07 21:12:02 +00:00 committed by GitHub
commit 6894738d97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1258 additions and 248 deletions

@ -8,6 +8,7 @@
## Changes since v7.0.1 ## Changes since v7.0.1
- [#1047](https://github.com/oauth2-proxy/oauth2-proxy/pull/1047) Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers (@JoelSpeed)
- [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) - [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves)
- [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich)
- [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed)

@ -117,6 +117,8 @@ They may change between releases without notice.
| `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.<br/>Once a user is authenticated, requests to the server will be proxied to<br/>these upstream servers based on the path mappings defined in this list. | | `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.<br/>Once a user is authenticated, requests to the server will be proxied to<br/>these upstream servers based on the path mappings defined in this list. |
| `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added<br/>to requests to upstream servers.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added<br/>to requests to upstream servers.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. |
| `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added<br/>to responses from the proxy.<br/>This is typically used when using the proxy as an external authentication<br/>provider in conjunction with another proxy such as NGINX and its<br/>auth_request module.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added<br/>to responses from the proxy.<br/>This is typically used when using the proxy as an external authentication<br/>provider in conjunction with another proxy such as NGINX and its<br/>auth_request module.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. |
| `server` | _[Server](#server)_ | Server is used to configure the HTTP(S) server for the proxy application.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. |
| `metricsServer` | _[Server](#server)_ | MetricsServer is used to configure the HTTP(S) server for metrics.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. |
### ClaimSource ### ClaimSource
@ -172,7 +174,7 @@ make up the header value
### SecretSource ### SecretSource
(**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue)) (**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue), [TLS](#tls))
SecretSource references an individual secret value. SecretSource references an individual secret value.
Only one source within the struct should be defined at any time. Only one source within the struct should be defined at any time.
@ -183,6 +185,29 @@ Only one source within the struct should be defined at any time.
| `fromEnv` | _string_ | FromEnv expects the name of an environment variable. | | `fromEnv` | _string_ | FromEnv expects the name of an environment variable. |
| `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. | | `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. |
### Server
(**Appears on:** [AlphaOptions](#alphaoptions))
Server represents the configuration for an HTTP(S) server
| Field | Type | Description |
| ----- | ---- | ----------- |
| `BindAddress` | _string_ | BindAddress is the the address on which to serve traffic.<br/>Leave blank or set to "-" to disable. |
| `SecureBindAddress` | _string_ | SecureBindAddress is the the address on which to serve secure traffic.<br/>Leave blank or set to "-" to disable. |
| `TLS` | _[TLS](#tls)_ | TLS contains the information for loading the certificate and key for the<br/>secure traffic. |
### TLS
(**Appears on:** [Server](#server))
TLS contains the information for loading a TLS certifcate and key.
| Field | Type | Description |
| ----- | ---- | ----------- |
| `Key` | _[SecretSource](#secretsource)_ | Key is the the TLS key data to use.<br/>Typically this will come from a file. |
| `Cert` | _[SecretSource](#secretsource)_ | Cert is the TLS certificate data to use.<br/>Typically this will come from a file. |
### Upstream ### Upstream
(**Appears on:** [Upstreams](#upstreams)) (**Appears on:** [Upstreams](#upstreams))

1
go.mod

@ -30,6 +30,7 @@ require (
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/net v0.0.0-20200707034311-ab3426394381 golang.org/x/net v0.0.0-20200707034311-ab3426394381
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
google.golang.org/api v0.20.0 google.golang.org/api v0.20.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/square/go-jose.v2 v2.4.1 gopkg.in/square/go-jose.v2 v2.4.1

3
go.sum

@ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

136
http.go

@ -1,136 +0,0 @@
package main
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
)
// Server represents an HTTP server
type Server struct {
Handler http.Handler
Opts *options.Options
stop chan struct{} // channel for waiting shutdown
}
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
func (s *Server) ListenAndServe() {
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
s.ServeHTTPS()
} else {
s.ServeHTTP()
}
}
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
func (s *Server) ServeHTTP() {
HTTPAddress := s.Opts.HTTPAddress
var scheme string
i := strings.Index(HTTPAddress, "://")
if i > -1 {
scheme = HTTPAddress[0:i]
}
var networkType string
switch scheme {
case "", "http":
networkType = "tcp"
default:
networkType = scheme
}
slice := strings.SplitN(HTTPAddress, "//", 2)
listenAddr := slice[len(slice)-1]
listener, err := net.Listen(networkType, listenAddr)
if err != nil {
logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
}
logger.Printf("HTTP: listening on %s", listenAddr)
s.serve(listener)
logger.Printf("HTTP: closing %s", listener.Addr())
}
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
func (s *Server) ServeHTTPS() {
addr := s.Opts.HTTPSAddress
config := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
}
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
var err error
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
if err != nil {
logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
}
ln, err := net.Listen("tcp", addr)
if err != nil {
logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
}
logger.Printf("HTTPS: listening on %s", ln.Addr())
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
s.serve(tlsListener)
logger.Printf("HTTPS: closing %s", tlsListener.Addr())
}
func (s *Server) serve(listener net.Listener) {
srv := &http.Server{Handler: s.Handler}
// See https://golang.org/pkg/net/http/#Server.Shutdown
idleConnsClosed := make(chan struct{})
go func() {
<-s.stop // wait notification for stopping server
// We received an interrupt signal, shut down.
if err := srv.Shutdown(context.Background()); err != nil {
// Error from closing listeners, or context timeout:
logger.Printf("HTTP server Shutdown: %v", err)
}
close(idleConnsClosed)
}()
err := srv.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Errorf("ERROR: http.Serve() - %s", err)
}
<-idleConnsClosed
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
err = tc.SetKeepAlive(true)
if err != nil {
logger.Printf("Error setting Keep-Alive: %v", err)
}
err = tc.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
logger.Printf("Error setting Keep-Alive period: %v", err)
}
return tc, nil
}

@ -1,39 +0,0 @@
package main
import (
"net/http"
"sync"
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/stretchr/testify/assert"
)
func TestGracefulShutdown(t *testing.T) {
opts := options.NewOptions()
stop := make(chan struct{}, 1)
srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.ServeHTTP()
}()
stop <- struct{}{} // emulate catching signals
// An idiomatic for sync.WaitGroup with timeout
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
case <-time.After(1 * time.Second):
t.Fatal("Server should return gracefully but timeout has occurred")
}
assert.Len(t, stop, 0) // check if stop chan is empty
}

54
main.go

@ -1,20 +1,15 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"net/http"
"os" "os"
"os/signal"
"runtime" "runtime"
"syscall"
"time" "time"
"github.com/ghodss/yaml" "github.com/ghodss/yaml"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
"github.com/spf13/pflag" "github.com/spf13/pflag"
) )
@ -67,54 +62,9 @@ func main() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
oauthProxyStop := make(chan struct{}, 1) if err := oauthproxy.Start(); err != nil {
metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop) logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
s := &Server{
Handler: oauthproxy,
Opts: opts,
stop: oauthProxyStop,
} }
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
s.stop <- struct{}{} // notify having caught signal stop oauthproxy
close(metricsStop) // and the metrics endpoint
}()
s.ListenAndServe()
}
// startMetricsServer will start the metrics server on the specified address.
// It always return a channel to signal stop even when it does not run.
func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} {
stop := make(chan struct{}, 1)
// Attempt to setup the metrics endpoint if we have an address
if address != "" {
s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler}
go func() {
// ListenAndServe always returns a non-nil error. After Shutdown or
// Close, the returned error is ErrServerClosed
if err := s.ListenAndServe(); err != http.ErrServerClosed {
logger.Println(err)
// Stop the metrics shutdown go routine
close(stop)
// Stop the oauthproxy server, we have encounter an unexpected error
close(oauthProxyStop)
}
}()
go func() {
<-stop
if err := s.Shutdown(context.Background()); err != nil {
logger.Print(err)
}
}()
}
return stop
} }
// loadConfiguration will load in the user's configuration. // loadConfiguration will load in the user's configuration.

@ -15,6 +15,7 @@ import (
var _ = Describe("Configuration Loading Suite", func() { var _ = Describe("Configuration Loading Suite", func() {
const testLegacyConfig = ` const testLegacyConfig = `
http_address="127.0.0.1:4180"
upstreams="http://httpbin" upstreams="http://httpbin"
set_basic_auth="true" set_basic_auth="true"
basic_auth_password="super-secret-password" basic_auth_password="super-secret-password"
@ -54,10 +55,11 @@ injectResponseHeaders:
prefix: "Basic " prefix: "Basic "
basicAuthPassword: basicAuthPassword:
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
server:
bindAddress: "127.0.0.1:4180"
` `
const testCoreConfig = ` const testCoreConfig = `
http_address="0.0.0.0:4180"
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
provider="oidc" provider="oidc"
email_domains="example.com" email_domains="example.com"
@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback"
opts, err := options.NewLegacyOptions().ToOptions() opts, err := options.NewLegacyOptions().ToOptions()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
opts.HTTPAddress = "0.0.0.0:4180"
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
opts.ProviderType = "oidc" opts.ProviderType = "oidc"
opts.EmailDomains = []string{"example.com"} opts.EmailDomains = []string{"example.com"}
@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
configContent: testCoreConfig, configContent: testCoreConfig,
alphaConfigContent: testAlphaConfig + ":", alphaConfigContent: testAlphaConfig + ":",
expectedOptions: func() *options.Options { return nil }, expectedOptions: func() *options.Options { return nil },
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"), expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"),
}), }),
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{ Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
configContent: testCoreConfig + "unknown_field=\"something\"", configContent: testCoreConfig + "unknown_field=\"something\"",

@ -8,8 +8,11 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"os/signal"
"regexp" "regexp"
"strings" "strings"
"syscall"
"time" "time"
"github.com/justinas/alice" "github.com/justinas/alice"
@ -21,6 +24,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
@ -102,6 +106,7 @@ type OAuthProxy struct {
headersChain alice.Chain headersChain alice.Chain
preAuthChain alice.Chain preAuthChain alice.Chain
pageWriter pagewriter.Writer pageWriter pagewriter.Writer
server proxyhttp.Server
} }
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, fmt.Errorf("could not build headers chain: %v", err) return nil, fmt.Errorf("could not build headers chain: %v", err)
} }
return &OAuthProxy{ p := &OAuthProxy{
CookieName: opts.Cookie.Name, CookieName: opts.Cookie.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
CookieSeed: opts.Cookie.Secret, CookieSeed: opts.Cookie.Secret,
@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
headersChain: headersChain, headersChain: headersChain,
preAuthChain: preAuthChain, preAuthChain: preAuthChain,
pageWriter: pageWriter, pageWriter: pageWriter,
}, nil }
if err := p.setupServer(opts); err != nil {
return nil, fmt.Errorf("error setting up server: %v", err)
}
return p, nil
}
func (p *OAuthProxy) Start() error {
if p.server == nil {
// We have to call setupServer before Start is called.
// If this doesn't happen it's a programming error.
panic("server has not been initialised")
}
ctx, cancel := context.WithCancel(context.Background())
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
cancel() // cancel the context
}()
return p.server.Start(ctx)
}
func (p *OAuthProxy) setupServer(opts *options.Options) error {
serverOpts := proxyhttp.Opts{
Handler: p,
BindAddress: opts.Server.BindAddress,
SecureBindAddress: opts.Server.SecureBindAddress,
TLS: opts.Server.TLS,
}
appServer, err := proxyhttp.NewServer(serverOpts)
if err != nil {
return fmt.Errorf("could not build app server: %v", err)
}
metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{
Handler: middleware.DefaultMetricsHandler,
BindAddress: opts.MetricsServer.BindAddress,
SecureBindAddress: opts.MetricsServer.BindAddress,
TLS: opts.MetricsServer.TLS,
})
if err != nil {
return fmt.Errorf("could not build metrics server: %v", err)
}
p.server = proxyhttp.NewServerGroup(appServer, metricsServer)
return nil
} }
// buildPreAuthChain constructs a chain that should process every request before // buildPreAuthChain constructs a chain that should process every request before
@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts.ReverseProxy)) chain := alice.New(middleware.NewScope(opts.ReverseProxy))
if opts.ForceHTTPS { if opts.ForceHTTPS {
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) _, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress)
if err != nil { if err != nil {
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err) return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err)
} }
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
} }

@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options {
}, },
}, },
} }
return opts return opts
} }

@ -28,6 +28,18 @@ type AlphaOptions struct {
// Headers may source values from either the authenticated user's session // Headers may source values from either the authenticated user's session
// or from a static secret value. // or from a static secret value.
InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"` InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"`
// Server is used to configure the HTTP(S) server for the proxy application.
// You may choose to run both HTTP and HTTPS servers simultaneously.
// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
// To use the secure server you must configure a TLS certificate and key.
Server Server `json:"server,omitempty"`
// MetricsServer is used to configure the HTTP(S) server for metrics.
// You may choose to run both HTTP and HTTPS servers simultaneously.
// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
// To use the secure server you must configure a TLS certificate and key.
MetricsServer Server `json:"metricsServer,omitempty"`
} }
// MergeInto replaces alpha options in the Options struct with the values // MergeInto replaces alpha options in the Options struct with the values
@ -36,6 +48,8 @@ func (a *AlphaOptions) MergeInto(opts *Options) {
opts.UpstreamServers = a.Upstreams opts.UpstreamServers = a.Upstreams
opts.InjectRequestHeaders = a.InjectRequestHeaders opts.InjectRequestHeaders = a.InjectRequestHeaders
opts.InjectResponseHeaders = a.InjectResponseHeaders opts.InjectResponseHeaders = a.InjectResponseHeaders
opts.Server = a.Server
opts.MetricsServer = a.MetricsServer
} }
// ExtractFrom populates the fields in the AlphaOptions with the values from // ExtractFrom populates the fields in the AlphaOptions with the values from
@ -44,4 +58,6 @@ func (a *AlphaOptions) ExtractFrom(opts *Options) {
a.Upstreams = opts.UpstreamServers a.Upstreams = opts.UpstreamServers
a.InjectRequestHeaders = opts.InjectRequestHeaders a.InjectRequestHeaders = opts.InjectRequestHeaders
a.InjectResponseHeaders = opts.InjectResponseHeaders a.InjectResponseHeaders = opts.InjectResponseHeaders
a.Server = opts.Server
a.MetricsServer = opts.MetricsServer
} }

@ -18,6 +18,9 @@ type LegacyOptions struct {
// Legacy options for injecting request/response headers // Legacy options for injecting request/response headers
LegacyHeaders LegacyHeaders `cfg:",squash"` LegacyHeaders LegacyHeaders `cfg:",squash"`
// Legacy options for the server address and TLS
LegacyServer LegacyServer `cfg:",squash"`
Options Options `cfg:",squash"` Options Options `cfg:",squash"`
} }
@ -35,6 +38,11 @@ func NewLegacyOptions() *LegacyOptions {
SkipAuthStripHeaders: true, SkipAuthStripHeaders: true,
}, },
LegacyServer: LegacyServer{
HTTPAddress: "127.0.0.1:4180",
HTTPSAddress: ":443",
},
Options: *NewOptions(), Options: *NewOptions(),
} }
} }
@ -44,6 +52,7 @@ func NewLegacyFlagSet() *pflag.FlagSet {
flagSet.AddFlagSet(legacyUpstreamsFlagSet()) flagSet.AddFlagSet(legacyUpstreamsFlagSet())
flagSet.AddFlagSet(legacyHeadersFlagSet()) flagSet.AddFlagSet(legacyHeadersFlagSet())
flagSet.AddFlagSet(legacyServerFlagset())
return flagSet return flagSet
} }
@ -56,6 +65,8 @@ func (l *LegacyOptions) ToOptions() (*Options, error) {
l.Options.UpstreamServers = upstreams l.Options.UpstreamServers = upstreams
l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert() l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert()
l.Options.Server, l.Options.MetricsServer = l.LegacyServer.convert()
return &l.Options, nil return &l.Options, nil
} }
@ -403,3 +414,69 @@ func getXAuthRequestAccessTokenHeader() Header {
}, },
} }
} }
type LegacyServer struct {
MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_address"`
MetricsTLSCertFile string `flag:"metrics-tls-cert-file" cfg:"tls_cert_file"`
MetricsTLSKeyFile string `flag:"metrics-tls-key-file" cfg:"tls_key_file"`
HTTPAddress string `flag:"http-address" cfg:"http_address"`
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
}
func legacyServerFlagset() *pflag.FlagSet {
flagSet := pflag.NewFlagSet("server", pflag.ExitOnError)
flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
flagSet.String("metrics-secure-address", "", "the address /metrics will be served on for HTTPS clients (e.g. \":9100\")")
flagSet.String("metrics-tls-cert-file", "", "path to certificate file for secure metrics server")
flagSet.String("metrics-tls-key-file", "", "path to private key file for secure metrics server")
flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients")
flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients")
flagSet.String("tls-cert-file", "", "path to certificate file")
flagSet.String("tls-key-file", "", "path to private key file")
return flagSet
}
func (l LegacyServer) convert() (Server, Server) {
appServer := Server{
BindAddress: l.HTTPAddress,
SecureBindAddress: l.HTTPSAddress,
}
if l.TLSKeyFile != "" || l.TLSCertFile != "" {
appServer.TLS = &TLS{
Key: &SecretSource{
FromFile: l.TLSKeyFile,
},
Cert: &SecretSource{
FromFile: l.TLSCertFile,
},
}
// Preserve backwards compatibility, only run one server
appServer.BindAddress = ""
} else {
// Disable the HTTPS server if there's no certificates.
// This preserves backwards compatibility.
appServer.SecureBindAddress = ""
}
metricsServer := Server{
BindAddress: l.MetricsAddress,
SecureBindAddress: l.MetricsSecureAddress,
}
if l.MetricsTLSKeyFile != "" || l.MetricsTLSCertFile != "" {
metricsServer.TLS = &TLS{
Key: &SecretSource{
FromFile: l.MetricsTLSKeyFile,
},
Cert: &SecretSource{
FromFile: l.MetricsTLSCertFile,
},
}
}
return appServer, metricsServer
}

@ -106,6 +106,10 @@ var _ = Describe("Legacy Options", func() {
opts.InjectResponseHeaders = []Header{} opts.InjectResponseHeaders = []Header{}
opts.Server = Server{
BindAddress: "127.0.0.1:4180",
}
converted, err := legacyOpts.ToOptions() converted, err := legacyOpts.ToOptions()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(converted).To(Equal(opts)) Expect(converted).To(Equal(opts))
@ -759,4 +763,93 @@ var _ = Describe("Legacy Options", func() {
}), }),
) )
}) })
Context("Legacy Servers", func() {
type legacyServersTableInput struct {
legacyServer LegacyServer
expectedAppServer Server
expectedMetricsServer Server
}
const (
insecureAddr = "127.0.0.1:8080"
insecureMetricsAddr = ":9090"
secureAddr = ":443"
secureMetricsAddr = ":9443"
crtPath = "tls.crt"
keyPath = "tls.key"
)
var tlsConfig = &TLS{
Cert: &SecretSource{
FromFile: crtPath,
},
Key: &SecretSource{
FromFile: keyPath,
},
}
DescribeTable("should convert to app and metrics servers",
func(in legacyServersTableInput) {
appServer, metricsServer := in.legacyServer.convert()
Expect(appServer).To(Equal(in.expectedAppServer))
Expect(metricsServer).To(Equal(in.expectedMetricsServer))
},
Entry("with default options only starts app HTTP server", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
},
expectedAppServer: Server{
BindAddress: insecureAddr,
},
}),
Entry("with TLS options specified only starts app HTTPS server", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
TLSKeyFile: keyPath,
TLSCertFile: crtPath,
},
expectedAppServer: Server{
SecureBindAddress: secureAddr,
TLS: tlsConfig,
},
}),
Entry("with metrics HTTP and HTTPS addresses", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
MetricsAddress: insecureMetricsAddr,
MetricsSecureAddress: secureMetricsAddr,
},
expectedAppServer: Server{
BindAddress: insecureAddr,
},
expectedMetricsServer: Server{
BindAddress: insecureMetricsAddr,
SecureBindAddress: secureMetricsAddr,
},
}),
Entry("with metrics HTTPS and tls cert/key", legacyServersTableInput{
legacyServer: LegacyServer{
HTTPAddress: insecureAddr,
HTTPSAddress: secureAddr,
MetricsAddress: insecureMetricsAddr,
MetricsSecureAddress: secureMetricsAddr,
MetricsTLSKeyFile: keyPath,
MetricsTLSCertFile: crtPath,
},
expectedAppServer: Server{
BindAddress: insecureAddr,
},
expectedMetricsServer: Server{
BindAddress: insecureMetricsAddr,
SecureBindAddress: secureMetricsAddr,
TLS: tlsConfig,
},
}),
)
})
}) })

@ -22,9 +22,6 @@ type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"` ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"`
PingPath string `flag:"ping-path" cfg:"ping_path"` PingPath string `flag:"ping-path" cfg:"ping_path"`
PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"` PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"`
MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"`
HTTPAddress string `flag:"http-address" cfg:"http_address"`
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"` ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"`
RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"` RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"`
TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"` TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"`
@ -33,8 +30,6 @@ type Options struct {
ClientID string `flag:"client-id" cfg:"client_id"` ClientID string `flag:"client-id" cfg:"client_id"`
ClientSecret string `flag:"client-secret" cfg:"client_secret"` ClientSecret string `flag:"client-secret" cfg:"client_secret"`
ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file"` ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file"`
TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"`
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"` KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"`
@ -68,6 +63,9 @@ type Options struct {
InjectRequestHeaders []Header `cfg:",internal"` InjectRequestHeaders []Header `cfg:",internal"`
InjectResponseHeaders []Header `cfg:",internal"` InjectResponseHeaders []Header `cfg:",internal"`
Server Server `cfg:",internal"`
MetricsServer Server `cfg:",internal"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"` SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"`
SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"`
@ -136,10 +134,7 @@ func NewOptions() *Options {
return &Options{ return &Options{
ProxyPrefix: "/oauth2", ProxyPrefix: "/oauth2",
ProviderType: "google", ProviderType: "google",
MetricsAddress: "",
PingPath: "/ping", PingPath: "/ping",
HTTPAddress: "127.0.0.1:4180",
HTTPSAddress: ":443",
RealClientIPHeader: "X-Real-IP", RealClientIPHeader: "X-Real-IP",
ForceHTTPS: false, ForceHTTPS: false,
Cookie: cookieDefaults(), Cookie: cookieDefaults(),
@ -162,14 +157,10 @@ func NewOptions() *Options {
func NewFlagSet() *pflag.FlagSet { func NewFlagSet() *pflag.FlagSet {
flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError) flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError)
flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients")
flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients")
flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted") flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted")
flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)") flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)")
flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.") flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.")
flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests") flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests")
flagSet.String("tls-cert-file", "", "path to certificate file")
flagSet.String("tls-key-file", "", "path to private key file")
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)") flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)")
flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods") flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods")
@ -204,7 +195,6 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)")
flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks")
flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks")
flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")")
flagSet.String("session-store-type", "cookie", "the session storage provider to use") flagSet.String("session-store-type", "cookie", "the session storage provider to use")
flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)")
flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])")

@ -0,0 +1,27 @@
package options
// Server represents the configuration for an HTTP(S) server
type Server struct {
// BindAddress is the the address on which to serve traffic.
// Leave blank or set to "-" to disable.
BindAddress string
// SecureBindAddress is the the address on which to serve secure traffic.
// Leave blank or set to "-" to disable.
SecureBindAddress string
// TLS contains the information for loading the certificate and key for the
// secure traffic.
TLS *TLS
}
// TLS contains the information for loading a TLS certifcate and key.
type TLS struct {
// Key is the the TLS key data to use.
// Typically this will come from a file.
Key *SecretSource
// Cert is the TLS certificate data to use.
// Typically this will come from a file.
Cert *SecretSource
}

@ -0,0 +1,88 @@
package http
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/http"
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var certData []byte
var certDataSource, keyDataSource options.SecretSource
var client *http.Client
func TestHTTPSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "HTTP")
}
var _ = BeforeSuite(func() {
By("Generating a self-signed cert for TLS tests", func() {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
keyOut := bytes.NewBuffer(nil)
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
Expect(err).ToNot(HaveOccurred())
Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed())
keyDataSource.Value = keyOut.Bytes()
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
Expect(err).ToNot(HaveOccurred())
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"OAuth2 Proxy Test Suite"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
Expect(err).ToNot(HaveOccurred())
certData = certBytes
certOut := bytes.NewBuffer(nil)
Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed())
certDataSource.Value = certOut.Bytes()
})
By("Setting up a http client", func() {
cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value)
Expect(err).ToNot(HaveOccurred())
certificate, err := x509.ParseCertificate(cert.Certificate[0])
Expect(err).ToNot(HaveOccurred())
certpool := x509.NewCertPool()
certpool.AddCert(certificate)
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig.RootCAs = certpool
client = &http.Client{
Transport: transport,
}
})
})

245
pkg/http/server.go Normal file

@ -0,0 +1,245 @@
package http
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"golang.org/x/sync/errgroup"
)
// Server represents an HTTP or HTTPS server.
type Server interface {
// Start blocks and runs the server.
Start(ctx context.Context) error
}
// Opts contains the information required to set up the server.
type Opts struct {
// Handler is the http.Handler to be used to serve http pages by the server.
Handler http.Handler
// BindAddress is the address the HTTP server should listen on.
BindAddress string
// SecureBindAddress is the address the HTTPS server should listen on.
SecureBindAddress string
// TLS is the TLS configuration for the server.
TLS *options.TLS
}
// NewServer creates a new Server from the options given.
func NewServer(opts Opts) (Server, error) {
s := &server{
handler: opts.Handler,
}
if err := s.setupListener(opts); err != nil {
return nil, fmt.Errorf("error setting up listener: %v", err)
}
if err := s.setupTLSListener(opts); err != nil {
return nil, fmt.Errorf("error setting up TLS listener: %v", err)
}
return s, nil
}
// server is an implementation of the Server interface.
type server struct {
handler http.Handler
listener net.Listener
tlsListener net.Listener
}
// setupListener sets the server listener if the HTTP server is enabled.
// The HTTP server can be disabled by setting the BindAddress to "-" or by
// leaving it empty.
func (s *server) setupListener(opts Opts) error {
if opts.BindAddress == "" || opts.BindAddress == "-" {
// No HTTP listener required
return nil
}
networkType := getNetworkScheme(opts.BindAddress)
listenAddr := getListenAddress(opts.BindAddress)
listener, err := net.Listen(networkType, listenAddr)
if err != nil {
return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err)
}
s.listener = listener
return nil
}
// setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
// leaving it empty.
func (s *server) setupTLSListener(opts Opts) error {
if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" {
// No HTTPS listener required
return nil
}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
NextProtos: []string{"http/1.1"},
}
if opts.TLS == nil {
return errors.New("no TLS config provided")
}
cert, err := getCertificate(opts.TLS)
if err != nil {
return fmt.Errorf("could not load certificate: %v", err)
}
config.Certificates = []tls.Certificate{cert}
listenAddr := getListenAddress(opts.SecureBindAddress)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
}
s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
return nil
}
// Start starts the HTTP and HTTPS server if applicable.
// It will block until the context is cancelled.
// If any errors occur, only the first error will be returned.
func (s *server) Start(ctx context.Context) error {
g, groupCtx := errgroup.WithContext(ctx)
if s.listener != nil {
g.Go(func() error {
if err := s.startServer(groupCtx, s.listener); err != nil {
return fmt.Errorf("error starting insecure server: %v", err)
}
return nil
})
}
if s.tlsListener != nil {
g.Go(func() error {
if err := s.startServer(groupCtx, s.tlsListener); err != nil {
return fmt.Errorf("error starting secure server: %v", err)
}
return nil
})
}
return g.Wait()
}
// startServer creates and starts a new server with the given listener.
// When the given context is cancelled the server will be shutdown.
// If any errors occur, only the first error will be returned.
func (s *server) startServer(ctx context.Context, listener net.Listener) error {
srv := &http.Server{Handler: s.handler}
g, groupCtx := errgroup.WithContext(ctx)
g.Go(func() error {
<-groupCtx.Done()
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("error shutting down server: %v", err)
}
return nil
})
g.Go(func() error {
if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("could not start server: %v", err)
}
return nil
})
return g.Wait()
}
// getNetworkScheme gets the scheme for the HTTP server.
func getNetworkScheme(addr string) string {
var scheme string
i := strings.Index(addr, "://")
if i > -1 {
scheme = addr[0:i]
}
switch scheme {
case "", "http":
return "tcp"
default:
return scheme
}
}
// getListenAddress gets the address for the HTTP server.
func getListenAddress(addr string) string {
slice := strings.SplitN(addr, "//", 2)
return slice[len(slice)-1]
}
// getCertificate loads the certificate data from the TLS config.
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
keyData, err := getSecretValue(opts.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
}
certData, err := getSecretValue(opts.Cert)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
}
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
}
return cert, nil
}
// getSecretValue wraps util.GetSecretValue so that we can return an error if no
// source is provided.
func getSecretValue(src *options.SecretSource) ([]byte, error) {
if src == nil {
return nil, errors.New("no configuration provided")
}
return util.GetSecretValue(src)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by so that dead TCP connections (e.g. closing laptop
// mid-download) eventually go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
// Accept implements the TCPListener interface.
// It sets the keep alive period to 3 minutes for each connection.
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
err = tc.SetKeepAlive(true)
if err != nil {
logger.Errorf("Error setting Keep-Alive: %v", err)
}
err = tc.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
logger.Printf("Error setting Keep-Alive period: %v", err)
}
return tc, nil
}

35
pkg/http/server_group.go Normal file

@ -0,0 +1,35 @@
package http
import (
"context"
"golang.org/x/sync/errgroup"
)
// NewServerGroup creates a new Server to start and gracefully stop a collection
// of Servers.
func NewServerGroup(servers ...Server) Server {
return &serverGroup{
servers: servers,
}
}
// serverGroup manages the starting and graceful shutdown of a collection of
// servers.
type serverGroup struct {
servers []Server
}
// Start runs the servers in the server group.
func (s *serverGroup) Start(ctx context.Context) error {
g, groupCtx := errgroup.WithContext(ctx)
for _, server := range s.servers {
srv := server
g.Go(func() error {
return srv.Start(groupCtx)
})
}
return g.Wait()
}

@ -0,0 +1,102 @@
package http
import (
"context"
"errors"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Server Group", func() {
var m1, m2, m3 *mockServer
var ctx context.Context
var cancel context.CancelFunc
var group Server
BeforeEach(func() {
ctx, cancel = context.WithCancel(context.Background())
m1 = newMockServer()
m2 = newMockServer()
m3 = newMockServer()
group = NewServerGroup(m1, m2, m3)
})
AfterEach(func() {
cancel()
})
It("starts each server in the group", func() {
go func() {
defer GinkgoRecover()
Expect(group.Start(ctx)).To(Succeed())
}()
Eventually(m1.started).Should(BeClosed(), "mock server 1 not started")
Eventually(m2.started).Should(BeClosed(), "mock server 2 not started")
Eventually(m3.started).Should(BeClosed(), "mock server 3 not started")
})
It("stop each server in the group when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(group.Start(ctx)).To(Succeed())
}()
cancel()
Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
})
It("stop each server in the group when the an error occurs", func() {
err := errors.New("server error")
go func() {
defer GinkgoRecover()
Expect(group.Start(ctx)).To(MatchError(err))
}()
m2.errors <- err
Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped")
Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped")
Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped")
})
})
// mockServer is used to test the server group can start
// and stop multiple servers simultaneously.
type mockServer struct {
started chan struct{}
startClosed bool
stopped chan struct{}
stopClosed bool
errors chan error
}
func newMockServer() *mockServer {
return &mockServer{
started: make(chan struct{}),
stopped: make(chan struct{}),
errors: make(chan error),
}
}
func (m *mockServer) Start(ctx context.Context) error {
if !m.startClosed {
close(m.started)
m.startClosed = true
}
defer func() {
if !m.stopClosed {
close(m.stopped)
m.stopClosed = true
}
}()
select {
case <-ctx.Done():
return nil
case err := <-m.errors:
return err
}
}

472
pkg/http/server_test.go Normal file

@ -0,0 +1,472 @@
package http
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
const hello = "Hello World!"
var _ = Describe("Server", func() {
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte(hello))
})
Context("NewServer", func() {
type newServerTableInput struct {
opts Opts
expectedErr error
expectHTTPListener bool
expectTLSListener bool
}
DescribeTable("When creating the new server from the options", func(in *newServerTableInput) {
srv, err := NewServer(in.opts)
if in.expectedErr != nil {
Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error())))
Expect(srv).To(BeNil())
return
}
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
Expect(s.listener != nil).To(Equal(in.expectHTTPListener))
if in.expectHTTPListener {
Expect(s.listener.Close()).To(Succeed())
}
Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener))
if in.expectTLSListener {
Expect(s.tlsListener.Close()).To(Succeed())
}
},
Entry("with a valid http bind address", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
},
expectedErr: nil,
expectHTTPListener: true,
expectTLSListener: false,
}),
Entry("with a valid https bind address, with no TLS config", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
},
expectedErr: errors.New("error setting up TLS listener: no TLS config provided"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: true,
}),
Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: true,
expectTLSListener: true,
}),
Entry("with a \"-\" for the bind address", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "-",
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with a \"-\" for the secure bind address", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "-",
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid bind address scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "invalid://127.0.0.1:0",
},
expectedErr: errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid secure bind address scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "invalid://127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: true,
}),
Entry("with an invalid bind address port", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "127.0.0.1:a",
},
expectedErr: errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid secure bind address port", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:a",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid TLS key", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &options.SecretSource{
Value: []byte("invalid"),
},
Cert: &certDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with an invalid TLS cert", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &options.SecretSource{
Value: []byte("invalid"),
},
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with no TLS key", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Cert: &certDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("with no TLS cert", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
},
},
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"),
expectHTTPListener: false,
expectTLSListener: false,
}),
Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
BindAddress: "http://127.0.0.1:0",
},
expectedErr: nil,
expectHTTPListener: true,
expectTLSListener: false,
}),
Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{
opts: Opts{
Handler: handler,
SecureBindAddress: "https://127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
},
expectedErr: nil,
expectHTTPListener: false,
expectTLSListener: true,
}),
)
})
Context("Start", func() {
var srv Server
var ctx context.Context
var cancel context.CancelFunc
BeforeEach(func() {
ctx, cancel = context.WithCancel(context.Background())
})
AfterEach(func() {
cancel()
})
Context("with an http server", func() {
var listenAddr string
BeforeEach(func() {
var err error
srv, err = NewServer(Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
})
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
})
It("Starts the server and serves the handler", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Stops the server when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
_, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(func() error {
_, err := client.Get(listenAddr)
return err
}).Should(HaveOccurred())
})
})
Context("with an https server", func() {
var secureListenAddr string
BeforeEach(func() {
var err error
srv, err = NewServer(Opts{
Handler: handler,
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
})
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
})
It("Starts the server and serves the handler", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Stops the server when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
_, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(func() error {
_, err := client.Get(secureListenAddr)
return err
}).Should(HaveOccurred())
})
It("Serves the certificate provided", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expect(resp.TLS.VerifiedChains).Should(HaveLen(1))
Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1))
Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData))
})
})
Context("with both an http and an https server", func() {
var listenAddr, secureListenAddr string
BeforeEach(func() {
var err error
srv, err = NewServer(Opts{
Handler: handler,
BindAddress: "127.0.0.1:0",
SecureBindAddress: "127.0.0.1:0",
TLS: &options.TLS{
Key: &keyDataSource,
Cert: &certDataSource,
},
})
Expect(err).ToNot(HaveOccurred())
s, ok := srv.(*server)
Expect(ok).To(BeTrue())
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
})
It("Starts the server and serves the handler on http", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Starts the server and serves the handler on https", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
resp, err := client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(hello))
})
It("Stops both servers when the context is cancelled", func() {
go func() {
defer GinkgoRecover()
Expect(srv.Start(ctx)).To(Succeed())
}()
_, err := client.Get(listenAddr)
Expect(err).ToNot(HaveOccurred())
_, err = client.Get(secureListenAddr)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(func() error {
_, err := client.Get(listenAddr)
return err
}).Should(HaveOccurred())
Eventually(func() error {
_, err := client.Get(secureListenAddr)
return err
}).Should(HaveOccurred())
})
})
})
Context("getNetworkScheme", func() {
DescribeTable("should return the scheme", func(in, expected string) {
Expect(getNetworkScheme(in)).To(Equal(expected))
},
Entry("with no scheme", "127.0.0.1:0", "tcp"),
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"),
Entry("with a http scheme", "http://192.168.0.1:1", "tcp"),
Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"),
Entry("with a random scheme", "random://10.10.10.10:10", "random"),
)
})
Context("getListenAddress", func() {
DescribeTable("should remove the scheme", func(in, expected string) {
Expect(getListenAddress(in)).To(Equal(expected))
},
Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"),
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"),
Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"),
Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"),
Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"),
)
})
})