From ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 22 Jul 2023 23:25:34 +0300 Subject: [PATCH] Changes from master (from 5b36ce36127b2c011e6f0b905958d2544eef8820 to b3ec8e0fdd9d904aa5b1b95479da20c4961a59eb) --- CHANGELOG.md | 95 ++++++++- README.md | 13 +- binder.go | 16 +- context.go | 8 +- echo.go | 8 +- echo_test.go | 25 +++ group_test.go | 74 +++++++ middleware/body_limit.go | 10 +- middleware/body_limit_test.go | 6 +- middleware/compress.go | 91 ++++++++- middleware/compress_test.go | 132 +++++++++++++ middleware/cors.go | 4 +- middleware/middleware.go | 4 +- middleware/proxy.go | 138 ++++++++++--- middleware/proxy_test.go | 330 +++++++++++++++++++++++++++++++- middleware/rate_limiter.go | 21 +- middleware/rate_limiter_test.go | 13 +- middleware/request_logger.go | 4 +- middleware/util.go | 47 +++-- middleware/util_test.go | 29 +++ response.go | 7 + response_test.go | 8 + route.go | 7 +- route_test.go | 98 ++++++++++ router.go | 3 +- router_test.go | 11 ++ 26 files changed, 1093 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb6bd78..59430042 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,101 @@ # Changelog -## v4.10.0 - 2022-xx-xx +## v4.11.1 - 2023-07-16 + +**Fixes** + +* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481) + + +## v4.11.0 - 2023-07-14 + + +**Fixes** + +* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409) +* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411) +* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456) +* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477) + + +**Enhancements** + +* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410) +* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424) +* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425) +* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429) +* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416) +* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436) +* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444) +* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414) +* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452) +* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267) +* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475) + + +## v4.10.2 - 2023-02-22 **Security** -This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are -several vulnerabilities fixed in these libraries. +* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406) +* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405) -Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. +**Enhancements** + +* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277) + + +## v4.10.1 - 2023-02-19 + +**Security** + +* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402) + + +**Enhancements** + +* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377) +* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385) +* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380) +* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394) + + +## v4.10.0 - 2022-12-27 + +**Security** + +* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead. + + JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using + which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain. + +* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are + several vulnerabilities fixed in these libraries. + + Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. + + +**Enhancements** + +* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305) +* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336) +* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316) +* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338) +* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315) +* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329) +* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340) +* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342) +* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343) +* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345) +* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182) +* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350) +* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341) +* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162) +* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358) +* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362) +* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366) +* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337) ## v4.9.1 - 2022-10-12 diff --git a/README.md b/README.md index 132a96fe..267ce4d0 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) -[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) +[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) @@ -11,12 +11,12 @@ ## Supported Go versions -Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with +older versions. As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: - Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. @@ -85,9 +85,9 @@ func hello(c echo.Context) error { Following list of middleware is maintained by Echo team. -| Repository | Description | -|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | +| Repository | Description | +|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | | [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | # Third-party middleware repositories @@ -101,6 +101,7 @@ of middlewares in this list. | [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | | [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | | [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | UberĀ“s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. | [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | | [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | | [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | diff --git a/binder.go b/binder.go index e022a7e8..5d357859 100644 --- a/binder.go +++ b/binder.go @@ -1236,7 +1236,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Second) } @@ -1247,7 +1247,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Second) } @@ -1257,7 +1257,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Millisecond) } @@ -1268,7 +1268,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Millisecond) } @@ -1280,8 +1280,8 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal -// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Nanosecond) } @@ -1294,8 +1294,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal -// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Nanosecond) } diff --git a/context.go b/context.go index 67673b50..50616aa5 100644 --- a/context.go +++ b/context.go @@ -113,8 +113,8 @@ type Context interface { // Set saves data in the context. Set(key string, val interface{}) - // Bind binds the request body into provided type `i`. The default binder - // does it based on Content-Type header. + // Bind binds path params, query params and the request body into provided type `i`. The default binder + // binds body based on Content-Type header. Bind(i interface{}) error // Validate validates provided `i`. It is usually called after `Context#Bind()`. @@ -536,8 +536,8 @@ func (c *DefaultContext) Set(key string, val interface{}) { c.store[key] = val } -// Bind binds the request body into provided type `i`. The default binder -// does it based on Content-Type header. +// Bind binds path params, query params and the request body into provided type `i`. The default binder +// binds body based on Content-Type header. func (c *DefaultContext) Bind(i interface{}) error { return c.echo.Binder.Bind(c, i) } diff --git a/echo.go b/echo.go index f38e66c8..c76d9d80 100644 --- a/echo.go +++ b/echo.go @@ -40,6 +40,7 @@ package echo import ( stdContext "context" + "encoding/json" "errors" "fmt" "io" @@ -336,12 +337,17 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { // Issue #1426 code := he.Code message := he.Message - if m, ok := he.Message.(string); ok { + switch m := he.Message.(type) { + case string: if exposeError { message = Map{"message": m, "error": err.Error()} } else { message = Map{"message": m} } + case json.Marshaler: + // do nothing - this type knows how to format itself to JSON + case error: + message = Map{"message": m.Error()} } // Send response diff --git a/echo_test.go b/echo_test.go index 48e97b0f..3adbd8c9 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1198,6 +1198,18 @@ func request(method, path string, e *Echo) (int, string) { return rec.Code, rec.Body.String() } +type customError struct { + s string +} + +func (ce *customError) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil +} + +func (ce *customError) Error() string { + return ce.s +} + func TestDefaultHTTPErrorHandler(t *testing.T) { var testCases = []struct { name string @@ -1263,6 +1275,19 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { expectStatus: http.StatusInternalServerError, expectBody: ``, }, + { + name: "ok, custom error implement MarshalJSON", + whenMethod: http.MethodGet, + whenError: NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}), + expectStatus: http.StatusBadRequest, + expectBody: "{\"x\":\"custom error msg\"}\n", + }, + { + name: "with Debug=false when httpError contains an error", + whenError: NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")), + expectStatus: http.StatusBadRequest, + expectBody: "{\"message\":\"error in httperror\"}\n", + }, } for _, tc := range testCases { diff --git a/group_test.go b/group_test.go index d6a05e54..6c08b2b6 100644 --- a/group_test.go +++ b/group_test.go @@ -754,3 +754,77 @@ func TestGroup_StaticPanic(t *testing.T) { }) } } + +func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { + var testCases = []struct { + name string + givenCustom404 bool + whenURL string + expectBody interface{} + expectCode int + expectMiddlewareCalled bool + }{ + { + name: "ok, custom 404 handler is called with middleware", + givenCustom404: true, + whenURL: "/group/test3", + expectBody: "404 GET /group/*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added + }, + { + name: "ok, default group 404 handler is not called with middleware", + givenCustom404: false, + whenURL: "/group/test3", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + }, + { + name: "ok, (no slash) default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path()) + } + + e := New() + e.GET("/test1", okHandler) + e.RouteNotFound("/*", notFoundHandler) + + g := e.Group("/group") + g.GET("/test1", okHandler) + + middlewareCalled := false + g.Use(func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + middlewareCalled = true + return next(c) + } + }) + if tc.givenCustom404 { + g.RouteNotFound("/*", notFoundHandler) + } + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled) + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) + }) + } +} diff --git a/middleware/body_limit.go b/middleware/body_limit.go index 6b32f9d4..f43556c7 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -18,9 +18,8 @@ type BodyLimitConfig struct { type limitedReader struct { BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context + reader io.ReadCloser + read int64 } // BodyLimit returns a BodyLimit middleware. @@ -65,7 +64,7 @@ func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Based on content read r := pool.Get().(*limitedReader) - r.Reset(c, req.Body) + r.Reset(req.Body) defer pool.Put(r) req.Body = r @@ -87,8 +86,7 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) { +func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader - r.context = context r.read = 0 } diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 47982255..4981918a 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -67,9 +67,6 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) { func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() config := BodyLimitConfig{ Skipper: DefaultSkipper, @@ -78,7 +75,6 @@ func TestBodyLimitReader(t *testing.T) { reader := &limitedReader{ BodyLimitConfig: config, reader: io.NopCloser(bytes.NewReader(hw)), - context: e.NewContext(req, rec), } // read all should return ErrStatusRequestEntityTooLarge @@ -88,7 +84,7 @@ func TestBodyLimitReader(t *testing.T) { // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(e.NewContext(req, rec), io.NopCloser(bytes.NewReader(hw))) + reader.Reset(io.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) diff --git a/middleware/compress.go b/middleware/compress.go index 241da957..fb606aee 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -2,6 +2,7 @@ package middleware import ( "bufio" + "bytes" "compress/gzip" "errors" "io" @@ -25,12 +26,30 @@ type GzipConfig struct { // Gzip compression level. // Optional. Default value -1. Level int + + // Length threshold before gzip compression is applied. + // Optional. Default value 0. + // + // Most of the time you will not need to change the default. Compressing + // a short response might increase the transmitted data because of the + // gzip format overhead. Compressing the response will also consume CPU + // and time on the server and the client (for decompressing). Depending on + // your use case such a threshold might be useful. + // + // See also: + // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits + MinLength int } type gzipResponseWriter struct { io.Writer http.ResponseWriter - wroteBody bool + wroteHeader bool + wroteBody bool + minLength int + minLengthExceeded bool + buffer *bytes.Buffer + code int } // Gzip returns a middleware which compresses HTTP response using gzip compression scheme. @@ -54,8 +73,12 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Level == 0 { config.Level = -1 } + if config.MinLength < 0 { + config.MinLength = 0 + } pool := gzipCompressPool(config) + bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -66,7 +89,6 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { - res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 i := pool.Get() w, ok := i.(*gzip.Writer) if !ok { @@ -74,19 +96,37 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } rw := res.Writer w.Reset(rw) - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} + buf := bpool.Get().(*bytes.Buffer) + buf.Reset() + + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} defer func() { + // There are different reasons for cases when we have not yet written response to the client and now need to do so. + // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. + // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { res.Header().Del(echo.HeaderContentEncoding) } + if grw.wroteHeader { + rw.WriteHeader(grw.code) + } // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. res.Writer = rw w.Reset(io.Discard) + } else if !grw.minLengthExceeded { + // Write uncompressed response + res.Writer = rw + if grw.wroteHeader { + grw.ResponseWriter.WriteHeader(grw.code) + } + grw.buffer.WriteTo(rw) + w.Reset(io.Discard) } w.Close() + bpool.Put(buf) pool.Put(w) }() res.Writer = grw @@ -98,7 +138,11 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { func (w *gzipResponseWriter) WriteHeader(code int) { w.Header().Del(echo.HeaderContentLength) // Issue #444 - w.ResponseWriter.WriteHeader(code) + + w.wroteHeader = true + + // Delay writing of the header until we know if we'll actually compress the response + w.code = code } func (w *gzipResponseWriter) Write(b []byte) (int, error) { @@ -106,10 +150,40 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } w.wroteBody = true + + if !w.minLengthExceeded { + n, err := w.buffer.Write(b) + + if w.buffer.Len() >= w.minLength { + w.minLengthExceeded = true + + // The minimum length is exceeded, add Content-Encoding header and write the header + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + return w.Writer.Write(w.buffer.Bytes()) + } + + return n, err + } + return w.Writer.Write(b) } func (w *gzipResponseWriter) Flush() { + if !w.minLengthExceeded { + // Enforce compression because we will not know how much more data will come + w.minLengthExceeded = true + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + w.Writer.Write(w.buffer.Bytes()) + } + w.Writer.(*gzip.Writer).Flush() if flusher, ok := w.ResponseWriter.(http.Flusher); ok { flusher.Flush() @@ -138,3 +212,12 @@ func gzipCompressPool(config GzipConfig) sync.Pool { }, } } + +func bufferPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + b := &bytes.Buffer{} + return b + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 3da3a105..551f1852 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "compress/gzip" + "io" "net/http" "net/http/httptest" "os" @@ -211,6 +212,137 @@ func TestGzipWithStatic(t *testing.T) { } } +func TestGzipWithMinLength(t *testing.T) { + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("foobarfoobar")) + return nil + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "foobarfoobar", buf.String()) + } +} + +func TestGzipWithMinLengthTooShort(t *testing.T) { + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Body.String(), "test") +} + +func TestGzipWithResponseWithoutBody(t *testing.T) { + e := echo.New() + + e.Use(Gzip()) + e.GET("/", func(c echo.Context) error { + return c.Redirect(http.StatusMovedPermanently, "http://localhost") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestGzipWithMinLengthChunked(t *testing.T) { + e := echo.New() + + // Gzip chunked + chunkBuf := make([]byte, 5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + var r *gzip.Reader = nil + + c := e.NewContext(req, rec) + next := func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Transfer-Encoding", "chunked") + + // Write and flush the first part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + // Read the first part of the data + assert.True(t, rec.Flushed) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + var err error + r, err = gzip.NewReader(rec.Body) + assert.NoError(t, err) + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) + + // Write and flush the second part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) + + // Write the final part of the data and return + c.Response().Write([]byte("test")) + return nil + } + err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c) + + assert.NoError(t, err) + assert.NotNil(t, r) + + buf := new(bytes.Buffer) + + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) + + r.Close() +} + +func TestGzipWithMinLengthNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + func BenchmarkGzip(b *testing.B) { e := echo.New() diff --git a/middleware/cors.go b/middleware/cors.go index 6854b35a..74ec5673 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -151,8 +151,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { allowOriginPatterns := []string{} for _, origin := range config.AllowOrigins { pattern := regexp.QuoteMeta(origin) - pattern = strings.Replace(pattern, "\\*", ".*", -1) - pattern = strings.Replace(pattern, "\\?", ".", -1) + pattern = strings.ReplaceAll(pattern, "\\*", ".*") + pattern = strings.ReplaceAll(pattern, "\\?", ".") pattern = "^" + pattern + "$" allowOriginPatterns = append(allowOriginPatterns, pattern) } diff --git a/middleware/middleware.go b/middleware/middleware.go index 2f8c8b5c..0f99d6d6 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -35,9 +35,9 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { rulesRegex := map[*regexp.Regexp]string{} for k, v := range rewrite { k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*?)", -1) + k = strings.ReplaceAll(k, `\*`, "(.*?)") if strings.HasPrefix(k, `\^`) { - k = strings.Replace(k, `\^`, "^", -1) + k = strings.ReplaceAll(k, `\^`, "^") } k = k + "$" rulesRegex[regexp.MustCompile(k)] = v diff --git a/middleware/proxy.go b/middleware/proxy.go index 68f059b2..d1183d6f 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -29,6 +29,33 @@ type ProxyConfig struct { // Required. Balancer ProxyBalancer + // RetryCount defines the number of times a failed proxied request should be retried + // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. + RetryCount int + + // RetryFilter defines a function used to determine if a failed request to a + // ProxyTarget should be retried. The RetryFilter will only be called when the number + // of previous retries is less than RetryCount. If the function returns true, the + // request will be retried. The provided error indicates the reason for the request + // failure. When the ProxyTarget is unavailable, the error will be an instance of + // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error + // will indicate an internal error in the Proxy middleware. When a RetryFilter is not + // specified, all requests that fail with http.StatusBadGateway will be retried. A custom + // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is + // only called when the request to the target fails, or an internal error in the Proxy + // middleware has occurred. Successful requests that return a non-200 response code cannot + // be retried. + RetryFilter func(c echo.Context, e error) bool + + // ErrorHandler defines a function which can be used to return custom errors from + // the Proxy middleware. ErrorHandler is only invoked when there has been + // either an internal error in the Proxy middleware or the ProxyTarget is + // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked + // when a ProxyTarget returns a non-200 response. In these cases, the response + // is already written so errors cannot be modified. ErrorHandler is only + // invoked after all retry attempts have been exhausted. + ErrorHandler func(c echo.Context, err error) error + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Examples: @@ -99,14 +126,14 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { - c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() out, err := net.Dial("tcp", t.URL.Host) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } defer out.Close() @@ -114,7 +141,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler { // Write header err = r.Write(out) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL))) return } @@ -128,7 +155,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler { go cp(in, out) err = <-errCh if err != nil && err != io.EOF { - c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL)) } }) } @@ -192,7 +219,12 @@ func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) { return b.targets[b.random.Intn(len(b.targets))], nil } -// Next returns an upstream target using round-robin technique. +// Next returns an upstream target using round-robin technique. In the case +// where a previously failed request is being retried, the round-robin +// balancer will attempt to use the next target relative to the original +// request. If the list of targets held by the balancer is modified while a +// failed request is being retried, it is possible that the balancer will +// return the original failed target. // // Note: `nil` is returned in case upstream target list is empty. func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) { @@ -203,13 +235,28 @@ func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) { } else if len(b.targets) == 1 { return b.targets[0], nil } - // reset the index if out of bounds - if b.i >= len(b.targets) { - b.i = 0 + + var i int + const lastIdxKey = "_round_robin_last_index" + // This request is a retry, start from the index of the previous + // target to ensure we don't attempt to retry the request with + // the same failed target + if c.Get(lastIdxKey) != nil { + i = c.Get(lastIdxKey).(int) + i++ + if i >= len(b.targets) { + i = 0 + } + } else { + // This is a first time request, use the global index + if b.i >= len(b.targets) { + b.i = 0 + } + i = b.i + b.i++ } - t := b.targets[b.i] - b.i++ - return t, nil + c.Set(lastIdxKey, i) + return b.targets[i], nil } // Proxy returns a Proxy middleware. @@ -239,6 +286,19 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Balancer == nil { return nil, errors.New("echo proxy middleware requires balancer") } + if config.RetryFilter == nil { + config.RetryFilter = func(c echo.Context, e error) bool { + if httpErr, ok := e.(*echo.HTTPError); ok { + return httpErr.Code == http.StatusBadGateway + } + return false + } + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(c echo.Context, err error) error { + return err + } + } if config.Rewrite != nil { if config.RegexRewrite == nil { @@ -257,14 +317,8 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { req := c.Request() res := c.Response() - tgt, err := config.Balancer.Next(c) - if err != nil { - return err - } - c.Set(config.ContextKey, tgt) - if err := rewriteURL(config.RegexRewrite, req); err != nil { - return err + return config.ErrorHandler(c, err) } // Fix header @@ -280,19 +334,43 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } - // Proxy - switch { - case c.IsWebSocket(): - proxyRaw(c, tgt).ServeHTTP(res, req) - case req.Header.Get(echo.HeaderAccept) == "text/event-stream": - default: - proxyHTTP(c, tgt, config).ServeHTTP(res, req) - } - if e, ok := c.Get("_error").(error); ok { - err = e - } + retries := config.RetryCount + for { + tgt, err := config.Balancer.Next(c) + if err != nil { + return config.ErrorHandler(c, err) + } - return + c.Set(config.ContextKey, tgt) + + //If retrying a failed request, clear any previous errors from + //context here so that balancers have the option to check for + //errors that occurred using previous target + if retries < config.RetryCount { + c.Set("_error", nil) + } + + // Proxy + switch { + case c.IsWebSocket(): + proxyRaw(c, tgt).ServeHTTP(res, req) + case req.Header.Get(echo.HeaderAccept) == "text/event-stream": + default: + proxyHTTP(c, tgt, config).ServeHTTP(res, req) + } + + err, hasError := c.Get("_error").(error) + if !hasError { + return nil + } + + retry := retries > 0 && config.RetryFilter(c, err) + if !retry { + return config.ErrorHandler(c, err) + } + + retries-- + } } }, nil } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 92ebccf4..c1732840 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -426,9 +427,14 @@ func TestFailNextTarget(t *testing.T) { } func TestRandomBalancerWithNoTargets(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + // Assert balancer with empty targets does return `nil` on `Next()` rb := NewRandomBalancer(nil) - target, err := rb.Next(nil) + target, err := rb.Next(c) assert.Nil(t, target) assert.NoError(t, err) } @@ -436,7 +442,327 @@ func TestRandomBalancerWithNoTargets(t *testing.T) { func TestRoundRobinBalancerWithNoTargets(t *testing.T) { // Assert balancer with empty targets does return `nil` on `Next()` rrb := NewRoundRobinBalancer([]*ProxyTarget{}) - target, err := rrb.Next(nil) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + target, err := rrb.Next(c) assert.Nil(t, target) assert.NoError(t, err) } + +func TestProxyRetries(t *testing.T) { + newServer := func(res int) (*url.URL, *httptest.Server) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(res) + }), + ) + targetURL, _ := url.Parse(server.URL) + return targetURL, server + } + + targetURL, server := newServer(http.StatusOK) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: targetURL, + } + + targetURL, server = newServer(http.StatusBadRequest) + defer server.Close() + goodTargetWith40X := &ProxyTarget{ + Name: "Good with 40X", + URL: targetURL, + } + + targetURL, _ = url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: targetURL, + } + + alwaysRetryFilter := func(c echo.Context, e error) bool { return true } + neverRetryFilter := func(c echo.Context, e error) bool { return false } + + testCases := []struct { + name string + retryCount int + retryFilters []func(c echo.Context, e error) bool + targets []*ProxyTarget + expectedResponse int + }{ + { + name: "retry count 0 does not attempt retry on fail", + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 1 does not attempt retry on success", + retryCount: 1, + targets: []*ProxyTarget{ + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does retry on handler return true", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does not retry on handler return false", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when no more retries left", + retryCount: 2, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as only 2 retries + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when retries left but handler returns false", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as retry handler returns false on 2nd check + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 3 succeeds", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "40x responses are not retried", + retryCount: 1, + targets: []*ProxyTarget{ + goodTargetWith40X, + goodTarget, + }, + expectedResponse: http.StatusBadRequest, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + retryFilterCall := 0 + retryFilter := func(c echo.Context, e error) bool { + if len(tc.retryFilters) == 0 { + assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) + } + + retryFilterCall++ + + nextRetryFilter := tc.retryFilters[0] + tc.retryFilters = tc.retryFilters[1:] + + return nextRetryFilter(c, e) + } + + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer(tc.targets), + RetryCount: tc.retryCount, + RetryFilter: retryFilter, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedResponse, rec.Code) + if len(tc.retryFilters) > 0 { + assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters))) + } + }) + } +} + +func TestProxyRetryWithBackendTimeout(t *testing.T) { + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.ResponseHeaderTimeout = time.Millisecond * 500 + + timeoutBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.WriteHeader(404) + }), + ) + defer timeoutBackend.Close() + + timeoutTargetURL, _ := url.Parse(timeoutBackend.URL) + goodBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }), + ) + defer goodBackend.Close() + + goodTargetURL, _ := url.Parse(goodBackend.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Transport: transport, + Balancer: NewRoundRobinBalancer([]*ProxyTarget{ + { + Name: "Timeout", + URL: timeoutTargetURL, + }, + { + Name: "Good", + URL: goodTargetURL, + }, + }), + RetryCount: 1, + }, + )) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, 200, rec.Code) + }() + } + + wg.Wait() + +} + +func TestProxyErrorHandler(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + goodURL, _ := url.Parse(server.URL) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: goodURL, + } + + badURL, _ := url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: badURL, + } + + transformedError := errors.New("a new error") + + testCases := []struct { + name string + target *ProxyTarget + errorHandler func(c echo.Context, e error) error + expectFinalError func(t *testing.T, err error) + }{ + { + name: "Error handler not invoked when request success", + target: goodTarget, + errorHandler: func(c echo.Context, e error) error { + assert.FailNow(t, "error handler should not be invoked") + return e + }, + }, + { + name: "Error handler invoked when request fails", + target: badTarget, + errorHandler: func(c echo.Context, e error) error { + httpErr, ok := e.(*echo.HTTPError) + assert.True(t, ok, "expected http error to be passed to handler") + assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") + return transformedError + }, + expectFinalError: func(t *testing.T, err error) { + assert.Equal(t, transformedError, err, "transformed error not returned from proxy") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}), + ErrorHandler: tc.errorHandler, + }, + )) + + errorHandlerCalled := false + dheh := echo.DefaultHTTPErrorHandler(false) + e.HTTPErrorHandler = func(c echo.Context, err error) { + errorHandlerCalled = true + tc.expectFinalError(t, err) + dheh(c, err) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + if !errorHandlerCalled && tc.expectFinalError != nil { + t.Fatalf("error handler was not called") + } + + }) + } +} diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 99445f31..5b30b612 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -156,6 +156,8 @@ type RateLimiterMemoryStore struct { burst int expiresIn time.Duration lastCleanup time.Time + + timeNow func() time.Time } // Visitor signifies a unique user's limiter details @@ -215,7 +217,8 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s store.burst = int(config.Rate) } store.visitors = make(map[string]*Visitor) - store.lastCleanup = now() + store.timeNow = time.Now + store.lastCleanup = store.timeNow() return } @@ -240,12 +243,13 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst) store.visitors[identifier] = limiter } - limiter.lastSeen = now() - if now().Sub(store.lastCleanup) > store.expiresIn { + now := store.timeNow() + limiter.lastSeen = now + if now.Sub(store.lastCleanup) > store.expiresIn { store.cleanupStaleVisitors() } store.mutex.Unlock() - return limiter.AllowN(now(), 1), nil + return limiter.AllowN(store.timeNow(), 1), nil } /* @@ -254,14 +258,9 @@ of users who haven't visited again after the configured expiry time has elapsed */ func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { for id, visitor := range store.visitors { - if now().Sub(visitor.lastSeen) > store.expiresIn { + if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { delete(store.visitors, id) } } - store.lastCleanup = now() + store.lastCleanup = store.timeNow() } - -/* -actual time method which is mocked in test file -*/ -var now = time.Now diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 878fcb78..7a63fb26 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -2,7 +2,6 @@ package middleware import ( "errors" - "fmt" "math/rand" "net/http" "net/http/httptest" @@ -361,7 +360,7 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) { for i, tc := range testCases { t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) - now = func() time.Time { + inMemoryStore.timeNow = func() time.Time { return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) } allowed, _ := inMemoryStore.Allow(tc.id) @@ -371,24 +370,22 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) { func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - now = time.Now - fmt.Println(now()) inMemoryStore.visitors = map[string]*Visitor{ "A": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now(), + lastSeen: time.Now(), }, "B": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-1 * time.Minute), + lastSeen: time.Now().Add(-1 * time.Minute), }, "C": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-5 * time.Minute), + lastSeen: time.Now().Add(-5 * time.Minute), }, "D": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-10 * time.Minute), + lastSeen: time.Now().Add(-10 * time.Minute), }, } diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 13ab851a..46539e6a 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -224,7 +224,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } - now = time.Now + now := time.Now if config.timeNow != nil { now = config.timeNow } @@ -256,7 +256,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.BeforeNextFunc(c) } err := next(c) - if config.HandleError { + if err != nil && config.HandleError { // When global error handler writes the error to the client the Response gets "committed". This state can be // checked with `c.Response().Committed` field. c.Error(err) diff --git a/middleware/util.go b/middleware/util.go index 40f383bc..ffdc66b1 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -1,9 +1,11 @@ package middleware import ( + "bufio" "crypto/rand" - "fmt" + "io" "strings" + "sync" ) const ( @@ -77,17 +79,38 @@ func createRandomStringGenerator(length uint8) func() string { } } -func randomString(length uint8) string { - charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +// https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls +var randomReaderPool = sync.Pool{New: func() interface{} { + return bufio.NewReader(rand.Reader) +}} - bytes := make([]byte, length) - _, err := rand.Read(bytes) - if err != nil { - // we are out of random. let the request fail - panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err)) +const randomStringCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +const randomStringCharsetLen = 52 // len(randomStringCharset) +const randomStringMaxByte = 255 - (256 % randomStringCharsetLen) + +func randomString(length uint8) string { + reader := randomReaderPool.Get().(*bufio.Reader) + defer randomReaderPool.Put(reader) + + b := make([]byte, length) + r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times + var i uint8 = 0 + + for { + _, err := io.ReadFull(reader, r) + if err != nil { + panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)") + } + for _, rb := range r { + if rb > randomStringMaxByte { + // Skip this number to avoid bias. + continue + } + b[i] = randomStringCharset[rb%randomStringCharsetLen] + i++ + if i == length { + return string(b) + } + } } - for i, b := range bytes { - bytes[i] = charset[b%byte(len(charset))] - } - return string(bytes) } diff --git a/middleware/util_test.go b/middleware/util_test.go index dbf634e1..ceeae0d1 100644 --- a/middleware/util_test.go +++ b/middleware/util_test.go @@ -2,6 +2,7 @@ package middleware import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "io" "testing" ) @@ -129,3 +130,31 @@ func TestRandomString(t *testing.T) { }) } } + +func TestRandomStringBias(t *testing.T) { + t.Parallel() + const slen = 33 + const loop = 100000 + + counts := make(map[rune]int) + var count int64 + + for i := 0; i < loop; i++ { + s := randomString(slen) + require.Equal(t, slen, len(s)) + for _, b := range s { + counts[b]++ + count++ + } + } + + require.Equal(t, randomStringCharsetLen, len(counts)) + + avg := float64(count) / float64(len(counts)) + for k, n := range counts { + diff := float64(n) / avg + if diff < 0.95 || diff > 1.05 { + t.Errorf("Bias on '%c': expected average %f, got %d", k, avg, n) + } + } +} diff --git a/response.go b/response.go index 293870d9..8ae234ea 100644 --- a/response.go +++ b/response.go @@ -95,6 +95,13 @@ func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { return r.Writer.(http.Hijacker).Hijack() } +// Unwrap returns the original http.ResponseWriter. +// ResponseController can be used to access the original http.ResponseWriter. +// See [https://go.dev/blog/go1.20] +func (r *Response) Unwrap() http.ResponseWriter { + return r.Writer +} + func (r *Response) reset(w http.ResponseWriter) { r.beforeFuncs = nil r.afterFuncs = nil diff --git a/response_test.go b/response_test.go index d95e079f..e4fd636d 100644 --- a/response_test.go +++ b/response_test.go @@ -72,3 +72,11 @@ func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) } + +func TestResponse_Unwrap(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + res := &Response{echo: e, Writer: rec} + + assert.Equal(t, rec, res.Unwrap()) +} diff --git a/route.go b/route.go index 4aa1df5c..bec85c8d 100644 --- a/route.go +++ b/route.go @@ -81,7 +81,12 @@ func (r routeInfo) Reverse(params ...interface{}) string { ln := len(params) n := 0 for i, l := 0, len(r.path); i < l; i++ { - if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln { + hasBackslash := r.path[i] == '\\' + if hasBackslash && i+1 < l && r.path[i+1] == ':' { + i++ // backslash before colon escapes that colon. in that case skip backslash + } + if n < ln && (r.path[i] == anyLabel || (!hasBackslash && r.path[i] == paramLabel)) { + // in case of `*` wildcard or `:` (unescaped colon) param we replace everything till next slash or end of path for ; i < l && r.path[i] != '/'; i++ { } uri.WriteString(fmt.Sprintf("%v", params[n])) diff --git a/route_test.go b/route_test.go index a4651a56..3b96be05 100644 --- a/route_test.go +++ b/route_test.go @@ -421,3 +421,101 @@ func TestRoutes_FilterByName(t *testing.T) { }) } } + +func TestRouteInfo_Reverse(t *testing.T) { + var testCases = []struct { + name string + givenParams []string + givenPath string + whenParams []interface{} + expect string + }{ + { + name: "ok,static with no params", + givenPath: "/static", + expect: "/static", + }, + { + name: "ok,static with non existent param", + givenParams: []string{"missing param"}, + givenPath: "/static", + whenParams: []interface{}{"missing param"}, + expect: "/static", + }, + { + name: "ok, wildcard with no params", + givenPath: "/static/*", + expect: "/static/*", + }, + { + name: "ok, wildcard with params", + givenParams: []string{"foo.txt"}, + givenPath: "/static/*", + whenParams: []interface{}{"foo.txt"}, + expect: "/static/foo.txt", + }, + { + name: "ok, single param without param", + givenPath: "/params/:foo", + expect: "/params/:foo", + }, + { + name: "ok, single param with param", + givenParams: []string{"one"}, + givenPath: "/params/:foo", + whenParams: []interface{}{"one"}, + expect: "/params/one", + }, + { + name: "ok, multi param without params", + givenPath: "/params/:foo/bar/:qux", + expect: "/params/:foo/bar/:qux", + }, + { + name: "ok, multi param with one param", + givenParams: []string{"one"}, + givenPath: "/params/:foo/bar/:qux", + whenParams: []interface{}{"one"}, + expect: "/params/one/bar/:qux", + }, + { + name: "ok, multi param with all params", + givenParams: []string{"one", "two"}, + givenPath: "/params/:foo/bar/:qux", + whenParams: []interface{}{"one", "two"}, + expect: "/params/one/bar/two", + }, + { + name: "ok, multi param + wildcard with all params", + givenParams: []string{"one", "two", "three"}, + givenPath: "/params/:foo/bar/:qux/*", + whenParams: []interface{}{"one", "two", "three"}, + expect: "/params/one/bar/two/three", + }, + { + name: "ok, backslash is not escaped", + givenParams: []string{"test"}, + givenPath: "/a\\b/:x", + whenParams: []interface{}{"test"}, + expect: `/a\b/test`, + }, + { + name: "ok, escaped colon verbs", + givenParams: []string{"PATCH"}, + givenPath: "/params\\::customVerb", + whenParams: []interface{}{"PATCH"}, + expect: `/params:PATCH`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := routeInfo{ + path: tc.givenPath, + params: tc.givenParams, + name: tc.expect, + } + + assert.Equal(t, tc.expect, r.Reverse(tc.whenParams...)) + }) + } +} diff --git a/router.go b/router.go index 6516a0bf..d51fc21e 100644 --- a/router.go +++ b/router.go @@ -96,6 +96,7 @@ type RouteInfo interface { Name() string Params() []string + // Reverse reverses route to URL string by replacing path parameters with given params values. Reverse(params ...interface{}) string // NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore @@ -1066,7 +1067,7 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc { } } - if r.unescapePathParamValues && currentNode.kind != staticKind { + if r.unescapePathParamValues { // See issue #1531, #1258 - there are cases when path parameter need to be unescaped for i, p := range *pathParams { tmpVal, err := url.PathUnescape(p.Value) diff --git a/router_test.go b/router_test.go index 1cf61aba..3ff0d628 100644 --- a/router_test.go +++ b/router_test.go @@ -3053,6 +3053,15 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) { {Name: "*", Value: " /with space"}, }, }, + { + name: "ok, ending with static node, unescape = true", + givenUnescapePathParamValues: true, + whenURL: "/fourth/%20%2Fwith%20space/static", + expectPath: "/fourth/:id/static", + expectPathParams: PathParams{ + {Name: "id", Value: " /with space"}, + }, + }, { name: "ok, unescape = false", givenUnescapePathParamValues: false, @@ -3080,6 +3089,8 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) { assert.NoError(t, err) _, err = router.Add(Route{Method: http.MethodGet, Path: "/third/*", Handler: handlerFunc}) assert.NoError(t, err) + _, err = router.Add(Route{Method: http.MethodGet, Path: "/fourth/:id/static", Handler: handlerFunc}) + assert.NoError(t, err) target, _ := url.Parse(tc.whenURL) req := httptest.NewRequest(http.MethodGet, target.String(), nil)