1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-07 23:01:56 +02:00

Bring over changes from master (latest commit 135c511f5d)

This commit is contained in:
toimtoimtoim 2022-12-04 22:17:48 +02:00
parent 74022662be
commit 74b8c4368c
No known key found for this signature in database
GPG Key ID: 0443E21F7D9928AF
39 changed files with 946 additions and 364 deletions

View File

@ -27,9 +27,10 @@ jobs:
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases # Echo tests with last four major releases (unless there are pressing vulnerabilities)
# except v5 starts from 1.17 until there is last four major releases after that # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when
go: [1.17, 1.18] # we derive from last four major releases promise.
go: [1.17, 1.18, 1.19]
name: ${{ matrix.os }} @ Go ${{ matrix.go }} name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@ -43,19 +44,23 @@ jobs:
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install Dependencies - name: Run Tests
run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
- name: Install dependencies for checks
run: | run: |
go install golang.org/x/lint/golint@latest go install golang.org/x/lint/golint@latest
go install honnef.co/go/tools/cmd/staticcheck@latest go install honnef.co/go/tools/cmd/staticcheck@latest
- name: Run Tests - name: Run golint
run: | run: golint -set_exit_status ./...
golint -set_exit_status ./...
staticcheck ./... - name: Run staticcheck
go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... run: staticcheck ./...
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' if: success() && matrix.go == 1.19 && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v1 uses: codecov/codecov-action@v3
with: with:
token: token:
fail_ci_if_error: false fail_ci_if_error: false
@ -64,7 +69,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
go: [1.18] go: [1.19]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@ -91,10 +96,12 @@ jobs:
run: | run: |
cd previous cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New) - name: Run Benchmark (New)
run: | run: |
cd new cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat - name: Run Benchstat
run: | run: |
benchstat previous/benchmark.txt new/benchmark.txt benchstat previous/benchmark.txt new/benchmark.txt

View File

@ -1,5 +1,74 @@
# Changelog # Changelog
## v4.10.0 - 2022-xx-xx
**Security**
This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
several vulnerabilities fixed in these libraries.
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
## v4.9.1 - 2022-10-12
**Fixes**
* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295)
**Enhancements**
* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272)
* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291)
* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254)
* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275)
* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301)
## v4.9.0 - 2022-09-04
**Security**
* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260)
**Enhancements**
* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257)
* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247)
## v4.8.0 - 2022-08-10
**Most notable things**
You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237)
```go
e.Add("COPY", "/*", func(c echo.Context) error
return c.String(http.StatusOK, "OK COPY")
})
```
You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217)
```go
e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
g := e.Group("/images")
g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
```
**Enhancements**
* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127)
* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145)
* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187)
* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191)
* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176)
* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209)
* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217)
* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227)
* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237)
## v4.7.2 - 2022-03-16 ## v4.7.2 - 2022-03-16
**Fixes** **Fixes**

View File

@ -13,6 +13,7 @@ init:
@go install honnef.co/go/tools/cmd/staticcheck@latest @go install honnef.co/go/tools/cmd/staticcheck@latest
lint: ## Lint the files lint: ## Lint the files
@staticcheck ${PKG_LIST}
@golint -set_exit_status ${PKG_LIST} @golint -set_exit_status ${PKG_LIST}
vet: ## Vet the files vet: ## Vet the files

View File

