You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-07-15 01:44:22 +02:00
Add new http server implementation
This commit is contained in:
1
go.mod
1
go.mod
@ -30,6 +30,7 @@ require (
|
|||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
|
||||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381
|
golang.org/x/net v0.0.0-20200707034311-ab3426394381
|
||||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
||||||
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a
|
||||||
google.golang.org/api v0.20.0
|
google.golang.org/api v0.20.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||||
gopkg.in/square/go-jose.v2 v2.4.1
|
gopkg.in/square/go-jose.v2 v2.4.1
|
||||||
|
3
go.sum
3
go.sum
@ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
|
||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs=
|
||||||
|
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
88
pkg/http/http_suite_test.go
Normal file
88
pkg/http/http_suite_test.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var certData []byte
|
||||||
|
var certDataSource, keyDataSource options.SecretSource
|
||||||
|
var client *http.Client
|
||||||
|
|
||||||
|
func TestHTTPSuite(t *testing.T) {
|
||||||
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
logger.SetErrOutput(GinkgoWriter)
|
||||||
|
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "HTTP")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = BeforeSuite(func() {
|
||||||
|
By("Generating a self-signed cert for TLS tests", func() {
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
keyOut := bytes.NewBuffer(nil)
|
||||||
|
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed())
|
||||||
|
keyDataSource.Value = keyOut.Bytes()
|
||||||
|
|
||||||
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"OAuth2 Proxy Test Suite"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
|
||||||
|
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
certData = certBytes
|
||||||
|
|
||||||
|
certOut := bytes.NewBuffer(nil)
|
||||||
|
Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed())
|
||||||
|
certDataSource.Value = certOut.Bytes()
|
||||||
|
})
|
||||||
|
|
||||||
|
By("Setting up a http client", func() {
|
||||||
|
cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
certificate, err := x509.ParseCertificate(cert.Certificate[0])
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
certpool := x509.NewCertPool()
|
||||||
|
certpool.AddCert(certificate)
|
||||||
|
|
||||||
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
transport.TLSClientConfig.RootCAs = certpool
|
||||||
|
|
||||||
|
client = &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
245
pkg/http/server.go
Normal file
245
pkg/http/server.go
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server represents an HTTP or HTTPS server.
|
||||||
|
type Server interface {
|
||||||
|
// Start blocks and runs the server.
|
||||||
|
Start(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Opts contains the information required to set up the server.
|
||||||
|
type Opts struct {
|
||||||
|
// Handler is the http.Handler to be used to serve http pages by the server.
|
||||||
|
Handler http.Handler
|
||||||
|
|
||||||
|
// BindAddress is the address the HTTP server should listen on.
|
||||||
|
BindAddress string
|
||||||
|
|
||||||
|
// SecureBindAddress is the address the HTTPS server should listen on.
|
||||||
|
SecureBindAddress string
|
||||||
|
|
||||||
|
// TLS is the TLS configuration for the server.
|
||||||
|
TLS *options.TLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new Server from the options given.
|
||||||
|
func NewServer(opts Opts) (Server, error) {
|
||||||
|
s := &server{
|
||||||
|
handler: opts.Handler,
|
||||||
|
}
|
||||||
|
if err := s.setupListener(opts); err != nil {
|
||||||
|
return nil, fmt.Errorf("error setting up listener: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.setupTLSListener(opts); err != nil {
|
||||||
|
return nil, fmt.Errorf("error setting up TLS listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// server is an implementation of the Server interface.
|
||||||
|
type server struct {
|
||||||
|
handler http.Handler
|
||||||
|
|
||||||
|
listener net.Listener
|
||||||
|
tlsListener net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupListener sets the server listener if the HTTP server is enabled.
|
||||||
|
// The HTTP server can be disabled by setting the BindAddress to "-" or by
|
||||||
|
// leaving it empty.
|
||||||
|
func (s *server) setupListener(opts Opts) error {
|
||||||
|
if opts.BindAddress == "" || opts.BindAddress == "-" {
|
||||||
|
// No HTTP listener required
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
networkType := getNetworkScheme(opts.BindAddress)
|
||||||
|
listenAddr := getListenAddress(opts.BindAddress)
|
||||||
|
|
||||||
|
listener, err := net.Listen(networkType, listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err)
|
||||||
|
}
|
||||||
|
s.listener = listener
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
|
||||||
|
// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
|
||||||
|
// leaving it empty.
|
||||||
|
func (s *server) setupTLSListener(opts Opts) error {
|
||||||
|
if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" {
|
||||||
|
// No HTTPS listener required
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
MaxVersion: tls.VersionTLS13,
|
||||||
|
NextProtos: []string{"http/1.1"},
|
||||||
|
}
|
||||||
|
if opts.TLS == nil {
|
||||||
|
return errors.New("no TLS config provided")
|
||||||
|
}
|
||||||
|
cert, err := getCertificate(opts.TLS)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not load certificate: %v", err)
|
||||||
|
}
|
||||||
|
config.Certificates = []tls.Certificate{cert}
|
||||||
|
|
||||||
|
listenAddr := getListenAddress(opts.SecureBindAddress)
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the HTTP and HTTPS server if applicable.
|
||||||
|
// It will block until the context is cancelled.
|
||||||
|
// If any errors occur, only the first error will be returned.
|
||||||
|
func (s *server) Start(ctx context.Context) error {
|
||||||
|
g, groupCtx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
if s.listener != nil {
|
||||||
|
g.Go(func() error {
|
||||||
|
if err := s.startServer(groupCtx, s.listener); err != nil {
|
||||||
|
return fmt.Errorf("error starting insecure server: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.tlsListener != nil {
|
||||||
|
g.Go(func() error {
|
||||||
|
if err := s.startServer(groupCtx, s.tlsListener); err != nil {
|
||||||
|
return fmt.Errorf("error starting secure server: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return g.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// startServer creates and starts a new server with the given listener.
|
||||||
|
// When the given context is cancelled the server will be shutdown.
|
||||||
|
// If any errors occur, only the first error will be returned.
|
||||||
|
func (s *server) startServer(ctx context.Context, listener net.Listener) error {
|
||||||
|
srv := &http.Server{Handler: s.handler}
|
||||||
|
g, groupCtx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
g.Go(func() error {
|
||||||
|
<-groupCtx.Done()
|
||||||
|
|
||||||
|
if err := srv.Shutdown(context.Background()); err != nil {
|
||||||
|
return fmt.Errorf("error shutting down server: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
g.Go(func() error {
|
||||||
|
if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return fmt.Errorf("could not start server: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return g.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNetworkScheme gets the scheme for the HTTP server.
|
||||||
|
func getNetworkScheme(addr string) string {
|
||||||
|
var scheme string
|
||||||
|
i := strings.Index(addr, "://")
|
||||||
|
if i > -1 {
|
||||||
|
scheme = addr[0:i]
|
||||||
|
}
|
||||||
|
|
||||||
|
switch scheme {
|
||||||
|
case "", "http":
|
||||||
|
return "tcp"
|
||||||
|
default:
|
||||||
|
return scheme
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getListenAddress gets the address for the HTTP server.
|
||||||
|
func getListenAddress(addr string) string {
|
||||||
|
slice := strings.SplitN(addr, "//", 2)
|
||||||
|
return slice[len(slice)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCertificate loads the certificate data from the TLS config.
|
||||||
|
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
|
||||||
|
keyData, err := getSecretValue(opts.Key)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certData, err := getSecretValue(opts.Cert)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair(certData, keyData)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSecretValue wraps util.GetSecretValue so that we can return an error if no
|
||||||
|
// source is provided.
|
||||||
|
func getSecretValue(src *options.SecretSource) ([]byte, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, errors.New("no configuration provided")
|
||||||
|
}
|
||||||
|
return util.GetSecretValue(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||||
|
// connections. It's used by so that dead TCP connections (e.g. closing laptop
|
||||||
|
// mid-download) eventually go away.
|
||||||
|
type tcpKeepAliveListener struct {
|
||||||
|
*net.TCPListener
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept implements the TCPListener interface.
|
||||||
|
// It sets the keep alive period to 3 minutes for each connection.
|
||||||
|
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||||
|
tc, err := ln.AcceptTCP()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = tc.SetKeepAlive(true)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error setting Keep-Alive: %v", err)
|
||||||
|
}
|
||||||
|
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("Error setting Keep-Alive period: %v", err)
|
||||||
|
}
|
||||||
|
return tc, nil
|
||||||
|
}
|
472
pkg/http/server_test.go
Normal file
472
pkg/http/server_test.go
Normal file
@ -0,0 +1,472 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
const hello = "Hello World!"
|
||||||
|
|
||||||
|
var _ = Describe("Server", func() {
|
||||||
|
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Write([]byte(hello))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("NewServer", func() {
|
||||||
|
type newServerTableInput struct {
|
||||||
|
opts Opts
|
||||||
|
expectedErr error
|
||||||
|
expectHTTPListener bool
|
||||||
|
expectTLSListener bool
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("When creating the new server from the options", func(in *newServerTableInput) {
|
||||||
|
srv, err := NewServer(in.opts)
|
||||||
|
if in.expectedErr != nil {
|
||||||
|
Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error())))
|
||||||
|
Expect(srv).To(BeNil())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
s, ok := srv.(*server)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
|
||||||
|
Expect(s.listener != nil).To(Equal(in.expectHTTPListener))
|
||||||
|
if in.expectHTTPListener {
|
||||||
|
Expect(s.listener.Close()).To(Succeed())
|
||||||
|
}
|
||||||
|
Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener))
|
||||||
|
if in.expectTLSListener {
|
||||||
|
Expect(s.tlsListener.Close()).To(Succeed())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Entry("with a valid http bind address", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: true,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with a valid https bind address, with no TLS config", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up TLS listener: no TLS config provided"),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: true,
|
||||||
|
}),
|
||||||
|
Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "127.0.0.1:0",
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: true,
|
||||||
|
expectTLSListener: true,
|
||||||
|
}),
|
||||||
|
Entry("with a \"-\" for the bind address", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "-",
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with a \"-\" for the secure bind address", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "-",
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with an invalid bind address scheme", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "invalid://127.0.0.1:0",
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with an invalid secure bind address scheme", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "invalid://127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: true,
|
||||||
|
}),
|
||||||
|
Entry("with an invalid bind address port", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "127.0.0.1:a",
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with an invalid secure bind address port", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:a",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with an invalid TLS key", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &options.SecretSource{
|
||||||
|
Value: []byte("invalid"),
|
||||||
|
},
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with an invalid TLS cert", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &options.SecretSource{
|
||||||
|
Value: []byte("invalid"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with no TLS key", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("with no TLS cert", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"),
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "http://127.0.0.1:0",
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: true,
|
||||||
|
expectTLSListener: false,
|
||||||
|
}),
|
||||||
|
Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{
|
||||||
|
opts: Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "https://127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectHTTPListener: false,
|
||||||
|
expectTLSListener: true,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("Start", func() {
|
||||||
|
var srv Server
|
||||||
|
var ctx context.Context
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with an http server", func() {
|
||||||
|
var listenAddr string
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
srv, err = NewServer(Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "127.0.0.1:0",
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
s, ok := srv.(*server)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
|
||||||
|
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Starts the server and serves the handler", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
resp, err := client.Get(listenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(string(body)).To(Equal(hello))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Stops the server when the context is cancelled", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := client.Get(listenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
Eventually(func() error {
|
||||||
|
_, err := client.Get(listenAddr)
|
||||||
|
return err
|
||||||
|
}).Should(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with an https server", func() {
|
||||||
|
var secureListenAddr string
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
srv, err = NewServer(Opts{
|
||||||
|
Handler: handler,
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
s, ok := srv.(*server)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
|
||||||
|
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Starts the server and serves the handler", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
resp, err := client.Get(secureListenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(string(body)).To(Equal(hello))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Stops the server when the context is cancelled", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := client.Get(secureListenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
Eventually(func() error {
|
||||||
|
_, err := client.Get(secureListenAddr)
|
||||||
|
return err
|
||||||
|
}).Should(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Serves the certificate provided", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
resp, err := client.Get(secureListenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
Expect(resp.TLS.VerifiedChains).Should(HaveLen(1))
|
||||||
|
Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1))
|
||||||
|
Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with both an http and an https server", func() {
|
||||||
|
var listenAddr, secureListenAddr string
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
srv, err = NewServer(Opts{
|
||||||
|
Handler: handler,
|
||||||
|
BindAddress: "127.0.0.1:0",
|
||||||
|
SecureBindAddress: "127.0.0.1:0",
|
||||||
|
TLS: &options.TLS{
|
||||||
|
Key: &keyDataSource,
|
||||||
|
Cert: &certDataSource,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
s, ok := srv.(*server)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
|
||||||
|
listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String())
|
||||||
|
secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Starts the server and serves the handler on http", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
resp, err := client.Get(listenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(string(body)).To(Equal(hello))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Starts the server and serves the handler on https", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
resp, err := client.Get(secureListenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(string(body)).To(Equal(hello))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("Stops both servers when the context is cancelled", func() {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(srv.Start(ctx)).To(Succeed())
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := client.Get(listenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = client.Get(secureListenAddr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
Eventually(func() error {
|
||||||
|
_, err := client.Get(listenAddr)
|
||||||
|
return err
|
||||||
|
}).Should(HaveOccurred())
|
||||||
|
Eventually(func() error {
|
||||||
|
_, err := client.Get(secureListenAddr)
|
||||||
|
return err
|
||||||
|
}).Should(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("getNetworkScheme", func() {
|
||||||
|
DescribeTable("should return the scheme", func(in, expected string) {
|
||||||
|
Expect(getNetworkScheme(in)).To(Equal(expected))
|
||||||
|
},
|
||||||
|
Entry("with no scheme", "127.0.0.1:0", "tcp"),
|
||||||
|
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"),
|
||||||
|
Entry("with a http scheme", "http://192.168.0.1:1", "tcp"),
|
||||||
|
Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"),
|
||||||
|
Entry("with a random scheme", "random://10.10.10.10:10", "random"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("getListenAddress", func() {
|
||||||
|
DescribeTable("should remove the scheme", func(in, expected string) {
|
||||||
|
Expect(getListenAddress(in)).To(Equal(expected))
|
||||||
|
},
|
||||||
|
Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"),
|
||||||
|
Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"),
|
||||||
|
Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"),
|
||||||
|
Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"),
|
||||||
|
Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
Reference in New Issue
Block a user