diff --git a/Makefile b/Makefile index dfcb6c02..c369913a 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,27 @@ +PKG := "github.com/labstack/echo" +PKG_LIST := $(shell go list ${PKG}/...) + tag: @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'` @git tag|grep -v ^v + +.DEFAULT_GOAL := check +check: lint vet race ## Check project + +init: + @go get -u golang.org/x/lint/golint + +lint: ## Lint the files + @golint -set_exit_status ${PKG_LIST} + +vet: ## Vet the files + @go vet ${PKG_LIST} + +test: ## Run tests + @go test -short ${PKG_LIST} + +race: ## Run tests with data race detector + @go test -race ${PKG_LIST} + +help: ## Display this help screen + @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/bind.go b/bind.go index c7be242b..acd2beda 100644 --- a/bind.go +++ b/bind.go @@ -98,12 +98,20 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } // Bind implements the `Binder#Bind` function. +// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous +// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { if err := b.BindPathParams(c, i); err != nil { return err } - if err = b.BindQueryParams(c, i); err != nil { - return err + // Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH) + // Reasoning here is that parameters in query and bind destination struct could have UNEXPECTED matches and results due that. + // i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding. + // This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body + if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete { + if err = b.BindQueryParams(c, i); err != nil { + return err + } } return b.BindBody(c, i) } diff --git a/bind_test.go b/bind_test.go index 60c2f9e0..345fbdf1 100644 --- a/bind_test.go +++ b/bind_test.go @@ -559,7 +559,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { // binding is done in steps and one source could overwrite previous source binded data // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed - type Node struct { + type Opts struct { ID int `json:"id"` Node string `json:"node"` } @@ -575,41 +575,77 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { expectError string }{ { - name: "ok, POST bind to struct with: path param + query param + empty body", + name: "ok, POST bind to struct with: path param + query param + body", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1}`), - expect: &Node{ID: 1, Node: "xxx"}, // in current implementation query params has higher priority than path params + expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used, node is filled from path }, { - name: "ok, POST bind to struct with: path param + empty body", + name: "ok, PUT bind to struct with: path param + query param + body", + givenMethod: http.MethodPut, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used + }, + { + name: "ok, GET bind to struct with: path param + query param + body", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "xxx"}, // query overwrites previous path value + }, + { + name: "ok, GET bind to struct with: path param + query param + body", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 1, Node: "zzz"}, // body is binded last and overwrites previous (path,query) values + }, + { + name: "ok, DELETE bind to struct with: path param + query param + body", + givenMethod: http.MethodDelete, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is binded after query params + }, + { + name: "ok, POST bind to struct with: path param + body", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`{"id": 1}`), - expect: &Node{ID: 1, Node: "real_node"}, + expect: &Opts{ID: 1, Node: "node_from_path"}, }, { name: "ok, POST bind to struct with path + query + body = body has priority", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), - expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority + expect: &Opts{ID: 1, Node: "zzz"}, // field value from content has higher priority }, { name: "nok, POST body bind failure", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{`), - expect: &Node{ID: 0, Node: "xxx"}, // query binding has already modified bind target + expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", }, + { + name: "nok, GET with body bind failure when types are not convertible", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?id=nope", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target + expectError: "code=400, message=strconv.ParseInt: parsing \"nope\": invalid syntax, internal=strconv.ParseInt: parsing \"nope\": invalid syntax", + }, { name: "nok, GET body bind failure - trying to bind json array to struct", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - expect: &Node{ID: 0, Node: "xxx"}, // query binding has already modified bind target - expectError: "code=400, message=Unmarshal type error: expected=echo.Node, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Node", + expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target + expectError: "code=400, message=Unmarshal type error: expected=echo.Opts, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Opts", }, { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice name: "nok, GET query params bind failure - trying to bind json array to slice", @@ -617,17 +653,27 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathParams: true, - whenBindTarget: &[]Node{}, - expect: &[]Node{}, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{}, expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", }, + { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice + name: "ok, POST binding to slice should not be affected query params types", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?id=nope&node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathParams: true, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{{ID: 1}}, + expectError: "", + }, { // binding path params interferes with body. b.BindBody() should be used to bind only body to slice name: "nok, GET path params bind failure - trying to bind json array to slice", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - whenBindTarget: &[]Node{}, - expect: &[]Node{}, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{}, expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", }, { @@ -636,8 +682,8 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathParams: true, - whenBindTarget: &[]Node{}, - expect: &[]Node{{ID: 1, Node: ""}}, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{{ID: 1, Node: ""}}, expectError: "", }, } @@ -653,14 +699,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { if !tc.whenNoPathParams { c.SetParamNames("node") - c.SetParamValues("real_node") + c.SetParamValues("node_from_path") } var bindTarget interface{} if tc.whenBindTarget != nil { bindTarget = tc.whenBindTarget } else { - bindTarget = &Node{} + bindTarget = &Opts{} } b := new(DefaultBinder) diff --git a/echo.go b/echo.go index d284ff39..6db485d1 100644 --- a/echo.go +++ b/echo.go @@ -67,6 +67,9 @@ type ( // Echo is the top-level framework instance. Echo struct { common + // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get + // listener address info (on which interface/port was listener binded) without having data races. + startupMutex sync.RWMutex StdLogger *stdLog.Logger colorer *color.Color premiddleware []MiddlewareFunc @@ -500,8 +503,15 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl } return c.File(name) } - if prefix == "/" { - return get(prefix+"*", h) + // Handle added routes based on trailing slash: + // /prefix => exact route "/prefix" + any route "/prefix/*" + // /prefix/ => only any route "/prefix/*" + if prefix != "" { + if prefix[len(prefix)-1] == '/' { + // Only add any route for intentional trailing slash + return get(prefix+"*", h) + } + get(prefix, h) } return get(prefix+"/*", h) } @@ -643,21 +653,30 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Start starts an HTTP server. func (e *Echo) Start(address string) error { + e.startupMutex.Lock() e.Server.Addr = address - return e.StartServer(e.Server) + if err := e.configureServer(e.Server); err != nil { + e.startupMutex.Unlock() + return err + } + e.startupMutex.Unlock() + return e.serve() } // StartTLS starts an HTTPS server. // If `certFile` or `keyFile` is `string` the values are treated as file paths. // If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { + e.startupMutex.Lock() var cert []byte if cert, err = filepathOrContent(certFile); err != nil { + e.startupMutex.Unlock() return } var key []byte if key, err = filepathOrContent(keyFile); err != nil { + e.startupMutex.Unlock() return } @@ -665,10 +684,17 @@ func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err erro s.TLSConfig = new(tls.Config) s.TLSConfig.Certificates = make([]tls.Certificate, 1) if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { + e.startupMutex.Unlock() return } - return e.startTLS(address) + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMutex.Unlock() + return err + } + e.startupMutex.Unlock() + return s.Serve(e.TLSListener) } func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { @@ -684,24 +710,41 @@ func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { // StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. func (e *Echo) StartAutoTLS(address string) error { + e.startupMutex.Lock() s := e.TLSServer s.TLSConfig = new(tls.Config) s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - return e.startTLS(address) + + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMutex.Unlock() + return err + } + e.startupMutex.Unlock() + return s.Serve(e.TLSListener) } -func (e *Echo) startTLS(address string) error { +func (e *Echo) configureTLS(address string) { s := e.TLSServer s.Addr = address if !e.DisableHTTP2 { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") } - return e.StartServer(e.TLSServer) } // StartServer starts a custom http server. func (e *Echo) StartServer(s *http.Server) (err error) { + e.startupMutex.Lock() + if err := e.configureServer(s); err != nil { + e.startupMutex.Unlock() + return err + } + e.startupMutex.Unlock() + return e.serve() +} + +func (e *Echo) configureServer(s *http.Server) (err error) { // Setup e.colorer.SetOutput(e.Logger.Output()) s.ErrorLog = e.StdLogger @@ -724,7 +767,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) } - return s.Serve(e.Listener) + return nil } if e.TLSListener == nil { l, err := newListener(s.Addr, e.ListenerNetwork) @@ -736,11 +779,39 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if !e.HidePort { e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) } - return s.Serve(e.TLSListener) + return nil +} + +func (e *Echo) serve() error { + if e.TLSListener != nil { + return e.Server.Serve(e.TLSListener) + } + return e.Server.Serve(e.Listener) +} + +// ListenerAddr returns net.Addr for Listener +func (e *Echo) ListenerAddr() net.Addr { + e.startupMutex.RLock() + defer e.startupMutex.RUnlock() + if e.Listener == nil { + return nil + } + return e.Listener.Addr() +} + +// TLSListenerAddr returns net.Addr for TLSListener +func (e *Echo) TLSListenerAddr() net.Addr { + e.startupMutex.RLock() + defer e.startupMutex.RUnlock() + if e.TLSListener == nil { + return nil + } + return e.TLSListener.Addr() } // StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { + e.startupMutex.Lock() // Setup s := e.Server s.Addr = address @@ -758,18 +829,22 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { if e.Listener == nil { e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { + e.startupMutex.Unlock() return err } } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) } + e.startupMutex.Unlock() return s.Serve(e.Listener) } // Close immediately stops the server. // It internally calls `http.Server#Close()`. func (e *Echo) Close() error { + e.startupMutex.Lock() + defer e.startupMutex.Unlock() if err := e.TLSServer.Close(); err != nil { return err } @@ -779,6 +854,8 @@ func (e *Echo) Close() error { // Shutdown stops the server gracefully. // It internally calls `http.Server#Shutdown()`. func (e *Echo) Shutdown(ctx stdContext.Context) error { + e.startupMutex.Lock() + defer e.startupMutex.Unlock() if err := e.TLSServer.Shutdown(ctx); err != nil { return err } diff --git a/echo_test.go b/echo_test.go index 29edca10..781b901f 100644 --- a/echo_test.go +++ b/echo_test.go @@ -3,12 +3,14 @@ package echo import ( "bytes" stdContext "context" + "crypto/tls" "errors" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" + "os" "reflect" "strings" "testing" @@ -103,6 +105,32 @@ func TestEchoStatic(t *testing.T) { expectHeaderLocation: "/folder/", expectBodyStartsWith: "", }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, { name: "Directory with index.html", givenPrefix: "/", @@ -111,6 +139,22 @@ func TestEchoStatic(t *testing.T) { expectStatus: http.StatusOK, expectBodyStartsWith: "", }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, { name: "Sub-directory with index.html", givenPrefix: "/", @@ -162,6 +206,40 @@ func TestEchoStatic(t *testing.T) { } } +func TestEchoStaticRedirectIndex(t *testing.T) { + assert := assert.New(t) + e := New() + + // HandlerFunc + e.Static("/static", "_fixture") + + errCh := make(chan error) + + go func() { + errCh <- e.Start("127.0.0.1:1323") + }() + + time.Sleep(200 * time.Millisecond) + + if resp, err := http.Get("http://127.0.0.1:1323/static"); err == nil { + defer resp.Body.Close() + assert.Equal(http.StatusOK, resp.StatusCode) + + if body, err := ioutil.ReadAll(resp.Body); err == nil { + assert.Equal(true, strings.HasPrefix(string(body), "")) + } else { + assert.Fail(err.Error()) + } + + } else { + assert.Fail(err.Error()) + } + + if err := e.Close(); err != nil { + t.Fatal(err) + } +} + func TestEchoFile(t *testing.T) { e := New() e.File("/walle", "_fixture/images/walle.png") @@ -485,26 +563,125 @@ func TestEchoContext(t *testing.T) { e.ReleaseContext(c) } -func TestEchoStart(t *testing.T) { - e := New() - go func() { - assert.NoError(t, e.Start(":0")) - }() - time.Sleep(200 * time.Millisecond) +func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + var addr net.Addr + if isTLS { + addr = e.TLSListenerAddr() + } else { + addr = e.ListenerAddr() + } + if addr != nil && strings.Contains(addr.String(), ":") { + return nil // was started + } + case err := <-errChan: + if err == http.ErrServerClosed { + return nil + } + return err + } + } } -func TestEchoStartTLS(t *testing.T) { +func TestEchoStart(t *testing.T) { e := New() + errChan := make(chan error) + go func() { - err := e.StartTLS(":0", "_fixture/certs/cert.pem", "_fixture/certs/key.pem") - // Prevent the test to fail after closing the servers - if err != http.ErrServerClosed { - assert.NoError(t, err) + err := e.Start(":0") + if err != nil { + errChan <- err } }() - time.Sleep(200 * time.Millisecond) - e.Close() + err := waitForServerStart(e, errChan, false) + assert.NoError(t, err) + + assert.NoError(t, e.Close()) +} + +func TestEcho_StartTLS(t *testing.T) { + var testCases = []struct { + name string + addr string + certFile string + keyFile string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid certFile", + addr: ":0", + certFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, invalid keyFile", + addr: ":0", + keyFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, failed to create cert out of certFile and keyFile", + addr: ":0", + keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key + expectError: "tls: found a certificate rather than a key in the PEM for the private key", + }, + { + name: "nok, invalid tls address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + errChan := make(chan error) + + go func() { + certFile := "_fixture/certs/cert.pem" + if tc.certFile != "" { + certFile = tc.certFile + } + keyFile := "_fixture/certs/key.pem" + if tc.keyFile != "" { + keyFile = tc.keyFile + } + + err := e.StartTLS(tc.addr, certFile, keyFile) + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, true) + if tc.expectError != "" { + if _, ok := err.(*os.PathError); ok { + assert.Error(t, err) // error messages for unix and windows are different. so test only error type here + } else { + assert.EqualError(t, err, tc.expectError) + } + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } } func TestEchoStartTLSByteString(t *testing.T) { @@ -557,47 +734,103 @@ func TestEchoStartTLSByteString(t *testing.T) { e := New() e.HideBanner = true - go func() { - err := e.StartTLS(":0", test.cert, test.key) - if test.expectedErr != nil { - require.EqualError(t, err, test.expectedErr.Error()) - } else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers - require.NoError(t, err) - } - }() - time.Sleep(200 * time.Millisecond) + errChan := make(chan error, 0) - require.NoError(t, e.Close()) + go func() { + errChan <- e.StartTLS(":0", test.cert, test.key) + }() + + err := waitForServerStart(e, errChan, true) + if test.expectedErr != nil { + assert.EqualError(t, err, test.expectedErr.Error()) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) }) } } -func TestEchoStartAutoTLS(t *testing.T) { - e := New() - errChan := make(chan error, 0) +func TestEcho_StartAutoTLS(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } - go func() { - errChan <- e.StartAutoTLS(":0") - }() - time.Sleep(200 * time.Millisecond) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + errChan := make(chan error, 0) - select { - case err := <-errChan: - assert.NoError(t, err) - default: - assert.NoError(t, e.Close()) + go func() { + errChan <- e.StartAutoTLS(tc.addr) + }() + + err := waitForServerStart(e, errChan, true) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) } } -func TestEchoStartH2CServer(t *testing.T) { - e := New() - e.Debug = true - h2s := &http2.Server{} +func TestEcho_StartH2CServer(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } - go func() { - assert.NoError(t, e.StartH2CServer(":0", h2s)) - }() - time.Sleep(200 * time.Millisecond) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + h2s := &http2.Server{} + + errChan := make(chan error) + go func() { + err := e.StartH2CServer(tc.addr, h2s) + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, false) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } } func testMethod(t *testing.T, method, path string, e *Echo) { @@ -686,7 +919,8 @@ func TestEchoClose(t *testing.T) { errCh <- e.Start(":0") }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if err := e.Close(); err != nil { t.Fatal(err) @@ -694,7 +928,7 @@ func TestEchoClose(t *testing.T) { assert.NoError(t, e.Close()) - err := <-errCh + err = <-errCh assert.Equal(t, err.Error(), "http: Server closed") } @@ -706,7 +940,8 @@ func TestEchoShutdown(t *testing.T) { errCh <- e.Start(":0") }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if err := e.Close(); err != nil { t.Fatal(err) @@ -716,7 +951,7 @@ func TestEchoShutdown(t *testing.T) { defer cancel() assert.NoError(t, e.Shutdown(ctx)) - err := <-errCh + err = <-errCh assert.Equal(t, err.Error(), "http: Server closed") } @@ -764,7 +999,8 @@ func TestEchoListenerNetwork(t *testing.T) { errCh <- e.Start(tt.address) }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { defer resp.Body.Close() @@ -823,3 +1059,101 @@ func TestEchoReverse(t *testing.T) { assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) } + +func TestEcho_ListenerAddr(t *testing.T) { + e := New() + + addr := e.ListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.Start(":0") + }() + + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) +} + +func TestEcho_TLSListenerAddr(t *testing.T) { + cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := ioutil.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + + e := New() + + addr := e.TLSListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.StartTLS(":0", cert, key) + }() + + err = waitForServerStart(e, errCh, true) + assert.NoError(t, err) +} + +func TestEcho_StartServer(t *testing.T) { + cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := ioutil.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + certs, err := tls.X509KeyPair(cert, key) + require.NoError(t, err) + + var testCases = []struct { + name string + addr string + TLSConfig *tls.Config + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "ok, start with TLS", + addr: ":0", + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + { + name: "nok, invalid tls address", + addr: "nope", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + + server := new(http.Server) + server.Addr = tc.addr + if tc.TLSConfig != nil { + server.TLSConfig = tc.TLSConfig + } + + errCh := make(chan error) + go func() { + errCh <- e.StartServer(server) + }() + + err := waitForServerStart(e, errCh, tc.TLSConfig != nil) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.NoError(t, e.Close()) + }) + } +} diff --git a/middleware/static_test.go b/middleware/static_test.go index 407dd15c..8c0c97de 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -3,7 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" - "path/filepath" + "strings" "testing" "github.com/labstack/echo/v4" @@ -11,84 +11,269 @@ import ( ) func TestStatic(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - config := StaticConfig{ - Root: "../_fixture", + var testCases = []struct { + name string + givenConfig *StaticConfig + givenAttachedToGroup string + whenURL string + expectContains string + expectLength string + expectCode int + }{ + { + name: "ok, serve index with Echo message", + whenURL: "/", + expectCode: http.StatusOK, + expectContains: "Echo", + }, + { + name: "ok, serve file from subdirectory", + whenURL: "/images/walle.png", + expectCode: http.StatusOK, + expectLength: "219885", + }, + { + name: "ok, when html5 mode serve index for any static file that does not exist", + givenConfig: &StaticConfig{ + Root: "../_fixture", + HTML5: true, + }, + whenURL: "/random", + expectCode: http.StatusOK, + expectContains: "Echo", + }, + { + name: "ok, serve index as directory index listing files directory", + givenConfig: &StaticConfig{ + Root: "../_fixture/certs", + Browse: true, + }, + whenURL: "/", + expectCode: http.StatusOK, + expectContains: "cert.pem", + }, + { + name: "ok, serve directory index with IgnoreBase and browse", + givenConfig: &StaticConfig{ + Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored + IgnoreBase: true, + Browse: true, + }, + givenAttachedToGroup: "/_fixture", + whenURL: "/_fixture/", + expectCode: http.StatusOK, + expectContains: `README.md`, + }, + { + name: "ok, serve file with IgnoreBase", + givenConfig: &StaticConfig{ + Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored + IgnoreBase: true, + Browse: true, + }, + givenAttachedToGroup: "/_fixture", + whenURL: "/_fixture/README.md", + expectCode: http.StatusOK, + expectContains: "This directory is used for the static middleware test", + }, + { + name: "nok, file not found", + whenURL: "/none", + expectCode: http.StatusNotFound, + expectContains: "{\"message\":\"Not Found\"}\n", + }, + { + name: "nok, do not allow directory traversal (backslash - windows separator)", + whenURL: `/..\\middleware/basic_auth.go`, + expectCode: http.StatusNotFound, + expectContains: "{\"message\":\"Not Found\"}\n", + }, + { + name: "nok,do not allow directory traversal (slash - unix separator)", + whenURL: `/../middleware/basic_auth.go`, + expectCode: http.StatusNotFound, + expectContains: "{\"message\":\"Not Found\"}\n", + }, } - // Directory - h := StaticWithConfig(config)(echo.NotFoundHandler) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() - assert := assert.New(t) + config := StaticConfig{Root: "../_fixture"} + if tc.givenConfig != nil { + config = *tc.givenConfig + } + middlewareFunc := StaticWithConfig(config) + if tc.givenAttachedToGroup != "" { + // middleware is attached to group + subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc) + // group without http handlers (routes) does not do anything. + // Request is matched against http handlers (routes) that have group middleware attached to them + subGroup.GET("", echo.NotFoundHandler) + subGroup.GET("/*", echo.NotFoundHandler) + } else { + // middleware is on root level + e.Use(middlewareFunc) + } - if assert.NoError(h(c)) { - assert.Contains(rec.Body.String(), "Echo") + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + if tc.expectContains != "" { + responseBody := rec.Body.String() + assert.Contains(t, responseBody, tc.expectContains) + } + if tc.expectLength != "" { + assert.Equal(t, rec.Header().Get(echo.HeaderContentLength), tc.expectLength) + } + }) + } +} + +func TestStatic_GroupWithStatic(t *testing.T) { + var testCases = []struct { + name string + givenGroup string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "../_fixture/images", + whenURL: "/group/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "../_fixture/scripts", + whenURL: "/group/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory not found (no trailing slash)", + givenPrefix: "/images", + givenRoot: "../_fixture/images", + whenURL: "/group/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory redirect", + givenPrefix: "/", + givenRoot: "../_fixture", + whenURL: "/group/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/group/folder/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenGroup: "_fixture", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "../_fixture", + whenURL: "/_fixture/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenGroup: "_fixture", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "../_fixture", + whenURL: "/_fixture/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/_fixture/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "../_fixture", + whenURL: "/group/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "../_fixture", + whenURL: "/group/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "../_fixture", + whenURL: "/group/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "../_fixture", + whenURL: "/group/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "../_fixture/", + whenURL: `/group/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "../_fixture/", + whenURL: `/group/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + group := "/group" + if tc.givenGroup != "" { + group = tc.givenGroup + } + g := e.Group(group) + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Header().Get(echo.HeaderLocation)) + } else { + _, ok := rec.Result().Header[echo.HeaderLocation] + assert.False(t, ok) + } + }) } - - // File found - req = httptest.NewRequest(http.MethodGet, "/images/walle.png", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - if assert.NoError(h(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(rec.Header().Get(echo.HeaderContentLength), "219885") - } - - // File not found - req = httptest.NewRequest(http.MethodGet, "/none", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusNotFound, he.Code) - - // HTML5 - req = httptest.NewRequest(http.MethodGet, "/random", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - config.HTML5 = true - static := StaticWithConfig(config) - h = static(echo.NotFoundHandler) - if assert.NoError(h(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Contains(rec.Body.String(), "Echo") - } - - // Browse - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - config.Root = "../_fixture/certs" - config.Browse = true - static = StaticWithConfig(config) - h = static(echo.NotFoundHandler) - if assert.NoError(h(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Contains(rec.Body.String(), "cert.pem") - } - - // IgnoreBase - req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) - rec = httptest.NewRecorder() - config.Root = "../_fixture" - config.IgnoreBase = true - static = StaticWithConfig(config) - c.Echo().Group("_fixture", static) - e.ServeHTTP(rec, req) - - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122") - - req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) - rec = httptest.NewRecorder() - config.Root = "../_fixture" - config.IgnoreBase = false - static = StaticWithConfig(config) - c.Echo().Group("_fixture", static) - e.ServeHTTP(rec, req) - - assert.Equal(http.StatusOK, rec.Code) - assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture")) }