1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-22 20:06:21 +02:00

Changes from master (from 5b36ce3612 to b3ec8e0fdd)

This commit is contained in:
toimtoimtoim 2023-07-22 23:25:34 +03:00
parent c2af0cf5a8
commit ec5b858dab
No known key found for this signature in database
GPG Key ID: 0443E21F7D9928AF
26 changed files with 1093 additions and 109 deletions

View File

@ -1,14 +1,101 @@
# Changelog
## v4.10.0 - 2022-xx-xx
## v4.11.1 - 2023-07-16
**Fixes**
* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481)
## v4.11.0 - 2023-07-14
**Fixes**
* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409)
* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411)
* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456)
* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477)
**Enhancements**
* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410)
* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424)
* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425)
* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429)
* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416)
* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436)
* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444)
* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414)
* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452)
* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267)
* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475)
## v4.10.2 - 2023-02-22
**Security**
This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
several vulnerabilities fixed in these libraries.
* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406)
* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405)
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
**Enhancements**
* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277)
## v4.10.1 - 2023-02-19
**Security**
* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402)
**Enhancements**
* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377)
* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385)
* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380)
* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394)
## v4.10.0 - 2022-12-27
**Security**
* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead.
JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using
which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain.
* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
several vulnerabilities fixed in these libraries.
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
**Enhancements**
* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305)
* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336)
* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316)
* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338)
* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315)
* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329)
* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340)
* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342)
* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343)
* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345)
* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182)
* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350)
* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341)
* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162)
* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358)
* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362)
* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366)
* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337)
## v4.9.1 - 2022-10-12

View File

@ -3,7 +3,7 @@
[![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge)
[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4)
[![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo)
[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo)
[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions)
[![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo)
[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions)
[![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack)
@ -11,12 +11,12 @@
## Supported Go versions
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with
older versions.
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
Therefore a Go version capable of understanding /vN suffixed imports is required:
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
way of using Echo going forward.
@ -85,9 +85,9 @@ func hello(c echo.Context) error {
Following list of middleware is maintained by Echo team.
| Repository | Description |
|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
| Repository | Description |
|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
# Third-party middleware repositories
@ -101,6 +101,7 @@ of middlewares in this list.
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface.
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |

View File

@ -1236,7 +1236,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, time.Second)
}
@ -1247,7 +1247,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, time.Second)
}
@ -1257,7 +1257,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, time.Millisecond)
}
@ -1268,7 +1268,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, time.Millisecond)
}
@ -1280,8 +1280,8 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
}
@ -1294,8 +1294,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, time.Nanosecond)
}

View File

