mirror of
https://github.com/labstack/echo.git
synced 2024-12-20 19:52:47 +02:00
Changes from master (from 5b36ce3612
to b3ec8e0fdd
)
This commit is contained in:
parent
c2af0cf5a8
commit
ec5b858dab
95
CHANGELOG.md
95
CHANGELOG.md
@ -1,14 +1,101 @@
|
||||
# Changelog
|
||||
|
||||
|
||||
## v4.10.0 - 2022-xx-xx
|
||||
## v4.11.1 - 2023-07-16
|
||||
|
||||
**Fixes**
|
||||
|
||||
* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481)
|
||||
|
||||
|
||||
## v4.11.0 - 2023-07-14
|
||||
|
||||
|
||||
**Fixes**
|
||||
|
||||
* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409)
|
||||
* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411)
|
||||
* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456)
|
||||
* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477)
|
||||
|
||||
|
||||
**Enhancements**
|
||||
|
||||
* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410)
|
||||
* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424)
|
||||
* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425)
|
||||
* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429)
|
||||
* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416)
|
||||
* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436)
|
||||
* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444)
|
||||
* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414)
|
||||
* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452)
|
||||
* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267)
|
||||
* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475)
|
||||
|
||||
|
||||
## v4.10.2 - 2023-02-22
|
||||
|
||||
**Security**
|
||||
|
||||
This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
|
||||
several vulnerabilities fixed in these libraries.
|
||||
* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406)
|
||||
* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405)
|
||||
|
||||
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
|
||||
**Enhancements**
|
||||
|
||||
* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277)
|
||||
|
||||
|
||||
## v4.10.1 - 2023-02-19
|
||||
|
||||
**Security**
|
||||
|
||||
* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402)
|
||||
|
||||
|
||||
**Enhancements**
|
||||
|
||||
* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377)
|
||||
* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385)
|
||||
* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380)
|
||||
* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394)
|
||||
|
||||
|
||||
## v4.10.0 - 2022-12-27
|
||||
|
||||
**Security**
|
||||
|
||||
* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead.
|
||||
|
||||
JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using
|
||||
which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain.
|
||||
|
||||
* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
|
||||
several vulnerabilities fixed in these libraries.
|
||||
|
||||
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
|
||||
|
||||
|
||||
**Enhancements**
|
||||
|
||||
* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305)
|
||||
* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336)
|
||||
* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316)
|
||||
* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338)
|
||||
* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315)
|
||||
* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329)
|
||||
* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340)
|
||||
* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342)
|
||||
* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343)
|
||||
* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345)
|
||||
* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182)
|
||||
* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350)
|
||||
* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341)
|
||||
* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162)
|
||||
* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358)
|
||||
* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362)
|
||||
* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366)
|
||||
* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337)
|
||||
|
||||
|
||||
## v4.9.1 - 2022-10-12
|
||||
|
13
README.md
13
README.md
@ -3,7 +3,7 @@
|
||||
[![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge)
|
||||
[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4)
|
||||
[![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo)
|
||||
[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo)
|
||||
[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions)
|
||||
[![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo)
|
||||
[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions)
|
||||
[![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack)
|
||||
@ -11,12 +11,12 @@
|
||||
|
||||
## Supported Go versions
|
||||
|
||||
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
|
||||
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with
|
||||
older versions.
|
||||
|
||||
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
|
||||
Therefore a Go version capable of understanding /vN suffixed imports is required:
|
||||
|
||||
|
||||
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
|
||||
way of using Echo going forward.
|
||||
|
||||
@ -85,9 +85,9 @@ func hello(c echo.Context) error {
|
||||
|
||||
Following list of middleware is maintained by Echo team.
|
||||
|
||||
| Repository | Description |
|
||||
|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
|
||||
| Repository | Description |
|
||||
|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
|
||||
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
|
||||
|
||||
# Third-party middleware repositories
|
||||
@ -101,6 +101,7 @@ of middlewares in this list.
|
||||
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
|
||||
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
|
||||
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
|
||||
| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface.
|
||||
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
|
||||
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
|
||||
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
|
||||
|
16
binder.go
16
binder.go
@ -1236,7 +1236,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim
|
||||
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, false, time.Second)
|
||||
}
|
||||
@ -1247,7 +1247,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder
|
||||
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, true, time.Second)
|
||||
}
|
||||
@ -1257,7 +1257,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi
|
||||
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, false, time.Millisecond)
|
||||
}
|
||||
@ -1268,7 +1268,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB
|
||||
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, true, time.Millisecond)
|
||||
}
|
||||
@ -1280,8 +1280,8 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va
|
||||
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||
func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
|
||||
}
|
||||
@ -1294,8 +1294,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi
|
||||
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||
func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, true, time.Nanosecond)
|
||||
}
|
||||
|
@ -113,8 +113,8 @@ type Context interface {
|
||||
// Set saves data in the context.
|
||||
Set(key string, val interface{})
|
||||
|
||||
// Bind binds the request body into provided type `i`. The default binder
|
||||
// does it based on Content-Type header.
|
||||
// Bind binds path params, query params and the request body into provided type `i`. The default binder
|
||||
// binds body based on Content-Type header.
|
||||
Bind(i interface{}) error
|
||||
|
||||
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
|
||||
@ -536,8 +536,8 @@ func (c *DefaultContext) Set(key string, val interface{}) {
|
||||
c.store[key] = val
|
||||
}
|
||||
|
||||
// Bind binds the request body into provided type `i`. The default binder
|
||||
// does it based on Content-Type header.
|
||||
// Bind binds path params, query params and the request body into provided type `i`. The default binder
|
||||
// binds body based on Content-Type header.
|
||||
func (c *DefaultContext) Bind(i interface{}) error {
|
||||
return c.echo.Binder.Bind(c, i)
|
||||
}
|
||||
|
8
echo.go
8
echo.go
@ -40,6 +40,7 @@ package echo
|
||||
|
||||
import (
|
||||
stdContext "context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -336,12 +337,17 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
|
||||
// Issue #1426
|
||||
code := he.Code
|
||||
message := he.Message
|
||||
if m, ok := he.Message.(string); ok {
|
||||
switch m := he.Message.(type) {
|
||||
case string:
|
||||
if exposeError {
|
||||
message = Map{"message": m, "error": err.Error()}
|
||||
} else {
|
||||
message = Map{"message": m}
|
||||
}
|
||||
case json.Marshaler:
|
||||
// do nothing - this type knows how to format itself to JSON
|
||||
case error:
|
||||
message = Map{"message": m.Error()}
|
||||
}
|
||||
|
||||
// Send response
|
||||
|
25
echo_test.go
25
echo_test.go
@ -1198,6 +1198,18 @@ func request(method, path string, e *Echo) (int, string) {
|
||||
return rec.Code, rec.Body.String()
|
||||
}
|
||||
|
||||
type customError struct {
|
||||
s string
|
||||
}
|
||||
|
||||
func (ce *customError) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil
|
||||
}
|
||||
|
||||
func (ce *customError) Error() string {
|
||||
return ce.s
|
||||
}
|
||||
|
||||
func TestDefaultHTTPErrorHandler(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
@ -1263,6 +1275,19 @@ func TestDefaultHTTPErrorHandler(t *testing.T) {
|
||||
expectStatus: http.StatusInternalServerError,
|
||||
expectBody: ``,
|
||||
},
|
||||
{
|
||||
name: "ok, custom error implement MarshalJSON",
|
||||
whenMethod: http.MethodGet,
|
||||
whenError: NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}),
|
||||
expectStatus: http.StatusBadRequest,
|
||||
expectBody: "{\"x\":\"custom error msg\"}\n",
|
||||
},
|
||||
{
|
||||
name: "with Debug=false when httpError contains an error",
|
||||
whenError: NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")),
|
||||
expectStatus: http.StatusBadRequest,
|
||||
expectBody: "{\"message\":\"error in httperror\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
@ -754,3 +754,77 @@ func TestGroup_StaticPanic(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenCustom404 bool
|
||||
whenURL string
|
||||
expectBody interface{}
|
||||
expectCode int
|
||||
expectMiddlewareCalled bool
|
||||
}{
|
||||
{
|
||||
name: "ok, custom 404 handler is called with middleware",
|
||||
givenCustom404: true,
|
||||
whenURL: "/group/test3",
|
||||
expectBody: "404 GET /group/*",
|
||||
expectCode: http.StatusNotFound,
|
||||
expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added
|
||||
},
|
||||
{
|
||||
name: "ok, default group 404 handler is not called with middleware",
|
||||
givenCustom404: false,
|
||||
whenURL: "/group/test3",
|
||||
expectBody: "404 GET /*",
|
||||
expectCode: http.StatusNotFound,
|
||||
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
|
||||
},
|
||||
{
|
||||
name: "ok, (no slash) default group 404 handler is called with middleware",
|
||||
givenCustom404: false,
|
||||
whenURL: "/group",
|
||||
expectBody: "404 GET /*",
|
||||
expectCode: http.StatusNotFound,
|
||||
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
okHandler := func(c Context) error {
|
||||
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
|
||||
}
|
||||
notFoundHandler := func(c Context) error {
|
||||
return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path())
|
||||
}
|
||||
|
||||
e := New()
|
||||
e.GET("/test1", okHandler)
|
||||
e.RouteNotFound("/*", notFoundHandler)
|
||||
|
||||
g := e.Group("/group")
|
||||
g.GET("/test1", okHandler)
|
||||
|
||||
middlewareCalled := false
|
||||
g.Use(func(next HandlerFunc) HandlerFunc {
|
||||
return func(c Context) error {
|
||||
middlewareCalled = true
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
if tc.givenCustom404 {
|
||||
g.RouteNotFound("/*", notFoundHandler)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled)
|
||||
assert.Equal(t, tc.expectCode, rec.Code)
|
||||
assert.Equal(t, tc.expectBody, rec.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -18,9 +18,8 @@ type BodyLimitConfig struct {
|
||||
|
||||
type limitedReader struct {
|
||||
BodyLimitConfig
|
||||
reader io.ReadCloser
|
||||
read int64
|
||||
context echo.Context
|
||||
reader io.ReadCloser
|
||||
read int64
|
||||
}
|
||||
|
||||
// BodyLimit returns a BodyLimit middleware.
|
||||
@ -65,7 +64,7 @@ func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
|
||||
// Based on content read
|
||||
r := pool.Get().(*limitedReader)
|
||||
r.Reset(c, req.Body)
|
||||
r.Reset(req.Body)
|
||||
defer pool.Put(r)
|
||||
req.Body = r
|
||||
|
||||
@ -87,8 +86,7 @@ func (r *limitedReader) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
|
||||
func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
|
||||
func (r *limitedReader) Reset(reader io.ReadCloser) {
|
||||
r.reader = reader
|
||||
r.context = context
|
||||
r.read = 0
|
||||
}
|
||||
|
@ -67,9 +67,6 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
|
||||
|
||||
func TestBodyLimitReader(t *testing.T) {
|
||||
hw := []byte("Hello, World!")
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
config := BodyLimitConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
@ -78,7 +75,6 @@ func TestBodyLimitReader(t *testing.T) {
|
||||
reader := &limitedReader{
|
||||
BodyLimitConfig: config,
|
||||
reader: io.NopCloser(bytes.NewReader(hw)),
|
||||
context: e.NewContext(req, rec),
|
||||
}
|
||||
|
||||
// read all should return ErrStatusRequestEntityTooLarge
|
||||
@ -88,7 +84,7 @@ func TestBodyLimitReader(t *testing.T) {
|
||||
|
||||
// reset reader and read two bytes must succeed
|
||||
bt := make([]byte, 2)
|
||||
reader.Reset(e.NewContext(req, rec), io.NopCloser(bytes.NewReader(hw)))
|
||||
reader.Reset(io.NopCloser(bytes.NewReader(hw)))
|
||||
n, err := reader.Read(bt)
|
||||
assert.Equal(t, 2, n)
|
||||
assert.Equal(t, nil, err)
|
||||
|
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
@ -25,12 +26,30 @@ type GzipConfig struct {
|
||||
// Gzip compression level.
|
||||
// Optional. Default value -1.
|
||||
Level int
|
||||
|
||||
// Length threshold before gzip compression is applied.
|
||||
// Optional. Default value 0.
|
||||
//
|
||||
// Most of the time you will not need to change the default. Compressing
|
||||
// a short response might increase the transmitted data because of the
|
||||
// gzip format overhead. Compressing the response will also consume CPU
|
||||
// and time on the server and the client (for decompressing). Depending on
|
||||
// your use case such a threshold might be useful.
|
||||
//
|
||||
// See also:
|
||||
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
|
||||
MinLength int
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
wroteBody bool
|
||||
wroteHeader bool
|
||||
wroteBody bool
|
||||
minLength int
|
||||
minLengthExceeded bool
|
||||
buffer *bytes.Buffer
|
||||
code int
|
||||
}
|
||||
|
||||
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||
@ -54,8 +73,12 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Level == 0 {
|
||||
config.Level = -1
|
||||
}
|
||||
if config.MinLength < 0 {
|
||||
config.MinLength = 0
|
||||
}
|
||||
|
||||
pool := gzipCompressPool(config)
|
||||
bpool := bufferPool()
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@ -66,7 +89,6 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
res := c.Response()
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
|
||||
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
|
||||
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
||||
i := pool.Get()
|
||||
w, ok := i.(*gzip.Writer)
|
||||
if !ok {
|
||||
@ -74,19 +96,37 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
}
|
||||
rw := res.Writer
|
||||
w.Reset(rw)
|
||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
|
||||
buf := bpool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
|
||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
|
||||
defer func() {
|
||||
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
|
||||
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
|
||||
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
|
||||
if !grw.wroteBody {
|
||||
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
|
||||
res.Header().Del(echo.HeaderContentEncoding)
|
||||
}
|
||||
if grw.wroteHeader {
|
||||
rw.WriteHeader(grw.code)
|
||||
}
|
||||
// We have to reset response to it's pristine state when
|
||||
// nothing is written to body or error is returned.
|
||||
// See issue #424, #407.
|
||||
res.Writer = rw
|
||||
w.Reset(io.Discard)
|
||||
} else if !grw.minLengthExceeded {
|
||||
// Write uncompressed response
|
||||
res.Writer = rw
|
||||
if grw.wroteHeader {
|
||||
grw.ResponseWriter.WriteHeader(grw.code)
|
||||
}
|
||||
grw.buffer.WriteTo(rw)
|
||||
w.Reset(io.Discard)
|
||||
}
|
||||
w.Close()
|
||||
bpool.Put(buf)
|
||||
pool.Put(w)
|
||||
}()
|
||||
res.Writer = grw
|
||||
@ -98,7 +138,11 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
|
||||
func (w *gzipResponseWriter) WriteHeader(code int) {
|
||||
w.Header().Del(echo.HeaderContentLength) // Issue #444
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
|
||||
w.wroteHeader = true
|
||||
|
||||
// Delay writing of the header until we know if we'll actually compress the response
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
@ -106,10 +150,40 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
|
||||
}
|
||||
w.wroteBody = true
|
||||
|
||||
if !w.minLengthExceeded {
|
||||
n, err := w.buffer.Write(b)
|
||||
|
||||
if w.buffer.Len() >= w.minLength {
|
||||
w.minLengthExceeded = true
|
||||
|
||||
// The minimum length is exceeded, add Content-Encoding header and write the header
|
||||
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
return w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return w.Writer.Write(b)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Flush() {
|
||||
if !w.minLengthExceeded {
|
||||
// Enforce compression because we will not know how much more data will come
|
||||
w.minLengthExceeded = true
|
||||
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
w.Writer.(*gzip.Writer).Flush()
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
@ -138,3 +212,12 @@ func gzipCompressPool(config GzipConfig) sync.Pool {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func bufferPool() sync.Pool {
|
||||
return sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := &bytes.Buffer{}
|
||||
return b
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@ -211,6 +212,137 @@ func TestGzipWithStatic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipWithMinLength(t *testing.T) {
|
||||
e := echo.New()
|
||||
// Minimal response length
|
||||
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
c.Response().Write([]byte("foobarfoobar"))
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
r, err := gzip.NewReader(rec.Body)
|
||||
if assert.NoError(t, err) {
|
||||
buf := new(bytes.Buffer)
|
||||
defer r.Close()
|
||||
buf.ReadFrom(r)
|
||||
assert.Equal(t, "foobarfoobar", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipWithMinLengthTooShort(t *testing.T) {
|
||||
e := echo.New()
|
||||
// Minimal response length
|
||||
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test"))
|
||||
return nil
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
|
||||
assert.Contains(t, rec.Body.String(), "test")
|
||||
}
|
||||
|
||||
func TestGzipWithResponseWithoutBody(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.Use(Gzip())
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.Redirect(http.StatusMovedPermanently, "http://localhost")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
|
||||
}
|
||||
|
||||
func TestGzipWithMinLengthChunked(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
// Gzip chunked
|
||||
chunkBuf := make([]byte, 5)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
var r *gzip.Reader = nil
|
||||
|
||||
c := e.NewContext(req, rec)
|
||||
next := func(c echo.Context) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Transfer-Encoding", "chunked")
|
||||
|
||||
// Write and flush the first part of the data
|
||||
c.Response().Write([]byte("test\n"))
|
||||
c.Response().Flush()
|
||||
|
||||
// Read the first part of the data
|
||||
assert.True(t, rec.Flushed)
|
||||
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
|
||||
var err error
|
||||
r, err = gzip.NewReader(rec.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = io.ReadFull(r, chunkBuf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test\n", string(chunkBuf))
|
||||
|
||||
// Write and flush the second part of the data
|
||||
c.Response().Write([]byte("test\n"))
|
||||
c.Response().Flush()
|
||||
|
||||
_, err = io.ReadFull(r, chunkBuf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test\n", string(chunkBuf))
|
||||
|
||||
// Write the final part of the data and return
|
||||
c.Response().Write([]byte("test"))
|
||||
return nil
|
||||
}
|
||||
err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
buf.ReadFrom(r)
|
||||
assert.Equal(t, "test", buf.String())
|
||||
|
||||
r.Close()
|
||||
}
|
||||
|
||||
func TestGzipWithMinLengthNoContent(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
})
|
||||
if assert.NoError(t, h(c)) {
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
|
||||
assert.Equal(t, 0, len(rec.Body.Bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGzip(b *testing.B) {
|
||||
e := echo.New()
|
||||
|
||||
|
@ -151,8 +151,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
allowOriginPatterns := []string{}
|
||||
for _, origin := range config.AllowOrigins {
|
||||
pattern := regexp.QuoteMeta(origin)
|
||||
pattern = strings.Replace(pattern, "\\*", ".*", -1)
|
||||
pattern = strings.Replace(pattern, "\\?", ".", -1)
|
||||
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
|
||||
pattern = strings.ReplaceAll(pattern, "\\?", ".")
|
||||
pattern = "^" + pattern + "$"
|
||||
allowOriginPatterns = append(allowOriginPatterns, pattern)
|
||||
}
|
||||
|
@ -35,9 +35,9 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
|
||||
rulesRegex := map[*regexp.Regexp]string{}
|
||||
for k, v := range rewrite {
|
||||
k = regexp.QuoteMeta(k)
|
||||
k = strings.Replace(k, `\*`, "(.*?)", -1)
|
||||
k = strings.ReplaceAll(k, `\*`, "(.*?)")
|
||||
if strings.HasPrefix(k, `\^`) {
|
||||
k = strings.Replace(k, `\^`, "^", -1)
|
||||
k = strings.ReplaceAll(k, `\^`, "^")
|
||||
}
|
||||
k = k + "$"
|
||||
rulesRegex[regexp.MustCompile(k)] = v
|
||||
|
@ -29,6 +29,33 @@ type ProxyConfig struct {
|
||||
// Required.
|
||||
Balancer ProxyBalancer
|
||||
|
||||
// RetryCount defines the number of times a failed proxied request should be retried
|
||||
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
|
||||
RetryCount int
|
||||
|
||||
// RetryFilter defines a function used to determine if a failed request to a
|
||||
// ProxyTarget should be retried. The RetryFilter will only be called when the number
|
||||
// of previous retries is less than RetryCount. If the function returns true, the
|
||||
// request will be retried. The provided error indicates the reason for the request
|
||||
// failure. When the ProxyTarget is unavailable, the error will be an instance of
|
||||
// echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
|
||||
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
|
||||
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
|
||||
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
|
||||
// only called when the request to the target fails, or an internal error in the Proxy
|
||||
// middleware has occurred. Successful requests that return a non-200 response code cannot
|
||||
// be retried.
|
||||
RetryFilter func(c echo.Context, e error) bool
|
||||
|
||||
// ErrorHandler defines a function which can be used to return custom errors from
|
||||
// the Proxy middleware. ErrorHandler is only invoked when there has been
|
||||
// either an internal error in the Proxy middleware or the ProxyTarget is
|
||||
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
|
||||
// when a ProxyTarget returns a non-200 response. In these cases, the response
|
||||
// is already written so errors cannot be modified. ErrorHandler is only
|
||||
// invoked after all retry attempts have been exhausted.
|
||||
ErrorHandler func(c echo.Context, err error) error
|
||||
|
||||
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Examples:
|
||||
@ -99,14 +126,14 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
in, _, err := c.Response().Hijack()
|
||||
if err != nil {
|
||||
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
|
||||
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
|
||||
return
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := net.Dial("tcp", t.URL.Host)
|
||||
if err != nil {
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
|
||||
return
|
||||
}
|
||||
defer out.Close()
|
||||
@ -114,7 +141,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
||||
// Write header
|
||||
err = r.Write(out)
|
||||
if err != nil {
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
|
||||
return
|
||||
}
|
||||
|
||||
@ -128,7 +155,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
||||
go cp(in, out)
|
||||
err = <-errCh
|
||||
if err != nil && err != io.EOF {
|
||||
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
|
||||
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -192,7 +219,12 @@ func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
||||
return b.targets[b.random.Intn(len(b.targets))], nil
|
||||
}
|
||||
|
||||
// Next returns an upstream target using round-robin technique.
|
||||
// Next returns an upstream target using round-robin technique. In the case
|
||||
// where a previously failed request is being retried, the round-robin
|
||||
// balancer will attempt to use the next target relative to the original
|
||||
// request. If the list of targets held by the balancer is modified while a
|
||||
// failed request is being retried, it is possible that the balancer will
|
||||
// return the original failed target.
|
||||
//
|
||||
// Note: `nil` is returned in case upstream target list is empty.
|
||||
func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
||||
@ -203,13 +235,28 @@ func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
||||
} else if len(b.targets) == 1 {
|
||||
return b.targets[0], nil
|
||||
}
|
||||
// reset the index if out of bounds
|
||||
if b.i >= len(b.targets) {
|
||||
b.i = 0
|
||||
|
||||
var i int
|
||||
const lastIdxKey = "_round_robin_last_index"
|
||||
// This request is a retry, start from the index of the previous
|
||||
// target to ensure we don't attempt to retry the request with
|
||||
// the same failed target
|
||||
if c.Get(lastIdxKey) != nil {
|
||||
i = c.Get(lastIdxKey).(int)
|
||||
i++
|
||||
if i >= len(b.targets) {
|
||||
i = 0
|
||||
}
|
||||
} else {
|
||||
// This is a first time request, use the global index
|
||||
if b.i >= len(b.targets) {
|
||||
b.i = 0
|
||||
}
|
||||
i = b.i
|
||||
b.i++
|
||||
}
|
||||
t := b.targets[b.i]
|
||||
b.i++
|
||||
return t, nil
|
||||
c.Set(lastIdxKey, i)
|
||||
return b.targets[i], nil
|
||||
}
|
||||
|
||||
// Proxy returns a Proxy middleware.
|
||||
@ -239,6 +286,19 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Balancer == nil {
|
||||
return nil, errors.New("echo proxy middleware requires balancer")
|
||||
}
|
||||
if config.RetryFilter == nil {
|
||||
config.RetryFilter = func(c echo.Context, e error) bool {
|
||||
if httpErr, ok := e.(*echo.HTTPError); ok {
|
||||
return httpErr.Code == http.StatusBadGateway
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
if config.ErrorHandler == nil {
|
||||
config.ErrorHandler = func(c echo.Context, err error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if config.Rewrite != nil {
|
||||
if config.RegexRewrite == nil {
|
||||
@ -257,14 +317,8 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
tgt, err := config.Balancer.Next(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Set(config.ContextKey, tgt)
|
||||
|
||||
if err := rewriteURL(config.RegexRewrite, req); err != nil {
|
||||
return err
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
// Fix header
|
||||
@ -280,19 +334,43 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
|
||||
}
|
||||
|
||||
// Proxy
|
||||
switch {
|
||||
case c.IsWebSocket():
|
||||
proxyRaw(c, tgt).ServeHTTP(res, req)
|
||||
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
|
||||
default:
|
||||
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
|
||||
}
|
||||
if e, ok := c.Get("_error").(error); ok {
|
||||
err = e
|
||||
}
|
||||
retries := config.RetryCount
|
||||
for {
|
||||
tgt, err := config.Balancer.Next(c)
|
||||
if err != nil {
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
return
|
||||
c.Set(config.ContextKey, tgt)
|
||||
|
||||
//If retrying a failed request, clear any previous errors from
|
||||
//context here so that balancers have the option to check for
|
||||
//errors that occurred using previous target
|
||||
if retries < config.RetryCount {
|
||||
c.Set("_error", nil)
|
||||
}
|
||||
|
||||
// Proxy
|
||||
switch {
|
||||
case c.IsWebSocket():
|
||||
proxyRaw(c, tgt).ServeHTTP(res, req)
|
||||
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
|
||||
default:
|
||||
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
|
||||
}
|
||||
|
||||
err, hasError := c.Get("_error").(error)
|
||||
if !hasError {
|
||||
return nil
|
||||
}
|
||||
|
||||
retry := retries > 0 && config.RetryFilter(c, err)
|
||||
if !retry {
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
retries--
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -426,9 +427,14 @@ func TestFailNextTarget(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRandomBalancerWithNoTargets(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Assert balancer with empty targets does return `nil` on `Next()`
|
||||
rb := NewRandomBalancer(nil)
|
||||
target, err := rb.Next(nil)
|
||||
target, err := rb.Next(c)
|
||||
assert.Nil(t, target)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@ -436,7 +442,327 @@ func TestRandomBalancerWithNoTargets(t *testing.T) {
|
||||
func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
|
||||
// Assert balancer with empty targets does return `nil` on `Next()`
|
||||
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
|
||||
target, err := rrb.Next(nil)
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
target, err := rrb.Next(c)
|
||||
assert.Nil(t, target)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestProxyRetries(t *testing.T) {
|
||||
newServer := func(res int) (*url.URL, *httptest.Server) {
|
||||
server := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(res)
|
||||
}),
|
||||
)
|
||||
targetURL, _ := url.Parse(server.URL)
|
||||
return targetURL, server
|
||||
}
|
||||
|
||||
targetURL, server := newServer(http.StatusOK)
|
||||
defer server.Close()
|
||||
goodTarget := &ProxyTarget{
|
||||
Name: "Good",
|
||||
URL: targetURL,
|
||||
}
|
||||
|
||||
targetURL, server = newServer(http.StatusBadRequest)
|
||||
defer server.Close()
|
||||
goodTargetWith40X := &ProxyTarget{
|
||||
Name: "Good with 40X",
|
||||
URL: targetURL,
|
||||
}
|
||||
|
||||
targetURL, _ = url.Parse("http://127.0.0.1:27121")
|
||||
badTarget := &ProxyTarget{
|
||||
Name: "Bad",
|
||||
URL: targetURL,
|
||||
}
|
||||
|
||||
alwaysRetryFilter := func(c echo.Context, e error) bool { return true }
|
||||
neverRetryFilter := func(c echo.Context, e error) bool { return false }
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
retryCount int
|
||||
retryFilters []func(c echo.Context, e error) bool
|
||||
targets []*ProxyTarget
|
||||
expectedResponse int
|
||||
}{
|
||||
{
|
||||
name: "retry count 0 does not attempt retry on fail",
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 1 does not attempt retry on success",
|
||||
retryCount: 1,
|
||||
targets: []*ProxyTarget{
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "retry count 1 does retry on handler return true",
|
||||
retryCount: 1,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "retry count 1 does not retry on handler return false",
|
||||
retryCount: 1,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
neverRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 2 returns error when no more retries left",
|
||||
retryCount: 2,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
badTarget,
|
||||
badTarget,
|
||||
goodTarget, //Should never be reached as only 2 retries
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 2 returns error when retries left but handler returns false",
|
||||
retryCount: 3,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
neverRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
badTarget,
|
||||
badTarget,
|
||||
goodTarget, //Should never be reached as retry handler returns false on 2nd check
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 3 succeeds",
|
||||
retryCount: 3,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
badTarget,
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "40x responses are not retried",
|
||||
retryCount: 1,
|
||||
targets: []*ProxyTarget{
|
||||
goodTargetWith40X,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
retryFilterCall := 0
|
||||
retryFilter := func(c echo.Context, e error) bool {
|
||||
if len(tc.retryFilters) == 0 {
|
||||
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
|
||||
}
|
||||
|
||||
retryFilterCall++
|
||||
|
||||
nextRetryFilter := tc.retryFilters[0]
|
||||
tc.retryFilters = tc.retryFilters[1:]
|
||||
|
||||
return nextRetryFilter(c, e)
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.Use(ProxyWithConfig(
|
||||
ProxyConfig{
|
||||
Balancer: NewRoundRobinBalancer(tc.targets),
|
||||
RetryCount: tc.retryCount,
|
||||
RetryFilter: retryFilter,
|
||||
},
|
||||
))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectedResponse, rec.Code)
|
||||
if len(tc.retryFilters) > 0 {
|
||||
assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyRetryWithBackendTimeout(t *testing.T) {
|
||||
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.ResponseHeaderTimeout = time.Millisecond * 500
|
||||
|
||||
timeoutBackend := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
w.WriteHeader(404)
|
||||
}),
|
||||
)
|
||||
defer timeoutBackend.Close()
|
||||
|
||||
timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
|
||||
goodBackend := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}),
|
||||
)
|
||||
defer goodBackend.Close()
|
||||
|
||||
goodTargetURL, _ := url.Parse(goodBackend.URL)
|
||||
e := echo.New()
|
||||
e.Use(ProxyWithConfig(
|
||||
ProxyConfig{
|
||||
Transport: transport,
|
||||
Balancer: NewRoundRobinBalancer([]*ProxyTarget{
|
||||
{
|
||||
Name: "Timeout",
|
||||
URL: timeoutTargetURL,
|
||||
},
|
||||
{
|
||||
Name: "Good",
|
||||
URL: goodTargetURL,
|
||||
},
|
||||
}),
|
||||
RetryCount: 1,
|
||||
},
|
||||
))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, 200, rec.Code)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
func TestProxyErrorHandler(t *testing.T) {
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
goodURL, _ := url.Parse(server.URL)
|
||||
defer server.Close()
|
||||
goodTarget := &ProxyTarget{
|
||||
Name: "Good",
|
||||
URL: goodURL,
|
||||
}
|
||||
|
||||
badURL, _ := url.Parse("http://127.0.0.1:27121")
|
||||
badTarget := &ProxyTarget{
|
||||
Name: "Bad",
|
||||
URL: badURL,
|
||||
}
|
||||
|
||||
transformedError := errors.New("a new error")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
target *ProxyTarget
|
||||
errorHandler func(c echo.Context, e error) error
|
||||
expectFinalError func(t *testing.T, err error)
|
||||
}{
|
||||
{
|
||||
name: "Error handler not invoked when request success",
|
||||
target: goodTarget,
|
||||
errorHandler: func(c echo.Context, e error) error {
|
||||
assert.FailNow(t, "error handler should not be invoked")
|
||||
return e
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error handler invoked when request fails",
|
||||
target: badTarget,
|
||||
errorHandler: func(c echo.Context, e error) error {
|
||||
httpErr, ok := e.(*echo.HTTPError)
|
||||
assert.True(t, ok, "expected http error to be passed to handler")
|
||||
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
|
||||
return transformedError
|
||||
},
|
||||
expectFinalError: func(t *testing.T, err error) {
|
||||
assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Use(ProxyWithConfig(
|
||||
ProxyConfig{
|
||||
Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
|
||||
ErrorHandler: tc.errorHandler,
|
||||
},
|
||||
))
|
||||
|
||||
errorHandlerCalled := false
|
||||
dheh := echo.DefaultHTTPErrorHandler(false)
|
||||
e.HTTPErrorHandler = func(c echo.Context, err error) {
|
||||
errorHandlerCalled = true
|
||||
tc.expectFinalError(t, err)
|
||||
dheh(c, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
if !errorHandlerCalled && tc.expectFinalError != nil {
|
||||
t.Fatalf("error handler was not called")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -156,6 +156,8 @@ type RateLimiterMemoryStore struct {
|
||||
burst int
|
||||
expiresIn time.Duration
|
||||
lastCleanup time.Time
|
||||
|
||||
timeNow func() time.Time
|
||||
}
|
||||
|
||||
// Visitor signifies a unique user's limiter details
|
||||
@ -215,7 +217,8 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s
|
||||
store.burst = int(config.Rate)
|
||||
}
|
||||
store.visitors = make(map[string]*Visitor)
|
||||
store.lastCleanup = now()
|
||||
store.timeNow = time.Now
|
||||
store.lastCleanup = store.timeNow()
|
||||
return
|
||||
}
|
||||
|
||||
@ -240,12 +243,13 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
|
||||
limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
|
||||
store.visitors[identifier] = limiter
|
||||
}
|
||||
limiter.lastSeen = now()
|
||||
if now().Sub(store.lastCleanup) > store.expiresIn {
|
||||
now := store.timeNow()
|
||||
limiter.lastSeen = now
|
||||
if now.Sub(store.lastCleanup) > store.expiresIn {
|
||||
store.cleanupStaleVisitors()
|
||||
}
|
||||
store.mutex.Unlock()
|
||||
return limiter.AllowN(now(), 1), nil
|
||||
return limiter.AllowN(store.timeNow(), 1), nil
|
||||
}
|
||||
|
||||
/*
|
||||
@ -254,14 +258,9 @@ of users who haven't visited again after the configured expiry time has elapsed
|
||||
*/
|
||||
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
|
||||
for id, visitor := range store.visitors {
|
||||
if now().Sub(visitor.lastSeen) > store.expiresIn {
|
||||
if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn {
|
||||
delete(store.visitors, id)
|
||||
}
|
||||
}
|
||||
store.lastCleanup = now()
|
||||
store.lastCleanup = store.timeNow()
|
||||
}
|
||||
|
||||
/*
|
||||
actual time method which is mocked in test file
|
||||
*/
|
||||
var now = time.Now
|
||||
|
@ -2,7 +2,6 @@ package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -361,7 +360,7 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) {
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
|
||||
now = func() time.Time {
|
||||
inMemoryStore.timeNow = func() time.Time {
|
||||
return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
|
||||
}
|
||||
allowed, _ := inMemoryStore.Allow(tc.id)
|
||||
@ -371,24 +370,22 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) {
|
||||
|
||||
func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
|
||||
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||
now = time.Now
|
||||
fmt.Println(now())
|
||||
inMemoryStore.visitors = map[string]*Visitor{
|
||||
"A": {
|
||||
Limiter: rate.NewLimiter(1, 3),
|
||||
lastSeen: now(),
|
||||
lastSeen: time.Now(),
|
||||
},
|
||||
"B": {
|
||||
Limiter: rate.NewLimiter(1, 3),
|
||||
lastSeen: now().Add(-1 * time.Minute),
|
||||
lastSeen: time.Now().Add(-1 * time.Minute),
|
||||
},
|
||||
"C": {
|
||||
Limiter: rate.NewLimiter(1, 3),
|
||||
lastSeen: now().Add(-5 * time.Minute),
|
||||
lastSeen: time.Now().Add(-5 * time.Minute),
|
||||
},
|
||||
"D": {
|
||||
Limiter: rate.NewLimiter(1, 3),
|
||||
lastSeen: now().Add(-10 * time.Minute),
|
||||
lastSeen: time.Now().Add(-10 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -224,7 +224,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
now = time.Now
|
||||
now := time.Now
|
||||
if config.timeNow != nil {
|
||||
now = config.timeNow
|
||||
}
|
||||
@ -256,7 +256,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
config.BeforeNextFunc(c)
|
||||
}
|
||||
err := next(c)
|
||||
if config.HandleError {
|
||||
if err != nil && config.HandleError {
|
||||
// When global error handler writes the error to the client the Response gets "committed". This state can be
|
||||
// checked with `c.Response().Committed` field.
|
||||
c.Error(err)
|
||||
|
@ -1,9 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -77,17 +79,38 @@ func createRandomStringGenerator(length uint8) func() string {
|
||||
}
|
||||
}
|
||||
|
||||
func randomString(length uint8) string {
|
||||
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
// https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls
|
||||
var randomReaderPool = sync.Pool{New: func() interface{} {
|
||||
return bufio.NewReader(rand.Reader)
|
||||
}}
|
||||
|
||||
bytes := make([]byte, length)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
// we are out of random. let the request fail
|
||||
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
|
||||
const randomStringCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
const randomStringCharsetLen = 52 // len(randomStringCharset)
|
||||
const randomStringMaxByte = 255 - (256 % randomStringCharsetLen)
|
||||
|
||||
func randomString(length uint8) string {
|
||||
reader := randomReaderPool.Get().(*bufio.Reader)
|
||||
defer randomReaderPool.Put(reader)
|
||||
|
||||
b := make([]byte, length)
|
||||
r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times
|
||||
var i uint8 = 0
|
||||
|
||||
for {
|
||||
_, err := io.ReadFull(reader, r)
|
||||
if err != nil {
|
||||
panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)")
|
||||
}
|
||||
for _, rb := range r {
|
||||
if rb > randomStringMaxByte {
|
||||
// Skip this number to avoid bias.
|
||||
continue
|
||||
}
|
||||
b[i] = randomStringCharset[rb%randomStringCharsetLen]
|
||||
i++
|
||||
if i == length {
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
for i, b := range bytes {
|
||||
bytes[i] = charset[b%byte(len(charset))]
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
@ -129,3 +130,31 @@ func TestRandomString(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomStringBias(t *testing.T) {
|
||||
t.Parallel()
|
||||
const slen = 33
|
||||
const loop = 100000
|
||||
|
||||
counts := make(map[rune]int)
|
||||
var count int64
|
||||
|
||||
for i := 0; i < loop; i++ {
|
||||
s := randomString(slen)
|
||||
require.Equal(t, slen, len(s))
|
||||
for _, b := range s {
|
||||
counts[b]++
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, randomStringCharsetLen, len(counts))
|
||||
|
||||
avg := float64(count) / float64(len(counts))
|
||||
for k, n := range counts {
|
||||
diff := float64(n) / avg
|
||||
if diff < 0.95 || diff > 1.05 {
|
||||
t.Errorf("Bias on '%c': expected average %f, got %d", k, avg, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -95,6 +95,13 @@ func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return r.Writer.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// Unwrap returns the original http.ResponseWriter.
|
||||
// ResponseController can be used to access the original http.ResponseWriter.
|
||||
// See [https://go.dev/blog/go1.20]
|
||||
func (r *Response) Unwrap() http.ResponseWriter {
|
||||
return r.Writer
|
||||
}
|
||||
|
||||
func (r *Response) reset(w http.ResponseWriter) {
|
||||
r.beforeFuncs = nil
|
||||
r.afterFuncs = nil
|
||||
|
@ -72,3 +72,11 @@ func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestResponse_Unwrap(t *testing.T) {
|
||||
e := New()
|
||||
rec := httptest.NewRecorder()
|
||||
res := &Response{echo: e, Writer: rec}
|
||||
|
||||
assert.Equal(t, rec, res.Unwrap())
|
||||
}
|
||||
|
7
route.go
7
route.go
@ -81,7 +81,12 @@ func (r routeInfo) Reverse(params ...interface{}) string {
|
||||
ln := len(params)
|
||||
n := 0
|
||||
for i, l := 0, len(r.path); i < l; i++ {
|
||||
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
|
||||
hasBackslash := r.path[i] == '\\'
|
||||
if hasBackslash && i+1 < l && r.path[i+1] == ':' {
|
||||
i++ // backslash before colon escapes that colon. in that case skip backslash
|
||||
}
|
||||
if n < ln && (r.path[i] == anyLabel || (!hasBackslash && r.path[i] == paramLabel)) {
|
||||
// in case of `*` wildcard or `:` (unescaped colon) param we replace everything till next slash or end of path
|
||||
for ; i < l && r.path[i] != '/'; i++ {
|
||||
}
|
||||
uri.WriteString(fmt.Sprintf("%v", params[n]))
|
||||
|
@ -421,3 +421,101 @@ func TestRoutes_FilterByName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteInfo_Reverse(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenParams []string
|
||||
givenPath string
|
||||
whenParams []interface{}
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "ok,static with no params",
|
||||
givenPath: "/static",
|
||||
expect: "/static",
|
||||
},
|
||||
{
|
||||
name: "ok,static with non existent param",
|
||||
givenParams: []string{"missing param"},
|
||||
givenPath: "/static",
|
||||
whenParams: []interface{}{"missing param"},
|
||||
expect: "/static",
|
||||
},
|
||||
{
|
||||
name: "ok, wildcard with no params",
|
||||
givenPath: "/static/*",
|
||||
expect: "/static/*",
|
||||
},
|
||||
{
|
||||
name: "ok, wildcard with params",
|
||||
givenParams: []string{"foo.txt"},
|
||||
givenPath: "/static/*",
|
||||
whenParams: []interface{}{"foo.txt"},
|
||||
expect: "/static/foo.txt",
|
||||
},
|
||||
{
|
||||
name: "ok, single param without param",
|
||||
givenPath: "/params/:foo",
|
||||
expect: "/params/:foo",
|
||||
},
|
||||
{
|
||||
name: "ok, single param with param",
|
||||
givenParams: []string{"one"},
|
||||
givenPath: "/params/:foo",
|
||||
whenParams: []interface{}{"one"},
|
||||
expect: "/params/one",
|
||||
},
|
||||
{
|
||||
name: "ok, multi param without params",
|
||||
givenPath: "/params/:foo/bar/:qux",
|
||||
expect: "/params/:foo/bar/:qux",
|
||||
},
|
||||
{
|
||||
name: "ok, multi param with one param",
|
||||
givenParams: []string{"one"},
|
||||
givenPath: "/params/:foo/bar/:qux",
|
||||
whenParams: []interface{}{"one"},
|
||||
expect: "/params/one/bar/:qux",
|
||||
},
|
||||
{
|
||||
name: "ok, multi param with all params",
|
||||
givenParams: []string{"one", "two"},
|
||||
givenPath: "/params/:foo/bar/:qux",
|
||||
whenParams: []interface{}{"one", "two"},
|
||||
expect: "/params/one/bar/two",
|
||||
},
|
||||
{
|
||||
name: "ok, multi param + wildcard with all params",
|
||||
givenParams: []string{"one", "two", "three"},
|
||||
givenPath: "/params/:foo/bar/:qux/*",
|
||||
whenParams: []interface{}{"one", "two", "three"},
|
||||
expect: "/params/one/bar/two/three",
|
||||
},
|
||||
{
|
||||
name: "ok, backslash is not escaped",
|
||||
givenParams: []string{"test"},
|
||||
givenPath: "/a\\b/:x",
|
||||
whenParams: []interface{}{"test"},
|
||||
expect: `/a\b/test`,
|
||||
},
|
||||
{
|
||||
name: "ok, escaped colon verbs",
|
||||
givenParams: []string{"PATCH"},
|
||||
givenPath: "/params\\::customVerb",
|
||||
whenParams: []interface{}{"PATCH"},
|
||||
expect: `/params:PATCH`,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := routeInfo{
|
||||
path: tc.givenPath,
|
||||
params: tc.givenParams,
|
||||
name: tc.expect,
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expect, r.Reverse(tc.whenParams...))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -96,6 +96,7 @@ type RouteInfo interface {
|
||||
Name() string
|
||||
|
||||
Params() []string
|
||||
// Reverse reverses route to URL string by replacing path parameters with given params values.
|
||||
Reverse(params ...interface{}) string
|
||||
|
||||
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
|
||||
@ -1066,7 +1067,7 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
if r.unescapePathParamValues && currentNode.kind != staticKind {
|
||||
if r.unescapePathParamValues {
|
||||
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
|
||||
for i, p := range *pathParams {
|
||||
tmpVal, err := url.PathUnescape(p.Value)
|
||||
|
@ -3053,6 +3053,15 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) {
|
||||
{Name: "*", Value: " /with space"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, ending with static node, unescape = true",
|
||||
givenUnescapePathParamValues: true,
|
||||
whenURL: "/fourth/%20%2Fwith%20space/static",
|
||||
expectPath: "/fourth/:id/static",
|
||||
expectPathParams: PathParams{
|
||||
{Name: "id", Value: " /with space"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, unescape = false",
|
||||
givenUnescapePathParamValues: false,
|
||||
@ -3080,6 +3089,8 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
_, err = router.Add(Route{Method: http.MethodGet, Path: "/third/*", Handler: handlerFunc})
|
||||
assert.NoError(t, err)
|
||||
_, err = router.Add(Route{Method: http.MethodGet, Path: "/fourth/:id/static", Handler: handlerFunc})
|
||||
assert.NoError(t, err)
|
||||
|
||||
target, _ := url.Parse(tc.whenURL)
|
||||
req := httptest.NewRequest(http.MethodGet, target.String(), nil)
|
||||
|
Loading…
Reference in New Issue
Block a user