@ -11,14 +11,11 @@
## Supported Go versions ## Supported Go versions
Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that. Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
Therefore a Go version capable of understanding /vN suffixed imports is required: Therefore a Go version capable of understanding /vN suffixed imports is required:
- 1.9.7+
- 1.10.3+
- 1.16+
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
way of using Echo going forward. way of using Echo going forward.
@ -95,6 +92,7 @@ func hello(c echo.Context) error {
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | | [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | | [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | | [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
Please send a PR to add your own library here. Please send a PR to add your own library here.

View File

@ -190,44 +190,39 @@ func TestToMultipleFields(t *testing.T) {
} }
func TestBindJSON(t *testing.T) { func TestBindJSON(t *testing.T) {
assert := assert.New(t) testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON)
testBindOkay(assert, strings.NewReader(userJSON), nil, MIMEApplicationJSON) testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON)
testBindOkay(assert, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON)
testBindArrayOkay(assert, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON)
testBindArrayOkay(assert, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{})
testBindError(assert, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{})
} }
func TestBindXML(t *testing.T) { func TestBindXML(t *testing.T) {
assert := assert.New(t) testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
testBindOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
testBindArrayOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New(""))
testBindArrayOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{})
testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{})
testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML)
testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML)
testBindOkay(assert, strings.NewReader(userXML), nil, MIMETextXML) testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New(""))
testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMETextXML) testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{})
testBindError(assert, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{})
testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{})
testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{})
} }
func TestBindForm(t *testing.T) { func TestBindForm(t *testing.T) {
assert := assert.New(t) testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm)
testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm)
testBindOkay(assert, strings.NewReader(userForm), nil, MIMEApplicationForm)
testBindOkay(assert, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm)
e := New() e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm)) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, MIMEApplicationForm) req.Header.Set(HeaderContentType, MIMEApplicationForm)
err := c.Bind(&[]struct{ Field string }{}) err := c.Bind(&[]struct{ Field string }{})
assert.Error(err) assert.Error(t, err)
} }
func TestBindQueryParams(t *testing.T) { func TestBindQueryParams(t *testing.T) {
@ -363,14 +358,13 @@ func TestBindUnmarshalParam(t *testing.T) {
err := c.Bind(&result) err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
assert := assert.New(t) if assert.NoError(t, err) {
if assert.NoError(err) {
// assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) // assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
assert.Equal(ts, result.T) assert.Equal(t, ts, result.T)
assert.Equal(StringArray([]string{"one", "two", "three"}), result.SA) assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal([]Timestamp{ts, ts}, result.TA) assert.Equal(t, []Timestamp{ts, ts}, result.TA)
assert.Equal(Struct{""}, result.ST) // child struct does not have a field with matching tag assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag
assert.Equal("baz", result.StWithTag.Foo) // child struct has field with matching tag assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag
} }
} }
@ -472,37 +466,34 @@ func TestBindMultipartForm(t *testing.T) {
mw.Close() mw.Close()
body := bodyBuffer.Bytes() body := bodyBuffer.Bytes()
assert := assert.New(t) testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType())
testBindOkay(assert, bytes.NewReader(body), nil, mw.FormDataContentType()) testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType())
testBindOkay(assert, bytes.NewReader(body), dummyQuery, mw.FormDataContentType())
} }
func TestBindUnsupportedMediaType(t *testing.T) { func TestBindUnsupportedMediaType(t *testing.T) {
assert := assert.New(t) testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
} }
func TestBindbindData(t *testing.T) { func TestBindbindData(t *testing.T) {
a := assert.New(t)
ts := new(bindTestStruct) ts := new(bindTestStruct)
err := bindData(ts, values, "form") err := bindData(ts, values, "form")
a.NoError(err) assert.NoError(t, err)
a.Equal(0, ts.I) assert.Equal(t, 0, ts.I)
a.Equal(int8(0), ts.I8) assert.Equal(t, int8(0), ts.I8)
a.Equal(int16(0), ts.I16) assert.Equal(t, int16(0), ts.I16)
a.Equal(int32(0), ts.I32) assert.Equal(t, int32(0), ts.I32)
a.Equal(int64(0), ts.I64) assert.Equal(t, int64(0), ts.I64)
a.Equal(uint(0), ts.UI) assert.Equal(t, uint(0), ts.UI)
a.Equal(uint8(0), ts.UI8) assert.Equal(t, uint8(0), ts.UI8)
a.Equal(uint16(0), ts.UI16) assert.Equal(t, uint16(0), ts.UI16)
a.Equal(uint32(0), ts.UI32) assert.Equal(t, uint32(0), ts.UI32)
a.Equal(uint64(0), ts.UI64) assert.Equal(t, uint64(0), ts.UI64)
a.Equal(false, ts.B) assert.Equal(t, false, ts.B)
a.Equal(float32(0), ts.F32) assert.Equal(t, float32(0), ts.F32)
a.Equal(float64(0), ts.F64) assert.Equal(t, float64(0), ts.F64)
a.Equal("", ts.S) assert.Equal(t, "", ts.S)
a.Equal("", ts.cantSet) assert.Equal(t, "", ts.cantSet)
} }
func TestBindParam(t *testing.T) { func TestBindParam(t *testing.T) {
@ -580,7 +571,6 @@ func TestBindUnmarshalTypeError(t *testing.T) {
} }
func TestBindSetWithProperType(t *testing.T) { func TestBindSetWithProperType(t *testing.T) {
assert := assert.New(t)
ts := new(bindTestStruct) ts := new(bindTestStruct)
typ := reflect.TypeOf(ts).Elem() typ := reflect.TypeOf(ts).Elem()
val := reflect.ValueOf(ts).Elem() val := reflect.ValueOf(ts).Elem()
@ -595,9 +585,9 @@ func TestBindSetWithProperType(t *testing.T) {
} }
val := values[typeField.Name][0] val := values[typeField.Name][0]
err := setWithProperType(typeField.Type.Kind(), val, structField) err := setWithProperType(typeField.Type.Kind(), val, structField)
assert.NoError(err) assert.NoError(t, err)
} }
assertBindTestStruct(assert, ts) assertBindTestStruct(t, ts)
type foo struct { type foo struct {
Bar bytes.Buffer Bar bytes.Buffer
@ -605,7 +595,7 @@ func TestBindSetWithProperType(t *testing.T) {
v := &foo{} v := &foo{}
typ = reflect.TypeOf(v).Elem() typ = reflect.TypeOf(v).Elem()
val = reflect.ValueOf(v).Elem() val = reflect.ValueOf(v).Elem()
assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
} }
func TestSetIntField(t *testing.T) { func TestSetIntField(t *testing.T) {
@ -730,28 +720,28 @@ func BenchmarkBindbindDataWithTags(b *testing.B) {
err = bindData(ts, values, "form") err = bindData(ts, values, "form")
} }
assert.NoError(err) assert.NoError(err)
assertBindTestStruct(assert, (*bindTestStruct)(ts)) assertBindTestStruct(b, (*bindTestStruct)(ts))
} }
func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) { func assertBindTestStruct(t testing.TB, ts *bindTestStruct) {
a.Equal(0, ts.I) assert.Equal(t, 0, ts.I)
a.Equal(int8(8), ts.I8) assert.Equal(t, int8(8), ts.I8)
a.Equal(int16(16), ts.I16) assert.Equal(t, int16(16), ts.I16)
a.Equal(int32(32), ts.I32) assert.Equal(t, int32(32), ts.I32)
a.Equal(int64(64), ts.I64) assert.Equal(t, int64(64), ts.I64)
a.Equal(uint(0), ts.UI) assert.Equal(t, uint(0), ts.UI)
a.Equal(uint8(8), ts.UI8) assert.Equal(t, uint8(8), ts.UI8)
a.Equal(uint16(16), ts.UI16) assert.Equal(t, uint16(16), ts.UI16)
a.Equal(uint32(32), ts.UI32) assert.Equal(t, uint32(32), ts.UI32)
a.Equal(uint64(64), ts.UI64) assert.Equal(t, uint64(64), ts.UI64)
a.Equal(true, ts.B) assert.Equal(t, true, ts.B)
a.Equal(float32(32.5), ts.F32) assert.Equal(t, float32(32.5), ts.F32)
a.Equal(float64(64.5), ts.F64) assert.Equal(t, float64(64.5), ts.F64)
a.Equal("test", ts.S) assert.Equal(t, "test", ts.S)
a.Equal("", ts.GetCantSet()) assert.Equal(t, "", ts.GetCantSet())
} }
func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { func testBindOkay(t testing.TB, r io.Reader, query url.Values, ctype string) {
e := New() e := New()
path := "/" path := "/"
if len(query) > 0 { if len(query) > 0 {
@ -763,13 +753,13 @@ func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctyp
req.Header.Set(HeaderContentType, ctype) req.Header.Set(HeaderContentType, ctype)
u := new(user) u := new(user)
err := c.Bind(u) err := c.Bind(u)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(1, u.ID) assert.Equal(t, 1, u.ID)
assert.Equal("Jon Snow", u.Name) assert.Equal(t, "Jon Snow", u.Name)
} }
} }
func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) {
e := New() e := New()
path := "/" path := "/"
if len(query) > 0 { if len(query) > 0 {
@ -781,14 +771,14 @@ func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values,
req.Header.Set(HeaderContentType, ctype) req.Header.Set(HeaderContentType, ctype)
u := []user{} u := []user{}
err := c.Bind(&u) err := c.Bind(&u)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(1, len(u)) assert.Equal(t, 1, len(u))
assert.Equal(1, u[0].ID) assert.Equal(t, 1, u[0].ID)
assert.Equal("Jon Snow", u[0].Name) assert.Equal(t, "Jon Snow", u[0].Name)
} }
} }
func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expectedInternal error) { func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) {
e := New() e := New()
req := httptest.NewRequest(http.MethodPost, "/", r) req := httptest.NewRequest(http.MethodPost, "/", r)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -800,14 +790,14 @@ func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expecte
switch { switch {
case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML), case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML),
strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
if assert.IsType(new(HTTPError), err) { if assert.IsType(t, new(HTTPError), err) {
assert.Equal(http.StatusBadRequest, err.(*HTTPError).Code) assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code)
assert.IsType(expectedInternal, err.(*HTTPError).Internal) assert.IsType(t, expectedInternal, err.(*HTTPError).Internal)
} }
default: default:
if assert.IsType(new(HTTPError), err) { if assert.IsType(t, new(HTTPError), err) {
assert.Equal(ErrUnsupportedMediaType, err) assert.Equal(t, ErrUnsupportedMediaType, err)
assert.IsType(expectedInternal, err.(*HTTPError).Internal) assert.IsType(t, expectedInternal, err.(*HTTPError).Internal)
} }
} }
} }

View File