@ -113,8 +113,8 @@ type Context interface {
// Set saves data in the context.
Set(key string, val interface{})
// Bind binds the request body into provided type `i`. The default binder
// does it based on Content-Type header.
// Bind binds path params, query params and the request body into provided type `i`. The default binder
// binds body based on Content-Type header.
Bind(i interface{}) error
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
@ -536,8 +536,8 @@ func (c *DefaultContext) Set(key string, val interface{}) {
c.store[key] = val
}
// Bind binds the request body into provided type `i`. The default binder
// does it based on Content-Type header.
// Bind binds path params, query params and the request body into provided type `i`. The default binder
// binds body based on Content-Type header.
func (c *DefaultContext) Bind(i interface{}) error {
return c.echo.Binder.Bind(c, i)
}

View File

@ -40,6 +40,7 @@ package echo
import (
stdContext "context"
"encoding/json"
"errors"
"fmt"
"io"
@ -336,12 +337,17 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
// Issue #1426
code := he.Code
message := he.Message
if m, ok := he.Message.(string); ok {
switch m := he.Message.(type) {
case string:
if exposeError {
message = Map{"message": m, "error": err.Error()}
} else {
message = Map{"message": m}
}
case json.Marshaler:
// do nothing - this type knows how to format itself to JSON
case error:
message = Map{"message": m.Error()}
}
// Send response

View File

@ -1198,6 +1198,18 @@ func request(method, path string, e *Echo) (int, string) {
return rec.Code, rec.Body.String()
}
type customError struct {
s string
}
func (ce *customError) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil
}
func (ce *customError) Error() string {
return ce.s
}
func TestDefaultHTTPErrorHandler(t *testing.T) {
var testCases = []struct {
name string
@ -1263,6 +1275,19 @@ func TestDefaultHTTPErrorHandler(t *testing.T) {
expectStatus: http.StatusInternalServerError,
expectBody: ``,
},
{
name: "ok, custom error implement MarshalJSON",
whenMethod: http.MethodGet,
whenError: NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}),
expectStatus: http.StatusBadRequest,
expectBody: "{\"x\":\"custom error msg\"}\n",
},
{
name: "with Debug=false when httpError contains an error",
whenError: NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")),
expectStatus: http.StatusBadRequest,
expectBody: "{\"message\":\"error in httperror\"}\n",
},
}
for _, tc := range testCases {

View File

@ -754,3 +754,77 @@ func TestGroup_StaticPanic(t *testing.T) {
})
}
}
func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
var testCases = []struct {
name string
givenCustom404 bool
whenURL string
expectBody interface{}
expectCode int
expectMiddlewareCalled bool
}{
{
name: "ok, custom 404 handler is called with middleware",
givenCustom404: true,
whenURL: "/group/test3",
expectBody: "404 GET /group/*",
expectCode: http.StatusNotFound,
expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added
},
{
name: "ok, default group 404 handler is not called with middleware",
givenCustom404: false,
whenURL: "/group/test3",
expectBody: "404 GET /*",
expectCode: http.StatusNotFound,
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
},
{
name: "ok, (no slash) default group 404 handler is called with middleware",
givenCustom404: false,
whenURL: "/group",
expectBody: "404 GET /*",
expectCode: http.StatusNotFound,
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c Context) error {
return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path())
}
e := New()
e.GET("/test1", okHandler)
e.RouteNotFound("/*", notFoundHandler)
g := e.Group("/group")
g.GET("/test1", okHandler)
middlewareCalled := false
g.Use(func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
middlewareCalled = true
return next(c)
}
})
if tc.givenCustom404 {
g.RouteNotFound("/*", notFoundHandler)
}
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled)
assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectBody, rec.Body.String())
})
}
}

View File

