mirror of
https://github.com/labstack/echo.git
synced 2025-01-01 22:09:21 +02:00
Changes from master (from 5b36ce3612
to b3ec8e0fdd
)
This commit is contained in:
parent
c2af0cf5a8
commit
ec5b858dab
95
CHANGELOG.md
95
CHANGELOG.md
@ -1,14 +1,101 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
|
||||||
## v4.10.0 - 2022-xx-xx
|
## v4.11.1 - 2023-07-16
|
||||||
|
|
||||||
|
**Fixes**
|
||||||
|
|
||||||
|
* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481)
|
||||||
|
|
||||||
|
|
||||||
|
## v4.11.0 - 2023-07-14
|
||||||
|
|
||||||
|
|
||||||
|
**Fixes**
|
||||||
|
|
||||||
|
* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409)
|
||||||
|
* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411)
|
||||||
|
* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456)
|
||||||
|
* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477)
|
||||||
|
|
||||||
|
|
||||||
|
**Enhancements**
|
||||||
|
|
||||||
|
* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410)
|
||||||
|
* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424)
|
||||||
|
* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425)
|
||||||
|
* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429)
|
||||||
|
* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416)
|
||||||
|
* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436)
|
||||||
|
* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444)
|
||||||
|
* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414)
|
||||||
|
* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452)
|
||||||
|
* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267)
|
||||||
|
* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475)
|
||||||
|
|
||||||
|
|
||||||
|
## v4.10.2 - 2023-02-22
|
||||||
|
|
||||||
**Security**
|
**Security**
|
||||||
|
|
||||||
This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
|
* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406)
|
||||||
several vulnerabilities fixed in these libraries.
|
* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405)
|
||||||
|
|
||||||
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
|
**Enhancements**
|
||||||
|
|
||||||
|
* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277)
|
||||||
|
|
||||||
|
|
||||||
|
## v4.10.1 - 2023-02-19
|
||||||
|
|
||||||
|
**Security**
|
||||||
|
|
||||||
|
* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402)
|
||||||
|
|
||||||
|
|
||||||
|
**Enhancements**
|
||||||
|
|
||||||
|
* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377)
|
||||||
|
* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385)
|
||||||
|
* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380)
|
||||||
|
* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394)
|
||||||
|
|
||||||
|
|
||||||
|
## v4.10.0 - 2022-12-27
|
||||||
|
|
||||||
|
**Security**
|
||||||
|
|
||||||
|
* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead.
|
||||||
|
|
||||||
|
JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using
|
||||||
|
which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain.
|
||||||
|
|
||||||
|
* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
|
||||||
|
several vulnerabilities fixed in these libraries.
|
||||||
|
|
||||||
|
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
|
||||||
|
|
||||||
|
|
||||||
|
**Enhancements**
|
||||||
|
|
||||||
|
* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305)
|
||||||
|
* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336)
|
||||||
|
* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316)
|
||||||
|
* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338)
|
||||||
|
* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315)
|
||||||
|
* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329)
|
||||||
|
* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340)
|
||||||
|
* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342)
|
||||||
|
* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343)
|
||||||
|
* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345)
|
||||||
|
* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182)
|
||||||
|
* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350)
|
||||||
|
* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341)
|
||||||
|
* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162)
|
||||||
|
* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358)
|
||||||
|
* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362)
|
||||||
|
* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366)
|
||||||
|
* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337)
|
||||||
|
|
||||||
|
|
||||||
## v4.9.1 - 2022-10-12
|
## v4.9.1 - 2022-10-12
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
[![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge)
|
[![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge)
|
||||||
[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4)
|
[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4)
|
||||||
[![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo)
|
[![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo)
|
||||||
[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo)
|
[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions)
|
||||||
[![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo)
|
[![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo)
|
||||||
[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions)
|
[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions)
|
||||||
[![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack)
|
[![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack)
|
||||||
@ -11,12 +11,12 @@
|
|||||||
|
|
||||||
## Supported Go versions
|
## Supported Go versions
|
||||||
|
|
||||||
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
|
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with
|
||||||
|
older versions.
|
||||||
|
|
||||||
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
|
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
|
||||||
Therefore a Go version capable of understanding /vN suffixed imports is required:
|
Therefore a Go version capable of understanding /vN suffixed imports is required:
|
||||||
|
|
||||||
|
|
||||||
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
|
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
|
||||||
way of using Echo going forward.
|
way of using Echo going forward.
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ func hello(c echo.Context) error {
|
|||||||
Following list of middleware is maintained by Echo team.
|
Following list of middleware is maintained by Echo team.
|
||||||
|
|
||||||
| Repository | Description |
|
| Repository | Description |
|
||||||
|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
|
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
|
||||||
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
|
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
|
||||||
|
|
||||||
@ -101,6 +101,7 @@ of middlewares in this list.
|
|||||||
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
|
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
|
||||||
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
|
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
|
||||||
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
|
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
|
||||||
|
| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface.
|
||||||
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
|
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
|
||||||
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
|
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
|
||||||
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
|
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
|
||||||
|
16
binder.go
16
binder.go
@ -1236,7 +1236,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim
|
|||||||
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
|
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
|
||||||
//
|
//
|
||||||
// Note:
|
// Note:
|
||||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||||
func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
||||||
return b.unixTime(sourceParam, dest, false, time.Second)
|
return b.unixTime(sourceParam, dest, false, time.Second)
|
||||||
}
|
}
|
||||||
@ -1247,7 +1247,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder
|
|||||||
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
|
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
|
||||||
//
|
//
|
||||||
// Note:
|
// Note:
|
||||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||||
func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
||||||
return b.unixTime(sourceParam, dest, true, time.Second)
|
return b.unixTime(sourceParam, dest, true, time.Second)
|
||||||
}
|
}
|
||||||
@ -1257,7 +1257,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi
|
|||||||
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
||||||
//
|
//
|
||||||
// Note:
|
// Note:
|
||||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||||
func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
||||||
return b.unixTime(sourceParam, dest, false, time.Millisecond)
|
return b.unixTime(sourceParam, dest, false, time.Millisecond)
|
||||||
}
|
}
|
||||||
@ -1268,7 +1268,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB
|
|||||||
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
||||||
//
|
//
|
||||||
// Note:
|
// Note:
|
||||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||||
func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
||||||
return b.unixTime(sourceParam, dest, true, time.Millisecond)
|
return b.unixTime(sourceParam, dest, true, time.Millisecond)
|
||||||
}
|
}
|
||||||
@ -1280,8 +1280,8 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va
|
|||||||
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
|
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
|
||||||
//
|
//
|
||||||
// Note:
|
// Note:
|
||||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||||
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||||
func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
||||||
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
|
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
|
||||||
}
|
}
|
||||||
@ -1294,8 +1294,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi
|
|||||||
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
|
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
|
||||||
//
|
//
|
||||||
// Note:
|
// Note:
|
||||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||||
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||||
func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
||||||
return b.unixTime(sourceParam, dest, true, time.Nanosecond)
|
return b.unixTime(sourceParam, dest, true, time.Nanosecond)
|
||||||
}
|
}
|
||||||
|
@ -113,8 +113,8 @@ type Context interface {
|
|||||||
// Set saves data in the context.
|
// Set saves data in the context.
|
||||||
Set(key string, val interface{})
|
Set(key string, val interface{})
|
||||||
|
|
||||||
// Bind binds the request body into provided type `i`. The default binder
|
// Bind binds path params, query params and the request body into provided type `i`. The default binder
|
||||||
// does it based on Content-Type header.
|
// binds body based on Content-Type header.
|
||||||
Bind(i interface{}) error
|
Bind(i interface{}) error
|
||||||
|
|
||||||
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
|
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
|
||||||
@ -536,8 +536,8 @@ func (c *DefaultContext) Set(key string, val interface{}) {
|
|||||||
c.store[key] = val
|
c.store[key] = val
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind binds the request body into provided type `i`. The default binder
|
// Bind binds path params, query params and the request body into provided type `i`. The default binder
|
||||||
// does it based on Content-Type header.
|
// binds body based on Content-Type header.
|
||||||
func (c *DefaultContext) Bind(i interface{}) error {
|
func (c *DefaultContext) Bind(i interface{}) error {
|
||||||
return c.echo.Binder.Bind(c, i)
|
return c.echo.Binder.Bind(c, i)
|
||||||
}
|
}
|
||||||
|
8
echo.go
8
echo.go
@ -40,6 +40,7 @@ package echo
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
stdContext "context"
|
stdContext "context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -336,12 +337,17 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
|
|||||||
// Issue #1426
|
// Issue #1426
|
||||||
code := he.Code
|
code := he.Code
|
||||||
message := he.Message
|
message := he.Message
|
||||||
if m, ok := he.Message.(string); ok {
|
switch m := he.Message.(type) {
|
||||||
|
case string:
|
||||||
if exposeError {
|
if exposeError {
|
||||||
message = Map{"message": m, "error": err.Error()}
|
message = Map{"message": m, "error": err.Error()}
|
||||||
} else {
|
} else {
|
||||||
message = Map{"message": m}
|
message = Map{"message": m}
|
||||||
}
|
}
|
||||||
|
case json.Marshaler:
|
||||||
|
// do nothing - this type knows how to format itself to JSON
|
||||||
|
case error:
|
||||||
|
message = Map{"message": m.Error()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
|
25
echo_test.go
25
echo_test.go
@ -1198,6 +1198,18 @@ func request(method, path string, e *Echo) (int, string) {
|
|||||||
return rec.Code, rec.Body.String()
|
return rec.Code, rec.Body.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type customError struct {
|
||||||
|
s string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce *customError) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce *customError) Error() string {
|
||||||
|
return ce.s
|
||||||
|
}
|
||||||
|
|
||||||
func TestDefaultHTTPErrorHandler(t *testing.T) {
|
func TestDefaultHTTPErrorHandler(t *testing.T) {
|
||||||
var testCases = []struct {
|
var testCases = []struct {
|
||||||
name string
|
name string
|
||||||
@ -1263,6 +1275,19 @@ func TestDefaultHTTPErrorHandler(t *testing.T) {
|
|||||||
expectStatus: http.StatusInternalServerError,
|
expectStatus: http.StatusInternalServerError,
|
||||||
expectBody: ``,
|
expectBody: ``,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "ok, custom error implement MarshalJSON",
|
||||||
|
whenMethod: http.MethodGet,
|
||||||
|
whenError: NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}),
|
||||||
|
expectStatus: http.StatusBadRequest,
|
||||||
|
expectBody: "{\"x\":\"custom error msg\"}\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with Debug=false when httpError contains an error",
|
||||||
|
whenError: NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")),
|
||||||
|
expectStatus: http.StatusBadRequest,
|
||||||
|
expectBody: "{\"message\":\"error in httperror\"}\n",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
@ -754,3 +754,77 @@ func TestGroup_StaticPanic(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenCustom404 bool
|
||||||
|
whenURL string
|
||||||
|
expectBody interface{}
|
||||||
|
expectCode int
|
||||||
|
expectMiddlewareCalled bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, custom 404 handler is called with middleware",
|
||||||
|
givenCustom404: true,
|
||||||
|
whenURL: "/group/test3",
|
||||||
|
expectBody: "404 GET /group/*",
|
||||||
|
expectCode: http.StatusNotFound,
|
||||||
|
expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, default group 404 handler is not called with middleware",
|
||||||
|
givenCustom404: false,
|
||||||
|
whenURL: "/group/test3",
|
||||||
|
expectBody: "404 GET /*",
|
||||||
|
expectCode: http.StatusNotFound,
|
||||||
|
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, (no slash) default group 404 handler is called with middleware",
|
||||||
|
givenCustom404: false,
|
||||||
|
whenURL: "/group",
|
||||||
|
expectBody: "404 GET /*",
|
||||||
|
expectCode: http.StatusNotFound,
|
||||||
|
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
okHandler := func(c Context) error {
|
||||||
|
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
|
||||||
|
}
|
||||||
|
notFoundHandler := func(c Context) error {
|
||||||
|
return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path())
|
||||||
|
}
|
||||||
|
|
||||||
|
e := New()
|
||||||
|
e.GET("/test1", okHandler)
|
||||||
|
e.RouteNotFound("/*", notFoundHandler)
|
||||||
|
|
||||||
|
g := e.Group("/group")
|
||||||
|
g.GET("/test1", okHandler)
|
||||||
|
|
||||||
|
middlewareCalled := false
|
||||||
|
g.Use(func(next HandlerFunc) HandlerFunc {
|
||||||
|
return func(c Context) error {
|
||||||
|
middlewareCalled = true
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if tc.givenCustom404 {
|
||||||
|
g.RouteNotFound("/*", notFoundHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled)
|
||||||
|
assert.Equal(t, tc.expectCode, rec.Code)
|
||||||
|
assert.Equal(t, tc.expectBody, rec.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -20,7 +20,6 @@ type limitedReader struct {
|
|||||||
BodyLimitConfig
|
BodyLimitConfig
|
||||||
reader io.ReadCloser
|
reader io.ReadCloser
|
||||||
read int64
|
read int64
|
||||||
context echo.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BodyLimit returns a BodyLimit middleware.
|
// BodyLimit returns a BodyLimit middleware.
|
||||||
@ -65,7 +64,7 @@ func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
|
|
||||||
// Based on content read
|
// Based on content read
|
||||||
r := pool.Get().(*limitedReader)
|
r := pool.Get().(*limitedReader)
|
||||||
r.Reset(c, req.Body)
|
r.Reset(req.Body)
|
||||||
defer pool.Put(r)
|
defer pool.Put(r)
|
||||||
req.Body = r
|
req.Body = r
|
||||||
|
|
||||||
@ -87,8 +86,7 @@ func (r *limitedReader) Close() error {
|
|||||||
return r.reader.Close()
|
return r.reader.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
|
func (r *limitedReader) Reset(reader io.ReadCloser) {
|
||||||
r.reader = reader
|
r.reader = reader
|
||||||
r.context = context
|
|
||||||
r.read = 0
|
r.read = 0
|
||||||
}
|
}
|
||||||
|
@ -67,9 +67,6 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
func TestBodyLimitReader(t *testing.T) {
|
func TestBodyLimitReader(t *testing.T) {
|
||||||
hw := []byte("Hello, World!")
|
hw := []byte("Hello, World!")
|
||||||
e := echo.New()
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
|
|
||||||
config := BodyLimitConfig{
|
config := BodyLimitConfig{
|
||||||
Skipper: DefaultSkipper,
|
Skipper: DefaultSkipper,
|
||||||
@ -78,7 +75,6 @@ func TestBodyLimitReader(t *testing.T) {
|
|||||||
reader := &limitedReader{
|
reader := &limitedReader{
|
||||||
BodyLimitConfig: config,
|
BodyLimitConfig: config,
|
||||||
reader: io.NopCloser(bytes.NewReader(hw)),
|
reader: io.NopCloser(bytes.NewReader(hw)),
|
||||||
context: e.NewContext(req, rec),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// read all should return ErrStatusRequestEntityTooLarge
|
// read all should return ErrStatusRequestEntityTooLarge
|
||||||
@ -88,7 +84,7 @@ func TestBodyLimitReader(t *testing.T) {
|
|||||||
|
|
||||||
// reset reader and read two bytes must succeed
|
// reset reader and read two bytes must succeed
|
||||||
bt := make([]byte, 2)
|
bt := make([]byte, 2)
|
||||||
reader.Reset(e.NewContext(req, rec), io.NopCloser(bytes.NewReader(hw)))
|
reader.Reset(io.NopCloser(bytes.NewReader(hw)))
|
||||||
n, err := reader.Read(bt)
|
n, err := reader.Read(bt)
|
||||||
assert.Equal(t, 2, n)
|
assert.Equal(t, 2, n)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
@ -25,12 +26,30 @@ type GzipConfig struct {
|
|||||||
// Gzip compression level.
|
// Gzip compression level.
|
||||||
// Optional. Default value -1.
|
// Optional. Default value -1.
|
||||||
Level int
|
Level int
|
||||||
|
|
||||||
|
// Length threshold before gzip compression is applied.
|
||||||
|
// Optional. Default value 0.
|
||||||
|
//
|
||||||
|
// Most of the time you will not need to change the default. Compressing
|
||||||
|
// a short response might increase the transmitted data because of the
|
||||||
|
// gzip format overhead. Compressing the response will also consume CPU
|
||||||
|
// and time on the server and the client (for decompressing). Depending on
|
||||||
|
// your use case such a threshold might be useful.
|
||||||
|
//
|
||||||
|
// See also:
|
||||||
|
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
|
||||||
|
MinLength int
|
||||||
}
|
}
|
||||||
|
|
||||||
type gzipResponseWriter struct {
|
type gzipResponseWriter struct {
|
||||||
io.Writer
|
io.Writer
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
|
wroteHeader bool
|
||||||
wroteBody bool
|
wroteBody bool
|
||||||
|
minLength int
|
||||||
|
minLengthExceeded bool
|
||||||
|
buffer *bytes.Buffer
|
||||||
|
code int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
|
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||||
@ -54,8 +73,12 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
if config.Level == 0 {
|
if config.Level == 0 {
|
||||||
config.Level = -1
|
config.Level = -1
|
||||||
}
|
}
|
||||||
|
if config.MinLength < 0 {
|
||||||
|
config.MinLength = 0
|
||||||
|
}
|
||||||
|
|
||||||
pool := gzipCompressPool(config)
|
pool := gzipCompressPool(config)
|
||||||
|
bpool := bufferPool()
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
@ -66,7 +89,6 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
res := c.Response()
|
res := c.Response()
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
|
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
|
||||||
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
|
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
|
||||||
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
|
||||||
i := pool.Get()
|
i := pool.Get()
|
||||||
w, ok := i.(*gzip.Writer)
|
w, ok := i.(*gzip.Writer)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -74,19 +96,37 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
}
|
}
|
||||||
rw := res.Writer
|
rw := res.Writer
|
||||||
w.Reset(rw)
|
w.Reset(rw)
|
||||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
|
buf := bpool.Get().(*bytes.Buffer)
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
|
||||||
defer func() {
|
defer func() {
|
||||||
|
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
|
||||||
|
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
|
||||||
|
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
|
||||||
if !grw.wroteBody {
|
if !grw.wroteBody {
|
||||||
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
|
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
|
||||||
res.Header().Del(echo.HeaderContentEncoding)
|
res.Header().Del(echo.HeaderContentEncoding)
|
||||||
}
|
}
|
||||||
|
if grw.wroteHeader {
|
||||||
|
rw.WriteHeader(grw.code)
|
||||||
|
}
|
||||||
// We have to reset response to it's pristine state when
|
// We have to reset response to it's pristine state when
|
||||||
// nothing is written to body or error is returned.
|
// nothing is written to body or error is returned.
|
||||||
// See issue #424, #407.
|
// See issue #424, #407.
|
||||||
res.Writer = rw
|
res.Writer = rw
|
||||||
w.Reset(io.Discard)
|
w.Reset(io.Discard)
|
||||||
|
} else if !grw.minLengthExceeded {
|
||||||
|
// Write uncompressed response
|
||||||
|
res.Writer = rw
|
||||||
|
if grw.wroteHeader {
|
||||||
|
grw.ResponseWriter.WriteHeader(grw.code)
|
||||||
|
}
|
||||||
|
grw.buffer.WriteTo(rw)
|
||||||
|
w.Reset(io.Discard)
|
||||||
}
|
}
|
||||||
w.Close()
|
w.Close()
|
||||||
|
bpool.Put(buf)
|
||||||
pool.Put(w)
|
pool.Put(w)
|
||||||
}()
|
}()
|
||||||
res.Writer = grw
|
res.Writer = grw
|
||||||
@ -98,7 +138,11 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
|
|
||||||
func (w *gzipResponseWriter) WriteHeader(code int) {
|
func (w *gzipResponseWriter) WriteHeader(code int) {
|
||||||
w.Header().Del(echo.HeaderContentLength) // Issue #444
|
w.Header().Del(echo.HeaderContentLength) // Issue #444
|
||||||
w.ResponseWriter.WriteHeader(code)
|
|
||||||
|
w.wroteHeader = true
|
||||||
|
|
||||||
|
// Delay writing of the header until we know if we'll actually compress the response
|
||||||
|
w.code = code
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||||
@ -106,10 +150,40 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
|||||||
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
|
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
|
||||||
}
|
}
|
||||||
w.wroteBody = true
|
w.wroteBody = true
|
||||||
|
|
||||||
|
if !w.minLengthExceeded {
|
||||||
|
n, err := w.buffer.Write(b)
|
||||||
|
|
||||||
|
if w.buffer.Len() >= w.minLength {
|
||||||
|
w.minLengthExceeded = true
|
||||||
|
|
||||||
|
// The minimum length is exceeded, add Content-Encoding header and write the header
|
||||||
|
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
||||||
|
if w.wroteHeader {
|
||||||
|
w.ResponseWriter.WriteHeader(w.code)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.Writer.Write(w.buffer.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
return w.Writer.Write(b)
|
return w.Writer.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *gzipResponseWriter) Flush() {
|
func (w *gzipResponseWriter) Flush() {
|
||||||
|
if !w.minLengthExceeded {
|
||||||
|
// Enforce compression because we will not know how much more data will come
|
||||||
|
w.minLengthExceeded = true
|
||||||
|
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
||||||
|
if w.wroteHeader {
|
||||||
|
w.ResponseWriter.WriteHeader(w.code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Writer.Write(w.buffer.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
w.Writer.(*gzip.Writer).Flush()
|
w.Writer.(*gzip.Writer).Flush()
|
||||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@ -138,3 +212,12 @@ func gzipCompressPool(config GzipConfig) sync.Pool {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func bufferPool() sync.Pool {
|
||||||
|
return sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
return b
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@ -211,6 +212,137 @@ func TestGzipWithStatic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGzipWithMinLength(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
// Minimal response length
|
||||||
|
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
c.Response().Write([]byte("foobarfoobar"))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||||
|
r, err := gzip.NewReader(rec.Body)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
defer r.Close()
|
||||||
|
buf.ReadFrom(r)
|
||||||
|
assert.Equal(t, "foobarfoobar", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGzipWithMinLengthTooShort(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
// Minimal response length
|
||||||
|
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
c.Response().Write([]byte("test"))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
|
||||||
|
assert.Contains(t, rec.Body.String(), "test")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGzipWithResponseWithoutBody(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
e.Use(Gzip())
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
return c.Redirect(http.StatusMovedPermanently, "http://localhost")
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
|
||||||
|
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGzipWithMinLengthChunked(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
// Gzip chunked
|
||||||
|
chunkBuf := make([]byte, 5)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
var r *gzip.Reader = nil
|
||||||
|
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
next := func(c echo.Context) error {
|
||||||
|
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Response().Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// Write and flush the first part of the data
|
||||||
|
c.Response().Write([]byte("test\n"))
|
||||||
|
c.Response().Flush()
|
||||||
|
|
||||||
|
// Read the first part of the data
|
||||||
|
assert.True(t, rec.Flushed)
|
||||||
|
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
r, err = gzip.NewReader(rec.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = io.ReadFull(r, chunkBuf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "test\n", string(chunkBuf))
|
||||||
|
|
||||||
|
// Write and flush the second part of the data
|
||||||
|
c.Response().Write([]byte("test\n"))
|
||||||
|
c.Response().Flush()
|
||||||
|
|
||||||
|
_, err = io.ReadFull(r, chunkBuf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "test\n", string(chunkBuf))
|
||||||
|
|
||||||
|
// Write the final part of the data and return
|
||||||
|
c.Response().Write([]byte("test"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, r)
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
|
buf.ReadFrom(r)
|
||||||
|
assert.Equal(t, "test", buf.String())
|
||||||
|
|
||||||
|
r.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGzipWithMinLengthNoContent(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error {
|
||||||
|
return c.NoContent(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
if assert.NoError(t, h(c)) {
|
||||||
|
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||||
|
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
|
||||||
|
assert.Equal(t, 0, len(rec.Body.Bytes()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGzip(b *testing.B) {
|
func BenchmarkGzip(b *testing.B) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
|
|
||||||
|
@ -151,8 +151,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
allowOriginPatterns := []string{}
|
allowOriginPatterns := []string{}
|
||||||
for _, origin := range config.AllowOrigins {
|
for _, origin := range config.AllowOrigins {
|
||||||
pattern := regexp.QuoteMeta(origin)
|
pattern := regexp.QuoteMeta(origin)
|
||||||
pattern = strings.Replace(pattern, "\\*", ".*", -1)
|
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
|
||||||
pattern = strings.Replace(pattern, "\\?", ".", -1)
|
pattern = strings.ReplaceAll(pattern, "\\?", ".")
|
||||||
pattern = "^" + pattern + "$"
|
pattern = "^" + pattern + "$"
|
||||||
allowOriginPatterns = append(allowOriginPatterns, pattern)
|
allowOriginPatterns = append(allowOriginPatterns, pattern)
|
||||||
}
|
}
|
||||||
|
@ -35,9 +35,9 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
|
|||||||
rulesRegex := map[*regexp.Regexp]string{}
|
rulesRegex := map[*regexp.Regexp]string{}
|
||||||
for k, v := range rewrite {
|
for k, v := range rewrite {
|
||||||
k = regexp.QuoteMeta(k)
|
k = regexp.QuoteMeta(k)
|
||||||
k = strings.Replace(k, `\*`, "(.*?)", -1)
|
k = strings.ReplaceAll(k, `\*`, "(.*?)")
|
||||||
if strings.HasPrefix(k, `\^`) {
|
if strings.HasPrefix(k, `\^`) {
|
||||||
k = strings.Replace(k, `\^`, "^", -1)
|
k = strings.ReplaceAll(k, `\^`, "^")
|
||||||
}
|
}
|
||||||
k = k + "$"
|
k = k + "$"
|
||||||
rulesRegex[regexp.MustCompile(k)] = v
|
rulesRegex[regexp.MustCompile(k)] = v
|
||||||
|
@ -29,6 +29,33 @@ type ProxyConfig struct {
|
|||||||
// Required.
|
// Required.
|
||||||
Balancer ProxyBalancer
|
Balancer ProxyBalancer
|
||||||
|
|
||||||
|
// RetryCount defines the number of times a failed proxied request should be retried
|
||||||
|
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
|
||||||
|
RetryCount int
|
||||||
|
|
||||||
|
// RetryFilter defines a function used to determine if a failed request to a
|
||||||
|
// ProxyTarget should be retried. The RetryFilter will only be called when the number
|
||||||
|
// of previous retries is less than RetryCount. If the function returns true, the
|
||||||
|
// request will be retried. The provided error indicates the reason for the request
|
||||||
|
// failure. When the ProxyTarget is unavailable, the error will be an instance of
|
||||||
|
// echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
|
||||||
|
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
|
||||||
|
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
|
||||||
|
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
|
||||||
|
// only called when the request to the target fails, or an internal error in the Proxy
|
||||||
|
// middleware has occurred. Successful requests that return a non-200 response code cannot
|
||||||
|
// be retried.
|
||||||
|
RetryFilter func(c echo.Context, e error) bool
|
||||||
|
|
||||||
|
// ErrorHandler defines a function which can be used to return custom errors from
|
||||||
|
// the Proxy middleware. ErrorHandler is only invoked when there has been
|
||||||
|
// either an internal error in the Proxy middleware or the ProxyTarget is
|
||||||
|
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
|
||||||
|
// when a ProxyTarget returns a non-200 response. In these cases, the response
|
||||||
|
// is already written so errors cannot be modified. ErrorHandler is only
|
||||||
|
// invoked after all retry attempts have been exhausted.
|
||||||
|
ErrorHandler func(c echo.Context, err error) error
|
||||||
|
|
||||||
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
|
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
|
||||||
// retrieved by index e.g. $1, $2 and so on.
|
// retrieved by index e.g. $1, $2 and so on.
|
||||||
// Examples:
|
// Examples:
|
||||||
@ -99,14 +126,14 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
in, _, err := c.Response().Hijack()
|
in, _, err := c.Response().Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
|
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer in.Close()
|
defer in.Close()
|
||||||
|
|
||||||
out, err := net.Dial("tcp", t.URL.Host)
|
out, err := net.Dial("tcp", t.URL.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
|
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer out.Close()
|
defer out.Close()
|
||||||
@ -114,7 +141,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
|||||||
// Write header
|
// Write header
|
||||||
err = r.Write(out)
|
err = r.Write(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
|
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,7 +155,7 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
|||||||
go cp(in, out)
|
go cp(in, out)
|
||||||
err = <-errCh
|
err = <-errCh
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
|
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -192,7 +219,12 @@ func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
|||||||
return b.targets[b.random.Intn(len(b.targets))], nil
|
return b.targets[b.random.Intn(len(b.targets))], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next returns an upstream target using round-robin technique.
|
// Next returns an upstream target using round-robin technique. In the case
|
||||||
|
// where a previously failed request is being retried, the round-robin
|
||||||
|
// balancer will attempt to use the next target relative to the original
|
||||||
|
// request. If the list of targets held by the balancer is modified while a
|
||||||
|
// failed request is being retried, it is possible that the balancer will
|
||||||
|
// return the original failed target.
|
||||||
//
|
//
|
||||||
// Note: `nil` is returned in case upstream target list is empty.
|
// Note: `nil` is returned in case upstream target list is empty.
|
||||||
func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
||||||
@ -203,13 +235,28 @@ func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
|||||||
} else if len(b.targets) == 1 {
|
} else if len(b.targets) == 1 {
|
||||||
return b.targets[0], nil
|
return b.targets[0], nil
|
||||||
}
|
}
|
||||||
// reset the index if out of bounds
|
|
||||||
|
var i int
|
||||||
|
const lastIdxKey = "_round_robin_last_index"
|
||||||
|
// This request is a retry, start from the index of the previous
|
||||||
|
// target to ensure we don't attempt to retry the request with
|
||||||
|
// the same failed target
|
||||||
|
if c.Get(lastIdxKey) != nil {
|
||||||
|
i = c.Get(lastIdxKey).(int)
|
||||||
|
i++
|
||||||
|
if i >= len(b.targets) {
|
||||||
|
i = 0
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// This is a first time request, use the global index
|
||||||
if b.i >= len(b.targets) {
|
if b.i >= len(b.targets) {
|
||||||
b.i = 0
|
b.i = 0
|
||||||
}
|
}
|
||||||
t := b.targets[b.i]
|
i = b.i
|
||||||
b.i++
|
b.i++
|
||||||
return t, nil
|
}
|
||||||
|
c.Set(lastIdxKey, i)
|
||||||
|
return b.targets[i], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy returns a Proxy middleware.
|
// Proxy returns a Proxy middleware.
|
||||||
@ -239,6 +286,19 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
if config.Balancer == nil {
|
if config.Balancer == nil {
|
||||||
return nil, errors.New("echo proxy middleware requires balancer")
|
return nil, errors.New("echo proxy middleware requires balancer")
|
||||||
}
|
}
|
||||||
|
if config.RetryFilter == nil {
|
||||||
|
config.RetryFilter = func(c echo.Context, e error) bool {
|
||||||
|
if httpErr, ok := e.(*echo.HTTPError); ok {
|
||||||
|
return httpErr.Code == http.StatusBadGateway
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.ErrorHandler == nil {
|
||||||
|
config.ErrorHandler = func(c echo.Context, err error) error {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if config.Rewrite != nil {
|
if config.Rewrite != nil {
|
||||||
if config.RegexRewrite == nil {
|
if config.RegexRewrite == nil {
|
||||||
@ -257,14 +317,8 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
|
|
||||||
req := c.Request()
|
req := c.Request()
|
||||||
res := c.Response()
|
res := c.Response()
|
||||||
tgt, err := config.Balancer.Next(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.Set(config.ContextKey, tgt)
|
|
||||||
|
|
||||||
if err := rewriteURL(config.RegexRewrite, req); err != nil {
|
if err := rewriteURL(config.RegexRewrite, req); err != nil {
|
||||||
return err
|
return config.ErrorHandler(c, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fix header
|
// Fix header
|
||||||
@ -280,6 +334,22 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
|
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
retries := config.RetryCount
|
||||||
|
for {
|
||||||
|
tgt, err := config.Balancer.Next(c)
|
||||||
|
if err != nil {
|
||||||
|
return config.ErrorHandler(c, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(config.ContextKey, tgt)
|
||||||
|
|
||||||
|
//If retrying a failed request, clear any previous errors from
|
||||||
|
//context here so that balancers have the option to check for
|
||||||
|
//errors that occurred using previous target
|
||||||
|
if retries < config.RetryCount {
|
||||||
|
c.Set("_error", nil)
|
||||||
|
}
|
||||||
|
|
||||||
// Proxy
|
// Proxy
|
||||||
switch {
|
switch {
|
||||||
case c.IsWebSocket():
|
case c.IsWebSocket():
|
||||||
@ -288,11 +358,19 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
default:
|
default:
|
||||||
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
|
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
|
||||||
}
|
}
|
||||||
if e, ok := c.Get("_error").(error); ok {
|
|
||||||
err = e
|
err, hasError := c.Get("_error").(error)
|
||||||
|
if !hasError {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
retry := retries > 0 && config.RetryFilter(c, err)
|
||||||
|
if !retry {
|
||||||
|
return config.ErrorHandler(c, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retries--
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -426,9 +427,14 @@ func TestFailNextTarget(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRandomBalancerWithNoTargets(t *testing.T) {
|
func TestRandomBalancerWithNoTargets(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
// Assert balancer with empty targets does return `nil` on `Next()`
|
// Assert balancer with empty targets does return `nil` on `Next()`
|
||||||
rb := NewRandomBalancer(nil)
|
rb := NewRandomBalancer(nil)
|
||||||
target, err := rb.Next(nil)
|
target, err := rb.Next(c)
|
||||||
assert.Nil(t, target)
|
assert.Nil(t, target)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
@ -436,7 +442,327 @@ func TestRandomBalancerWithNoTargets(t *testing.T) {
|
|||||||
func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
|
func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
|
||||||
// Assert balancer with empty targets does return `nil` on `Next()`
|
// Assert balancer with empty targets does return `nil` on `Next()`
|
||||||
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
|
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
|
||||||
target, err := rrb.Next(nil)
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
target, err := rrb.Next(c)
|
||||||
assert.Nil(t, target)
|
assert.Nil(t, target)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyRetries(t *testing.T) {
|
||||||
|
newServer := func(res int) (*url.URL, *httptest.Server) {
|
||||||
|
server := httptest.NewServer(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(res)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
targetURL, _ := url.Parse(server.URL)
|
||||||
|
return targetURL, server
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL, server := newServer(http.StatusOK)
|
||||||
|
defer server.Close()
|
||||||
|
goodTarget := &ProxyTarget{
|
||||||
|
Name: "Good",
|
||||||
|
URL: targetURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL, server = newServer(http.StatusBadRequest)
|
||||||
|
defer server.Close()
|
||||||
|
goodTargetWith40X := &ProxyTarget{
|
||||||
|
Name: "Good with 40X",
|
||||||
|
URL: targetURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL, _ = url.Parse("http://127.0.0.1:27121")
|
||||||
|
badTarget := &ProxyTarget{
|
||||||
|
Name: "Bad",
|
||||||
|
URL: targetURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
alwaysRetryFilter := func(c echo.Context, e error) bool { return true }
|
||||||
|
neverRetryFilter := func(c echo.Context, e error) bool { return false }
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
retryCount int
|
||||||
|
retryFilters []func(c echo.Context, e error) bool
|
||||||
|
targets []*ProxyTarget
|
||||||
|
expectedResponse int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retry count 0 does not attempt retry on fail",
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
badTarget,
|
||||||
|
goodTarget,
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusBadGateway,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retry count 1 does not attempt retry on success",
|
||||||
|
retryCount: 1,
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
goodTarget,
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retry count 1 does retry on handler return true",
|
||||||
|
retryCount: 1,
|
||||||
|
retryFilters: []func(c echo.Context, e error) bool{
|
||||||
|
alwaysRetryFilter,
|
||||||
|
},
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
badTarget,
|
||||||
|
goodTarget,
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retry count 1 does not retry on handler return false",
|
||||||
|
retryCount: 1,
|
||||||
|
retryFilters: []func(c echo.Context, e error) bool{
|
||||||
|
neverRetryFilter,
|
||||||
|
},
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
badTarget,
|
||||||
|
goodTarget,
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusBadGateway,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retry count 2 returns error when no more retries left",
|
||||||
|
retryCount: 2,
|
||||||
|
retryFilters: []func(c echo.Context, e error) bool{
|
||||||
|
alwaysRetryFilter,
|
||||||
|
alwaysRetryFilter,
|
||||||
|
},
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
badTarget,
|
||||||
|
badTarget,
|
||||||
|
badTarget,
|
||||||
|
goodTarget, //Should never be reached as only 2 retries
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusBadGateway,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retry count 2 returns error when retries left but handler returns false",
|
||||||
|
retryCount: 3,
|
||||||
|
retryFilters: []func(c echo.Context, e error) bool{
|
||||||
|
alwaysRetryFilter,
|
||||||
|
alwaysRetryFilter,
|
||||||
|
neverRetryFilter,
|
||||||
|
},
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
badTarget,
|
||||||
|
badTarget,
|
||||||
|
badTarget,
|
||||||
|
goodTarget, //Should never be reached as retry handler returns false on 2nd check
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusBadGateway,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retry count 3 succeeds",
|
||||||
|
retryCount: 3,
|
||||||
|
retryFilters: []func(c echo.Context, e error) bool{
|
||||||
|
alwaysRetryFilter,
|
||||||
|
alwaysRetryFilter,
|
||||||
|
alwaysRetryFilter,
|
||||||
|
},
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
badTarget,
|
||||||
|
badTarget,
|
||||||
|
badTarget,
|
||||||
|
goodTarget,
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "40x responses are not retried",
|
||||||
|
retryCount: 1,
|
||||||
|
targets: []*ProxyTarget{
|
||||||
|
goodTargetWith40X,
|
||||||
|
goodTarget,
|
||||||
|
},
|
||||||
|
expectedResponse: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
retryFilterCall := 0
|
||||||
|
retryFilter := func(c echo.Context, e error) bool {
|
||||||
|
if len(tc.retryFilters) == 0 {
|
||||||
|
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
|
||||||
|
}
|
||||||
|
|
||||||
|
retryFilterCall++
|
||||||
|
|
||||||
|
nextRetryFilter := tc.retryFilters[0]
|
||||||
|
tc.retryFilters = tc.retryFilters[1:]
|
||||||
|
|
||||||
|
return nextRetryFilter(c, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
e.Use(ProxyWithConfig(
|
||||||
|
ProxyConfig{
|
||||||
|
Balancer: NewRoundRobinBalancer(tc.targets),
|
||||||
|
RetryCount: tc.retryCount,
|
||||||
|
RetryFilter: retryFilter,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectedResponse, rec.Code)
|
||||||
|
if len(tc.retryFilters) > 0 {
|
||||||
|
assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRetryWithBackendTimeout(t *testing.T) {
|
||||||
|
|
||||||
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
transport.ResponseHeaderTimeout = time.Millisecond * 500
|
||||||
|
|
||||||
|
timeoutBackend := httptest.NewServer(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
w.WriteHeader(404)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
defer timeoutBackend.Close()
|
||||||
|
|
||||||
|
timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
|
||||||
|
goodBackend := httptest.NewServer(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
defer goodBackend.Close()
|
||||||
|
|
||||||
|
goodTargetURL, _ := url.Parse(goodBackend.URL)
|
||||||
|
e := echo.New()
|
||||||
|
e.Use(ProxyWithConfig(
|
||||||
|
ProxyConfig{
|
||||||
|
Transport: transport,
|
||||||
|
Balancer: NewRoundRobinBalancer([]*ProxyTarget{
|
||||||
|
{
|
||||||
|
Name: "Timeout",
|
||||||
|
URL: timeoutTargetURL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Good",
|
||||||
|
URL: goodTargetURL,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
RetryCount: 1,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, 200, rec.Code)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyErrorHandler(t *testing.T) {
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
goodURL, _ := url.Parse(server.URL)
|
||||||
|
defer server.Close()
|
||||||
|
goodTarget := &ProxyTarget{
|
||||||
|
Name: "Good",
|
||||||
|
URL: goodURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
badURL, _ := url.Parse("http://127.0.0.1:27121")
|
||||||
|
badTarget := &ProxyTarget{
|
||||||
|
Name: "Bad",
|
||||||
|
URL: badURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
transformedError := errors.New("a new error")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
target *ProxyTarget
|
||||||
|
errorHandler func(c echo.Context, e error) error
|
||||||
|
expectFinalError func(t *testing.T, err error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Error handler not invoked when request success",
|
||||||
|
target: goodTarget,
|
||||||
|
errorHandler: func(c echo.Context, e error) error {
|
||||||
|
assert.FailNow(t, "error handler should not be invoked")
|
||||||
|
return e
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error handler invoked when request fails",
|
||||||
|
target: badTarget,
|
||||||
|
errorHandler: func(c echo.Context, e error) error {
|
||||||
|
httpErr, ok := e.(*echo.HTTPError)
|
||||||
|
assert.True(t, ok, "expected http error to be passed to handler")
|
||||||
|
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
|
||||||
|
return transformedError
|
||||||
|
},
|
||||||
|
expectFinalError: func(t *testing.T, err error) {
|
||||||
|
assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
e.Use(ProxyWithConfig(
|
||||||
|
ProxyConfig{
|
||||||
|
Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
|
||||||
|
ErrorHandler: tc.errorHandler,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
errorHandlerCalled := false
|
||||||
|
dheh := echo.DefaultHTTPErrorHandler(false)
|
||||||
|
e.HTTPErrorHandler = func(c echo.Context, err error) {
|
||||||
|
errorHandlerCalled = true
|
||||||
|
tc.expectFinalError(t, err)
|
||||||
|
dheh(c, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if !errorHandlerCalled && tc.expectFinalError != nil {
|
||||||
|
t.Fatalf("error handler was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -156,6 +156,8 @@ type RateLimiterMemoryStore struct {
|
|||||||
burst int
|
burst int
|
||||||
expiresIn time.Duration
|
expiresIn time.Duration
|
||||||
lastCleanup time.Time
|
lastCleanup time.Time
|
||||||
|
|
||||||
|
timeNow func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visitor signifies a unique user's limiter details
|
// Visitor signifies a unique user's limiter details
|
||||||
@ -215,7 +217,8 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s
|
|||||||
store.burst = int(config.Rate)
|
store.burst = int(config.Rate)
|
||||||
}
|
}
|
||||||
store.visitors = make(map[string]*Visitor)
|
store.visitors = make(map[string]*Visitor)
|
||||||
store.lastCleanup = now()
|
store.timeNow = time.Now
|
||||||
|
store.lastCleanup = store.timeNow()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,12 +243,13 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
|
|||||||
limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
|
limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
|
||||||
store.visitors[identifier] = limiter
|
store.visitors[identifier] = limiter
|
||||||
}
|
}
|
||||||
limiter.lastSeen = now()
|
now := store.timeNow()
|
||||||
if now().Sub(store.lastCleanup) > store.expiresIn {
|
limiter.lastSeen = now
|
||||||
|
if now.Sub(store.lastCleanup) > store.expiresIn {
|
||||||
store.cleanupStaleVisitors()
|
store.cleanupStaleVisitors()
|
||||||
}
|
}
|
||||||
store.mutex.Unlock()
|
store.mutex.Unlock()
|
||||||
return limiter.AllowN(now(), 1), nil
|
return limiter.AllowN(store.timeNow(), 1), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -254,14 +258,9 @@ of users who haven't visited again after the configured expiry time has elapsed
|
|||||||
*/
|
*/
|
||||||
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
|
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
|
||||||
for id, visitor := range store.visitors {
|
for id, visitor := range store.visitors {
|
||||||
if now().Sub(visitor.lastSeen) > store.expiresIn {
|
if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn {
|
||||||
delete(store.visitors, id)
|
delete(store.visitors, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
store.lastCleanup = now()
|
store.lastCleanup = store.timeNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
actual time method which is mocked in test file
|
|
||||||
*/
|
|
||||||
var now = time.Now
|
|
||||||
|
@ -2,7 +2,6 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -361,7 +360,7 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) {
|
|||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
|
t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
|
||||||
now = func() time.Time {
|
inMemoryStore.timeNow = func() time.Time {
|
||||||
return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
|
return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
|
||||||
}
|
}
|
||||||
allowed, _ := inMemoryStore.Allow(tc.id)
|
allowed, _ := inMemoryStore.Allow(tc.id)
|
||||||
@ -371,24 +370,22 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) {
|
|||||||
|
|
||||||
func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
|
func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
|
||||||
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
now = time.Now
|
|
||||||
fmt.Println(now())
|
|
||||||
inMemoryStore.visitors = map[string]*Visitor{
|
inMemoryStore.visitors = map[string]*Visitor{
|
||||||
"A": {
|
"A": {
|
||||||
Limiter: rate.NewLimiter(1, 3),
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
lastSeen: now(),
|
lastSeen: time.Now(),
|
||||||
},
|
},
|
||||||
"B": {
|
"B": {
|
||||||
Limiter: rate.NewLimiter(1, 3),
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
lastSeen: now().Add(-1 * time.Minute),
|
lastSeen: time.Now().Add(-1 * time.Minute),
|
||||||
},
|
},
|
||||||
"C": {
|
"C": {
|
||||||
Limiter: rate.NewLimiter(1, 3),
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
lastSeen: now().Add(-5 * time.Minute),
|
lastSeen: time.Now().Add(-5 * time.Minute),
|
||||||
},
|
},
|
||||||
"D": {
|
"D": {
|
||||||
Limiter: rate.NewLimiter(1, 3),
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
lastSeen: now().Add(-10 * time.Minute),
|
lastSeen: time.Now().Add(-10 * time.Minute),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
if config.Skipper == nil {
|
if config.Skipper == nil {
|
||||||
config.Skipper = DefaultSkipper
|
config.Skipper = DefaultSkipper
|
||||||
}
|
}
|
||||||
now = time.Now
|
now := time.Now
|
||||||
if config.timeNow != nil {
|
if config.timeNow != nil {
|
||||||
now = config.timeNow
|
now = config.timeNow
|
||||||
}
|
}
|
||||||
@ -256,7 +256,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
config.BeforeNextFunc(c)
|
config.BeforeNextFunc(c)
|
||||||
}
|
}
|
||||||
err := next(c)
|
err := next(c)
|
||||||
if config.HandleError {
|
if err != nil && config.HandleError {
|
||||||
// When global error handler writes the error to the client the Response gets "committed". This state can be
|
// When global error handler writes the error to the client the Response gets "committed". This state can be
|
||||||
// checked with `c.Response().Committed` field.
|
// checked with `c.Response().Committed` field.
|
||||||
c.Error(err)
|
c.Error(err)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"fmt"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -77,17 +79,38 @@ func createRandomStringGenerator(length uint8) func() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func randomString(length uint8) string {
|
// https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls
|
||||||
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
var randomReaderPool = sync.Pool{New: func() interface{} {
|
||||||
|
return bufio.NewReader(rand.Reader)
|
||||||
|
}}
|
||||||
|
|
||||||
bytes := make([]byte, length)
|
const randomStringCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
_, err := rand.Read(bytes)
|
const randomStringCharsetLen = 52 // len(randomStringCharset)
|
||||||
|
const randomStringMaxByte = 255 - (256 % randomStringCharsetLen)
|
||||||
|
|
||||||
|
func randomString(length uint8) string {
|
||||||
|
reader := randomReaderPool.Get().(*bufio.Reader)
|
||||||
|
defer randomReaderPool.Put(reader)
|
||||||
|
|
||||||
|
b := make([]byte, length)
|
||||||
|
r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times
|
||||||
|
var i uint8 = 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, err := io.ReadFull(reader, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// we are out of random. let the request fail
|
panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)")
|
||||||
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
|
}
|
||||||
|
for _, rb := range r {
|
||||||
|
if rb > randomStringMaxByte {
|
||||||
|
// Skip this number to avoid bias.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b[i] = randomStringCharset[rb%randomStringCharsetLen]
|
||||||
|
i++
|
||||||
|
if i == length {
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for i, b := range bytes {
|
|
||||||
bytes[i] = charset[b%byte(len(charset))]
|
|
||||||
}
|
}
|
||||||
return string(bytes)
|
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -129,3 +130,31 @@ func TestRandomString(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRandomStringBias(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
const slen = 33
|
||||||
|
const loop = 100000
|
||||||
|
|
||||||
|
counts := make(map[rune]int)
|
||||||
|
var count int64
|
||||||
|
|
||||||
|
for i := 0; i < loop; i++ {
|
||||||
|
s := randomString(slen)
|
||||||
|
require.Equal(t, slen, len(s))
|
||||||
|
for _, b := range s {
|
||||||
|
counts[b]++
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, randomStringCharsetLen, len(counts))
|
||||||
|
|
||||||
|
avg := float64(count) / float64(len(counts))
|
||||||
|
for k, n := range counts {
|
||||||
|
diff := float64(n) / avg
|
||||||
|
if diff < 0.95 || diff > 1.05 {
|
||||||
|
t.Errorf("Bias on '%c': expected average %f, got %d", k, avg, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -95,6 +95,13 @@ func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|||||||
return r.Writer.(http.Hijacker).Hijack()
|
return r.Writer.(http.Hijacker).Hijack()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the original http.ResponseWriter.
|
||||||
|
// ResponseController can be used to access the original http.ResponseWriter.
|
||||||
|
// See [https://go.dev/blog/go1.20]
|
||||||
|
func (r *Response) Unwrap() http.ResponseWriter {
|
||||||
|
return r.Writer
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Response) reset(w http.ResponseWriter) {
|
func (r *Response) reset(w http.ResponseWriter) {
|
||||||
r.beforeFuncs = nil
|
r.beforeFuncs = nil
|
||||||
r.afterFuncs = nil
|
r.afterFuncs = nil
|
||||||
|
@ -72,3 +72,11 @@ func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponse_Unwrap(t *testing.T) {
|
||||||
|
e := New()
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
res := &Response{echo: e, Writer: rec}
|
||||||
|
|
||||||
|
assert.Equal(t, rec, res.Unwrap())
|
||||||
|
}
|
||||||
|
7
route.go
7
route.go
@ -81,7 +81,12 @@ func (r routeInfo) Reverse(params ...interface{}) string {
|
|||||||
ln := len(params)
|
ln := len(params)
|
||||||
n := 0
|
n := 0
|
||||||
for i, l := 0, len(r.path); i < l; i++ {
|
for i, l := 0, len(r.path); i < l; i++ {
|
||||||
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
|
hasBackslash := r.path[i] == '\\'
|
||||||
|
if hasBackslash && i+1 < l && r.path[i+1] == ':' {
|
||||||
|
i++ // backslash before colon escapes that colon. in that case skip backslash
|
||||||
|
}
|
||||||
|
if n < ln && (r.path[i] == anyLabel || (!hasBackslash && r.path[i] == paramLabel)) {
|
||||||
|
// in case of `*` wildcard or `:` (unescaped colon) param we replace everything till next slash or end of path
|
||||||
for ; i < l && r.path[i] != '/'; i++ {
|
for ; i < l && r.path[i] != '/'; i++ {
|
||||||
}
|
}
|
||||||
uri.WriteString(fmt.Sprintf("%v", params[n]))
|
uri.WriteString(fmt.Sprintf("%v", params[n]))
|
||||||
|
@ -421,3 +421,101 @@ func TestRoutes_FilterByName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteInfo_Reverse(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenParams []string
|
||||||
|
givenPath string
|
||||||
|
whenParams []interface{}
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok,static with no params",
|
||||||
|
givenPath: "/static",
|
||||||
|
expect: "/static",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok,static with non existent param",
|
||||||
|
givenParams: []string{"missing param"},
|
||||||
|
givenPath: "/static",
|
||||||
|
whenParams: []interface{}{"missing param"},
|
||||||
|
expect: "/static",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, wildcard with no params",
|
||||||
|
givenPath: "/static/*",
|
||||||
|
expect: "/static/*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, wildcard with params",
|
||||||
|
givenParams: []string{"foo.txt"},
|
||||||
|
givenPath: "/static/*",
|
||||||
|
whenParams: []interface{}{"foo.txt"},
|
||||||
|
expect: "/static/foo.txt",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, single param without param",
|
||||||
|
givenPath: "/params/:foo",
|
||||||
|
expect: "/params/:foo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, single param with param",
|
||||||
|
givenParams: []string{"one"},
|
||||||
|
givenPath: "/params/:foo",
|
||||||
|
whenParams: []interface{}{"one"},
|
||||||
|
expect: "/params/one",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multi param without params",
|
||||||
|
givenPath: "/params/:foo/bar/:qux",
|
||||||
|
expect: "/params/:foo/bar/:qux",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multi param with one param",
|
||||||
|
givenParams: []string{"one"},
|
||||||
|
givenPath: "/params/:foo/bar/:qux",
|
||||||
|
whenParams: []interface{}{"one"},
|
||||||
|
expect: "/params/one/bar/:qux",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multi param with all params",
|
||||||
|
givenParams: []string{"one", "two"},
|
||||||
|
givenPath: "/params/:foo/bar/:qux",
|
||||||
|
whenParams: []interface{}{"one", "two"},
|
||||||
|
expect: "/params/one/bar/two",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multi param + wildcard with all params",
|
||||||
|
givenParams: []string{"one", "two", "three"},
|
||||||
|
givenPath: "/params/:foo/bar/:qux/*",
|
||||||
|
whenParams: []interface{}{"one", "two", "three"},
|
||||||
|
expect: "/params/one/bar/two/three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, backslash is not escaped",
|
||||||
|
givenParams: []string{"test"},
|
||||||
|
givenPath: "/a\\b/:x",
|
||||||
|
whenParams: []interface{}{"test"},
|
||||||
|
expect: `/a\b/test`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, escaped colon verbs",
|
||||||
|
givenParams: []string{"PATCH"},
|
||||||
|
givenPath: "/params\\::customVerb",
|
||||||
|
whenParams: []interface{}{"PATCH"},
|
||||||
|
expect: `/params:PATCH`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := routeInfo{
|
||||||
|
path: tc.givenPath,
|
||||||
|
params: tc.givenParams,
|
||||||
|
name: tc.expect,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expect, r.Reverse(tc.whenParams...))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -96,6 +96,7 @@ type RouteInfo interface {
|
|||||||
Name() string
|
Name() string
|
||||||
|
|
||||||
Params() []string
|
Params() []string
|
||||||
|
// Reverse reverses route to URL string by replacing path parameters with given params values.
|
||||||
Reverse(params ...interface{}) string
|
Reverse(params ...interface{}) string
|
||||||
|
|
||||||
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
|
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
|
||||||
@ -1066,7 +1067,7 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.unescapePathParamValues && currentNode.kind != staticKind {
|
if r.unescapePathParamValues {
|
||||||
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
|
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
|
||||||
for i, p := range *pathParams {
|
for i, p := range *pathParams {
|
||||||
tmpVal, err := url.PathUnescape(p.Value)
|
tmpVal, err := url.PathUnescape(p.Value)
|
||||||
|
@ -3053,6 +3053,15 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) {
|
|||||||
{Name: "*", Value: " /with space"},
|
{Name: "*", Value: " /with space"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "ok, ending with static node, unescape = true",
|
||||||
|
givenUnescapePathParamValues: true,
|
||||||
|
whenURL: "/fourth/%20%2Fwith%20space/static",
|
||||||
|
expectPath: "/fourth/:id/static",
|
||||||
|
expectPathParams: PathParams{
|
||||||
|
{Name: "id", Value: " /with space"},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "ok, unescape = false",
|
name: "ok, unescape = false",
|
||||||
givenUnescapePathParamValues: false,
|
givenUnescapePathParamValues: false,
|
||||||
@ -3080,6 +3089,8 @@ func TestDefaultRouter_UnescapePathParamValues(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_, err = router.Add(Route{Method: http.MethodGet, Path: "/third/*", Handler: handlerFunc})
|
_, err = router.Add(Route{Method: http.MethodGet, Path: "/third/*", Handler: handlerFunc})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
_, err = router.Add(Route{Method: http.MethodGet, Path: "/fourth/:id/static", Handler: handlerFunc})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
target, _ := url.Parse(tc.whenURL)
|
target, _ := url.Parse(tc.whenURL)
|
||||||
req := httptest.NewRequest(http.MethodGet, target.String(), nil)
|
req := httptest.NewRequest(http.MethodGet, target.String(), nil)
|
||||||
|
Loading…
Reference in New Issue
Block a user