@ -185,6 +185,14 @@ type Context interface {
// Redirect redirects the request to a provided URL with status code. // Redirect redirects the request to a provided URL with status code.
Redirect(code int, url string) error Redirect(code int, url string) error
// Error invokes the registered global HTTP error handler. Generally used by middleware.
// A side-effect of calling global error handler is that now Response has been committed (sent to the client) and
// middlewares up in chain can not change Response status code or Response body anymore.
//
// Avoid using this method in handlers as no middleware will be able to effectively handle errors after that.
// Instead of calling this method in handler return your error and let it be handled by middlewares or global error handler.
Error(err error)
// Echo returns the `Echo` instance. // Echo returns the `Echo` instance.
// //
// WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them // WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them
@ -337,11 +345,16 @@ func (c *DefaultContext) RealIP() string {
if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
i := strings.IndexAny(ip, ",") i := strings.IndexAny(ip, ",")
if i > 0 { if i > 0 {
return strings.TrimSpace(ip[:i]) xffip := strings.TrimSpace(ip[:i])
xffip = strings.TrimPrefix(xffip, "[")
xffip = strings.TrimSuffix(xffip, "]")
return xffip
} }
return ip return ip
} }
if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
ip = strings.TrimPrefix(ip, "[")
ip = strings.TrimSuffix(ip, "]")
return ip return ip
} }
ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) ra, _, _ := net.SplitHostPort(c.request.RemoteAddr)
@ -757,6 +770,16 @@ func (c *DefaultContext) Redirect(code int, url string) error {
return nil return nil
} }
// Error invokes the registered global HTTP error handler. Generally used by middleware.
// A side-effect of calling global error handler is that now Response has been committed (sent to the client) and
// middlewares up in chain can not change Response status code or Response body anymore.
//
// Avoid using this method in handlers as no middleware will be able to effectively handle errors after that.
// Instead of calling this method in handler return your error and let it be handled by middlewares or global error handler.
func (c *DefaultContext) Error(err error) {
c.echo.HTTPErrorHandler(c, err)
}
// Echo returns the `Echo` instance. // Echo returns the `Echo` instance.
func (c *DefaultContext) Echo() *Echo { func (c *DefaultContext) Echo() *Echo {
return c.echo return c.echo

View File

@ -377,6 +377,19 @@ func TestContext(t *testing.T) {
assert.Equal(t, 0, len(c.QueryParams())) assert.Equal(t, 0, len(c.QueryParams()))
} }
func TestContext_Error(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Error(errors.New("error"))
assert.True(t, c.Response().Committed)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, `{"message":"Internal Server Error"}`+"\n", rec.Body.String())
}
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -1061,6 +1074,30 @@ func TestContext_RealIP(t *testing.T) {
}, },
"127.0.0.1", "127.0.0.1",
}, },
{
&DefaultContext{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}},
},
},
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{
&DefaultContext{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}},
},
},
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{
&DefaultContext{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}},
},
},
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{ {
&DefaultContext{ &DefaultContext{
request: &http.Request{ request: &http.Request{
@ -1071,6 +1108,17 @@ func TestContext_RealIP(t *testing.T) {
}, },
"192.168.0.1", "192.168.0.1",
}, },
{
&DefaultContext{
request: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{"[2001:db8::1]"},
},
},
},
"2001:db8::1",
},
{ {
&DefaultContext{ &DefaultContext{
request: &http.Request{ request: &http.Request{

99
echo.go
View File

@ -3,36 +3,36 @@ Package echo implements high performance, minimalist Go web framework.
Example: Example:
package main package main
import ( import (
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware" "github.com/labstack/echo/v5/middleware"
"log" "log"
"net/http" "net/http"
) )
// Handler // Handler
func hello(c echo.Context) error { func hello(c echo.Context) error {
return c.String(http.StatusOK, "Hello, World!") return c.String(http.StatusOK, "Hello, World!")
} }
func main() { func main() {
// Echo instance // Echo instance
e := echo.New() e := echo.New()
// Middleware // Middleware
e.Use(middleware.Logger()) e.Use(middleware.Logger())
e.Use(middleware.Recover()) e.Use(middleware.Recover())
// Routes // Routes
e.GET("/", hello) e.GET("/", hello)
// Start server // Start server
if err := e.Start(":8080"); err != http.ErrServerClosed { if err := e.Start(":8080"); err != http.ErrServerClosed {
log.Fatal(err) log.Fatal(err)
}
} }
}
Learn more at https://echo.labstack.com Learn more at https://echo.labstack.com
*/ */
@ -49,7 +49,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"sync" "sync"
) )
@ -420,8 +419,11 @@ func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) Ro
return e.Add(RouteNotFound, path, h, m...) return e.Add(RouteNotFound, path, h, m...)
} }
// Any registers a new route for all supported HTTP methods and path with matching handler // Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler
// in the router with optional route-level middleware. Panics on error. // in the router with optional route-level middleware.
//
// Note: this method only adds specific set of supported HTTP methods as handler and is not true
// "catch-any-arbitrary-method" way of matching requests.
func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
errs := make([]error, 0) errs := make([]error, 0)
ris := make(Routes, 0) ris := make(Routes, 0)
@ -515,7 +517,7 @@ func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) Handle
p = c.Request().URL.Path // path must not be empty. p = c.Request().URL.Path // path must not be empty.
if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
// Redirect to ends with "/" // Redirect to ends with "/"
return c.Redirect(http.StatusMovedPermanently, p+"/") return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/"))
} }
return fsFile(c, name, fileSystem) return fsFile(c, name, fileSystem)
} }
@ -625,7 +627,7 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c = e.contextPool.Get().(*DefaultContext) c = e.contextPool.Get().(*DefaultContext)
} }
c.Reset(r, w) c.Reset(r, w)
var h func(Context) error var h HandlerFunc
if e.premiddleware == nil { if e.premiddleware == nil {
h = applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) h = applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...)
@ -654,12 +656,15 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// options. // options.
// //
// In need of customization use: // In need of customization use:
// sc := echo.StartConfig{Address: ":8080"} //
// sc := echo.StartConfig{Address: ":8080"}
// if err := sc.Start(e); err != http.ErrServerClosed { // if err := sc.Start(e); err != http.ErrServerClosed {
// log.Fatal(err) // log.Fatal(err)
// } // }
//
// // or standard library `http.Server` // // or standard library `http.Server`
// s := http.Server{Addr: ":8080", Handler: e} //
// s := http.Server{Addr: ":8080", Handler: e}
// if err := s.ListenAndServe(); err != http.ErrServerClosed { // if err := s.ListenAndServe(); err != http.ErrServerClosed {
// log.Fatal(err) // log.Fatal(err)
// } // }
@ -741,7 +746,7 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) {
// we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
// fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
// were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
if isRelativePath(root) { if !filepath.IsAbs(root) {
root = filepath.Join(dFS.prefix, root) root = filepath.Join(dFS.prefix, root)
} }
return &defaultFS{ return &defaultFS{
@ -752,21 +757,6 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) {
return fs.Sub(currentFs, root) return fs.Sub(currentFs, root)
} }
func isRelativePath(path string) bool {
if path == "" {
return true
}
if path[0] == '/' {
return false
}
if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 {
// https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names
// https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats
return false
}
return true
}
// MustSubFS creates sub FS from current filesystem or panic on failure. // MustSubFS creates sub FS from current filesystem or panic on failure.
// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
// //
@ -780,3 +770,12 @@ func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
} }
return subFs return subFs
} }
func sanitizeURI(uri string) string {
// double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
// we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
uri = "/" + strings.TrimLeft(uri, `/\`)
}
return uri
}

View File

@ -187,6 +187,15 @@ func TestEcho_StaticFS(t *testing.T) {
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
}, },
{
name: "open redirect vulnerability",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: "/open.redirect.hackercom%2f..",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad
expectBodyStartsWith: "",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -1163,7 +1172,7 @@ func TestEcho_customContext(t *testing.T) {
func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
e := New() e := New()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
u := req.URL u := req.URL
w := httptest.NewRecorder() w := httptest.NewRecorder()

14
go.mod
View File

@ -3,17 +3,17 @@ module github.com/labstack/echo/v5
go 1.17 go 1.17
require ( require (
github.com/golang-jwt/jwt/v4 v4.2.0 github.com/golang-jwt/jwt/v4 v4.4.3
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.8.1
github.com/valyala/fasttemplate v1.2.1 github.com/valyala/fasttemplate v1.2.2
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 golang.org/x/net v0.2.0
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 golang.org/x/time v0.3.0
) )
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/text v0.3.3 // indirect golang.org/x/text v0.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

55
go.sum
View File

@ -1,31 +1,54 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU=
github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -3,7 +3,6 @@ package echo
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io/fs" "io/fs"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -71,7 +70,7 @@ func TestGroupFile(t *testing.T) {
e := New() e := New()
g := e.Group("/group") g := e.Group("/group")
g.File("/walle", "_fixture/images/walle.png") g.File("/walle", "_fixture/images/walle.png")
expectedData, err := ioutil.ReadFile("_fixture/images/walle.png") expectedData, err := os.ReadFile("_fixture/images/walle.png")
assert.Nil(t, err) assert.Nil(t, err)
req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) req := httptest.NewRequest(http.MethodGet, "/group/walle", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()

7
ip.go
View File

@ -227,6 +227,8 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
return func(req *http.Request) string { return func(req *http.Request) string {
realIP := req.Header.Get(HeaderXRealIP) realIP := req.Header.Get(HeaderXRealIP)
if realIP != "" { if realIP != "" {
realIP = strings.TrimPrefix(realIP, "[")
realIP = strings.TrimSuffix(realIP, "]")
if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) {
return realIP return realIP
} }
@ -248,7 +250,10 @@ func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
} }
ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP) ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP)
for i := len(ips) - 1; i >= 0; i-- { for i := len(ips) - 1; i >= 0; i-- {
ip := net.ParseIP(strings.TrimSpace(ips[i])) ips[i] = strings.TrimSpace(ips[i])
ips[i] = strings.TrimPrefix(ips[i], "[")
ips[i] = strings.TrimSuffix(ips[i], "]")
ip := net.ParseIP(ips[i])
if ip == nil { if ip == nil {
// Unable to parse IP; cannot trust entire records // Unable to parse IP; cannot trust entire records
return directIP return directIP

View File

@ -459,6 +459,7 @@ func TestExtractIPDirect(t *testing.T) {
func TestExtractIPFromRealIPHeader(t *testing.T) { func TestExtractIPFromRealIPHeader(t *testing.T) {
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
_, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
var testCases = []struct { var testCases = []struct {
name string name string
@ -493,6 +494,16 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
}, },
expectIP: "203.0.113.1", expectIP: "203.0.113.1",
}, },
{
name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
expectIP: "2001:db8::113:1",
},
{ {
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
@ -506,6 +517,19 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
}, },
expectIP: "203.0.113.199", expectIP: "203.0.113.199",
}, },
{
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"[2001:db8::113:199]"},
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
expectIP: "2001:db8::113:199",
},
{ {
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
@ -520,6 +544,20 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
}, },
expectIP: "203.0.113.199", expectIP: "203.0.113.199",
}, },
{
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"[2001:db8::113:199]"},
HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
expectIP: "2001:db8::113:199",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -532,6 +570,7 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
func TestExtractIPFromXFFHeader(t *testing.T) { func TestExtractIPFromXFFHeader(t *testing.T) {
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
_, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
var testCases = []struct { var testCases = []struct {
name string name string
@ -566,6 +605,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) {
}, },
expectIP: "127.0.0.3", expectIP: "127.0.0.3",
}, },
{
name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"},
},
RemoteAddr: "[fe80::1]:8080",
},
expectIP: "fe80::3",
},
{ {
name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
whenRequest: http.Request{ whenRequest: http.Request{
@ -576,6 +625,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) {
}, },
expectIP: "203.0.113.1", expectIP: "203.0.113.1",
}, },
{
name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted
},
RemoteAddr: "[2001:db8::2]:8080",
},
expectIP: "2001:db8::2",
},
{ {
name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
givenTrustOptions: []TrustOption{ givenTrustOptions: []TrustOption{
@ -595,6 +654,25 @@ func TestExtractIPFromXFFHeader(t *testing.T) {
}, },
expectIP: "203.0.100.100", // this is first trusted IP in XFF chain expectIP: "203.0.100.100", // this is first trusted IP in XFF chain
}, },
{
name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
givenTrustOptions: []TrustOption{
TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
},
// from request its seems that request has been proxied through 6 servers.
// 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed)
// 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs)
// 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office)
// 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
// 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"},
},
RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP
},
expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain
},
} }
for _, tc := range testCases { for _, tc := range testCases {

View File

@ -1,7 +1,7 @@
package echo package echo
import ( import (
testify "github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -16,16 +16,14 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*DefaultContext) c := e.NewContext(req, rec).(*DefaultContext)
assert := testify.New(t)
// Echo // Echo
assert.Equal(e, c.Echo()) assert.Equal(t, e, c.Echo())
// Request // Request
assert.NotNil(c.Request()) assert.NotNil(t, c.Request())
// Response // Response
assert.NotNil(c.Response()) assert.NotNil(t, c.Response())
//-------- //--------
// Default JSON encoder // Default JSON encoder
@ -34,16 +32,16 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
enc := new(DefaultJSONSerializer) enc := new(DefaultJSONSerializer)
err := enc.Serialize(c, user{1, "Jon Snow"}, "") err := enc.Serialize(c, user{1, "Jon Snow"}, "")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(userJSON+"\n", rec.Body.String()) assert.Equal(t, userJSON+"\n", rec.Body.String())
} }
req = httptest.NewRequest(http.MethodPost, "/", nil) req = httptest.NewRequest(http.MethodPost, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*DefaultContext) c = e.NewContext(req, rec).(*DefaultContext)
err = enc.Serialize(c, user{1, "Jon Snow"}, " ") err = enc.Serialize(c, user{1, "Jon Snow"}, " ")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(userJSONPretty+"\n", rec.Body.String()) assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
} }
} }
@ -55,16 +53,14 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*DefaultContext) c := e.NewContext(req, rec).(*DefaultContext)
assert := testify.New(t)
// Echo // Echo
assert.Equal(e, c.Echo()) assert.Equal(t, e, c.Echo())
// Request // Request
assert.NotNil(c.Request()) assert.NotNil(t, c.Request())
// Response // Response
assert.NotNil(c.Response()) assert.NotNil(t, c.Response())
//-------- //--------
// Default JSON encoder // Default JSON encoder
@ -74,8 +70,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
var u = user{} var u = user{}
err := enc.Deserialize(c, &u) err := enc.Deserialize(c, &u)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(u, user{ID: 1, Name: "Jon Snow"}) assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"})
} }
var userUnmarshalSyntaxError = user{} var userUnmarshalSyntaxError = user{}
@ -83,8 +79,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*DefaultContext) c = e.NewContext(req, rec).(*DefaultContext)
err = enc.Deserialize(c, &userUnmarshalSyntaxError) err = enc.Deserialize(c, &userUnmarshalSyntaxError)
assert.IsType(&HTTPError{}, err) assert.IsType(t, &HTTPError{}, err)
assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value")
var userUnmarshalTypeError = struct { var userUnmarshalTypeError = struct {
ID string `json:"id"` ID string `json:"id"`
@ -95,7 +91,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*DefaultContext) c = e.NewContext(req, rec).(*DefaultContext)
err = enc.Deserialize(c, &userUnmarshalTypeError) err = enc.Deserialize(c, &userUnmarshalTypeError)
assert.IsType(&HTTPError{}, err) assert.IsType(t, &HTTPError{}, err)
assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string")
} }

View File

@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
@ -62,9 +61,9 @@ func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Request // Request
reqBody := []byte{} reqBody := []byte{}
if c.Request().Body != nil { if c.Request().Body != nil {
reqBody, _ = ioutil.ReadAll(c.Request().Body) reqBody, _ = io.ReadAll(c.Request().Body)
} }
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset
// Response // Response
resBody := new(bytes.Buffer) resBody := new(bytes.Buffer)

View File

@ -2,7 +2,7 @@ package middleware
import ( import (
"errors" "errors"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -19,7 +19,7 @@ func TestBodyDump(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := func(c echo.Context) error { h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body) body, err := io.ReadAll(c.Request().Body)
if err != nil { if err != nil {
return err return err
} }

View File

@ -2,7 +2,7 @@ package middleware
import ( import (
"bytes" "bytes"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -18,7 +18,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := func(c echo.Context) error { h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body) body, err := io.ReadAll(c.Request().Body)
if err != nil { if err != nil {
return err return err
} }
@ -77,18 +77,18 @@ func TestBodyLimitReader(t *testing.T) {
} }
reader := &limitedReader{ reader := &limitedReader{
BodyLimitConfig: config, BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)), reader: io.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec), context: e.NewContext(req, rec),
} }
// read all should return ErrStatusRequestEntityTooLarge // read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader) _, err := io.ReadAll(reader)
he := err.(*echo.HTTPError) he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed // reset reader and read two bytes must succeed
bt := make([]byte, 2) bt := make([]byte, 2)
reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw))) reader.Reset(e.NewContext(req, rec), io.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt) n, err := reader.Read(bt)
assert.Equal(t, 2, n) assert.Equal(t, 2, n)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
@ -97,7 +97,7 @@ func TestBodyLimitReader(t *testing.T) {
func TestBodyLimit_skipper(t *testing.T) { func TestBodyLimit_skipper(t *testing.T) {
e := echo.New() e := echo.New()
h := func(c echo.Context) error { h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body) body, err := io.ReadAll(c.Request().Body)
if err != nil { if err != nil {
return err return err
} }
@ -129,7 +129,7 @@ func TestBodyLimitWithConfig(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := func(c echo.Context) error { h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body) body, err := io.ReadAll(c.Request().Body)
if err != nil { if err != nil {
return err return err
} }
@ -151,7 +151,7 @@ func TestBodyLimit(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := func(c echo.Context) error { h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body) body, err := io.ReadAll(c.Request().Body)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,7 +5,6 @@ import (
"compress/gzip" "compress/gzip"
"errors" "errors"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -71,7 +70,7 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
i := pool.Get() i := pool.Get()
w, ok := i.(*gzip.Writer) w, ok := i.(*gzip.Writer)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) return echo.NewHTTPErrorWithInternal(http.StatusInternalServerError, i.(error))
} }
rw := res.Writer rw := res.Writer
w.Reset(rw) w.Reset(rw)
@ -85,7 +84,7 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// nothing is written to body or error is returned. // nothing is written to body or error is returned.
// See issue #424, #407. // See issue #424, #407.
res.Writer = rw res.Writer = rw
w.Reset(ioutil.Discard) w.Reset(io.Discard)
} }
w.Close() w.Close()
pool.Put(w) pool.Put(w)
@ -131,7 +130,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
func gzipCompressPool(config GzipConfig) sync.Pool { func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{ return sync.Pool{
New: func() interface{} { New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) w, err := gzip.NewWriterLevel(io.Discard, config.Level)
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,7 +3,6 @@ package middleware
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -203,7 +202,7 @@ func TestGzipWithStatic(t *testing.T) {
r, err := gzip.NewReader(rec.Body) r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) { if assert.NoError(t, err) {
defer r.Close() defer r.Close()
want, err := ioutil.ReadFile("../_fixture/images/walle.png") want, err := os.ReadFile("../_fixture/images/walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.ReadFrom(r) buf.ReadFrom(r)

View File

@ -14,46 +14,85 @@ type CORSConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource. // AllowOrigins determines the value of the Access-Control-Allow-Origin
// response header. This header defines a list of origins that may access the
// resource. The wildcard characters '*' and '?' are supported and are
// converted to regex fragments '.*' and '.' accordingly.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
AllowOrigins []string AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the // AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If // origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is // an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored. // set, AllowOrigins is ignored.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional. // Optional.
AllowOriginFunc func(origin string) (bool, error) AllowOriginFunc func(origin string) (bool, error)
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods determines the value of the Access-Control-Allow-Methods
// This is used in response to a preflight request. // response header. This header specified the list of methods allowed when
// accessing the resource. This is used in response to a preflight request.
//
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
// If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value // If `allowMethods` is left empty, this middleware will fill for preflight
// request `Access-Control-Allow-Methods` header value
// from `Allow` header that echo.Router set into context. // from `Allow` header that echo.Router set into context.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
AllowMethods []string AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when // AllowHeaders determines the value of the Access-Control-Allow-Headers
// making the actual request. This is in response to a preflight request. // response header. This header is used in response to a preflight request to
// indicate which HTTP headers can be used when making the actual request.
//
// Optional. Default value []string{}. // Optional. Default value []string{}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
AllowHeaders []string AllowHeaders []string
// AllowCredentials indicates whether or not the response to the request // AllowCredentials determines the value of the
// can be exposed when the credentials flag is true. When used as part of // Access-Control-Allow-Credentials response header. This header indicates
// a response to a preflight request, this indicates whether or not the // whether or not the response to the request can be exposed when the
// actual request can be made using credentials. // credentials mode (Request.credentials) is true. When used as part of a
// Optional. Default value false. // response to a preflight request, this indicates whether or not the actual
// request can be made using credentials. See also
// [MDN: Access-Control-Allow-Credentials].
//
// Optional. Default value false, in which case the header is not set.
//
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // See "Exploiting CORS misconfigurations for Bitcoins and bounties",
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
AllowCredentials bool AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to // ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// access. // defines a list of headers that clients are allowed to access.
// Optional. Default value []string{}. //
// Optional. Default value []string{}, in which case the header is not set.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
ExposeHeaders []string ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request // MaxAge determines the value of the Access-Control-Max-Age response header.
// can be cached. // This header indicates how long (in seconds) the results of a preflight
// Optional. Default value 0. // request can be cached.
//
// Optional. Default value 0. The header is set only if MaxAge > 0.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
MaxAge int MaxAge int
} }
@ -65,13 +104,22 @@ var DefaultCORSConfig = CORSConfig{
} }
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS // See also [MDN: Cross-Origin Resource Sharing (CORS)].
//
// Security: Poorly configured CORS can compromise security because it allows
// relaxation of the browser's Same-Origin policy. See [Exploiting CORS
// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin
// resource sharing (CORS)] for more details.
//
// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors
func CORS() echo.MiddlewareFunc { func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig) return CORSWithConfig(DefaultCORSConfig)
} }
// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. // CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
// See: `CORS()`. // See: [CORS].
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config) return toMiddlewareOrPanic(config)
} }

View File

@ -63,6 +63,9 @@ type CSRFConfig struct {
// Indicates SameSite mode of the CSRF cookie. // Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode. // Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite CookieSameSite http.SameSite
// ErrorHandler defines a function which is executed for returning custom errors.
ErrorHandler func(c echo.Context, err error) error
} }
// ErrCSRFInvalid is returned when CSRF check fails // ErrCSRFInvalid is returned when CSRF check fails
@ -159,10 +162,17 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
lastTokenErr = ErrCSRFInvalid lastTokenErr = ErrCSRFInvalid
} }
} }
var finalErr error
if lastTokenErr != nil { if lastTokenErr != nil {
return lastTokenErr finalErr = lastTokenErr
} else if lastExtractorErr != nil { } else if lastExtractorErr != nil {
return echo.ErrBadRequest.WithInternal(lastExtractorErr) finalErr = echo.ErrBadRequest.WithInternal(lastExtractorErr)
}
if finalErr != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(c, finalErr)
}
return finalErr
} }
} }

View File

@ -392,3 +392,25 @@ func TestCSRFConfig_skipper(t *testing.T) {
}) })
} }
} }
func TestCSRFErrorHandling(t *testing.T) {
cfg := CSRFConfig{
ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
},
}
e := echo.New()
e.POST("/", func(c echo.Context) error {
return c.String(http.StatusNotImplemented, "should not end up here")
})
e.Use(CSRFWithConfig(cfg))
req := httptest.NewRequest(http.MethodPost, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
}

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"errors" "errors"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -35,7 +35,7 @@ func TestDecompress(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body) b, err := io.ReadAll(req.Body)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, body, string(b)) assert.Equal(t, body, string(b))
} }
@ -97,7 +97,7 @@ func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body) b, err := io.ReadAll(req.Body)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, body, string(b)) assert.Equal(t, body, string(b))
} }
@ -114,7 +114,7 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body) b, err := io.ReadAll(req.Body)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, b, body) assert.NotEqual(t, b, body)
assert.Equal(t, b, gz) assert.Equal(t, b, gz)
@ -171,7 +171,7 @@ func TestDecompressSkipper(t *testing.T) {
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body) reqBody, err := io.ReadAll(c.Request().Body)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, body, string(reqBody)) assert.Equal(t, body, string(reqBody))
} }
@ -202,7 +202,7 @@ func TestDecompressPoolError(t *testing.T) {
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body) reqBody, err := io.ReadAll(c.Request().Body)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, body, string(reqBody)) assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError) assert.Equal(t, rec.Code, http.StatusInternalServerError)

View File

@ -51,6 +51,26 @@ var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error) type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error)
// CreateExtractors creates ValuesExtractors from given lookups.
// Lookups is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract key from the request.
// Possible values:
// - "header:<name>" or "header:<name>:<cut-prefix>"
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
// want to cut is `<auth-scheme> ` note the space at the end.
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
// - "query:<name>"
// - "param:<name>"
// - "form:<name>"
// - "cookie:<name>"
//
// Multiple sources example:
// - "header:Authorization,header:X-Api-Key"
func CreateExtractors(lookups string) ([]ValuesExtractor, error) {
return createExtractors(lookups)
}
func createExtractors(lookups string) ([]ValuesExtractor, error) { func createExtractors(lookups string) ([]ValuesExtractor, error) {
if lookups == "" { if lookups == "" {
return nil, nil return nil, nil

View File

@ -100,7 +100,7 @@ func TestCreateExtractors(t *testing.T) {
c.SetRawPathParams(&tc.givenPathParams) c.SetRawPathParams(&tc.givenPathParams)
} }
extractors, err := createExtractors(tc.whenLoopups) extractors, err := CreateExtractors(tc.whenLoopups)
if tc.expectCreateError != "" { if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError) assert.EqualError(t, err, tc.expectCreateError)
return return

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"strconv" "strconv"
@ -34,6 +35,7 @@ type LoggerConfig struct {
// - host // - host
// - method // - method
// - path // - path
// - route
// - protocol // - protocol
// - referer // - referer
// - user_agent // - user_agent
@ -46,6 +48,7 @@ type LoggerConfig struct {
// - header:<NAME> // - header:<NAME>
// - query:<NAME> // - query:<NAME>
// - form:<NAME> // - form:<NAME>
// - custom (see CustomTagFunc field)
// //
// Example "${remote_ip} ${status}" // Example "${remote_ip} ${status}"
// //
@ -55,6 +58,11 @@ type LoggerConfig struct {
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat. // Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string CustomTimeFormat string
// CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf.
// Make sure that outputted text creates valid JSON string with other logged tags.
// Optional.
CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error)
// Output is a writer where logs in JSON format are written. // Output is a writer where logs in JSON format are written.
// Optional. Default destination `echo.Logger.Infof()` // Optional. Default destination `echo.Logger.Infof()`
Output io.Writer Output io.Writer
@ -111,6 +119,11 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
start := time.Now() start := time.Now()
err := next(c) err := next(c)
if err = next(c); err != nil {
// When global error handler writes the error to the client the Response gets "committed". This state can be
// checked with `c.Response().Committed` field.
c.Error(err)
}
stop := time.Now() stop := time.Now()
buf := config.pool.Get().(*bytes.Buffer) buf := config.pool.Get().(*bytes.Buffer)
@ -119,20 +132,25 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
_, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "custom":
if config.CustomTagFunc == nil {
return 0, nil
}
return config.CustomTagFunc(c, buf)
case "time_unix": case "time_unix":
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) return buf.WriteString(strconv.FormatInt(stop.Unix(), 10))
case "time_unix_milli": case "time_unix_milli":
return buf.WriteString(strconv.FormatInt(time.Now().UnixMilli(), 10)) return buf.WriteString(strconv.FormatInt(stop.UnixMilli(), 10))
case "time_unix_micro": case "time_unix_micro":
return buf.WriteString(strconv.FormatInt(time.Now().UnixMicro(), 10)) return buf.WriteString(strconv.FormatInt(stop.UnixMicro(), 10))
case "time_unix_nano": case "time_unix_nano":
return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) return buf.WriteString(strconv.FormatInt(stop.UnixNano(), 10))
case "time_rfc3339": case "time_rfc3339":
return buf.WriteString(time.Now().Format(time.RFC3339)) return buf.WriteString(stop.Format(time.RFC3339))
case "time_rfc3339_nano": case "time_rfc3339_nano":
return buf.WriteString(time.Now().Format(time.RFC3339Nano)) return buf.WriteString(stop.Format(time.RFC3339Nano))
case "time_custom": case "time_custom":
return buf.WriteString(time.Now().Format(config.CustomTimeFormat)) return buf.WriteString(stop.Format(config.CustomTimeFormat))
case "id": case "id":
id := req.Header.Get(echo.HeaderXRequestID) id := req.Header.Get(echo.HeaderXRequestID)
if id == "" { if id == "" {
@ -153,6 +171,8 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
p = "/" p = "/"
} }
return buf.WriteString(p) return buf.WriteString(p)
case "route":
return buf.WriteString(c.Path())
case "protocol": case "protocol":
return buf.WriteString(req.Proto) return buf.WriteString(req.Proto)
case "referer": case "referer":
@ -162,7 +182,8 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
case "status": case "status":
status := res.Status status := res.Status
if err != nil { if err != nil {
if httpErr, ok := err.(*echo.HTTPError); ok { var httpErr *echo.HTTPError
if errors.As(err, &httpErr) {
status = httpErr.Code status = httpErr.Code
} }
} }

View File

@ -92,17 +92,17 @@ func TestLoggerTemplate(t *testing.T) {
e.Use(LoggerWithConfig(LoggerConfig{ e.Use(LoggerWithConfig(LoggerConfig{
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` +
`"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
`"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "route":"${route}", "referer":"${referer}",` +
`"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` +
`"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
Output: buf, Output: buf,
})) }))
e.GET("/", func(c echo.Context) error { e.GET("/users/:id", func(c echo.Context) error {
return c.String(http.StatusOK, "Header Logged") return c.String(http.StatusOK, "Header Logged")
}) })
req := httptest.NewRequest(http.MethodGet, "/?username=apagano-param&password=secret", nil) req := httptest.NewRequest(http.MethodGet, "/users/1?username=apagano-param&password=secret", nil)
req.RequestURI = "/" req.RequestURI = "/"
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
req.Header.Add("Referer", "google.com") req.Header.Add("Referer", "google.com")
@ -127,7 +127,8 @@ func TestLoggerTemplate(t *testing.T) {
"hexvalue": false, "hexvalue": false,
"GET": true, "GET": true,
"127.0.0.1": true, "127.0.0.1": true,
"\"path\":\"/\"": true, "\"path\":\"/users/1\"": true,
"\"route\":\"/users/:id\"": true,
"\"uri\":\"/\"": true, "\"uri\":\"/\"": true,
"\"status\":200": true, "\"status\":200": true,
"\"bytes_in\":0": true, "\"bytes_in\":0": true,
@ -291,3 +292,25 @@ func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) {
buf.Reset() buf.Reset()
} }
} }
func TestLoggerCustomTagFunc(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
e.Use(LoggerWithConfig(LoggerConfig{
Format: `{"method":"${method}",${custom}}` + "\n",
CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) {
return buf.WriteString(`"tag":"my-value"`)
},
Output: buf,
}))
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "custom time stamp test")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, `{"method":"GET","tag":"my-value"}`+"\n", buf.String())
}

View File

@ -69,7 +69,7 @@ type ProxyTarget struct {
type ProxyBalancer interface { type ProxyBalancer interface {
AddTarget(*ProxyTarget) bool AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool RemoveTarget(string) bool
Next(echo.Context) *ProxyTarget Next(echo.Context) (*ProxyTarget, error)
} }
type commonBalancer struct { type commonBalancer struct {
@ -174,21 +174,21 @@ func (b *commonBalancer) RemoveTarget(name string) bool {
} }
// Next randomly returns an upstream target. // Next randomly returns an upstream target.
func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) {
if b.random == nil { if b.random == nil {
b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
} }
b.mutex.RLock() b.mutex.RLock()
defer b.mutex.RUnlock() defer b.mutex.RUnlock()
return b.targets[b.random.Intn(len(b.targets))] return b.targets[b.random.Intn(len(b.targets))], nil
} }
// Next returns an upstream target using round-robin technique. // Next returns an upstream target using round-robin technique.
func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) {
b.i = b.i % uint32(len(b.targets)) b.i = b.i % uint32(len(b.targets))
t := b.targets[b.i] t := b.targets[b.i]
atomic.AddUint32(&b.i, 1) atomic.AddUint32(&b.i, 1)
return t return t, nil
} }
// Proxy returns a Proxy middleware. // Proxy returns a Proxy middleware.
@ -236,7 +236,10 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
tgt := config.Balancer.Next(c) tgt, err := config.Balancer.Next(c)
if err != nil {
return err
}
c.Set(config.ContextKey, tgt) c.Set(config.ContextKey, tgt)
if err := rewriteURL(config.RegexRewrite, req); err != nil { if err := rewriteURL(config.RegexRewrite, req); err != nil {

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -18,7 +18,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
//Assert expected with url.EscapedPath method to obtain the path. // Assert expected with url.EscapedPath method to obtain the path.
func TestProxy(t *testing.T) { func TestProxy(t *testing.T) {
// Setup // Setup
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -94,7 +94,7 @@ func TestProxy(t *testing.T) {
e.Use(ProxyWithConfig(ProxyConfig{ e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb, Balancer: rrb,
ModifyResponse: func(res *http.Response) error { ModifyResponse: func(res *http.Response) error {
res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified"))) res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified")))
res.Header.Set("X-Modified", "1") res.Header.Set("X-Modified", "1")
return nil return nil
}, },
@ -379,3 +379,48 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
timeoutStop.Done() timeoutStop.Done()
assert.Equal(t, 499, rec.Code) assert.Equal(t, 499, rec.Code)
} }
type testProvider struct {
*commonBalancer
target *ProxyTarget
err error
}
func (p *testProvider) Next(c echo.Context) (*ProxyTarget, error) {
return p.target, p.err
}
func TestTargetProvider(t *testing.T) {
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "target 1")
}))
defer t1.Close()
url1, _ := url.Parse(t1.URL)
e := echo.New()
tp := &testProvider{commonBalancer: new(commonBalancer)}
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
e.Use(Proxy(tp))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
e.ServeHTTP(rec, req)
body := rec.Body.String()
assert.Equal(t, "target 1", body)
}
func TestFailNextTarget(t *testing.T) {
url1, err := url.Parse("http://dummy:8080")
assert.Nil(t, err)
e := echo.New()
tp := &testProvider{commonBalancer: new(commonBalancer)}
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
e.Use(Proxy(tp))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
e.ServeHTTP(rec, req)
body := rec.Body.String()
assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
}

View File

@ -9,10 +9,16 @@ import (
// Example for `fmt.Printf` // Example for `fmt.Printf`
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogStatus: true, // LogStatus: true,
// LogURI: true, // LogURI: true,
// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) // if v.Error == nil {
// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status)
// } else {
// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error)
// }
// return nil // return nil
// }, // },
// })) // }))
@ -20,15 +26,23 @@ import (
// Example for Zerolog (https://github.com/rs/zerolog) // Example for Zerolog (https://github.com/rs/zerolog)
// logger := zerolog.New(os.Stdout) // logger := zerolog.New(os.Stdout)
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true, // LogURI: true,
// LogStatus: true, // LogStatus: true,
// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info(). // if v.Error == nil {
// Date("request_start", v.StartTime). // logger.Info().
// Str("URI", v.URI). // Str("URI", v.URI).
// Int("status", v.Status). // Int("status", v.Status).
// Msg("request") // Msg("request")
// // } else {
// logger.Error().
// Err(v.Error).
// Str("URI", v.URI).
// Int("status", v.Status).
// Msg("request error")
// }
// return nil // return nil
// }, // },
// })) // }))
@ -36,31 +50,47 @@ import (
// Example for Zap (https://github.com/uber-go/zap) // Example for Zap (https://github.com/uber-go/zap)
// logger, _ := zap.NewProduction() // logger, _ := zap.NewProduction()
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true, // LogURI: true,
// LogStatus: true, // LogStatus: true,
// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info("request", // if v.Error == nil {
// zap.Time("request_start", v.StartTime), // logger.Info("request",
// zap.String("URI", v.URI), // zap.String("URI", v.URI),
// zap.Int("status", v.Status), // zap.Int("status", v.Status),
// ) // )
// // } else {
// logger.Error("request error",
// zap.String("URI", v.URI),
// zap.Int("status", v.Status),
// zap.Error(v.Error),
// )
// }
// return nil // return nil
// }, // },
// })) // }))
// //
// Example for Logrus (https://github.com/sirupsen/logrus) // Example for Logrus (https://github.com/sirupsen/logrus)
// log := logrus.New() // log := logrus.New()
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true, // LogURI: true,
// LogStatus: true, // LogStatus: true,
// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { // LogError: true,
// log.WithFields(logrus.Fields{ // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// "request_start": values.StartTime, // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// "URI": values.URI, // if v.Error == nil {
// "status": values.Status, // log.WithFields(logrus.Fields{
// }).Info("request") // "URI": v.URI,
// // "status": v.Status,
// }).Info("request")
// } else {
// log.WithFields(logrus.Fields{
// "URI": v.URI,
// "status": v.Status,
// "error": v.Error,
// }).Error("request error")
// }
// return nil // return nil
// }, // },
// })) // }))
@ -76,6 +106,13 @@ type RequestLoggerConfig struct {
// Mandatory. // Mandatory.
LogValuesFunc func(c echo.Context, v RequestLoggerValues) error LogValuesFunc func(c echo.Context, v RequestLoggerValues) error
// HandleError instructs logger to call global error handler when next middleware/handler returns an error.
// This is useful when you have custom error handler that can decide to use different status codes.
//
// A side-effect of calling global error handler is that now Response has been committed and sent to the client
// and middlewares up in chain can not change Response status code or response body.
HandleError bool
// LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call). // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call).
LogLatency bool LogLatency bool
// LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`) // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`)
@ -219,6 +256,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
config.BeforeNextFunc(c) config.BeforeNextFunc(c)
} }
err := next(c) err := next(c)
if config.HandleError {
// When global error handler writes the error to the client the Response gets "committed". This state can be
// checked with `c.Response().Committed` field.
c.Error(err)
}
v := RequestLoggerValues{ v := RequestLoggerValues{
StartTime: start, StartTime: start,
@ -266,8 +308,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
} }
if config.LogStatus { if config.LogStatus {
v.Status = res.Status v.Status = res.Status
if err != nil { if err != nil && !config.HandleError {
if httpErr, ok := err.(*echo.HTTPError); ok { // this block should not be executed in case of HandleError=true as the global error handler will decide
// the status code. In that case status code could be different from what err contains.
var httpErr *echo.HTTPError
if errors.As(err, &httpErr) {
v.Status = httpErr.Code v.Status = httpErr.Code
} }
} }
@ -310,7 +355,10 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil {
return errOnLog return errOnLog
} }
// in case of HandleError=true we are returning the error that we already have handled with global error handler
// this is deliberate as this error could be useful for upstream middlewares and default global error handler
// will ignore that error when it bubbles up in middleware chain.
// Committed response can be checked in custom error handler with `c.Response().Committed` field
return err return err
} }
}, nil }, nil