@ -18,9 +18,8 @@ type BodyLimitConfig struct {
type limitedReader struct {
BodyLimitConfig
reader io.ReadCloser
read int64
context echo.Context
reader io.ReadCloser
read int64
}
// BodyLimit returns a BodyLimit middleware.
@ -65,7 +64,7 @@ func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Based on content read
r := pool.Get().(*limitedReader)
r.Reset(c, req.Body)
r.Reset(req.Body)
defer pool.Put(r)
req.Body = r
@ -87,8 +86,7 @@ func (r *limitedReader) Close() error {
return r.reader.Close()
}
func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
func (r *limitedReader) Reset(reader io.ReadCloser) {
r.reader = reader
r.context = context
r.read = 0
}

View File

@ -67,9 +67,6 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
@ -78,7 +75,6 @@ func TestBodyLimitReader(t *testing.T) {
reader := &limitedReader{
BodyLimitConfig: config,
reader: io.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
@ -88,7 +84,7 @@ func TestBodyLimitReader(t *testing.T) {
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(e.NewContext(req, rec), io.NopCloser(bytes.NewReader(hw)))
reader.Reset(io.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)

View File

@ -2,6 +2,7 @@ package middleware
import (
"bufio"
"bytes"
"compress/gzip"
"errors"
"io"
@ -25,12 +26,30 @@ type GzipConfig struct {
// Gzip compression level.
// Optional. Default value -1.
Level int
// Length threshold before gzip compression is applied.
// Optional. Default value 0.
//
// Most of the time you will not need to change the default. Compressing
// a short response might increase the transmitted data because of the
// gzip format overhead. Compressing the response will also consume CPU
// and time on the server and the client (for decompressing). Depending on
// your use case such a threshold might be useful.
//
// See also:
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
MinLength int
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
wroteBody bool
wroteHeader bool
wroteBody bool
minLength int
minLengthExceeded bool
buffer *bytes.Buffer
code int
}
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
@ -54,8 +73,12 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Level == 0 {
config.Level = -1
}
if config.MinLength < 0 {
config.MinLength = 0
}
pool := gzipCompressPool(config)
bpool := bufferPool()
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
@ -66,7 +89,6 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
res := c.Response()
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
i := pool.Get()
w, ok := i.(*gzip.Writer)
if !ok {
@ -74,19 +96,37 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
rw := res.Writer
w.Reset(rw)
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
buf := bpool.Get().(*bytes.Buffer)
buf.Reset()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
defer func() {
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
if !grw.wroteBody {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
res.Header().Del(echo.HeaderContentEncoding)
}
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue #424, #407.
res.Writer = rw
w.Reset(io.Discard)
} else if !grw.minLengthExceeded {
// Write uncompressed response
res.Writer = rw
if grw.wroteHeader {
grw.ResponseWriter.WriteHeader(grw.code)
}
grw.buffer.WriteTo(rw)
w.Reset(io.Discard)
}
w.Close()
bpool.Put(buf)
pool.Put(w)
}()
res.Writer = grw
@ -98,7 +138,11 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
func (w *gzipResponseWriter) WriteHeader(code int) {
w.Header().Del(echo.HeaderContentLength) // Issue #444
w.ResponseWriter.WriteHeader(code)
w.wroteHeader = true
// Delay writing of the header until we know if we'll actually compress the response
w.code = code
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
@ -106,10 +150,40 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) {
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
}
w.wroteBody = true
if !w.minLengthExceeded {
n, err := w.buffer.Write(b)
if w.buffer.Len() >= w.minLength {
w.minLengthExceeded = true
// The minimum length is exceeded, add Content-Encoding header and write the header
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
return w.Writer.Write(w.buffer.Bytes())
}
return n, err
}
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) Flush() {
if !w.minLengthExceeded {
// Enforce compression because we will not know how much more data will come
w.minLengthExceeded = true
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
w.Writer.Write(w.buffer.Bytes())
}
w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
@ -138,3 +212,12 @@ func gzipCompressPool(config GzipConfig) sync.Pool {
},
}
}
func bufferPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
b := &bytes.Buffer{}
return b
},
}
}

View File

@ -3,6 +3,7 @@ package middleware
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"os"
@ -211,6 +212,137 @@ func TestGzipWithStatic(t *testing.T) {
}
}
func TestGzipWithMinLength(t *testing.T) {
e := echo.New()
// Minimal response length
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("foobarfoobar"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) {
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal(t, "foobarfoobar", buf.String())
}
}
func TestGzipWithMinLengthTooShort(t *testing.T) {
e := echo.New()
// Minimal response length
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Body.String(), "test")
}
func TestGzipWithResponseWithoutBody(t *testing.T) {
e := echo.New()
e.Use(Gzip())
e.GET("/", func(c echo.Context) error {
return c.Redirect(http.StatusMovedPermanently, "http://localhost")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
}
func TestGzipWithMinLengthChunked(t *testing.T) {
e := echo.New()
// Gzip chunked
chunkBuf := make([]byte, 5)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
var r *gzip.Reader = nil
c := e.NewContext(req, rec)
next := func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
c.Response().Write([]byte("test\n"))
c.Response().Flush()
// Read the first part of the data
assert.True(t, rec.Flushed)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
var err error
r, err = gzip.NewReader(rec.Body)
assert.NoError(t, err)
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))
// Write and flush the second part of the data
c.Response().Write([]byte("test\n"))
c.Response().Flush()
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))
// Write the final part of the data and return
c.Response().Write([]byte("test"))
return nil
}
err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c)
assert.NoError(t, err)
assert.NotNil(t, r)
buf := new(bytes.Buffer)
buf.ReadFrom(r)
assert.Equal(t, "test", buf.String())
r.Close()
}
func TestGzipWithMinLengthNoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func BenchmarkGzip(b *testing.B) {
e := echo.New()

View File

@ -151,8 +151,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
allowOriginPatterns := []string{}
for _, origin := range config.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.Replace(pattern, "\\*", ".*", -1)
pattern = strings.Replace(pattern, "\\?", ".", -1)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
pattern = "^" + pattern + "$"
allowOriginPatterns = append(allowOriginPatterns, pattern)
}

