mirror of
https://github.com/labstack/echo.git
synced 2025-01-07 23:01:56 +02:00
Changes from master (from 70acd57105
to 5b36ce3612
)
This commit is contained in:
parent
0d62f0065f
commit
13a733fdf9
2
.github/workflows/checks.yml
vendored
2
.github/workflows/checks.yml
vendored
@ -14,7 +14,7 @@ permissions:
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
# run static analysis only with the latest Go version
|
# run static analysis only with the latest Go version
|
||||||
LATEST_GO_VERSION: 1.19
|
LATEST_GO_VERSION: "1.20"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check:
|
check:
|
||||||
|
4
.github/workflows/echo.yml
vendored
4
.github/workflows/echo.yml
vendored
@ -14,7 +14,7 @@ permissions:
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
# run coverage and benchmarks only with the latest Go version
|
# run coverage and benchmarks only with the latest Go version
|
||||||
LATEST_GO_VERSION: 1.19
|
LATEST_GO_VERSION: "1.20"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@ -25,7 +25,7 @@ jobs:
|
|||||||
# Echo tests with last four major releases (unless there are pressing vulnerabilities)
|
# Echo tests with last four major releases (unless there are pressing vulnerabilities)
|
||||||
# As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when
|
# As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when
|
||||||
# we derive from last four major releases promise.
|
# we derive from last four major releases promise.
|
||||||
go: [1.18, 1.19]
|
go: ["1.19", "1.20"]
|
||||||
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
|
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
|
4
Makefile
4
Makefile
@ -31,6 +31,6 @@ benchmark: ## Run benchmarks
|
|||||||
help: ## Display this help screen
|
help: ## Display this help screen
|
||||||
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
goversion ?= "1.17"
|
goversion ?= "1.20"
|
||||||
test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17
|
test_version: ## Run tests inside Docker with given version (defaults to 1.20 oldest supported). Example: make test_version goversion=1.20
|
||||||
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
|
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
|
||||||
|
33
README.md
33
README.md
@ -81,18 +81,29 @@ func hello(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
# Third-party middlewares
|
# Official middleware repositories
|
||||||
|
|
||||||
| Repository | Description |
|
Following list of middleware is maintained by Echo team.
|
||||||
|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|
||||||
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
|
| Repository | Description |
|
||||||
| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
|
|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
|
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
|
||||||
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
|
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
|
||||||
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
|
|
||||||
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
|
# Third-party middleware repositories
|
||||||
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
|
|
||||||
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
|
Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality
|
||||||
|
of middlewares in this list.
|
||||||
|
|
||||||
|
| Repository | Description |
|
||||||
|
|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
|
||||||
|
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
|
||||||
|
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
|
||||||
|
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
|
||||||
|
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
|
||||||
|
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
|
||||||
|
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
|
||||||
|
|
||||||
Please send a PR to add your own library here.
|
Please send a PR to add your own library here.
|
||||||
|
|
||||||
|
6
go.mod
6
go.mod
@ -3,9 +3,9 @@ module github.com/labstack/echo/v5
|
|||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/stretchr/testify v1.8.1
|
github.com/stretchr/testify v1.8.2
|
||||||
github.com/valyala/fasttemplate v1.2.2
|
github.com/valyala/fasttemplate v1.2.2
|
||||||
golang.org/x/net v0.4.0
|
golang.org/x/net v0.7.0
|
||||||
golang.org/x/time v0.3.0
|
golang.org/x/time v0.3.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,6 +13,6 @@ require (
|
|||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
golang.org/x/text v0.5.0 // indirect
|
golang.org/x/text v0.7.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
12
go.sum
12
go.sum
@ -8,16 +8,16 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
|
|||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
|
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
|
||||||
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||||
golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU=
|
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
|
||||||
golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
|
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||||
golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM=
|
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
|
||||||
golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
67
middleware/context_timeout.go
Normal file
67
middleware/context_timeout.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/labstack/echo/v5"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContextTimeoutConfig defines the config for ContextTimeout middleware.
|
||||||
|
type ContextTimeoutConfig struct {
|
||||||
|
// Skipper defines a function to skip middleware.
|
||||||
|
Skipper Skipper
|
||||||
|
|
||||||
|
// ErrorHandler is a function when error aries in middeware execution.
|
||||||
|
ErrorHandler func(c echo.Context, err error) error
|
||||||
|
|
||||||
|
// Timeout configures a timeout for the middleware
|
||||||
|
Timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
|
||||||
|
// when underlying method returns context.DeadlineExceeded error.
|
||||||
|
func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
|
||||||
|
return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextTimeoutWithConfig returns a Timeout middleware with config.
|
||||||
|
func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
|
||||||
|
return toMiddlewareOrPanic(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMiddleware converts Config to middleware.
|
||||||
|
func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||||
|
if config.Timeout == 0 {
|
||||||
|
return nil, errors.New("timeout must be set")
|
||||||
|
}
|
||||||
|
if config.Skipper == nil {
|
||||||
|
config.Skipper = DefaultSkipper
|
||||||
|
}
|
||||||
|
if config.ErrorHandler == nil {
|
||||||
|
config.ErrorHandler = func(c echo.Context, err error) error {
|
||||||
|
if err != nil && errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return echo.ErrServiceUnavailable.WithInternal(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
if config.Skipper(c) {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c.SetRequest(c.Request().WithContext(timeoutContext))
|
||||||
|
|
||||||
|
if err := next(c); err != nil {
|
||||||
|
return config.ErrorHandler(c, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}, nil
|
||||||
|
}
|
225
middleware/context_timeout_test.go
Normal file
225
middleware/context_timeout_test.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/labstack/echo/v5"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextTimeoutSkipper(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
|
||||||
|
Skipper: func(context echo.Context) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
Timeout: 10 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
err := m(func(c echo.Context) error {
|
||||||
|
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("response from handler")
|
||||||
|
})(c)
|
||||||
|
|
||||||
|
// if not skipped we would have not returned error due context timeout logic
|
||||||
|
assert.EqualError(t, err, "response from handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextTimeoutWithTimeout0(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
ContextTimeout(time.Duration(0))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextTimeoutErrorOutInHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
|
||||||
|
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
|
||||||
|
Timeout: 10 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
rec.Code = 1 // we want to be sure that even 200 will not be sent
|
||||||
|
err := m(func(c echo.Context) error {
|
||||||
|
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
|
||||||
|
// to handle returned error and this can be done only then handler has not yet committed (written status code)
|
||||||
|
// the response.
|
||||||
|
return echo.NewHTTPError(http.StatusTeapot, "err")
|
||||||
|
})(c)
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.EqualError(t, err, "code=418, message=err")
|
||||||
|
assert.Equal(t, 1, rec.Code)
|
||||||
|
assert.Equal(t, "", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextTimeoutSuccessfulRequest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
|
||||||
|
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
|
||||||
|
Timeout: 10 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
err := m(func(c echo.Context) error {
|
||||||
|
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
|
||||||
|
})(c)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||||
|
assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextTimeoutTestRequestClone(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
|
||||||
|
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
|
||||||
|
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
|
||||||
|
Timeout: 1 * time.Second,
|
||||||
|
})
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
err := m(func(c echo.Context) error {
|
||||||
|
// Cookie test
|
||||||
|
cookie, err := c.Request().Cookie("cookie")
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.EqualValues(t, "cookie", cookie.Name)
|
||||||
|
assert.EqualValues(t, "value", cookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Form values
|
||||||
|
if assert.NoError(t, c.Request().ParseForm()) {
|
||||||
|
assert.EqualValues(t, "value", c.Request().FormValue("form"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query string
|
||||||
|
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
|
||||||
|
return nil
|
||||||
|
})(c)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
timeout := 10 * time.Millisecond
|
||||||
|
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
|
||||||
|
Timeout: timeout,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
err := m(func(c echo.Context) error {
|
||||||
|
if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.String(http.StatusOK, "Hello, World!")
|
||||||
|
})(c)
|
||||||
|
|
||||||
|
assert.IsType(t, &echo.HTTPError{}, err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
|
||||||
|
assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
timeoutErrorHandler := func(c echo.Context, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return &echo.HTTPError{
|
||||||
|
Code: http.StatusServiceUnavailable,
|
||||||
|
Message: "Timeout! change me",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := 50 * time.Millisecond
|
||||||
|
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
|
||||||
|
Timeout: timeout,
|
||||||
|
ErrorHandler: timeoutErrorHandler,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
err := m(func(c echo.Context) error {
|
||||||
|
// NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order
|
||||||
|
// for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky.
|
||||||
|
|
||||||
|
if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Request Context should have a Deadline set by http.ContextTimeoutHandler
|
||||||
|
if _, ok := c.Request().Context().Deadline(); !ok {
|
||||||
|
assert.Fail(t, "No timeout set on Request Context")
|
||||||
|
}
|
||||||
|
return c.String(http.StatusOK, "Hello, World!")
|
||||||
|
})(c)
|
||||||
|
|
||||||
|
assert.IsType(t, &echo.HTTPError{}, err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
|
||||||
|
assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = timer.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
@ -78,6 +78,15 @@ type CORSConfig struct {
|
|||||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
|
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
|
||||||
AllowCredentials bool
|
AllowCredentials bool
|
||||||
|
|
||||||
|
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
|
||||||
|
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
|
||||||
|
//
|
||||||
|
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
|
||||||
|
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
|
||||||
|
//
|
||||||
|
// Optional. Default value is false.
|
||||||
|
UnsafeWildcardOriginWithAllowCredentials bool
|
||||||
|
|
||||||
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
|
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
|
||||||
// defines a list of headers that clients are allowed to access.
|
// defines a list of headers that clients are allowed to access.
|
||||||
//
|
//
|
||||||
@ -204,7 +213,7 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
} else {
|
} else {
|
||||||
// Check allowed origins
|
// Check allowed origins
|
||||||
for _, o := range config.AllowOrigins {
|
for _, o := range config.AllowOrigins {
|
||||||
if o == "*" && config.AllowCredentials {
|
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
|
||||||
allowOrigin = origin
|
allowOrigin = origin
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -11,106 +11,190 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCORS(t *testing.T) {
|
func TestCORS(t *testing.T) {
|
||||||
e := echo.New()
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenMW echo.MiddlewareFunc
|
||||||
|
whenMethod string
|
||||||
|
whenHeaders map[string]string
|
||||||
|
expectHeaders map[string]string
|
||||||
|
notExpectHeaders map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, wildcard origin",
|
||||||
|
whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
|
||||||
|
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, wildcard AllowedOrigin with no Origin header in request",
|
||||||
|
notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, specific AllowOrigins and AllowCredentials",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"localhost"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAge: 3600,
|
||||||
|
}),
|
||||||
|
whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
|
||||||
|
expectHeaders: map[string]string{
|
||||||
|
echo.HeaderAccessControlAllowOrigin: "localhost",
|
||||||
|
echo.HeaderAccessControlAllowCredentials: "true",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, preflight request with matching origin for `AllowOrigins`",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"localhost"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAge: 3600,
|
||||||
|
}),
|
||||||
|
whenMethod: http.MethodOptions,
|
||||||
|
whenHeaders: map[string]string{
|
||||||
|
echo.HeaderOrigin: "localhost",
|
||||||
|
echo.HeaderContentType: echo.MIMEApplicationJSON,
|
||||||
|
},
|
||||||
|
expectHeaders: map[string]string{
|
||||||
|
echo.HeaderAccessControlAllowOrigin: "localhost",
|
||||||
|
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||||
|
echo.HeaderAccessControlAllowCredentials: "true",
|
||||||
|
echo.HeaderAccessControlMaxAge: "3600",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"*"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAge: 3600,
|
||||||
|
}),
|
||||||
|
whenMethod: http.MethodOptions,
|
||||||
|
whenHeaders: map[string]string{
|
||||||
|
echo.HeaderOrigin: "localhost",
|
||||||
|
echo.HeaderContentType: echo.MIMEApplicationJSON,
|
||||||
|
},
|
||||||
|
expectHeaders: map[string]string{
|
||||||
|
echo.HeaderAccessControlAllowOrigin: "*", // Note: browsers will ignore and complain about responses having `*`
|
||||||
|
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||||
|
echo.HeaderAccessControlAllowCredentials: "true",
|
||||||
|
echo.HeaderAccessControlMaxAge: "3600",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"*"},
|
||||||
|
AllowCredentials: false, // important for this testcase
|
||||||
|
MaxAge: 3600,
|
||||||
|
}),
|
||||||
|
whenMethod: http.MethodOptions,
|
||||||
|
whenHeaders: map[string]string{
|
||||||
|
echo.HeaderOrigin: "localhost",
|
||||||
|
echo.HeaderContentType: echo.MIMEApplicationJSON,
|
||||||
|
},
|
||||||
|
expectHeaders: map[string]string{
|
||||||
|
echo.HeaderAccessControlAllowOrigin: "*",
|
||||||
|
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||||
|
echo.HeaderAccessControlMaxAge: "3600",
|
||||||
|
},
|
||||||
|
notExpectHeaders: map[string]string{
|
||||||
|
echo.HeaderAccessControlAllowCredentials: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"*"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
UnsafeWildcardOriginWithAllowCredentials: true, // important for this testcase
|
||||||
|
MaxAge: 3600,
|
||||||
|
}),
|
||||||
|
whenMethod: http.MethodOptions,
|
||||||
|
whenHeaders: map[string]string{
|
||||||
|
echo.HeaderOrigin: "localhost",
|
||||||
|
echo.HeaderContentType: echo.MIMEApplicationJSON,
|
||||||
|
},
|
||||||
|
expectHeaders: map[string]string{
|
||||||
|
echo.HeaderAccessControlAllowOrigin: "localhost", // This could end up as cross-origin attack
|
||||||
|
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||||
|
echo.HeaderAccessControlAllowCredentials: "true",
|
||||||
|
echo.HeaderAccessControlMaxAge: "3600",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, preflight request with Access-Control-Request-Headers",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"*"},
|
||||||
|
}),
|
||||||
|
whenMethod: http.MethodOptions,
|
||||||
|
whenHeaders: map[string]string{
|
||||||
|
echo.HeaderOrigin: "localhost",
|
||||||
|
echo.HeaderContentType: echo.MIMEApplicationJSON,
|
||||||
|
echo.HeaderAccessControlRequestHeaders: "Special-Request-Header",
|
||||||
|
},
|
||||||
|
expectHeaders: map[string]string{
|
||||||
|
echo.HeaderAccessControlAllowOrigin: "*",
|
||||||
|
echo.HeaderAccessControlAllowHeaders: "Special-Request-Header",
|
||||||
|
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"http://*.example.com"},
|
||||||
|
}),
|
||||||
|
whenMethod: http.MethodOptions,
|
||||||
|
whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"},
|
||||||
|
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *",
|
||||||
|
givenMW: CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"http://*.example.com"},
|
||||||
|
}),
|
||||||
|
whenMethod: http.MethodOptions,
|
||||||
|
whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"},
|
||||||
|
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
// Wildcard origin
|
mw := CORS()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
if tc.givenMW != nil {
|
||||||
rec := httptest.NewRecorder()
|
mw = tc.givenMW
|
||||||
c := e.NewContext(req, rec)
|
}
|
||||||
h := CORS()(func(c echo.Context) error { return echo.ErrNotFound })
|
h := mw(func(c echo.Context) error {
|
||||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
return nil
|
||||||
h(c)
|
})
|
||||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
|
||||||
|
|
||||||
// Wildcard AllowedOrigin with no Origin header in request
|
method := http.MethodGet
|
||||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
if tc.whenMethod != "" {
|
||||||
rec = httptest.NewRecorder()
|
method = tc.whenMethod
|
||||||
c = e.NewContext(req, rec)
|
}
|
||||||
h = CORS()(func(c echo.Context) error { return echo.ErrNotFound })
|
req := httptest.NewRequest(method, "/", nil)
|
||||||
h(c)
|
rec := httptest.NewRecorder()
|
||||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
c := e.NewContext(req, rec)
|
||||||
|
for k, v := range tc.whenHeaders {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
// Allow origins
|
err := h(c)
|
||||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
rec = httptest.NewRecorder()
|
|
||||||
c = e.NewContext(req, rec)
|
|
||||||
h = CORSWithConfig(CORSConfig{
|
|
||||||
AllowOrigins: []string{"localhost"},
|
|
||||||
AllowCredentials: true,
|
|
||||||
MaxAge: 3600,
|
|
||||||
})(func(c echo.Context) error { return echo.ErrNotFound })
|
|
||||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
|
||||||
h(c)
|
|
||||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
|
||||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
|
||||||
|
|
||||||
// Preflight request
|
assert.NoError(t, err)
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
header := rec.Header()
|
||||||
rec = httptest.NewRecorder()
|
for k, v := range tc.expectHeaders {
|
||||||
c = e.NewContext(req, rec)
|
assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v)
|
||||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
}
|
||||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
for k, v := range tc.notExpectHeaders {
|
||||||
cors := CORSWithConfig(CORSConfig{
|
if v == "" {
|
||||||
AllowOrigins: []string{"localhost"},
|
assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k)
|
||||||
AllowCredentials: true,
|
} else {
|
||||||
MaxAge: 3600,
|
assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v)
|
||||||
})
|
}
|
||||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
}
|
||||||
h(c)
|
})
|
||||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
}
|
||||||
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
|
||||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
|
||||||
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
|
||||||
|
|
||||||
// Preflight request with `AllowOrigins` *
|
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
|
||||||
rec = httptest.NewRecorder()
|
|
||||||
c = e.NewContext(req, rec)
|
|
||||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
|
||||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
|
||||||
cors = CORSWithConfig(CORSConfig{
|
|
||||||
AllowOrigins: []string{"*"},
|
|
||||||
AllowCredentials: true,
|
|
||||||
MaxAge: 3600,
|
|
||||||
})
|
|
||||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
|
||||||
h(c)
|
|
||||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
|
||||||
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
|
||||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
|
||||||
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
|
||||||
|
|
||||||
// Preflight request with Access-Control-Request-Headers
|
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
|
||||||
rec = httptest.NewRecorder()
|
|
||||||
c = e.NewContext(req, rec)
|
|
||||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
|
||||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
|
||||||
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
|
|
||||||
cors = CORSWithConfig(CORSConfig{
|
|
||||||
AllowOrigins: []string{"*"},
|
|
||||||
})
|
|
||||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
|
||||||
h(c)
|
|
||||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
|
||||||
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
|
|
||||||
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
|
||||||
|
|
||||||
// Preflight request with `AllowOrigins` which allow all subdomains with *
|
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
|
||||||
rec = httptest.NewRecorder()
|
|
||||||
c = e.NewContext(req, rec)
|
|
||||||
req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com")
|
|
||||||
cors = CORSWithConfig(CORSConfig{
|
|
||||||
AllowOrigins: []string{"http://*.example.com"},
|
|
||||||
})
|
|
||||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
|
||||||
h(c)
|
|
||||||
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
|
||||||
|
|
||||||
req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com")
|
|
||||||
h(c)
|
|
||||||
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_allowOriginScheme(t *testing.T) {
|
func Test_allowOriginScheme(t *testing.T) {
|
||||||
|
@ -121,9 +121,9 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
config.CookieSecure = true
|
config.CookieSecure = true
|
||||||
}
|
}
|
||||||
|
|
||||||
extractors, err := createExtractors(config.TokenLookup)
|
extractors, cErr := createExtractors(config.TokenLookup)
|
||||||
if err != nil {
|
if cErr != nil {
|
||||||
return nil, err
|
return nil, cErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
@ -27,8 +27,6 @@ const (
|
|||||||
ExtractorSourceCookie ExtractorSource = "cookie"
|
ExtractorSourceCookie ExtractorSource = "cookie"
|
||||||
// ExtractorSourceForm means value was extracted from request form values
|
// ExtractorSourceForm means value was extracted from request form values
|
||||||
ExtractorSourceForm ExtractorSource = "form"
|
ExtractorSourceForm ExtractorSource = "form"
|
||||||
// ExtractorSourceCustom means value was extracted by custom extractor
|
|
||||||
ExtractorSourceCustom ExtractorSource = "custom"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
|
// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
|
||||||
|
@ -99,9 +99,9 @@ func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
return nil, errors.New("echo key-auth middleware requires a validator function")
|
return nil, errors.New("echo key-auth middleware requires a validator function")
|
||||||
}
|
}
|
||||||
|
|
||||||
extractors, err := createExtractors(config.KeyLookup)
|
extractors, cErr := createExtractors(config.KeyLookup)
|
||||||
if err != nil {
|
if cErr != nil {
|
||||||
return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err)
|
return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr)
|
||||||
}
|
}
|
||||||
if len(extractors) == 0 {
|
if len(extractors) == 0 {
|
||||||
return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
|
return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v5"
|
"github.com/labstack/echo/v5"
|
||||||
@ -74,19 +73,20 @@ type ProxyBalancer interface {
|
|||||||
|
|
||||||
type commonBalancer struct {
|
type commonBalancer struct {
|
||||||
targets []*ProxyTarget
|
targets []*ProxyTarget
|
||||||
mutex sync.RWMutex
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// RandomBalancer implements a random load balancing technique.
|
// RandomBalancer implements a random load balancing technique.
|
||||||
type randomBalancer struct {
|
type randomBalancer struct {
|
||||||
*commonBalancer
|
commonBalancer
|
||||||
random *rand.Rand
|
random *rand.Rand
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundRobinBalancer implements a round-robin load balancing technique.
|
// RoundRobinBalancer implements a round-robin load balancing technique.
|
||||||
type roundRobinBalancer struct {
|
type roundRobinBalancer struct {
|
||||||
*commonBalancer
|
commonBalancer
|
||||||
i uint32
|
// tracking the index on `targets` slice for the next `*ProxyTarget` to be used
|
||||||
|
i int
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultProxyConfig is the default Proxy middleware config.
|
// DefaultProxyConfig is the default Proxy middleware config.
|
||||||
@ -135,32 +135,37 @@ func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
|||||||
|
|
||||||
// NewRandomBalancer returns a random proxy balancer.
|
// NewRandomBalancer returns a random proxy balancer.
|
||||||
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
|
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
|
||||||
b := &randomBalancer{commonBalancer: new(commonBalancer)}
|
b := randomBalancer{}
|
||||||
b.targets = targets
|
b.targets = targets
|
||||||
return b
|
b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
||||||
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRoundRobinBalancer returns a round-robin proxy balancer.
|
// NewRoundRobinBalancer returns a round-robin proxy balancer.
|
||||||
func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
|
func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
|
||||||
b := &roundRobinBalancer{commonBalancer: new(commonBalancer)}
|
b := roundRobinBalancer{}
|
||||||
b.targets = targets
|
b.targets = targets
|
||||||
return b
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTarget adds an upstream target to the list.
|
// AddTarget adds an upstream target to the list and returns `true`.
|
||||||
|
//
|
||||||
|
// However, if a target with the same name already exists then the operation is aborted returning `false`.
|
||||||
func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
|
func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
|
||||||
|
b.mutex.Lock()
|
||||||
|
defer b.mutex.Unlock()
|
||||||
for _, t := range b.targets {
|
for _, t := range b.targets {
|
||||||
if t.Name == target.Name {
|
if t.Name == target.Name {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b.mutex.Lock()
|
|
||||||
defer b.mutex.Unlock()
|
|
||||||
b.targets = append(b.targets, target)
|
b.targets = append(b.targets, target)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveTarget removes an upstream target from the list.
|
// RemoveTarget removes an upstream target from the list by name.
|
||||||
|
//
|
||||||
|
// Returns `true` on success, `false` if no target with the name is found.
|
||||||
func (b *commonBalancer) RemoveTarget(name string) bool {
|
func (b *commonBalancer) RemoveTarget(name string) bool {
|
||||||
b.mutex.Lock()
|
b.mutex.Lock()
|
||||||
defer b.mutex.Unlock()
|
defer b.mutex.Unlock()
|
||||||
@ -174,20 +179,36 @@ func (b *commonBalancer) RemoveTarget(name string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Next randomly returns an upstream target.
|
// Next randomly returns an upstream target.
|
||||||
|
//
|
||||||
|
// Note: `nil` is returned in case upstream target list is empty.
|
||||||
func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
||||||
if b.random == nil {
|
b.mutex.Lock()
|
||||||
b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
defer b.mutex.Unlock()
|
||||||
|
if len(b.targets) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
} else if len(b.targets) == 1 {
|
||||||
|
return b.targets[0], nil
|
||||||
}
|
}
|
||||||
b.mutex.RLock()
|
|
||||||
defer b.mutex.RUnlock()
|
|
||||||
return b.targets[b.random.Intn(len(b.targets))], nil
|
return b.targets[b.random.Intn(len(b.targets))], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next returns an upstream target using round-robin technique.
|
// Next returns an upstream target using round-robin technique.
|
||||||
|
//
|
||||||
|
// Note: `nil` is returned in case upstream target list is empty.
|
||||||
func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
|
||||||
b.i = b.i % uint32(len(b.targets))
|
b.mutex.Lock()
|
||||||
|
defer b.mutex.Unlock()
|
||||||
|
if len(b.targets) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
} else if len(b.targets) == 1 {
|
||||||
|
return b.targets[0], nil
|
||||||
|
}
|
||||||
|
// reset the index if out of bounds
|
||||||
|
if b.i >= len(b.targets) {
|
||||||
|
b.i = 0
|
||||||
|
}
|
||||||
t := b.targets[b.i]
|
t := b.targets[b.i]
|
||||||
atomic.AddUint32(&b.i, 1)
|
b.i++
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,7 +381,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type testProvider struct {
|
type testProvider struct {
|
||||||
*commonBalancer
|
commonBalancer
|
||||||
target *ProxyTarget
|
target *ProxyTarget
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
@ -398,7 +398,7 @@ func TestTargetProvider(t *testing.T) {
|
|||||||
url1, _ := url.Parse(t1.URL)
|
url1, _ := url.Parse(t1.URL)
|
||||||
|
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
tp := &testProvider{commonBalancer: new(commonBalancer)}
|
tp := &testProvider{}
|
||||||
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
|
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
|
||||||
e.Use(Proxy(tp))
|
e.Use(Proxy(tp))
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@ -413,7 +413,7 @@ func TestFailNextTarget(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
tp := &testProvider{commonBalancer: new(commonBalancer)}
|
tp := &testProvider{}
|
||||||
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
|
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
|
||||||
tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
|
tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
|
||||||
|
|
||||||
@ -424,3 +424,19 @@ func TestFailNextTarget(t *testing.T) {
|
|||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
|
assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRandomBalancerWithNoTargets(t *testing.T) {
|
||||||
|
// Assert balancer with empty targets does return `nil` on `Next()`
|
||||||
|
rb := NewRandomBalancer(nil)
|
||||||
|
target, err := rb.Next(nil)
|
||||||
|
assert.Nil(t, target)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
|
||||||
|
// Assert balancer with empty targets does return `nil` on `Next()`
|
||||||
|
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
|
||||||
|
target, err := rrb.Next(nil)
|
||||||
|
assert.Nil(t, target)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -165,13 +164,13 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
config.DirectoryListTemplate = directoryListHTMLTemplate
|
config.DirectoryListTemplate = directoryListHTMLTemplate
|
||||||
}
|
}
|
||||||
|
|
||||||
dirListTemplate, err := template.New("index").Parse(config.DirectoryListTemplate)
|
dirListTemplate, tErr := template.New("index").Parse(config.DirectoryListTemplate)
|
||||||
if err != nil {
|
if tErr != nil {
|
||||||
return nil, fmt.Errorf("echo static middleware directory list template parsing error: %w", err)
|
return nil, fmt.Errorf("echo static middleware directory list template parsing error: %w", tErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) (err error) {
|
||||||
if config.Skipper(c) {
|
if config.Skipper(c) {
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
@ -188,7 +187,7 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
|
name := path.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
|
||||||
|
|
||||||
if config.IgnoreBase {
|
if config.IgnoreBase {
|
||||||
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
|
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
|
||||||
@ -204,13 +203,13 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
currentFS = c.Echo().Filesystem
|
currentFS = c.Echo().Filesystem
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := openFile(currentFS, name)
|
file, err := currentFS.Open(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !isIgnorableOpenFileError(err) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// file with that path did not exist, so we continue down in middleware/handler chain, hoping that we end up in
|
||||||
// when file does not exist let handler to handle that request. if it succeeds then we are done
|
// handler that is meant to handle this request
|
||||||
err = next(c)
|
err = next(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -221,7 +220,7 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// is case HTML5 mode is enabled + echo 404 we serve index to the client
|
// is case HTML5 mode is enabled + echo 404 we serve index to the client
|
||||||
file, err = openFile(currentFS, filepath.Join(config.Root, config.Index))
|
file, err = currentFS.Open(path.Join(config.Root, config.Index))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -235,15 +234,13 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if info.IsDir() {
|
if info.IsDir() {
|
||||||
index, err := openFile(currentFS, filepath.Join(name, config.Index))
|
index, err := currentFS.Open(path.Join(name, config.Index))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if config.Browse {
|
if config.Browse {
|
||||||
return listDir(dirListTemplate, name, currentFS, file, c.Response())
|
return listDir(dirListTemplate, name, currentFS, file, c.Response())
|
||||||
}
|
}
|
||||||
|
|
||||||
if os.IsNotExist(err) {
|
return next(c)
|
||||||
return next(c)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer index.Close()
|
defer index.Close()
|
||||||
@ -261,11 +258,6 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openFile(fs fs.FS, name string) (fs.File, error) {
|
|
||||||
pathWithSlashes := filepath.ToSlash(name)
|
|
||||||
return fs.Open(pathWithSlashes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func serveFile(c echo.Context, file fs.File, info os.FileInfo) error {
|
func serveFile(c echo.Context, file fs.File, info os.FileInfo) error {
|
||||||
ff, ok := file.(io.ReadSeeker)
|
ff, ok := file.(io.ReadSeeker)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
12
middleware/static_other.go
Normal file
12
middleware/static_other.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// We ignore these errors as there could be handler that matches request path.
|
||||||
|
func isIgnorableOpenFileError(err error) bool {
|
||||||
|
return os.IsNotExist(err)
|
||||||
|
}
|
23
middleware/static_windows.go
Normal file
23
middleware/static_windows.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// We ignore these errors as there could be handler that matches request path.
|
||||||
|
//
|
||||||
|
// As of Go 1.20 filepath.Clean has different behaviour on OS related filesystems so we need to use path.Clean
|
||||||
|
// on Windows which has some caveats. The Open methods might return different errors than earlier versions and
|
||||||
|
// as of 1.20 path checks are more strict on the provided path and considers [UNC](https://en.wikipedia.org/wiki/Path_(computing)#UNC)
|
||||||
|
// paths with missing host etc parts as invalid. Previously it would result you `fs.ErrNotExist`.
|
||||||
|
//
|
||||||
|
// For 1.20@Windows we need to treat those errors the same as `fs.ErrNotExists` so we can continue handling
|
||||||
|
// errors in the middleware/handler chain. Otherwise we might end up with status 500 instead of finding a route
|
||||||
|
// or return 404 not found.
|
||||||
|
func isIgnorableOpenFileError(err error) bool {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
errTxt := err.Error()
|
||||||
|
return errTxt == "http: invalid or unsafe file path" || errTxt == "invalid path"
|
||||||
|
}
|
@ -2798,6 +2798,19 @@ func TestRouter_Routes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouterNoRoutablePath(t *testing.T) {
|
||||||
|
e := New()
|
||||||
|
|
||||||
|
e.router.Add(Route{Path: "/static", Name: "/static", Method: http.MethodGet, Handler: func(Context) error { return nil }})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/notFound", nil)
|
||||||
|
c := e.NewContext(req, nil)
|
||||||
|
|
||||||
|
e.router.Route(c.(RoutableContext))
|
||||||
|
// No routable path, don't set Path.
|
||||||
|
assert.Equal(t, "", c.Path())
|
||||||
|
}
|
||||||
|
|
||||||
func benchmarkRouterRoutes(b *testing.B, routes []testRoute, routesToFind []testRoute) {
|
func benchmarkRouterRoutes(b *testing.B, routes []testRoute, routesToFind []testRoute) {
|
||||||
e := New()
|
e := New()
|
||||||
r := e.router
|
r := e.router
|
||||||
|
Loading…
Reference in New Issue
Block a user