View File

@ -103,12 +103,12 @@ func TestRequestLogger_beforeNextFunc(t *testing.T) {
func TestRequestLogger_logError(t *testing.T) { func TestRequestLogger_logError(t *testing.T) {
e := echo.New() e := echo.New()
var expect RequestLoggerValues var actual RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
LogError: true, LogError: true,
LogStatus: true, LogStatus: true,
LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
expect = values actual = values
return nil return nil
}, },
})) }))
@ -123,8 +123,52 @@ func TestRequestLogger_logError(t *testing.T) {
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotAcceptable, rec.Code) assert.Equal(t, http.StatusNotAcceptable, rec.Code)
assert.Equal(t, http.StatusNotAcceptable, expect.Status) assert.Equal(t, http.StatusNotAcceptable, actual.Status)
assert.EqualError(t, expect.Error, "code=406, message=nope") assert.EqualError(t, actual.Error, "code=406, message=nope")
}
func TestRequestLogger_HandleError(t *testing.T) {
e := echo.New()
var actual RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
timeNow: func() time.Time {
return time.Unix(1631045377, 0).UTC()
},
HandleError: true,
LogError: true,
LogStatus: true,
LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
actual = values
return nil
},
}))
// to see if "HandleError" works we create custom error handler that uses its own status codes
e.HTTPErrorHandler = func(c echo.Context, err error) {
if c.Response().Committed {
return
}
c.JSON(http.StatusTeapot, "custom error handler")
}
e.GET("/test", func(c echo.Context) error {
return echo.NewHTTPError(http.StatusForbidden, "nope")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
expect := RequestLoggerValues{
StartTime: time.Unix(1631045377, 0).UTC(),
Status: http.StatusTeapot,
Error: echo.NewHTTPError(http.StatusForbidden, "nope"),
}
assert.Equal(t, expect, actual)
} }
func TestRequestLogger_LogValuesFuncError(t *testing.T) { func TestRequestLogger_LogValuesFuncError(t *testing.T) {

View File

@ -1,7 +1,7 @@
package middleware package middleware
import ( import (
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -195,7 +195,7 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
defer rec.Result().Body.Close() defer rec.Result().Body.Close()
bodyBytes, _ := ioutil.ReadAll(rec.Result().Body) bodyBytes, _ := io.ReadAll(rec.Result().Body)
assert.Equal(t, "hosts", string(bodyBytes)) assert.Equal(t, "hosts", string(bodyBytes))
} }
} }