View File

@ -35,9 +35,9 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
rulesRegex := map[*regexp.Regexp]string{}
for k, v := range rewrite {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*?)", -1)
k = strings.ReplaceAll(k, `\*`, "(.*?)")
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
k = strings.ReplaceAll(k, `\^`, "^")
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v

View File

@ -29,6 +29,33 @@ type ProxyConfig struct {
// Required.
Balancer ProxyBalancer
// RetryCount defines the number of times a failed proxied request should be retried
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
RetryCount int
// RetryFilter defines a function used to determine if a failed request to a
// ProxyTarget should be retried. The RetryFilter will only be called when the number
// of previous retries is less than RetryCount. If the function returns true, the
// request will be retried. The provided error indicates the reason for the request
// failure. When the ProxyTarget is unavailable, the error will be an instance of
// echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
// only called when the request to the target fails, or an internal error in the Proxy
// middleware has occurred. Successful requests that return a non-200 response code cannot
// be retried.
RetryFilter func(c echo.Context, e error) bool
// ErrorHandler defines a function which can be used to return custom errors from
// the Proxy middleware. ErrorHandler is only invoked when there has been
// either an internal error in the Proxy middleware or the ProxyTarget is
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
// when a ProxyTarget returns a non-200 response. In these cases, the response
// is already written so errors cannot be modified. ErrorHandler is only
// invoked after all retry attempts have been exhausted.
ErrorHandler func(c echo.Context, err error) error
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Examples:
@ -99,14 +126,14 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
return
}
defer in.Close()
out, err := net.Dial("tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()
@ -114,7 +141,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
// Write header
err = r.Write(out)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
return
}
@ -128,7 +155,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
go cp(in, out)
err = <-errCh
if err != nil && err != io.EOF {
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL))
}
})
}
@ -192,7 +219,12 @@ func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) {
return b.targets[b.random.Intn(len(b.targets))], nil
}
// Next returns an upstream target using round-robin technique.
// Next returns an upstream target using round-robin technique. In the case
// where a previously failed request is being retried, the round-robin
// balancer will attempt to use the next target relative to the original
// request. If the list of targets held by the balancer is modified while a
// failed request is being retried, it is possible that the balancer will
// return the original failed target.
//
// Note: `nil` is returned in case upstream target list is empty.
func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
@ -203,13 +235,28 @@ func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
} else if len(b.targets) == 1 {
return b.targets[0], nil
}
// reset the index if out of bounds
if b.i >= len(b.targets) {
b.i = 0
var i int
const lastIdxKey = "_round_robin_last_index"
// This request is a retry, start from the index of the previous
// target to ensure we don't attempt to retry the request with
// the same failed target
if c.Get(lastIdxKey) != nil {
i = c.Get(lastIdxKey).(int)
i++
if i >= len(b.targets) {
i = 0
}
} else {
// This is a first time request, use the global index
if b.i >= len(b.targets) {
b.i = 0
}
i = b.i
b.i++
}
t := b.targets[b.i]
b.i++
return t, nil
c.Set(lastIdxKey, i)
return b.targets[i], nil
}
// Proxy returns a Proxy middleware.
@ -239,6 +286,19 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Balancer == nil {
return nil, errors.New("echo proxy middleware requires balancer")
}
if config.RetryFilter == nil {
config.RetryFilter = func(c echo.Context, e error) bool {
if httpErr, ok := e.(*echo.HTTPError); ok {
return httpErr.Code == http.StatusBadGateway
}
return false
}
}
if config.ErrorHandler == nil {
config.ErrorHandler = func(c echo.Context, err error) error {
return err
}
}
if config.Rewrite != nil {
if config.RegexRewrite == nil {
@ -257,14 +317,8 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
req := c.Request()
res := c.Response()
tgt, err := config.Balancer.Next(c)
if err != nil {
return err
}
c.Set(config.ContextKey, tgt)
if err := rewriteURL(config.RegexRewrite, req); err != nil {
return err
return config.ErrorHandler(c, err)
}
// Fix header
@ -280,19 +334,43 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(c, tgt).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
}
if e, ok := c.Get("_error").(error); ok {
err = e
}
retries := config.RetryCount
for {
tgt, err := config.Balancer.Next(c)
if err != nil {
return config.ErrorHandler(c, err)
}
return
c.Set(config.ContextKey, tgt)
//If retrying a failed request, clear any previous errors from
//context here so that balancers have the option to check for
//errors that occurred using previous target
if retries < config.RetryCount {
c.Set("_error", nil)
}
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(c, tgt).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
}
err, hasError := c.Get("_error").(error)
if !hasError {
return nil
}
retry := retries > 0 && config.RetryFilter(c, err)
if !retry {
return config.ErrorHandler(c, err)
}
retries--
}
}
}, nil
}

