mirror of
https://github.com/labstack/echo.git
synced 2025-01-24 03:16:14 +02:00
refactor Echo server startup to allow data race free access to listener address
This commit is contained in:
parent
b065180250
commit
734e313f71
84
echo.go
84
echo.go
@ -67,6 +67,9 @@ type (
|
||||
// Echo is the top-level framework instance.
|
||||
Echo struct {
|
||||
common
|
||||
// startupMu is mutex to lock Echo instance access during server configuration and startup. Useful for to get
|
||||
// listener address info (on which interface/port was listener binded) without having data races.
|
||||
startupMu sync.RWMutex
|
||||
StdLogger *stdLog.Logger
|
||||
colorer *color.Color
|
||||
premiddleware []MiddlewareFunc
|
||||
@ -643,21 +646,30 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Start starts an HTTP server.
|
||||
func (e *Echo) Start(address string) error {
|
||||
e.startupMu.Lock()
|
||||
e.Server.Addr = address
|
||||
return e.StartServer(e.Server)
|
||||
if err := e.configureServer(e.Server); err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return err
|
||||
}
|
||||
e.startupMu.Unlock()
|
||||
return e.serve()
|
||||
}
|
||||
|
||||
// StartTLS starts an HTTPS server.
|
||||
// If `certFile` or `keyFile` is `string` the values are treated as file paths.
|
||||
// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
|
||||
func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) {
|
||||
e.startupMu.Lock()
|
||||
var cert []byte
|
||||
if cert, err = filepathOrContent(certFile); err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
var key []byte
|
||||
if key, err = filepathOrContent(keyFile); err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
@ -665,10 +677,17 @@ func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err erro
|
||||
s.TLSConfig = new(tls.Config)
|
||||
s.TLSConfig.Certificates = make([]tls.Certificate, 1)
|
||||
if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
return e.startTLS(address)
|
||||
e.configureTLS(address)
|
||||
if err := e.configureServer(s); err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return err
|
||||
}
|
||||
e.startupMu.Unlock()
|
||||
return s.Serve(e.TLSListener)
|
||||
}
|
||||
|
||||
func filepathOrContent(fileOrContent interface{}) (content []byte, err error) {
|
||||
@ -684,24 +703,41 @@ func filepathOrContent(fileOrContent interface{}) (content []byte, err error) {
|
||||
|
||||
// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org.
|
||||
func (e *Echo) StartAutoTLS(address string) error {
|
||||
e.startupMu.Lock()
|
||||
s := e.TLSServer
|
||||
s.TLSConfig = new(tls.Config)
|
||||
s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate
|
||||
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto)
|
||||
return e.startTLS(address)
|
||||
|
||||
e.configureTLS(address)
|
||||
if err := e.configureServer(s); err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return err
|
||||
}
|
||||
e.startupMu.Unlock()
|
||||
return s.Serve(e.TLSListener)
|
||||
}
|
||||
|
||||
func (e *Echo) startTLS(address string) error {
|
||||
func (e *Echo) configureTLS(address string) {
|
||||
s := e.TLSServer
|
||||
s.Addr = address
|
||||
if !e.DisableHTTP2 {
|
||||
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2")
|
||||
}
|
||||
return e.StartServer(e.TLSServer)
|
||||
}
|
||||
|
||||
// StartServer starts a custom http server.
|
||||
func (e *Echo) StartServer(s *http.Server) (err error) {
|
||||
e.startupMu.Lock()
|
||||
if err := e.configureServer(s); err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return err
|
||||
}
|
||||
e.startupMu.Unlock()
|
||||
return e.serve()
|
||||
}
|
||||
|
||||
func (e *Echo) configureServer(s *http.Server) (err error) {
|
||||
// Setup
|
||||
e.colorer.SetOutput(e.Logger.Output())
|
||||
s.ErrorLog = e.StdLogger
|
||||
@ -724,7 +760,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
|
||||
if !e.HidePort {
|
||||
e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr()))
|
||||
}
|
||||
return s.Serve(e.Listener)
|
||||
return nil
|
||||
}
|
||||
if e.TLSListener == nil {
|
||||
l, err := newListener(s.Addr, e.ListenerNetwork)
|
||||
@ -736,11 +772,39 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
|
||||
if !e.HidePort {
|
||||
e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr()))
|
||||
}
|
||||
return s.Serve(e.TLSListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Echo) serve() error {
|
||||
if e.TLSListener != nil {
|
||||
return e.Server.Serve(e.TLSListener)
|
||||
}
|
||||
return e.Server.Serve(e.Listener)
|
||||
}
|
||||
|
||||
// ListenerAddr returns net.Addr for Listener
|
||||
func (e *Echo) ListenerAddr() net.Addr {
|
||||
e.startupMu.RLock()
|
||||
defer e.startupMu.RUnlock()
|
||||
if e.Listener == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Listener.Addr()
|
||||
}
|
||||
|
||||
// TLSListenerAddr returns net.Addr for TLSListener
|
||||
func (e *Echo) TLSListenerAddr() net.Addr {
|
||||
e.startupMu.RLock()
|
||||
defer e.startupMu.RUnlock()
|
||||
if e.TLSListener == nil {
|
||||
return nil
|
||||
}
|
||||
return e.TLSListener.Addr()
|
||||
}
|
||||
|
||||
// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext).
|
||||
func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
|
||||
e.startupMu.Lock()
|
||||
// Setup
|
||||
s := e.Server
|
||||
s.Addr = address
|
||||
@ -758,18 +822,22 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
|
||||
if e.Listener == nil {
|
||||
e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
|
||||
if err != nil {
|
||||
e.startupMu.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !e.HidePort {
|
||||
e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr()))
|
||||
}
|
||||
e.startupMu.Unlock()
|
||||
return s.Serve(e.Listener)
|
||||
}
|
||||
|
||||
// Close immediately stops the server.
|
||||
// It internally calls `http.Server#Close()`.
|
||||
func (e *Echo) Close() error {
|
||||
e.startupMu.Lock()
|
||||
defer e.startupMu.Unlock()
|
||||
if err := e.TLSServer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -779,6 +847,8 @@ func (e *Echo) Close() error {
|
||||
// Shutdown stops the server gracefully.
|
||||
// It internally calls `http.Server#Shutdown()`.
|
||||
func (e *Echo) Shutdown(ctx stdContext.Context) error {
|
||||
e.startupMu.Lock()
|
||||
defer e.startupMu.Unlock()
|
||||
if err := e.TLSServer.Shutdown(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
354
echo_test.go
354
echo_test.go
@ -3,12 +3,14 @@ package echo
|
||||
import (
|
||||
"bytes"
|
||||
stdContext "context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -485,26 +487,125 @@ func TestEchoContext(t *testing.T) {
|
||||
e.ReleaseContext(c)
|
||||
}
|
||||
|
||||
func TestEchoStart(t *testing.T) {
|
||||
e := New()
|
||||
go func() {
|
||||
assert.NoError(t, e.Start(":0"))
|
||||
}()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error {
|
||||
ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
var addr net.Addr
|
||||
if isTLS {
|
||||
addr = e.TLSListenerAddr()
|
||||
} else {
|
||||
addr = e.ListenerAddr()
|
||||
}
|
||||
if addr != nil && strings.Contains(addr.String(), ":") {
|
||||
return nil // was started
|
||||
}
|
||||
case err := <-errChan:
|
||||
if err == http.ErrServerClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoStartTLS(t *testing.T) {
|
||||
func TestEchoStart(t *testing.T) {
|
||||
e := New()
|
||||
errChan := make(chan error)
|
||||
|
||||
go func() {
|
||||
err := e.StartTLS(":0", "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
|
||||
// Prevent the test to fail after closing the servers
|
||||
if err != http.ErrServerClosed {
|
||||
assert.NoError(t, err)
|
||||
err := e.Start(":0")
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
e.Close()
|
||||
err := waitForServerStart(e, errChan, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, e.Close())
|
||||
}
|
||||
|
||||
func TestEcho_StartTLS(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
addr string
|
||||
certFile string
|
||||
keyFile string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
addr: ":0",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid certFile",
|
||||
addr: ":0",
|
||||
certFile: "not existing",
|
||||
expectError: "open not existing: no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid keyFile",
|
||||
addr: ":0",
|
||||
keyFile: "not existing",
|
||||
expectError: "open not existing: no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "nok, failed to create cert out of certFile and keyFile",
|
||||
addr: ":0",
|
||||
keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
|
||||
expectError: "tls: found a certificate rather than a key in the PEM for the private key",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid tls address",
|
||||
addr: "nope",
|
||||
expectError: "listen tcp: address nope: missing port in address",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
errChan := make(chan error)
|
||||
|
||||
go func() {
|
||||
certFile := "_fixture/certs/cert.pem"
|
||||
if tc.certFile != "" {
|
||||
certFile = tc.certFile
|
||||
}
|
||||
keyFile := "_fixture/certs/key.pem"
|
||||
if tc.keyFile != "" {
|
||||
keyFile = tc.keyFile
|
||||
}
|
||||
|
||||
err := e.StartTLS(tc.addr, certFile, keyFile)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
err := waitForServerStart(e, errChan, true)
|
||||
if tc.expectError != "" {
|
||||
if _, ok := err.(*os.PathError); ok {
|
||||
assert.Error(t, err) // error messages for unix and windows are different. so test only error type here
|
||||
} else {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.NoError(t, e.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoStartTLSByteString(t *testing.T) {
|
||||
@ -557,47 +658,103 @@ func TestEchoStartTLSByteString(t *testing.T) {
|
||||
e := New()
|
||||
e.HideBanner = true
|
||||
|
||||
go func() {
|
||||
err := e.StartTLS(":0", test.cert, test.key)
|
||||
if test.expectedErr != nil {
|
||||
require.EqualError(t, err, test.expectedErr.Error())
|
||||
} else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
errChan := make(chan error, 0)
|
||||
|
||||
require.NoError(t, e.Close())
|
||||
go func() {
|
||||
errChan <- e.StartTLS(":0", test.cert, test.key)
|
||||
}()
|
||||
|
||||
err := waitForServerStart(e, errChan, true)
|
||||
if test.expectedErr != nil {
|
||||
assert.EqualError(t, err, test.expectedErr.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.NoError(t, e.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoStartAutoTLS(t *testing.T) {
|
||||
e := New()
|
||||
errChan := make(chan error, 0)
|
||||
func TestEcho_StartAutoTLS(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
addr string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
addr: ":0",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid address",
|
||||
addr: "nope",
|
||||
expectError: "listen tcp: address nope: missing port in address",
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
errChan <- e.StartAutoTLS(":0")
|
||||
}()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
errChan := make(chan error, 0)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
assert.NoError(t, err)
|
||||
default:
|
||||
assert.NoError(t, e.Close())
|
||||
go func() {
|
||||
errChan <- e.StartAutoTLS(tc.addr)
|
||||
}()
|
||||
|
||||
err := waitForServerStart(e, errChan, true)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.NoError(t, e.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoStartH2CServer(t *testing.T) {
|
||||
e := New()
|
||||
e.Debug = true
|
||||
h2s := &http2.Server{}
|
||||
func TestEcho_StartH2CServer(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
addr string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
addr: ":0",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid address",
|
||||
addr: "nope",
|
||||
expectError: "listen tcp: address nope: missing port in address",
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
assert.NoError(t, e.StartH2CServer(":0", h2s))
|
||||
}()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
e.Debug = true
|
||||
h2s := &http2.Server{}
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
err := e.StartH2CServer(tc.addr, h2s)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
err := waitForServerStart(e, errChan, false)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.NoError(t, e.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testMethod(t *testing.T, method, path string, e *Echo) {
|
||||
@ -686,7 +843,8 @@ func TestEchoClose(t *testing.T) {
|
||||
errCh <- e.Start(":0")
|
||||
}()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
err := waitForServerStart(e, errCh, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if err := e.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -694,7 +852,7 @@ func TestEchoClose(t *testing.T) {
|
||||
|
||||
assert.NoError(t, e.Close())
|
||||
|
||||
err := <-errCh
|
||||
err = <-errCh
|
||||
assert.Equal(t, err.Error(), "http: Server closed")
|
||||
}
|
||||
|
||||
@ -706,7 +864,8 @@ func TestEchoShutdown(t *testing.T) {
|
||||
errCh <- e.Start(":0")
|
||||
}()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
err := waitForServerStart(e, errCh, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if err := e.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -716,7 +875,7 @@ func TestEchoShutdown(t *testing.T) {
|
||||
defer cancel()
|
||||
assert.NoError(t, e.Shutdown(ctx))
|
||||
|
||||
err := <-errCh
|
||||
err = <-errCh
|
||||
assert.Equal(t, err.Error(), "http: Server closed")
|
||||
}
|
||||
|
||||
@ -764,7 +923,8 @@ func TestEchoListenerNetwork(t *testing.T) {
|
||||
errCh <- e.Start(tt.address)
|
||||
}()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
err := waitForServerStart(e, errCh, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
|
||||
defer resp.Body.Close()
|
||||
@ -823,3 +983,101 @@ func TestEchoReverse(t *testing.T) {
|
||||
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
|
||||
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
|
||||
}
|
||||
|
||||
func TestEcho_ListenerAddr(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
addr := e.ListenerAddr()
|
||||
assert.Nil(t, addr)
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- e.Start(":0")
|
||||
}()
|
||||
|
||||
err := waitForServerStart(e, errCh, false)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEcho_TLSListenerAddr(t *testing.T) {
|
||||
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
|
||||
require.NoError(t, err)
|
||||
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
|
||||
require.NoError(t, err)
|
||||
|
||||
e := New()
|
||||
|
||||
addr := e.TLSListenerAddr()
|
||||
assert.Nil(t, addr)
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- e.StartTLS(":0", cert, key)
|
||||
}()
|
||||
|
||||
err = waitForServerStart(e, errCh, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEcho_StartServer(t *testing.T) {
|
||||
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
|
||||
require.NoError(t, err)
|
||||
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
|
||||
require.NoError(t, err)
|
||||
certs, err := tls.X509KeyPair(cert, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
addr string
|
||||
TLSConfig *tls.Config
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
addr: ":0",
|
||||
},
|
||||
{
|
||||
name: "ok, start with TLS",
|
||||
addr: ":0",
|
||||
TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}},
|
||||
},
|
||||
{
|
||||
name: "nok, invalid address",
|
||||
addr: "nope",
|
||||
expectError: "listen tcp: address nope: missing port in address",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid tls address",
|
||||
addr: "nope",
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
expectError: "listen tcp: address nope: missing port in address",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
e.Debug = true
|
||||
|
||||
server := new(http.Server)
|
||||
server.Addr = tc.addr
|
||||
if tc.TLSConfig != nil {
|
||||
server.TLSConfig = tc.TLSConfig
|
||||
}
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- e.StartServer(server)
|
||||
}()
|
||||
|
||||
err := waitForServerStart(e, errCh, tc.TLSConfig != nil)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.NoError(t, e.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user