From 734e313f711bd06067759bcfcfb2ba73c3a4dde5 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 29 Dec 2020 11:46:09 +0200 Subject: [PATCH] refactor Echo server startup to allow data race free access to listener address --- echo.go | 84 +++++++++++- echo_test.go | 354 ++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 383 insertions(+), 55 deletions(-) diff --git a/echo.go b/echo.go index d284ff39..4c4e7d47 100644 --- a/echo.go +++ b/echo.go @@ -67,6 +67,9 @@ type ( // Echo is the top-level framework instance. Echo struct { common + // startupMu is mutex to lock Echo instance access during server configuration and startup. Useful for to get + // listener address info (on which interface/port was listener binded) without having data races. + startupMu sync.RWMutex StdLogger *stdLog.Logger colorer *color.Color premiddleware []MiddlewareFunc @@ -643,21 +646,30 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Start starts an HTTP server. func (e *Echo) Start(address string) error { + e.startupMu.Lock() e.Server.Addr = address - return e.StartServer(e.Server) + if err := e.configureServer(e.Server); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return e.serve() } // StartTLS starts an HTTPS server. // If `certFile` or `keyFile` is `string` the values are treated as file paths. // If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { + e.startupMu.Lock() var cert []byte if cert, err = filepathOrContent(certFile); err != nil { + e.startupMu.Unlock() return } var key []byte if key, err = filepathOrContent(keyFile); err != nil { + e.startupMu.Unlock() return } @@ -665,10 +677,17 @@ func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err erro s.TLSConfig = new(tls.Config) s.TLSConfig.Certificates = make([]tls.Certificate, 1) if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { + e.startupMu.Unlock() return } - return e.startTLS(address) + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return s.Serve(e.TLSListener) } func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { @@ -684,24 +703,41 @@ func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { // StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. func (e *Echo) StartAutoTLS(address string) error { + e.startupMu.Lock() s := e.TLSServer s.TLSConfig = new(tls.Config) s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - return e.startTLS(address) + + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return s.Serve(e.TLSListener) } -func (e *Echo) startTLS(address string) error { +func (e *Echo) configureTLS(address string) { s := e.TLSServer s.Addr = address if !e.DisableHTTP2 { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") } - return e.StartServer(e.TLSServer) } // StartServer starts a custom http server. func (e *Echo) StartServer(s *http.Server) (err error) { + e.startupMu.Lock() + if err := e.configureServer(s); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return e.serve() +} + +func (e *Echo) configureServer(s *http.Server) (err error) { // Setup e.colorer.SetOutput(e.Logger.Output()) s.ErrorLog = e.StdLogger @@ -724,7 +760,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) } - return s.Serve(e.Listener) + return nil } if e.TLSListener == nil { l, err := newListener(s.Addr, e.ListenerNetwork) @@ -736,11 +772,39 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if !e.HidePort { e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) } - return s.Serve(e.TLSListener) + return nil +} + +func (e *Echo) serve() error { + if e.TLSListener != nil { + return e.Server.Serve(e.TLSListener) + } + return e.Server.Serve(e.Listener) +} + +// ListenerAddr returns net.Addr for Listener +func (e *Echo) ListenerAddr() net.Addr { + e.startupMu.RLock() + defer e.startupMu.RUnlock() + if e.Listener == nil { + return nil + } + return e.Listener.Addr() +} + +// TLSListenerAddr returns net.Addr for TLSListener +func (e *Echo) TLSListenerAddr() net.Addr { + e.startupMu.RLock() + defer e.startupMu.RUnlock() + if e.TLSListener == nil { + return nil + } + return e.TLSListener.Addr() } // StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { + e.startupMu.Lock() // Setup s := e.Server s.Addr = address @@ -758,18 +822,22 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { if e.Listener == nil { e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { + e.startupMu.Unlock() return err } } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) } + e.startupMu.Unlock() return s.Serve(e.Listener) } // Close immediately stops the server. // It internally calls `http.Server#Close()`. func (e *Echo) Close() error { + e.startupMu.Lock() + defer e.startupMu.Unlock() if err := e.TLSServer.Close(); err != nil { return err } @@ -779,6 +847,8 @@ func (e *Echo) Close() error { // Shutdown stops the server gracefully. // It internally calls `http.Server#Shutdown()`. func (e *Echo) Shutdown(ctx stdContext.Context) error { + e.startupMu.Lock() + defer e.startupMu.Unlock() if err := e.TLSServer.Shutdown(ctx); err != nil { return err } diff --git a/echo_test.go b/echo_test.go index 29edca10..7f359742 100644 --- a/echo_test.go +++ b/echo_test.go @@ -3,12 +3,14 @@ package echo import ( "bytes" stdContext "context" + "crypto/tls" "errors" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" + "os" "reflect" "strings" "testing" @@ -485,26 +487,125 @@ func TestEchoContext(t *testing.T) { e.ReleaseContext(c) } -func TestEchoStart(t *testing.T) { - e := New() - go func() { - assert.NoError(t, e.Start(":0")) - }() - time.Sleep(200 * time.Millisecond) +func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + var addr net.Addr + if isTLS { + addr = e.TLSListenerAddr() + } else { + addr = e.ListenerAddr() + } + if addr != nil && strings.Contains(addr.String(), ":") { + return nil // was started + } + case err := <-errChan: + if err == http.ErrServerClosed { + return nil + } + return err + } + } } -func TestEchoStartTLS(t *testing.T) { +func TestEchoStart(t *testing.T) { e := New() + errChan := make(chan error) + go func() { - err := e.StartTLS(":0", "_fixture/certs/cert.pem", "_fixture/certs/key.pem") - // Prevent the test to fail after closing the servers - if err != http.ErrServerClosed { - assert.NoError(t, err) + err := e.Start(":0") + if err != nil { + errChan <- err } }() - time.Sleep(200 * time.Millisecond) - e.Close() + err := waitForServerStart(e, errChan, false) + assert.NoError(t, err) + + assert.NoError(t, e.Close()) +} + +func TestEcho_StartTLS(t *testing.T) { + var testCases = []struct { + name string + addr string + certFile string + keyFile string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid certFile", + addr: ":0", + certFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, invalid keyFile", + addr: ":0", + keyFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, failed to create cert out of certFile and keyFile", + addr: ":0", + keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key + expectError: "tls: found a certificate rather than a key in the PEM for the private key", + }, + { + name: "nok, invalid tls address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + errChan := make(chan error) + + go func() { + certFile := "_fixture/certs/cert.pem" + if tc.certFile != "" { + certFile = tc.certFile + } + keyFile := "_fixture/certs/key.pem" + if tc.keyFile != "" { + keyFile = tc.keyFile + } + + err := e.StartTLS(tc.addr, certFile, keyFile) + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, true) + if tc.expectError != "" { + if _, ok := err.(*os.PathError); ok { + assert.Error(t, err) // error messages for unix and windows are different. so test only error type here + } else { + assert.EqualError(t, err, tc.expectError) + } + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } } func TestEchoStartTLSByteString(t *testing.T) { @@ -557,47 +658,103 @@ func TestEchoStartTLSByteString(t *testing.T) { e := New() e.HideBanner = true - go func() { - err := e.StartTLS(":0", test.cert, test.key) - if test.expectedErr != nil { - require.EqualError(t, err, test.expectedErr.Error()) - } else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers - require.NoError(t, err) - } - }() - time.Sleep(200 * time.Millisecond) + errChan := make(chan error, 0) - require.NoError(t, e.Close()) + go func() { + errChan <- e.StartTLS(":0", test.cert, test.key) + }() + + err := waitForServerStart(e, errChan, true) + if test.expectedErr != nil { + assert.EqualError(t, err, test.expectedErr.Error()) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) }) } } -func TestEchoStartAutoTLS(t *testing.T) { - e := New() - errChan := make(chan error, 0) +func TestEcho_StartAutoTLS(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } - go func() { - errChan <- e.StartAutoTLS(":0") - }() - time.Sleep(200 * time.Millisecond) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + errChan := make(chan error, 0) - select { - case err := <-errChan: - assert.NoError(t, err) - default: - assert.NoError(t, e.Close()) + go func() { + errChan <- e.StartAutoTLS(tc.addr) + }() + + err := waitForServerStart(e, errChan, true) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) } } -func TestEchoStartH2CServer(t *testing.T) { - e := New() - e.Debug = true - h2s := &http2.Server{} +func TestEcho_StartH2CServer(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } - go func() { - assert.NoError(t, e.StartH2CServer(":0", h2s)) - }() - time.Sleep(200 * time.Millisecond) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + h2s := &http2.Server{} + + errChan := make(chan error) + go func() { + err := e.StartH2CServer(tc.addr, h2s) + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, false) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } } func testMethod(t *testing.T, method, path string, e *Echo) { @@ -686,7 +843,8 @@ func TestEchoClose(t *testing.T) { errCh <- e.Start(":0") }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if err := e.Close(); err != nil { t.Fatal(err) @@ -694,7 +852,7 @@ func TestEchoClose(t *testing.T) { assert.NoError(t, e.Close()) - err := <-errCh + err = <-errCh assert.Equal(t, err.Error(), "http: Server closed") } @@ -706,7 +864,8 @@ func TestEchoShutdown(t *testing.T) { errCh <- e.Start(":0") }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if err := e.Close(); err != nil { t.Fatal(err) @@ -716,7 +875,7 @@ func TestEchoShutdown(t *testing.T) { defer cancel() assert.NoError(t, e.Shutdown(ctx)) - err := <-errCh + err = <-errCh assert.Equal(t, err.Error(), "http: Server closed") } @@ -764,7 +923,8 @@ func TestEchoListenerNetwork(t *testing.T) { errCh <- e.Start(tt.address) }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { defer resp.Body.Close() @@ -823,3 +983,101 @@ func TestEchoReverse(t *testing.T) { assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) } + +func TestEcho_ListenerAddr(t *testing.T) { + e := New() + + addr := e.ListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.Start(":0") + }() + + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) +} + +func TestEcho_TLSListenerAddr(t *testing.T) { + cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := ioutil.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + + e := New() + + addr := e.TLSListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.StartTLS(":0", cert, key) + }() + + err = waitForServerStart(e, errCh, true) + assert.NoError(t, err) +} + +func TestEcho_StartServer(t *testing.T) { + cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := ioutil.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + certs, err := tls.X509KeyPair(cert, key) + require.NoError(t, err) + + var testCases = []struct { + name string + addr string + TLSConfig *tls.Config + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "ok, start with TLS", + addr: ":0", + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + { + name: "nok, invalid tls address", + addr: "nope", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + + server := new(http.Server) + server.Addr = tc.addr + if tc.TLSConfig != nil { + server.TLSConfig = tc.TLSConfig + } + + errCh := make(chan error) + go func() { + errCh <- e.StartServer(server) + }() + + err := waitForServerStart(e, errCh, tc.TLSConfig != nil) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.NoError(t, e.Close()) + }) + } +}