View File

@ -3,6 +3,7 @@ package middleware
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
@ -426,9 +427,14 @@ func TestFailNextTarget(t *testing.T) {
}
func TestRandomBalancerWithNoTargets(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Assert balancer with empty targets does return `nil` on `Next()`
rb := NewRandomBalancer(nil)
target, err := rb.Next(nil)
target, err := rb.Next(c)
assert.Nil(t, target)
assert.NoError(t, err)
}
@ -436,7 +442,327 @@ func TestRandomBalancerWithNoTargets(t *testing.T) {
func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
// Assert balancer with empty targets does return `nil` on `Next()`
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
target, err := rrb.Next(nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
target, err := rrb.Next(c)
assert.Nil(t, target)
assert.NoError(t, err)
}
func TestProxyRetries(t *testing.T) {
newServer := func(res int) (*url.URL, *httptest.Server) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(res)
}),
)
targetURL, _ := url.Parse(server.URL)
return targetURL, server
}
targetURL, server := newServer(http.StatusOK)
defer server.Close()
goodTarget := &ProxyTarget{
Name: "Good",
URL: targetURL,
}
targetURL, server = newServer(http.StatusBadRequest)
defer server.Close()
goodTargetWith40X := &ProxyTarget{
Name: "Good with 40X",
URL: targetURL,
}
targetURL, _ = url.Parse("http://127.0.0.1:27121")
badTarget := &ProxyTarget{
Name: "Bad",
URL: targetURL,
}
alwaysRetryFilter := func(c echo.Context, e error) bool { return true }
neverRetryFilter := func(c echo.Context, e error) bool { return false }
testCases := []struct {
name string
retryCount int
retryFilters []func(c echo.Context, e error) bool
targets []*ProxyTarget
expectedResponse int
}{
{
name: "retry count 0 does not attempt retry on fail",
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 1 does not attempt retry on success",
retryCount: 1,
targets: []*ProxyTarget{
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "retry count 1 does retry on handler return true",
retryCount: 1,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "retry count 1 does not retry on handler return false",
retryCount: 1,
retryFilters: []func(c echo.Context, e error) bool{
neverRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 2 returns error when no more retries left",
retryCount: 2,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget, //Should never be reached as only 2 retries
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 2 returns error when retries left but handler returns false",
retryCount: 3,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
neverRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget, //Should never be reached as retry handler returns false on 2nd check
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 3 succeeds",
retryCount: 3,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "40x responses are not retried",
retryCount: 1,
targets: []*ProxyTarget{
goodTargetWith40X,
goodTarget,
},
expectedResponse: http.StatusBadRequest,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
retryFilterCall := 0
retryFilter := func(c echo.Context, e error) bool {
if len(tc.retryFilters) == 0 {
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
}
retryFilterCall++
nextRetryFilter := tc.retryFilters[0]
tc.retryFilters = tc.retryFilters[1:]
return nextRetryFilter(c, e)
}
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: NewRoundRobinBalancer(tc.targets),
RetryCount: tc.retryCount,
RetryFilter: retryFilter,
},
))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectedResponse, rec.Code)
if len(tc.retryFilters) > 0 {
assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
}
})
}
}
func TestProxyRetryWithBackendTimeout(t *testing.T) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.ResponseHeaderTimeout = time.Millisecond * 500
timeoutBackend := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.WriteHeader(404)
}),
)
defer timeoutBackend.Close()
timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
goodBackend := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}),
)
defer goodBackend.Close()
goodTargetURL, _ := url.Parse(goodBackend.URL)
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Transport: transport,
Balancer: NewRoundRobinBalancer([]*ProxyTarget{
{
Name: "Timeout",
URL: timeoutTargetURL,
},
{
Name: "Good",
URL: goodTargetURL,
},
}),
RetryCount: 1,
},
))
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, 200, rec.Code)
}()
}
wg.Wait()
}
func TestProxyErrorHandler(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
goodURL, _ := url.Parse(server.URL)
defer server.Close()
goodTarget := &ProxyTarget{
Name: "Good",
URL: goodURL,
}
badURL, _ := url.Parse("http://127.0.0.1:27121")
badTarget := &ProxyTarget{
Name: "Bad",
URL: badURL,
}
transformedError := errors.New("a new error")
testCases := []struct {
name string
target *ProxyTarget
errorHandler func(c echo.Context, e error) error
expectFinalError func(t *testing.T, err error)
}{
{
name: "Error handler not invoked when request success",
target: goodTarget,
errorHandler: func(c echo.Context, e error) error {
assert.FailNow(t, "error handler should not be invoked")
return e
},
},
{
name: "Error handler invoked when request fails",
target: badTarget,
errorHandler: func(c echo.Context, e error) error {
httpErr, ok := e.(*echo.HTTPError)
assert.True(t, ok, "expected http error to be passed to handler")
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
return transformedError
},
expectFinalError: func(t *testing.T, err error) {
assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
ErrorHandler: tc.errorHandler,
},
))
errorHandlerCalled := false
dheh := echo.DefaultHTTPErrorHandler(false)
e.HTTPErrorHandler = func(c echo.Context, err error) {
errorHandlerCalled = true
tc.expectFinalError(t, err)
dheh(c, err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
if !errorHandlerCalled && tc.expectFinalError != nil {
t.Fatalf("error handler was not called")
}
})
}
}