View File

@ -27,7 +27,7 @@ func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(AddTrailingSlashConfig{}) return AddTrailingSlashWithConfig(AddTrailingSlashConfig{})
} }
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config or panics on invalid configuration. // AddTrailingSlashWithConfig returns an AddTrailingSlash middleware with config or panics on invalid configuration.
func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc { func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config) return toMiddlewareOrPanic(config)
} }

View File

@ -216,8 +216,8 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
return nil return nil
} }
he, ok := err.(*echo.HTTPError) var he *echo.HTTPError
if !(ok && config.HTML5 && he.Code == http.StatusNotFound) { if !(errors.As(err, &he) && config.HTML5 && he.Code == http.StatusNotFound) {
return err return err
} }
// is case HTML5 mode is enabled + echo 404 we serve index to the client // is case HTML5 mode is enabled + echo 404 we serve index to the client

View File

@ -257,6 +257,15 @@ func TestStatic_GroupWithStatic(t *testing.T) {
expectHeaderLocation: "/group/folder/", expectHeaderLocation: "/group/folder/",
expectBodyStartsWith: "", expectBodyStartsWith: "",
}, },
{
name: "Directory redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/group/folder%2f..",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/group/folder/../",
expectBodyStartsWith: "",
},
{ {
name: "Prefixed directory 404 (request URL without slash)", name: "Prefixed directory 404 (request URL without slash)",
givenGroup: "_fixture", givenGroup: "_fixture",

View File

@ -10,13 +10,13 @@ import (
// Router is interface for routing request contexts to registered routes. // Router is interface for routing request contexts to registered routes.
// //
// Contract between Echo/Context instance and the router: // Contract between Echo/Context instance and the router:
// * all routes must be added through methods on echo.Echo instance. // - all routes must be added through methods on echo.Echo instance.
// Reason: Echo instance uses RouteInfo.Params() length to allocate slice for paths parameters (see `Echo.contextPathParamAllocSize`). // Reason: Echo instance uses RouteInfo.Params() length to allocate slice for paths parameters (see `Echo.contextPathParamAllocSize`).
// * Router must populate Context during Router.Route call with: // - Router must populate Context during Router.Route call with:
// * RoutableContext.SetPath // - RoutableContext.SetPath
// * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns) // - RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns)
// * RoutableContext.SetRouteInfo // - RoutableContext.SetRouteInfo
// And optionally can set additional information to Context with RoutableContext.Set // And optionally can set additional information to Context with RoutableContext.Set
type Router interface { type Router interface {
// Add registers Routable with the Router and returns registered RouteInfo // Add registers Routable with the Router and returns registered RouteInfo
Add(routable Routable) (RouteInfo, error) Add(routable Routable) (RouteInfo, error)
@ -344,7 +344,7 @@ func (m *routeMethods) updateAllowHeader() {
if m.report != nil { if m.report != nil {
buf.WriteString(", REPORT") buf.WriteString(", REPORT")
} }
for method := range m.anyOther { for method := range m.anyOther { // for simplicity, we use map and therefore order is not deterministic here
buf.WriteString(", ") buf.WriteString(", ")
buf.WriteString(method) buf.WriteString(method)
} }

View File

@ -669,7 +669,7 @@ func checkUnusedParamValues(t *testing.T, c *DefaultContext, expectParam map[str
func TestRouterStatic(t *testing.T) { func TestRouterStatic(t *testing.T) {
path := "/folders/a/files/echo.gif" path := "/folders/a/files/echo.gif"
req := httptest.NewRequest("GET", path, nil) req := httptest.NewRequest(http.MethodGet, path, nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e := New() e := New()
@ -711,7 +711,7 @@ func TestRouterParam(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
c := e.NewContext(nil, nil).(*DefaultContext) c := e.NewContext(nil, nil).(*DefaultContext)
c.SetRequest(httptest.NewRequest("GET", tc.whenURL, nil)) c.SetRequest(httptest.NewRequest(http.MethodGet, tc.whenURL, nil))
_ = e.router.Route(c) _ = e.router.Route(c)
assert.Equal(t, tc.expectRoute, c.Path()) assert.Equal(t, tc.expectRoute, c.Path())
@ -725,8 +725,11 @@ func TestRouterParam(t *testing.T) {
func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) { func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) {
var testCases = []struct { var testCases = []struct {
name string name string
whenMethod string givenNoAddRoute bool
whenMethod string
expectPath string
expectError string
}{ }{
{name: "ok, CONNECT", whenMethod: http.MethodConnect}, {name: "ok, CONNECT", whenMethod: http.MethodConnect},
{name: "ok, DELETE", whenMethod: http.MethodDelete}, {name: "ok, DELETE", whenMethod: http.MethodDelete},
@ -740,6 +743,13 @@ func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) {
{name: "ok, TRACE", whenMethod: http.MethodTrace}, {name: "ok, TRACE", whenMethod: http.MethodTrace},
{name: "ok, REPORT", whenMethod: REPORT}, {name: "ok, REPORT", whenMethod: REPORT},
{name: "ok, NON_TRADITIONAL_METHOD", whenMethod: "NON_TRADITIONAL_METHOD"}, {name: "ok, NON_TRADITIONAL_METHOD", whenMethod: "NON_TRADITIONAL_METHOD"},
{
name: "ok, NOT_EXISTING_METHOD",
whenMethod: "NOT_EXISTING_METHOD",
givenNoAddRoute: true,
expectPath: "/*",
expectError: "code=405, message=Method Not Allowed",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -747,7 +757,9 @@ func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) {
e := New() e := New()
e.GET("/*", handlerFunc) e.GET("/*", handlerFunc)
e.Add(tc.whenMethod, "/my/*", handlerFunc) if !tc.givenNoAddRoute {
e.Add(tc.whenMethod, "/my/*", handlerFunc)
}
req := httptest.NewRequest(tc.whenMethod, "/my/some-url", nil) req := httptest.NewRequest(tc.whenMethod, "/my/some-url", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -756,12 +768,45 @@ func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) {
handler := e.router.Route(c) handler := e.router.Route(c)
err := handler(c) err := handler(c)
assert.NoError(t, err) if tc.expectError != "" {
assert.Equal(t, "/my/*", c.Path()) assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
expectPath := "/my/*"
if tc.expectPath != "" {
expectPath = tc.expectPath
}
assert.Equal(t, expectPath, c.Path())
}) })
} }
} }
func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) {
e := New()
r := e.router
_, err := r.Add(Route{Method: http.MethodGet, Path: "/users", Handler: handlerFunc})
assert.NoError(t, err)
_, err = r.Add(Route{Method: "COPY", Path: "/users", Handler: handlerFunc})
assert.NoError(t, err)
_, err = r.Add(Route{Method: "LOCK", Path: "/users", Handler: handlerFunc})
assert.NoError(t, err)
req := httptest.NewRequest("TEST", "/users", nil)
rec := httptest.NewRecorder()
//r.Find("TEST", "/users", c)
c := e.NewContext(req, rec).(*DefaultContext)
handler := e.router.Route(c)
err = handler(c)
assert.EqualError(t, err, "code=405, message=Method Not Allowed")
assert.ElementsMatch(t, []string{"COPY", "GET", "LOCK", "OPTIONS"}, strings.Split(c.Response().Header().Get(HeaderAllow), ", "))
}
func TestMethodNotAllowedAndNotFound(t *testing.T) { func TestMethodNotAllowedAndNotFound(t *testing.T) {
e := New() e := New()
@ -970,19 +1015,22 @@ func TestRouterParamWithSlash(t *testing.T) {
// Searching route for "/a/c/f" should match "/a/*/f" // Searching route for "/a/c/f" should match "/a/*/f"
// When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f" // When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f"
// //
// +----------+ // +----------+
// +-----+ "/" root +--------------------+--------------------------+ // +-----+ "/" root +--------------------+--------------------------+
// | +----------+ | | // | +----------+ | |
// | | | // | | |
// +-------v-------+ +---v---------+ +-------v---+ // +-------v-------+ +---v---------+ +-------v---+
// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | // | "a/" (static) +---------------+ | ":" (param) | | "*" (any) |
// +-+----------+--+ | +-----------+-+ +-----------+ // +-+----------+--+ | +-----------+-+ +-----------+
// | | | | // | | | |
//
// +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ // +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+
// | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | // | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) |
// +---------+------+ +--------+----+ +----------++ +-----------------+ // +---------+------+ +--------+----+ +----------++ +-----------------+
// | | | //
// | | | // | | |
// | | |
//
// +---------v----+ +------v--------+ +------v--------+ // +---------v----+ +------v--------+ +------v--------+
// | "f" (static) | | "/c" (static) | | "/f" (static) | // | "f" (static) | | "/c" (static) | | "/f" (static) |
// +--------------+ +---------------+ +---------------+ // +--------------+ +---------------+ +---------------+
@ -1052,22 +1100,22 @@ func TestRouteMultiLevelBacktracking(t *testing.T) {
// //
// Request for "/a/c/f" should match "/:e/c/f" // Request for "/a/c/f" should match "/:e/c/f"
// //
// +-0,7--------+ // +-0,7--------+
// | "/" (root) |----------------------------------+ // | "/" (root) |----------------------------------+
// +------------+ | // +------------+ |
// | | | // | | |
// | | | // | | |
// +-1,6-----------+ | | +-8-----------+ +------v----+ // +-1,6-----------+ | | +-8-----------+ +------v----+
// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | // | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) |
// +---------------+ +-------------+ +-----------+ // +---------------+ +-------------+ +-----------+
// | | | // | | |
// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ // +-2--------v-----+ +v-3,5--------+ +-9------v--------+
// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | // | "c/d" (static) | | ":" (param) | | "/c/f" (static) |
// +----------------+ +-------------+ +-----------------+ // +----------------+ +-------------+ +-----------------+
// | // |
// +-4--v----------+ // +-4--v----------+
// | "/c" (static) | // | "/c" (static) |
// +---------------+ // +---------------+
func TestRouteMultiLevelBacktracking2(t *testing.T) { func TestRouteMultiLevelBacktracking2(t *testing.T) {
e := New() e := New()
@ -2753,7 +2801,7 @@ func TestRouter_Routes(t *testing.T) {
func benchmarkRouterRoutes(b *testing.B, routes []testRoute, routesToFind []testRoute) { func benchmarkRouterRoutes(b *testing.B, routes []testRoute, routesToFind []testRoute) {
e := New() e := New()
r := e.router r := e.router
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
b.ReportAllocs() b.ReportAllocs()
// Add routes // Add routes

View File

@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -67,7 +66,7 @@ func doGet(url string) (int, string, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return resp.StatusCode, "", err return resp.StatusCode, "", err
} }
@ -427,9 +426,9 @@ func TestStartConfig_StartTLSAndStart(t *testing.T) {
} }
func TestFilepathOrContent(t *testing.T) { func TestFilepathOrContent(t *testing.T) {
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") cert, err := os.ReadFile("_fixture/certs/cert.pem")
require.NoError(t, err) require.NoError(t, err)
key, err := ioutil.ReadFile("_fixture/certs/key.pem") key, err := os.ReadFile("_fixture/certs/key.pem")
require.NoError(t, err) require.NoError(t, err)
testCases := []struct { testCases := []struct {
@ -796,7 +795,7 @@ func TestWithDisableHTTP2(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
log.Fatalf("Failed reading response body: %s", err) log.Fatalf("Failed reading response body: %s", err)
} }