diff --git a/go.mod b/go.mod index 3b56cac8..b6ce05cd 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20200707034311-ab3426394381 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d + golang.org/x/sync v0.0.0-20201207232520-09787c993a3a google.golang.org/api v0.20.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/square/go-jose.v2 v2.4.1 diff --git a/go.sum b/go.sum index 0c94e0ce..0b129b72 100644 --- a/go.sum +++ b/go.sum @@ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/pkg/http/http_suite_test.go b/pkg/http/http_suite_test.go new file mode 100644 index 00000000..13bd56e9 --- /dev/null +++ b/pkg/http/http_suite_test.go @@ -0,0 +1,88 @@ +package http + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/http" + "testing" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var certData []byte +var certDataSource, keyDataSource options.SecretSource +var client *http.Client + +func TestHTTPSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + logger.SetErrOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "HTTP") +} + +var _ = BeforeSuite(func() { + By("Generating a self-signed cert for TLS tests", func() { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + + keyOut := bytes.NewBuffer(nil) + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + Expect(err).ToNot(HaveOccurred()) + Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed()) + keyDataSource.Value = keyOut.Bytes() + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + Expect(err).ToNot(HaveOccurred()) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"OAuth2 Proxy Test Suite"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + Expect(err).ToNot(HaveOccurred()) + certData = certBytes + + certOut := bytes.NewBuffer(nil) + Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed()) + certDataSource.Value = certOut.Bytes() + }) + + By("Setting up a http client", func() { + cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value) + Expect(err).ToNot(HaveOccurred()) + + certificate, err := x509.ParseCertificate(cert.Certificate[0]) + Expect(err).ToNot(HaveOccurred()) + + certpool := x509.NewCertPool() + certpool.AddCert(certificate) + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig.RootCAs = certpool + + client = &http.Client{ + Transport: transport, + } + }) +}) diff --git a/pkg/http/server.go b/pkg/http/server.go new file mode 100644 index 00000000..e9a1d248 --- /dev/null +++ b/pkg/http/server.go @@ -0,0 +1,245 @@ +package http + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "golang.org/x/sync/errgroup" +) + +// Server represents an HTTP or HTTPS server. +type Server interface { + // Start blocks and runs the server. + Start(ctx context.Context) error +} + +// Opts contains the information required to set up the server. +type Opts struct { + // Handler is the http.Handler to be used to serve http pages by the server. + Handler http.Handler + + // BindAddress is the address the HTTP server should listen on. + BindAddress string + + // SecureBindAddress is the address the HTTPS server should listen on. + SecureBindAddress string + + // TLS is the TLS configuration for the server. + TLS *options.TLS +} + +// NewServer creates a new Server from the options given. +func NewServer(opts Opts) (Server, error) { + s := &server{ + handler: opts.Handler, + } + if err := s.setupListener(opts); err != nil { + return nil, fmt.Errorf("error setting up listener: %v", err) + } + if err := s.setupTLSListener(opts); err != nil { + return nil, fmt.Errorf("error setting up TLS listener: %v", err) + } + + return s, nil +} + +// server is an implementation of the Server interface. +type server struct { + handler http.Handler + + listener net.Listener + tlsListener net.Listener +} + +// setupListener sets the server listener if the HTTP server is enabled. +// The HTTP server can be disabled by setting the BindAddress to "-" or by +// leaving it empty. +func (s *server) setupListener(opts Opts) error { + if opts.BindAddress == "" || opts.BindAddress == "-" { + // No HTTP listener required + return nil + } + + networkType := getNetworkScheme(opts.BindAddress) + listenAddr := getListenAddress(opts.BindAddress) + + listener, err := net.Listen(networkType, listenAddr) + if err != nil { + return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err) + } + s.listener = listener + + return nil +} + +// setupTLSListener sets the server TLS listener if the HTTPS server is enabled. +// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by +// leaving it empty. +func (s *server) setupTLSListener(opts Opts) error { + if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" { + // No HTTPS listener required + return nil + } + + config := &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + NextProtos: []string{"http/1.1"}, + } + if opts.TLS == nil { + return errors.New("no TLS config provided") + } + cert, err := getCertificate(opts.TLS) + if err != nil { + return fmt.Errorf("could not load certificate: %v", err) + } + config.Certificates = []tls.Certificate{cert} + + listenAddr := getListenAddress(opts.SecureBindAddress) + + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return fmt.Errorf("listen (%s) failed: %v", listenAddr, err) + } + + s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config) + return nil +} + +// Start starts the HTTP and HTTPS server if applicable. +// It will block until the context is cancelled. +// If any errors occur, only the first error will be returned. +func (s *server) Start(ctx context.Context) error { + g, groupCtx := errgroup.WithContext(ctx) + + if s.listener != nil { + g.Go(func() error { + if err := s.startServer(groupCtx, s.listener); err != nil { + return fmt.Errorf("error starting insecure server: %v", err) + } + return nil + }) + } + + if s.tlsListener != nil { + g.Go(func() error { + if err := s.startServer(groupCtx, s.tlsListener); err != nil { + return fmt.Errorf("error starting secure server: %v", err) + } + return nil + }) + } + + return g.Wait() +} + +// startServer creates and starts a new server with the given listener. +// When the given context is cancelled the server will be shutdown. +// If any errors occur, only the first error will be returned. +func (s *server) startServer(ctx context.Context, listener net.Listener) error { + srv := &http.Server{Handler: s.handler} + g, groupCtx := errgroup.WithContext(ctx) + + g.Go(func() error { + <-groupCtx.Done() + + if err := srv.Shutdown(context.Background()); err != nil { + return fmt.Errorf("error shutting down server: %v", err) + } + return nil + }) + + g.Go(func() error { + if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("could not start server: %v", err) + } + return nil + }) + + return g.Wait() +} + +// getNetworkScheme gets the scheme for the HTTP server. +func getNetworkScheme(addr string) string { + var scheme string + i := strings.Index(addr, "://") + if i > -1 { + scheme = addr[0:i] + } + + switch scheme { + case "", "http": + return "tcp" + default: + return scheme + } +} + +// getListenAddress gets the address for the HTTP server. +func getListenAddress(addr string) string { + slice := strings.SplitN(addr, "//", 2) + return slice[len(slice)-1] +} + +// getCertificate loads the certificate data from the TLS config. +func getCertificate(opts *options.TLS) (tls.Certificate, error) { + keyData, err := getSecretValue(opts.Key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err) + } + + certData, err := getSecretValue(opts.Cert) + if err != nil { + return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err) + } + + cert, err := tls.X509KeyPair(certData, keyData) + if err != nil { + return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err) + } + + return cert, nil +} + +// getSecretValue wraps util.GetSecretValue so that we can return an error if no +// source is provided. +func getSecretValue(src *options.SecretSource) ([]byte, error) { + if src == nil { + return nil, errors.New("no configuration provided") + } + return util.GetSecretValue(src) +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by so that dead TCP connections (e.g. closing laptop +// mid-download) eventually go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +// Accept implements the TCPListener interface. +// It sets the keep alive period to 3 minutes for each connection. +func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { + tc, err := ln.AcceptTCP() + if err != nil { + return nil, err + } + err = tc.SetKeepAlive(true) + if err != nil { + logger.Errorf("Error setting Keep-Alive: %v", err) + } + err = tc.SetKeepAlivePeriod(3 * time.Minute) + if err != nil { + logger.Printf("Error setting Keep-Alive period: %v", err) + } + return tc, nil +} diff --git a/pkg/http/server_test.go b/pkg/http/server_test.go new file mode 100644 index 00000000..a2b995a2 --- /dev/null +++ b/pkg/http/server_test.go @@ -0,0 +1,472 @@ +package http + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +const hello = "Hello World!" + +var _ = Describe("Server", func() { + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Write([]byte(hello)) + }) + + Context("NewServer", func() { + type newServerTableInput struct { + opts Opts + expectedErr error + expectHTTPListener bool + expectTLSListener bool + } + + DescribeTable("When creating the new server from the options", func(in *newServerTableInput) { + srv, err := NewServer(in.opts) + if in.expectedErr != nil { + Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error()))) + Expect(srv).To(BeNil()) + return + } + + Expect(err).ToNot(HaveOccurred()) + + s, ok := srv.(*server) + Expect(ok).To(BeTrue()) + + Expect(s.listener != nil).To(Equal(in.expectHTTPListener)) + if in.expectHTTPListener { + Expect(s.listener.Close()).To(Succeed()) + } + Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener)) + if in.expectTLSListener { + Expect(s.tlsListener.Close()).To(Succeed()) + } + }, + Entry("with a valid http bind address", &newServerTableInput{ + opts: Opts{ + Handler: handler, + BindAddress: "127.0.0.1:0", + }, + expectedErr: nil, + expectHTTPListener: true, + expectTLSListener: false, + }), + Entry("with a valid https bind address, with no TLS config", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:0", + }, + expectedErr: errors.New("error setting up TLS listener: no TLS config provided"), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &certDataSource, + }, + }, + expectedErr: nil, + expectHTTPListener: false, + expectTLSListener: true, + }), + Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{ + opts: Opts{ + Handler: handler, + BindAddress: "127.0.0.1:0", + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &certDataSource, + }, + }, + expectedErr: nil, + expectHTTPListener: true, + expectTLSListener: true, + }), + Entry("with a \"-\" for the bind address", &newServerTableInput{ + opts: Opts{ + Handler: handler, + BindAddress: "-", + }, + expectedErr: nil, + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with a \"-\" for the secure bind address", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "-", + }, + expectedErr: nil, + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with an invalid bind address scheme", &newServerTableInput{ + opts: Opts{ + Handler: handler, + BindAddress: "invalid://127.0.0.1:0", + }, + expectedErr: errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with an invalid secure bind address scheme", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "invalid://127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &certDataSource, + }, + }, + expectedErr: nil, + expectHTTPListener: false, + expectTLSListener: true, + }), + Entry("with an invalid bind address port", &newServerTableInput{ + opts: Opts{ + Handler: handler, + BindAddress: "127.0.0.1:a", + }, + expectedErr: errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with an invalid secure bind address port", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:a", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &certDataSource, + }, + }, + expectedErr: errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with an invalid TLS key", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Key: &options.SecretSource{ + Value: []byte("invalid"), + }, + Cert: &certDataSource, + }, + }, + expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with an invalid TLS cert", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &options.SecretSource{ + Value: []byte("invalid"), + }, + }, + }, + expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with no TLS key", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Cert: &certDataSource, + }, + }, + expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("with no TLS cert", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + }, + }, + expectedErr: errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"), + expectHTTPListener: false, + expectTLSListener: false, + }), + Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{ + opts: Opts{ + Handler: handler, + BindAddress: "http://127.0.0.1:0", + }, + expectedErr: nil, + expectHTTPListener: true, + expectTLSListener: false, + }), + Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{ + opts: Opts{ + Handler: handler, + SecureBindAddress: "https://127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &certDataSource, + }, + }, + expectedErr: nil, + expectHTTPListener: false, + expectTLSListener: true, + }), + ) + }) + + Context("Start", func() { + var srv Server + var ctx context.Context + var cancel context.CancelFunc + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + }) + + AfterEach(func() { + cancel() + }) + + Context("with an http server", func() { + var listenAddr string + + BeforeEach(func() { + var err error + srv, err = NewServer(Opts{ + Handler: handler, + BindAddress: "127.0.0.1:0", + }) + Expect(err).ToNot(HaveOccurred()) + + s, ok := srv.(*server) + Expect(ok).To(BeTrue()) + + listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String()) + }) + + It("Starts the server and serves the handler", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + resp, err := client.Get(listenAddr) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + body, err := ioutil.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(hello)) + }) + + It("Stops the server when the context is cancelled", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + _, err := client.Get(listenAddr) + Expect(err).ToNot(HaveOccurred()) + + cancel() + + Eventually(func() error { + _, err := client.Get(listenAddr) + return err + }).Should(HaveOccurred()) + }) + }) + + Context("with an https server", func() { + var secureListenAddr string + + BeforeEach(func() { + var err error + srv, err = NewServer(Opts{ + Handler: handler, + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &certDataSource, + }, + }) + Expect(err).ToNot(HaveOccurred()) + + s, ok := srv.(*server) + Expect(ok).To(BeTrue()) + + secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String()) + }) + + It("Starts the server and serves the handler", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + resp, err := client.Get(secureListenAddr) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + body, err := ioutil.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(hello)) + }) + + It("Stops the server when the context is cancelled", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + _, err := client.Get(secureListenAddr) + Expect(err).ToNot(HaveOccurred()) + + cancel() + + Eventually(func() error { + _, err := client.Get(secureListenAddr) + return err + }).Should(HaveOccurred()) + }) + + It("Serves the certificate provided", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + resp, err := client.Get(secureListenAddr) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + Expect(resp.TLS.VerifiedChains).Should(HaveLen(1)) + Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1)) + Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData)) + }) + }) + + Context("with both an http and an https server", func() { + var listenAddr, secureListenAddr string + + BeforeEach(func() { + var err error + srv, err = NewServer(Opts{ + Handler: handler, + BindAddress: "127.0.0.1:0", + SecureBindAddress: "127.0.0.1:0", + TLS: &options.TLS{ + Key: &keyDataSource, + Cert: &certDataSource, + }, + }) + Expect(err).ToNot(HaveOccurred()) + + s, ok := srv.(*server) + Expect(ok).To(BeTrue()) + + listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String()) + secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String()) + }) + + It("Starts the server and serves the handler on http", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + resp, err := client.Get(listenAddr) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + body, err := ioutil.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(hello)) + }) + + It("Starts the server and serves the handler on https", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + resp, err := client.Get(secureListenAddr) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + body, err := ioutil.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(hello)) + }) + + It("Stops both servers when the context is cancelled", func() { + go func() { + defer GinkgoRecover() + Expect(srv.Start(ctx)).To(Succeed()) + }() + + _, err := client.Get(listenAddr) + Expect(err).ToNot(HaveOccurred()) + _, err = client.Get(secureListenAddr) + Expect(err).ToNot(HaveOccurred()) + + cancel() + + Eventually(func() error { + _, err := client.Get(listenAddr) + return err + }).Should(HaveOccurred()) + Eventually(func() error { + _, err := client.Get(secureListenAddr) + return err + }).Should(HaveOccurred()) + }) + }) + }) + + Context("getNetworkScheme", func() { + DescribeTable("should return the scheme", func(in, expected string) { + Expect(getNetworkScheme(in)).To(Equal(expected)) + }, + Entry("with no scheme", "127.0.0.1:0", "tcp"), + Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"), + Entry("with a http scheme", "http://192.168.0.1:1", "tcp"), + Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"), + Entry("with a random scheme", "random://10.10.10.10:10", "random"), + ) + }) + + Context("getListenAddress", func() { + DescribeTable("should remove the scheme", func(in, expected string) { + Expect(getListenAddress(in)).To(Equal(expected)) + }, + Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"), + Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"), + Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"), + Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"), + Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"), + ) + }) +})