View File

@ -156,6 +156,8 @@ type RateLimiterMemoryStore struct {
burst int
expiresIn time.Duration
lastCleanup time.Time
timeNow func() time.Time
}
// Visitor signifies a unique user's limiter details
@ -215,7 +217,8 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s
store.burst = int(config.Rate)
}
store.visitors = make(map[string]*Visitor)
store.lastCleanup = now()
store.timeNow = time.Now
store.lastCleanup = store.timeNow()
return
}
@ -240,12 +243,13 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
store.visitors[identifier] = limiter
}
limiter.lastSeen = now()
if now().Sub(store.lastCleanup) > store.expiresIn {
now := store.timeNow()
limiter.lastSeen = now
if now.Sub(store.lastCleanup) > store.expiresIn {
store.cleanupStaleVisitors()
}
store.mutex.Unlock()
return limiter.AllowN(now(), 1), nil
return limiter.AllowN(store.timeNow(), 1), nil
}
/*
@ -254,14 +258,9 @@ of users who haven't visited again after the configured expiry time has elapsed
*/
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
for id, visitor := range store.visitors {
if now().Sub(visitor.lastSeen) > store.expiresIn {
if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn {
delete(store.visitors, id)
}
}
store.lastCleanup = now()
store.lastCleanup = store.timeNow()
}
/*
actual time method which is mocked in test file
*/
var now = time.Now

View File

@ -2,7 +2,6 @@ package middleware
import (
"errors"
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
@ -361,7 +360,7 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) {
for i, tc := range testCases {
t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
now = func() time.Time {
inMemoryStore.timeNow = func() time.Time {
return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
}
allowed, _ := inMemoryStore.Allow(tc.id)
@ -371,24 +370,22 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) {
func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
now = time.Now
fmt.Println(now())
inMemoryStore.visitors = map[string]*Visitor{
"A": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: now(),
lastSeen: time.Now(),
},
"B": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: now().Add(-1 * time.Minute),
lastSeen: time.Now().Add(-1 * time.Minute),
},
"C": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: now().Add(-5 * time.Minute),
lastSeen: time.Now().Add(-5 * time.Minute),
},
"D": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: now().Add(-10 * time.Minute),
lastSeen: time.Now().Add(-10 * time.Minute),
},
}

View File

@ -224,7 +224,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
now = time.Now
now := time.Now
if config.timeNow != nil {
now = config.timeNow
}
@ -256,7 +256,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
config.BeforeNextFunc(c)
}
err := next(c)
if config.HandleError {
if err != nil && config.HandleError {
// When global error handler writes the error to the client the Response gets "committed". This state can be
// checked with `c.Response().Committed` field.
c.Error(err)

View File

@ -1,9 +1,11 @@
package middleware
import (
"bufio"
"crypto/rand"
"fmt"
"io"
"strings"
"sync"
)
const (
@ -77,17 +79,38 @@ func createRandomStringGenerator(length uint8) func() string {
}
}
func randomString(length uint8) string {
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
// https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls
var randomReaderPool = sync.Pool{New: func() interface{} {
return bufio.NewReader(rand.Reader)
}}
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
// we are out of random. let the request fail
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
const randomStringCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
const randomStringCharsetLen = 52 // len(randomStringCharset)
const randomStringMaxByte = 255 - (256 % randomStringCharsetLen)
func randomString(length uint8) string {
reader := randomReaderPool.Get().(*bufio.Reader)
defer randomReaderPool.Put(reader)
b := make([]byte, length)
r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times
var i uint8 = 0
for {
_, err := io.ReadFull(reader, r)
if err != nil {
panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)")
}
for _, rb := range r {
if rb > randomStringMaxByte {
// Skip this number to avoid bias.
continue
}
b[i] = randomStringCharset[rb%randomStringCharsetLen]
i++
if i == length {
return string(b)
}
}
}
for i, b := range bytes {
bytes[i] = charset[b%byte(len(charset))]
}
return string(bytes)
}

View File

@ -2,6 +2,7 @@ package middleware
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"testing"
)
@ -129,3 +130,31 @@ func TestRandomString(t *testing.T) {
})
}
}
func TestRandomStringBias(t *testing.T) {
t.Parallel()
const slen = 33
const loop = 100000
counts := make(map[rune]int)
var count int64
for i := 0; i < loop; i++ {
s := randomString(slen)
require.Equal(t, slen, len(s))
for _, b := range s {
counts[b]++
count++
}
}
require.Equal(t, randomStringCharsetLen, len(counts))
avg := float64(count) / float64(len(counts))
for k, n := range counts {
diff := float64(n) / avg
if diff < 0.95 || diff > 1.05 {
t.Errorf("Bias on '%c': expected average %f, got %d", k, avg, n)
}
}
}

View File

@ -95,6 +95,13 @@ func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.Writer.(http.Hijacker).Hijack()
}
// Unwrap returns the original http.ResponseWriter.
// ResponseController can be used to access the original http.ResponseWriter.
// See [https://go.dev/blog/go1.20]
func (r *Response) Unwrap() http.ResponseWriter {
return r.Writer
}
func (r *Response) reset(w http.ResponseWriter) {
r.beforeFuncs = nil
r.afterFuncs = nil

View File

@ -72,3 +72,11 @@ func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestResponse_Unwrap(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
assert.Equal(t, rec, res.Unwrap())
}

View File

@ -81,7 +81,12 @@ func (r routeInfo) Reverse(params ...interface{}) string {
ln := len(params)
n := 0
for i, l := 0, len(r.path); i < l; i++ {
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
hasBackslash := r.path[i] == '\\'
if hasBackslash && i+1 < l && r.path[i+1] == ':' {
i++ // backslash before colon escapes that colon. in that case skip backslash
}
if n < ln && (r.path[i] == anyLabel || (!hasBackslash && r.path[i] == paramLabel)) {
// in case of `*` wildcard or `:` (unescaped colon) param we replace everything till next slash or end of path
for ; i < l && r.path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))

View File

@ -421,3 +421,101 @@ func TestRoutes_FilterByName(t *testing.T) {
})
}
}
func TestRouteInfo_Reverse(t *testing.T) {
var testCases = []struct {
name string
givenParams []string
givenPath string
whenParams []interface{}
expect string
}{
{
name: "ok,static with no params",
givenPath: "/static",
expect: "/static",
},
{
name: "ok,static with non existent param",
givenParams: []string{"missing param"},
givenPath: "/static",
whenParams: []interface{}{"missing param"},
expect: "/static",
},
{
name: "ok, wildcard with no params",
givenPath: "/static/*",
expect: "/static/*",
},
{
name: "ok, wildcard with params",
givenParams: []string{"foo.txt"},
givenPath: "/static/*",
whenParams: []interface{}{"foo.txt"},
expect: "/static/foo.txt",
},
{
name: "ok, single param without param",
givenPath: "/params/:foo",
expect: "/params/:foo",
},
{
name: "ok, single param with param",
givenParams: []string{"one"},
givenPath: "/params/:foo",
whenParams: []interface{}{"one"},
expect: "/params/one",
},
{
name: "ok, multi param without params",
givenPath: "/params/:foo/bar/:qux",
expect: "/params/:foo/bar/:qux",
},
{
name: "ok, multi param with one param",
givenParams: []string{"one"},
givenPath: "/params/:foo/bar/:qux",
whenParams: []interface{}{"one"},
expect: "/params/one/bar/:qux",
},
{
name: "ok, multi param with all params",
givenParams: []string{"one", "two"},
givenPath: "/params/:foo/bar/:qux",
whenParams: []interface{}{"one", "two"},
expect: "/params/one/bar/two",
},
{
name: "ok, multi param + wildcard with all params",
givenParams: []string{"one", "two", "three"},
givenPath: "/params/:foo/bar/:qux/*",
whenParams: []interface{}{"one", "two", "three"},
expect: "/params/one/bar/two/three",
},
{
name: "ok, backslash is not escaped",
givenParams: []string{"test"},
givenPath: "/a\\b/:x",
whenParams: []interface{}{"test"},
expect: `/a\b/test`,
},
{
name: "ok, escaped colon verbs",
givenParams: []string{"PATCH"},
givenPath: "/params\\::customVerb",
whenParams: []interface{}{"PATCH"},
expect: `/params:PATCH`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := routeInfo{
path: tc.givenPath,
params: tc.givenParams,
name: tc.expect,
}
assert.Equal(t, tc.expect, r.Reverse(tc.whenParams...))
})
}
}

View File

@ -96,6 +96,7 @@ type RouteInfo interface {
Name() string
Params() []string
// Reverse reverses route to URL string by replacing path parameters with given params values.
Reverse(params ...interface{}) string
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
@ -1066,7 +1067,7 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
}
}
if r.unescapePathParamValues && currentNode.kind != staticKind {
if r.unescapePathParamValues {
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
for i, p := range *pathParams {
tmpVal, err := url.PathUnescape(p.Value)

View File

@ -3053,6 +3053,15 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) {
{Name: "*", Value: " /with space"},
},
},
{
name: "ok, ending with static node, unescape = true",
givenUnescapePathParamValues: true,
whenURL: "/fourth/%20%2Fwith%20space/static",
expectPath: "/fourth/:id/static",
expectPathParams: PathParams{
{Name: "id", Value: " /with space"},
},
},
{
name: "ok, unescape = false",
givenUnescapePathParamValues: false,
@ -3080,6 +3089,8 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) {
assert.NoError(t, err)
_, err = router.Add(Route{Method: http.MethodGet, Path: "/third/*", Handler: handlerFunc})
assert.NoError(t, err)
_, err = router.Add(Route{Method: http.MethodGet, Path: "/fourth/:id/static", Handler: handlerFunc})
assert.NoError(t, err)
target, _ := url.Parse(tc.whenURL)
req := httptest.NewRequest(http.MethodGet, target.String(), nil)