diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index af6a5a05..e35e7f10 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -27,9 +27,10 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy - # Echo tests with last four major releases - # except v5 starts from 1.17 until there is last four major releases after that - go: [1.17, 1.18] + # Echo tests with last four major releases (unless there are pressing vulnerabilities) + # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when + # we derive from last four major releases promise. + go: [1.17, 1.18, 1.19] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -43,19 +44,23 @@ jobs: with: go-version: ${{ matrix.go }} - - name: Install Dependencies + - name: Run Tests + run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... + + - name: Install dependencies for checks run: | go install golang.org/x/lint/golint@latest go install honnef.co/go/tools/cmd/staticcheck@latest - - name: Run Tests - run: | - golint -set_exit_status ./... - staticcheck ./... - go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... + - name: Run golint + run: golint -set_exit_status ./... + + - name: Run staticcheck + run: staticcheck ./... + - name: Upload coverage to Codecov - if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v1 + if: success() && matrix.go == 1.19 && matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v3 with: token: fail_ci_if_error: false @@ -64,7 +69,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go: [1.18] + go: [1.19] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -91,10 +96,12 @@ jobs: run: | cd previous go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + - name: Run Benchmark (New) run: | cd new go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + - name: Run Benchstat run: | benchstat previous/benchmark.txt new/benchmark.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index ba75d71f..cdb6bd78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,74 @@ # Changelog + +## v4.10.0 - 2022-xx-xx + +**Security** + +This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are +several vulnerabilities fixed in these libraries. + +Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. + + +## v4.9.1 - 2022-10-12 + +**Fixes** + +* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295) + +**Enhancements** + +* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272) +* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291) +* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254) +* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275) +* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301) + +## v4.9.0 - 2022-09-04 + +**Security** + +* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260) + +**Enhancements** + +* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257) +* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247) + + +## v4.8.0 - 2022-08-10 + +**Most notable things** + +You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237) +```go +e.Add("COPY", "/*", func(c echo.Context) error + return c.String(http.StatusOK, "OK COPY") +}) +``` + +You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217) +```go +e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) + +g := e.Group("/images") +g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) +``` + +**Enhancements** + +* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127) +* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145) +* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187) +* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191) +* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176) +* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209) +* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217) +* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227) +* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237) + + ## v4.7.2 - 2022-03-16 **Fixes** diff --git a/Makefile b/Makefile index 8149aeba..7aa624cc 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,7 @@ init: @go install honnef.co/go/tools/cmd/staticcheck@latest lint: ## Lint the files + @staticcheck ${PKG_LIST} @golint -set_exit_status ${PKG_LIST} vet: ## Vet the files diff --git a/README.md b/README.md index b9cb69e3..af5f493a 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,11 @@ ## Supported Go versions -Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that. +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: -- 1.9.7+ -- 1.10.3+ -- 1.16+ Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. @@ -95,6 +92,7 @@ func hello(c echo.Context) error { | [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | UberĀ“s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | | [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | | [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | +| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | Please send a PR to add your own library here. diff --git a/bind_test.go b/bind_test.go index 8f711c4f..e499cbff 100644 --- a/bind_test.go +++ b/bind_test.go @@ -190,44 +190,39 @@ func TestToMultipleFields(t *testing.T) { } func TestBindJSON(t *testing.T) { - assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userJSON), nil, MIMEApplicationJSON) - testBindOkay(assert, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) - testBindArrayOkay(assert, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) - testBindArrayOkay(assert, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) - testBindError(assert, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) + testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON) + testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) } func TestBindXML(t *testing.T) { - assert := assert.New(t) - - testBindOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) - testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) - testBindArrayOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) - testBindArrayOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) - testBindOkay(assert, strings.NewReader(userXML), nil, MIMETextXML) - testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMETextXML) - testBindError(assert, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML) + testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) } func TestBindForm(t *testing.T) { - assert := assert.New(t) - - testBindOkay(assert, strings.NewReader(userForm), nil, MIMEApplicationForm) - testBindOkay(assert, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, MIMEApplicationForm) err := c.Bind(&[]struct{ Field string }{}) - assert.Error(err) + assert.Error(t, err) } func TestBindQueryParams(t *testing.T) { @@ -363,14 +358,13 @@ func TestBindUnmarshalParam(t *testing.T) { err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) - assert := assert.New(t) - if assert.NoError(err) { + if assert.NoError(t, err) { // assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) - assert.Equal(ts, result.T) - assert.Equal(StringArray([]string{"one", "two", "three"}), result.SA) - assert.Equal([]Timestamp{ts, ts}, result.TA) - assert.Equal(Struct{""}, result.ST) // child struct does not have a field with matching tag - assert.Equal("baz", result.StWithTag.Foo) // child struct has field with matching tag + assert.Equal(t, ts, result.T) + assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) + assert.Equal(t, []Timestamp{ts, ts}, result.TA) + assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag + assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag } } @@ -472,37 +466,34 @@ func TestBindMultipartForm(t *testing.T) { mw.Close() body := bodyBuffer.Bytes() - assert := assert.New(t) - testBindOkay(assert, bytes.NewReader(body), nil, mw.FormDataContentType()) - testBindOkay(assert, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) } func TestBindUnsupportedMediaType(t *testing.T) { - assert := assert.New(t) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) } func TestBindbindData(t *testing.T) { - a := assert.New(t) ts := new(bindTestStruct) err := bindData(ts, values, "form") - a.NoError(err) + assert.NoError(t, err) - a.Equal(0, ts.I) - a.Equal(int8(0), ts.I8) - a.Equal(int16(0), ts.I16) - a.Equal(int32(0), ts.I32) - a.Equal(int64(0), ts.I64) - a.Equal(uint(0), ts.UI) - a.Equal(uint8(0), ts.UI8) - a.Equal(uint16(0), ts.UI16) - a.Equal(uint32(0), ts.UI32) - a.Equal(uint64(0), ts.UI64) - a.Equal(false, ts.B) - a.Equal(float32(0), ts.F32) - a.Equal(float64(0), ts.F64) - a.Equal("", ts.S) - a.Equal("", ts.cantSet) + assert.Equal(t, 0, ts.I) + assert.Equal(t, int8(0), ts.I8) + assert.Equal(t, int16(0), ts.I16) + assert.Equal(t, int32(0), ts.I32) + assert.Equal(t, int64(0), ts.I64) + assert.Equal(t, uint(0), ts.UI) + assert.Equal(t, uint8(0), ts.UI8) + assert.Equal(t, uint16(0), ts.UI16) + assert.Equal(t, uint32(0), ts.UI32) + assert.Equal(t, uint64(0), ts.UI64) + assert.Equal(t, false, ts.B) + assert.Equal(t, float32(0), ts.F32) + assert.Equal(t, float64(0), ts.F64) + assert.Equal(t, "", ts.S) + assert.Equal(t, "", ts.cantSet) } func TestBindParam(t *testing.T) { @@ -580,7 +571,6 @@ func TestBindUnmarshalTypeError(t *testing.T) { } func TestBindSetWithProperType(t *testing.T) { - assert := assert.New(t) ts := new(bindTestStruct) typ := reflect.TypeOf(ts).Elem() val := reflect.ValueOf(ts).Elem() @@ -595,9 +585,9 @@ func TestBindSetWithProperType(t *testing.T) { } val := values[typeField.Name][0] err := setWithProperType(typeField.Type.Kind(), val, structField) - assert.NoError(err) + assert.NoError(t, err) } - assertBindTestStruct(assert, ts) + assertBindTestStruct(t, ts) type foo struct { Bar bytes.Buffer @@ -605,7 +595,7 @@ func TestBindSetWithProperType(t *testing.T) { v := &foo{} typ = reflect.TypeOf(v).Elem() val = reflect.ValueOf(v).Elem() - assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) + assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } func TestSetIntField(t *testing.T) { @@ -730,28 +720,28 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { err = bindData(ts, values, "form") } assert.NoError(err) - assertBindTestStruct(assert, (*bindTestStruct)(ts)) + assertBindTestStruct(b, (*bindTestStruct)(ts)) } -func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) { - a.Equal(0, ts.I) - a.Equal(int8(8), ts.I8) - a.Equal(int16(16), ts.I16) - a.Equal(int32(32), ts.I32) - a.Equal(int64(64), ts.I64) - a.Equal(uint(0), ts.UI) - a.Equal(uint8(8), ts.UI8) - a.Equal(uint16(16), ts.UI16) - a.Equal(uint32(32), ts.UI32) - a.Equal(uint64(64), ts.UI64) - a.Equal(true, ts.B) - a.Equal(float32(32.5), ts.F32) - a.Equal(float64(64.5), ts.F64) - a.Equal("test", ts.S) - a.Equal("", ts.GetCantSet()) +func assertBindTestStruct(t testing.TB, ts *bindTestStruct) { + assert.Equal(t, 0, ts.I) + assert.Equal(t, int8(8), ts.I8) + assert.Equal(t, int16(16), ts.I16) + assert.Equal(t, int32(32), ts.I32) + assert.Equal(t, int64(64), ts.I64) + assert.Equal(t, uint(0), ts.UI) + assert.Equal(t, uint8(8), ts.UI8) + assert.Equal(t, uint16(16), ts.UI16) + assert.Equal(t, uint32(32), ts.UI32) + assert.Equal(t, uint64(64), ts.UI64) + assert.Equal(t, true, ts.B) + assert.Equal(t, float32(32.5), ts.F32) + assert.Equal(t, float64(64.5), ts.F64) + assert.Equal(t, "test", ts.S) + assert.Equal(t, "", ts.GetCantSet()) } -func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { +func testBindOkay(t testing.TB, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { @@ -763,13 +753,13 @@ func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctyp req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) - if assert.NoError(err) { - assert.Equal(1, u.ID) - assert.Equal("Jon Snow", u.Name) + if assert.NoError(t, err) { + assert.Equal(t, 1, u.ID) + assert.Equal(t, "Jon Snow", u.Name) } } -func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { +func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { @@ -781,14 +771,14 @@ func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, req.Header.Set(HeaderContentType, ctype) u := []user{} err := c.Bind(&u) - if assert.NoError(err) { - assert.Equal(1, len(u)) - assert.Equal(1, u[0].ID) - assert.Equal("Jon Snow", u[0].Name) + if assert.NoError(t, err) { + assert.Equal(t, 1, len(u)) + assert.Equal(t, 1, u[0].ID) + assert.Equal(t, "Jon Snow", u[0].Name) } } -func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expectedInternal error) { +func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) { e := New() req := httptest.NewRequest(http.MethodPost, "/", r) rec := httptest.NewRecorder() @@ -800,14 +790,14 @@ func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expecte switch { case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML), strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - if assert.IsType(new(HTTPError), err) { - assert.Equal(http.StatusBadRequest, err.(*HTTPError).Code) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } default: - if assert.IsType(new(HTTPError), err) { - assert.Equal(ErrUnsupportedMediaType, err) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, ErrUnsupportedMediaType, err) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } } } diff --git a/context.go b/context.go index 5f4a2010..67673b50 100644 --- a/context.go +++ b/context.go @@ -185,6 +185,14 @@ type Context interface { // Redirect redirects the request to a provided URL with status code. Redirect(code int, url string) error + // Error invokes the registered global HTTP error handler. Generally used by middleware. + // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and + // middlewares up in chain can not change Response status code or Response body anymore. + // + // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. + // Instead of calling this method in handler return your error and let it be handled by middlewares or global error handler. + Error(err error) + // Echo returns the `Echo` instance. // // WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them @@ -337,11 +345,16 @@ func (c *DefaultContext) RealIP() string { if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { i := strings.IndexAny(ip, ",") if i > 0 { - return strings.TrimSpace(ip[:i]) + xffip := strings.TrimSpace(ip[:i]) + xffip = strings.TrimPrefix(xffip, "[") + xffip = strings.TrimSuffix(xffip, "]") + return xffip } return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { + ip = strings.TrimPrefix(ip, "[") + ip = strings.TrimSuffix(ip, "]") return ip } ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) @@ -757,6 +770,16 @@ func (c *DefaultContext) Redirect(code int, url string) error { return nil } +// Error invokes the registered global HTTP error handler. Generally used by middleware. +// A side-effect of calling global error handler is that now Response has been committed (sent to the client) and +// middlewares up in chain can not change Response status code or Response body anymore. +// +// Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. +// Instead of calling this method in handler return your error and let it be handled by middlewares or global error handler. +func (c *DefaultContext) Error(err error) { + c.echo.HTTPErrorHandler(c, err) +} + // Echo returns the `Echo` instance. func (c *DefaultContext) Echo() *Echo { return c.echo diff --git a/context_test.go b/context_test.go index 0df10539..9e06b3e9 100644 --- a/context_test.go +++ b/context_test.go @@ -377,6 +377,19 @@ func TestContext(t *testing.T) { assert.Equal(t, 0, len(c.QueryParams())) } +func TestContext_Error(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + c.Error(errors.New("error")) + + assert.True(t, c.Response().Committed) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, `{"message":"Internal Server Error"}`+"\n", rec.Body.String()) +} + func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -1061,6 +1074,30 @@ func TestContext_RealIP(t *testing.T) { }, "127.0.0.1", }, + { + &DefaultContext{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &DefaultContext{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &DefaultContext{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, { &DefaultContext{ request: &http.Request{ @@ -1071,6 +1108,17 @@ func TestContext_RealIP(t *testing.T) { }, "192.168.0.1", }, + { + &DefaultContext{ + request: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{"[2001:db8::1]"}, + }, + }, + }, + "2001:db8::1", + }, + { &DefaultContext{ request: &http.Request{ diff --git a/echo.go b/echo.go index c43d26e2..5dfa225f 100644 --- a/echo.go +++ b/echo.go @@ -3,36 +3,36 @@ Package echo implements high performance, minimalist Go web framework. Example: - package main + package main - import ( - "github.com/labstack/echo/v5" - "github.com/labstack/echo/v5/middleware" - "log" - "net/http" - ) + import ( + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + "log" + "net/http" + ) - // Handler - func hello(c echo.Context) error { - return c.String(http.StatusOK, "Hello, World!") - } - - func main() { - // Echo instance - e := echo.New() - - // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) - - // Routes - e.GET("/", hello) - - // Start server - if err := e.Start(":8080"); err != http.ErrServerClosed { - log.Fatal(err) + // Handler + func hello(c echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") + } + + func main() { + // Echo instance + e := echo.New() + + // Middleware + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) + + // Routes + e.GET("/", hello) + + // Start server + if err := e.Start(":8080"); err != http.ErrServerClosed { + log.Fatal(err) + } } - } Learn more at https://echo.labstack.com */ @@ -49,7 +49,6 @@ import ( "os" "os/signal" "path/filepath" - "runtime" "strings" "sync" ) @@ -420,8 +419,11 @@ func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) Ro return e.Add(RouteNotFound, path, h, m...) } -// Any registers a new route for all supported HTTP methods and path with matching handler -// in the router with optional route-level middleware. Panics on error. +// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler +// in the router with optional route-level middleware. +// +// Note: this method only adds specific set of supported HTTP methods as handler and is not true +// "catch-any-arbitrary-method" way of matching requests. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { errs := make([]error, 0) ris := make(Routes, 0) @@ -515,7 +517,7 @@ func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) Handle p = c.Request().URL.Path // path must not be empty. if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") + return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) } return fsFile(c, name, fileSystem) } @@ -625,7 +627,7 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { c = e.contextPool.Get().(*DefaultContext) } c.Reset(r, w) - var h func(Context) error + var h HandlerFunc if e.premiddleware == nil { h = applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) @@ -654,12 +656,15 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // options. // // In need of customization use: -// sc := echo.StartConfig{Address: ":8080"} +// +// sc := echo.StartConfig{Address: ":8080"} // if err := sc.Start(e); err != http.ErrServerClosed { // log.Fatal(err) // } +// // // or standard library `http.Server` -// s := http.Server{Addr: ":8080", Handler: e} +// +// s := http.Server{Addr: ":8080", Handler: e} // if err := s.ListenAndServe(); err != http.ErrServerClosed { // log.Fatal(err) // } @@ -741,7 +746,7 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs - if isRelativePath(root) { + if !filepath.IsAbs(root) { root = filepath.Join(dFS.prefix, root) } return &defaultFS{ @@ -752,21 +757,6 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { return fs.Sub(currentFs, root) } -func isRelativePath(path string) bool { - if path == "" { - return true - } - if path[0] == '/' { - return false - } - if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { - // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names - // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats - return false - } - return true -} - // MustSubFS creates sub FS from current filesystem or panic on failure. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // @@ -780,3 +770,12 @@ func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { } return subFs } + +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) + } + return uri +} diff --git a/echo_test.go b/echo_test.go index b75bd253..9181f317 100644 --- a/echo_test.go +++ b/echo_test.go @@ -187,6 +187,15 @@ func TestEcho_StaticFS(t *testing.T) { expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, + { + name: "open redirect vulnerability", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/open.redirect.hackercom%2f..", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad + expectBodyStartsWith: "", + }, } for _, tc := range testCases { @@ -1163,7 +1172,7 @@ func TestEcho_customContext(t *testing.T) { func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) u := req.URL w := httptest.NewRecorder() diff --git a/go.mod b/go.mod index 339d24ef..d4fb334b 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,17 @@ module github.com/labstack/echo/v5 go 1.17 require ( - github.com/golang-jwt/jwt/v4 v4.2.0 - github.com/stretchr/testify v1.7.0 - github.com/valyala/fasttemplate v1.2.1 - golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 - golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 + github.com/golang-jwt/jwt/v4 v4.4.3 + github.com/stretchr/testify v1.8.1 + github.com/valyala/fasttemplate v1.2.2 + golang.org/x/net v0.2.0 + golang.org/x/time v0.3.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/text v0.3.3 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + golang.org/x/text v0.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3290b99b..f137dbf1 100644 --- a/go.sum +++ b/go.sum @@ -1,31 +1,54 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= -github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= -github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= -github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= -github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= -golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/group_test.go b/group_test.go index bd215726..d6a05e54 100644 --- a/group_test.go +++ b/group_test.go @@ -3,7 +3,6 @@ package echo import ( "github.com/stretchr/testify/assert" "io/fs" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -71,7 +70,7 @@ func TestGroupFile(t *testing.T) { e := New() g := e.Group("/group") g.File("/walle", "_fixture/images/walle.png") - expectedData, err := ioutil.ReadFile("_fixture/images/walle.png") + expectedData, err := os.ReadFile("_fixture/images/walle.png") assert.Nil(t, err) req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) rec := httptest.NewRecorder() diff --git a/ip.go b/ip.go index 46d464cf..1bcd756a 100644 --- a/ip.go +++ b/ip.go @@ -227,6 +227,8 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { return func(req *http.Request) string { realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { + realIP = strings.TrimPrefix(realIP, "[") + realIP = strings.TrimSuffix(realIP, "]") if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } @@ -248,7 +250,10 @@ func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { } ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP) for i := len(ips) - 1; i >= 0; i-- { - ip := net.ParseIP(strings.TrimSpace(ips[i])) + ips[i] = strings.TrimSpace(ips[i]) + ips[i] = strings.TrimPrefix(ips[i], "[") + ips[i] = strings.TrimSuffix(ips[i], "]") + ip := net.ParseIP(ips[i]) if ip == nil { // Unable to parse IP; cannot trust entire records return directIP diff --git a/ip_test.go b/ip_test.go index 755900d3..38c4a1ca 100644 --- a/ip_test.go +++ b/ip_test.go @@ -459,6 +459,7 @@ func TestExtractIPDirect(t *testing.T) { func TestExtractIPFromRealIPHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { name string @@ -493,6 +494,16 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.1", }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:1", + }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" @@ -506,6 +517,19 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.199", }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" @@ -520,6 +544,20 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.199", }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, } for _, tc := range testCases { @@ -532,6 +570,7 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { func TestExtractIPFromXFFHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { name string @@ -566,6 +605,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "127.0.0.3", }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", + }, + expectIP: "fe80::3", + }, { name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", whenRequest: http.Request{ @@ -576,6 +625,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "203.0.113.1", }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::2]:8080", + }, + expectIP: "2001:db8::2", + }, { name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", givenTrustOptions: []TrustOption{ @@ -595,6 +654,25 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "203.0.100.100", // this is first trusted IP in XFF chain }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed) + // 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs) + // 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office) + // 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP + }, + expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain + }, } for _, tc := range testCases { diff --git a/json_test.go b/json_test.go index ac64d289..1d1483d2 100644 --- a/json_test.go +++ b/json_test.go @@ -1,7 +1,7 @@ package echo import ( - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "strings" @@ -16,16 +16,14 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*DefaultContext) - assert := testify.New(t) - // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -34,16 +32,16 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { enc := new(DefaultJSONSerializer) err := enc.Serialize(c, user{1, "Jon Snow"}, "") - if assert.NoError(err) { - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSON+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*DefaultContext) err = enc.Serialize(c, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } @@ -55,16 +53,14 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*DefaultContext) - assert := testify.New(t) - // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -74,8 +70,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var u = user{} err := enc.Deserialize(c, &u) - if assert.NoError(err) { - assert.Equal(u, user{ID: 1, Name: "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"}) } var userUnmarshalSyntaxError = user{} @@ -83,8 +79,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalSyntaxError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") var userUnmarshalTypeError = struct { ID string `json:"id"` @@ -95,7 +91,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalTypeError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index 390c37d6..a26dd8e7 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -5,7 +5,6 @@ import ( "bytes" "errors" "io" - "io/ioutil" "net" "net/http" @@ -62,9 +61,9 @@ func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Request reqBody := []byte{} if c.Request().Body != nil { - reqBody, _ = ioutil.ReadAll(c.Request().Body) + reqBody, _ = io.ReadAll(c.Request().Body) } - c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset + c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset // Response resBody := new(bytes.Buffer) diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index 323f46c1..fd608167 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -2,7 +2,7 @@ package middleware import ( "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -19,7 +19,7 @@ func TestBodyDump(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 6e7778ea..47982255 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -2,7 +2,7 @@ package middleware import ( "bytes" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -18,7 +18,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } @@ -77,18 +77,18 @@ func TestBodyLimitReader(t *testing.T) { } reader := &limitedReader{ BodyLimitConfig: config, - reader: ioutil.NopCloser(bytes.NewReader(hw)), + reader: io.NopCloser(bytes.NewReader(hw)), context: e.NewContext(req, rec), } // read all should return ErrStatusRequestEntityTooLarge - _, err := ioutil.ReadAll(reader) + _, err := io.ReadAll(reader) he := err.(*echo.HTTPError) assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw))) + reader.Reset(e.NewContext(req, rec), io.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) @@ -97,7 +97,7 @@ func TestBodyLimitReader(t *testing.T) { func TestBodyLimit_skipper(t *testing.T) { e := echo.New() h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } @@ -129,7 +129,7 @@ func TestBodyLimitWithConfig(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } @@ -151,7 +151,7 @@ func TestBodyLimit(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } diff --git a/middleware/compress.go b/middleware/compress.go index d383cac6..241da957 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -5,7 +5,6 @@ import ( "compress/gzip" "errors" "io" - "io/ioutil" "net" "net/http" "strings" @@ -71,7 +70,7 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { i := pool.Get() w, ok := i.(*gzip.Writer) if !ok { - return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + return echo.NewHTTPErrorWithInternal(http.StatusInternalServerError, i.(error)) } rw := res.Writer w.Reset(rw) @@ -85,7 +84,7 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // nothing is written to body or error is returned. // See issue #424, #407. res.Writer = rw - w.Reset(ioutil.Discard) + w.Reset(io.Discard) } w.Close() pool.Put(w) @@ -131,7 +130,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ New: func() interface{} { - w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) + w, err := gzip.NewWriterLevel(io.Discard, config.Level) if err != nil { return err } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index d6b4f60e..3da3a105 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -3,7 +3,6 @@ package middleware import ( "bytes" "compress/gzip" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -203,7 +202,7 @@ func TestGzipWithStatic(t *testing.T) { r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { defer r.Close() - want, err := ioutil.ReadFile("../_fixture/images/walle.png") + want, err := os.ReadFile("../_fixture/images/walle.png") if assert.NoError(t, err) { buf := new(bytes.Buffer) buf.ReadFrom(r) diff --git a/middleware/cors.go b/middleware/cors.go index 78b44975..74b79c5e 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -14,46 +14,85 @@ type CORSConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // AllowOrigin defines a list of origins that may access the resource. + // AllowOrigins determines the value of the Access-Control-Allow-Origin + // response header. This header defines a list of origins that may access the + // resource. The wildcard characters '*' and '?' are supported and are + // converted to regex fragments '.*' and '.' accordingly. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // // Optional. Default value []string{"*"}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin AllowOrigins []string // AllowOriginFunc is a custom function to validate the origin. It takes the // origin as an argument and returns true if allowed or false otherwise. If // an error is returned, it is returned by the handler. If this option is // set, AllowOrigins is ignored. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // // Optional. AllowOriginFunc func(origin string) (bool, error) - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. + // AllowMethods determines the value of the Access-Control-Allow-Methods + // response header. This header specified the list of methods allowed when + // accessing the resource. This is used in response to a preflight request. + // // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value + // If `allowMethods` is left empty, this middleware will fill for preflight + // request `Access-Control-Allow-Methods` header value // from `Allow` header that echo.Router set into context. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods AllowMethods []string - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This is in response to a preflight request. + // AllowHeaders determines the value of the Access-Control-Allow-Headers + // response header. This header is used in response to a preflight request to + // indicate which HTTP headers can be used when making the actual request. + // // Optional. Default value []string{}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers AllowHeaders []string - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. + // AllowCredentials determines the value of the + // Access-Control-Allow-Credentials response header. This header indicates + // whether or not the response to the request can be exposed when the + // credentials mode (Request.credentials) is true. When used as part of a + // response to a preflight request, this indicates whether or not the actual + // request can be made using credentials. See also + // [MDN: Access-Control-Allow-Credentials]. + // + // Optional. Default value false, in which case the header is not set. + // // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // See "Exploiting CORS misconfigurations for Bitcoins and bounties", + // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials AllowCredentials bool - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which + // defines a list of headers that clients are allowed to access. + // + // Optional. Default value []string{}, in which case the header is not set. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header ExposeHeaders []string - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. + // MaxAge determines the value of the Access-Control-Max-Age response header. + // This header indicates how long (in seconds) the results of a preflight + // request can be cached. + // + // Optional. Default value 0. The header is set only if MaxAge > 0. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age MaxAge int } @@ -65,13 +104,22 @@ var DefaultCORSConfig = CORSConfig{ } // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. -// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// See also [MDN: Cross-Origin Resource Sharing (CORS)]. +// +// Security: Poorly configured CORS can compromise security because it allows +// relaxation of the browser's Same-Origin policy. See [Exploiting CORS +// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin +// resource sharing (CORS)] for more details. +// +// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html +// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors func CORS() echo.MiddlewareFunc { return CORSWithConfig(DefaultCORSConfig) } // CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. -// See: `CORS()`. +// See: [CORS]. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } diff --git a/middleware/csrf.go b/middleware/csrf.go index 895a9c63..1b891e4f 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -63,6 +63,9 @@ type CSRFConfig struct { // Indicates SameSite mode of the CSRF cookie. // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler func(c echo.Context, err error) error } // ErrCSRFInvalid is returned when CSRF check fails @@ -159,10 +162,17 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { lastTokenErr = ErrCSRFInvalid } } + var finalErr error if lastTokenErr != nil { - return lastTokenErr + finalErr = lastTokenErr } else if lastExtractorErr != nil { - return echo.ErrBadRequest.WithInternal(lastExtractorErr) + finalErr = echo.ErrBadRequest.WithInternal(lastExtractorErr) + } + if finalErr != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(c, finalErr) + } + return finalErr } } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index f8af5e9c..de97cd6c 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -392,3 +392,25 @@ func TestCSRFConfig_skipper(t *testing.T) { }) } } + +func TestCSRFErrorHandling(t *testing.T) { + cfg := CSRFConfig{ + ErrorHandler: func(c echo.Context, err error) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") + }, + } + + e := echo.New() + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(CSRFWithConfig(cfg)) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) +} diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 5651a493..53b51e24 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -4,7 +4,7 @@ import ( "bytes" "compress/gzip" "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -35,7 +35,7 @@ func TestDecompress(t *testing.T) { assert.NoError(t, err) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) + b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } @@ -97,7 +97,7 @@ func TestDecompressWithConfig_DefaultConfig(t *testing.T) { assert.NoError(t, err) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) + b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } @@ -114,7 +114,7 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) + b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.NotEqual(t, b, body) assert.Equal(t, b, gz) @@ -171,7 +171,7 @@ func TestDecompressSkipper(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) - reqBody, err := ioutil.ReadAll(c.Request().Body) + reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) } @@ -202,7 +202,7 @@ func TestDecompressPoolError(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - reqBody, err := ioutil.ReadAll(c.Request().Body) + reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) assert.Equal(t, rec.Code, http.StatusInternalServerError) diff --git a/middleware/extractor.go b/middleware/extractor.go index 4ad676f6..95c5a861 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -51,6 +51,26 @@ var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error) +// CreateExtractors creates ValuesExtractors from given lookups. +// Lookups is a string in the form of ":" or ":,:" that is used +// to extract key from the request. +// Possible values: +// - "header:" or "header::" +// `` is argument value to cut/trim prefix of the extracted value. This is useful if header +// value has static prefix like `Authorization: ` where part that we +// want to cut is ` ` note the space at the end. +// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. +// - "query:" +// - "param:" +// - "form:" +// - "cookie:" +// +// Multiple sources example: +// - "header:Authorization,header:X-Api-Key" +func CreateExtractors(lookups string) ([]ValuesExtractor, error) { + return createExtractors(lookups) +} + func createExtractors(lookups string) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index afa776ec..7b8b3d4f 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -100,7 +100,7 @@ func TestCreateExtractors(t *testing.T) { c.SetRawPathParams(&tc.givenPathParams) } - extractors, err := createExtractors(tc.whenLoopups) + extractors, err := CreateExtractors(tc.whenLoopups) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return diff --git a/middleware/logger.go b/middleware/logger.go index 0e525e74..9da7ec41 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "encoding/json" + "errors" "fmt" "io" "strconv" @@ -34,6 +35,7 @@ type LoggerConfig struct { // - host // - method // - path + // - route // - protocol // - referer // - user_agent @@ -46,6 +48,7 @@ type LoggerConfig struct { // - header: // - query: // - form: + // - custom (see CustomTagFunc field) // // Example "${remote_ip} ${status}" // @@ -55,6 +58,11 @@ type LoggerConfig struct { // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. CustomTimeFormat string + // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. + // Make sure that outputted text creates valid JSON string with other logged tags. + // Optional. + CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) + // Output is a writer where logs in JSON format are written. // Optional. Default destination `echo.Logger.Infof()` Output io.Writer @@ -111,6 +119,11 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { start := time.Now() err := next(c) + if err = next(c); err != nil { + // When global error handler writes the error to the client the Response gets "committed". This state can be + // checked with `c.Response().Committed` field. + c.Error(err) + } stop := time.Now() buf := config.pool.Get().(*bytes.Buffer) @@ -119,20 +132,25 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { switch tag { + case "custom": + if config.CustomTagFunc == nil { + return 0, nil + } + return config.CustomTagFunc(c, buf) case "time_unix": - return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) + return buf.WriteString(strconv.FormatInt(stop.Unix(), 10)) case "time_unix_milli": - return buf.WriteString(strconv.FormatInt(time.Now().UnixMilli(), 10)) + return buf.WriteString(strconv.FormatInt(stop.UnixMilli(), 10)) case "time_unix_micro": - return buf.WriteString(strconv.FormatInt(time.Now().UnixMicro(), 10)) + return buf.WriteString(strconv.FormatInt(stop.UnixMicro(), 10)) case "time_unix_nano": - return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) + return buf.WriteString(strconv.FormatInt(stop.UnixNano(), 10)) case "time_rfc3339": - return buf.WriteString(time.Now().Format(time.RFC3339)) + return buf.WriteString(stop.Format(time.RFC3339)) case "time_rfc3339_nano": - return buf.WriteString(time.Now().Format(time.RFC3339Nano)) + return buf.WriteString(stop.Format(time.RFC3339Nano)) case "time_custom": - return buf.WriteString(time.Now().Format(config.CustomTimeFormat)) + return buf.WriteString(stop.Format(config.CustomTimeFormat)) case "id": id := req.Header.Get(echo.HeaderXRequestID) if id == "" { @@ -153,6 +171,8 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { p = "/" } return buf.WriteString(p) + case "route": + return buf.WriteString(c.Path()) case "protocol": return buf.WriteString(req.Proto) case "referer": @@ -162,7 +182,8 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { case "status": status := res.Status if err != nil { - if httpErr, ok := err.(*echo.HTTPError); ok { + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { status = httpErr.Code } } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 455520f9..d311da15 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -92,17 +92,17 @@ func TestLoggerTemplate(t *testing.T) { e.Use(LoggerWithConfig(LoggerConfig{ Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "route":"${route}", "referer":"${referer}",` + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", Output: buf, })) - e.GET("/", func(c echo.Context) error { + e.GET("/users/:id", func(c echo.Context) error { return c.String(http.StatusOK, "Header Logged") }) - req := httptest.NewRequest(http.MethodGet, "/?username=apagano-param&password=secret", nil) + req := httptest.NewRequest(http.MethodGet, "/users/1?username=apagano-param&password=secret", nil) req.RequestURI = "/" req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") req.Header.Add("Referer", "google.com") @@ -127,7 +127,8 @@ func TestLoggerTemplate(t *testing.T) { "hexvalue": false, "GET": true, "127.0.0.1": true, - "\"path\":\"/\"": true, + "\"path\":\"/users/1\"": true, + "\"route\":\"/users/:id\"": true, "\"uri\":\"/\"": true, "\"status\":200": true, "\"bytes_in\":0": true, @@ -291,3 +292,25 @@ func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) { buf.Reset() } } + +func TestLoggerCustomTagFunc(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `{"method":"${method}",${custom}}` + "\n", + CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { + return buf.WriteString(`"tag":"my-value"`) + }, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "custom time stamp test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, `{"method":"GET","tag":"my-value"}`+"\n", buf.String()) +} diff --git a/middleware/proxy.go b/middleware/proxy.go index 1efbc243..5d902c72 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -69,7 +69,7 @@ type ProxyTarget struct { type ProxyBalancer interface { AddTarget(*ProxyTarget) bool RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget + Next(echo.Context) (*ProxyTarget, error) } type commonBalancer struct { @@ -174,21 +174,21 @@ func (b *commonBalancer) RemoveTarget(name string) bool { } // Next randomly returns an upstream target. -func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { +func (b *randomBalancer) Next(c echo.Context) (*ProxyTarget, error) { if b.random == nil { b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) } b.mutex.RLock() defer b.mutex.RUnlock() - return b.targets[b.random.Intn(len(b.targets))] + return b.targets[b.random.Intn(len(b.targets))], nil } // Next returns an upstream target using round-robin technique. -func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { +func (b *roundRobinBalancer) Next(c echo.Context) (*ProxyTarget, error) { b.i = b.i % uint32(len(b.targets)) t := b.targets[b.i] atomic.AddUint32(&b.i, 1) - return t + return t, nil } // Proxy returns a Proxy middleware. @@ -236,7 +236,10 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { req := c.Request() res := c.Response() - tgt := config.Balancer.Next(c) + tgt, err := config.Balancer.Next(c) + if err != nil { + return err + } c.Set(config.ContextKey, tgt) if err := rewriteURL(config.RegexRewrite, req); err != nil { diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index c2ae7755..18a9e41a 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "fmt" - "io/ioutil" + "io" "net" "net/http" "net/http/httptest" @@ -18,7 +18,7 @@ import ( "github.com/stretchr/testify/assert" ) -//Assert expected with url.EscapedPath method to obtain the path. +// Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -94,7 +94,7 @@ func TestProxy(t *testing.T) { e.Use(ProxyWithConfig(ProxyConfig{ Balancer: rrb, ModifyResponse: func(res *http.Response) error { - res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified"))) + res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified"))) res.Header.Set("X-Modified", "1") return nil }, @@ -379,3 +379,48 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { timeoutStop.Done() assert.Equal(t, 499, rec.Code) } + +type testProvider struct { + *commonBalancer + target *ProxyTarget + err error +} + +func (p *testProvider) Next(c echo.Context) (*ProxyTarget, error) { + return p.target, p.err +} + +func TestTargetProvider(t *testing.T) { + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + + e := echo.New() + tp := &testProvider{commonBalancer: new(commonBalancer)} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "target 1", body) +} + +func TestFailNextTarget(t *testing.T) { + url1, err := url.Parse("http://dummy:8080") + assert.Nil(t, err) + + e := echo.New() + tp := &testProvider{commonBalancer: new(commonBalancer)} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") + + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) +} diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 63b6402f..13ab851a 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -9,10 +9,16 @@ import ( // Example for `fmt.Printf` // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogStatus: true, -// LogURI: true, +// LogStatus: true, +// LogURI: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// if v.Error == nil { +// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// } else { +// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error) +// } // return nil // }, // })) @@ -20,15 +26,23 @@ import ( // Example for Zerolog (https://github.com/rs/zerolog) // logger := zerolog.New(os.Stdout) // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// logger.Info(). -// Date("request_start", v.StartTime). -// Str("URI", v.URI). -// Int("status", v.Status). -// Msg("request") -// +// if v.Error == nil { +// logger.Info(). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request") +// } else { +// logger.Error(). +// Err(v.Error). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request error") +// } // return nil // }, // })) @@ -36,31 +50,47 @@ import ( // Example for Zap (https://github.com/uber-go/zap) // logger, _ := zap.NewProduction() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// logger.Info("request", -// zap.Time("request_start", v.StartTime), -// zap.String("URI", v.URI), -// zap.Int("status", v.Status), -// ) -// +// if v.Error == nil { +// logger.Info("request", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// ) +// } else { +// logger.Error("request error", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// zap.Error(v.Error), +// ) +// } // return nil // }, // })) // // Example for Logrus (https://github.com/sirupsen/logrus) -// log := logrus.New() +// log := logrus.New() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, -// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { -// log.WithFields(logrus.Fields{ -// "request_start": values.StartTime, -// "URI": values.URI, -// "status": values.Status, -// }).Info("request") -// +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// }).Info("request") +// } else { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// "error": v.Error, +// }).Error("request error") +// } // return nil // }, // })) @@ -76,6 +106,13 @@ type RequestLoggerConfig struct { // Mandatory. LogValuesFunc func(c echo.Context, v RequestLoggerValues) error + // HandleError instructs logger to call global error handler when next middleware/handler returns an error. + // This is useful when you have custom error handler that can decide to use different status codes. + // + // A side-effect of calling global error handler is that now Response has been committed and sent to the client + // and middlewares up in chain can not change Response status code or response body. + HandleError bool + // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call). LogLatency bool // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`) @@ -219,6 +256,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.BeforeNextFunc(c) } err := next(c) + if config.HandleError { + // When global error handler writes the error to the client the Response gets "committed". This state can be + // checked with `c.Response().Committed` field. + c.Error(err) + } v := RequestLoggerValues{ StartTime: start, @@ -266,8 +308,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } if config.LogStatus { v.Status = res.Status - if err != nil { - if httpErr, ok := err.(*echo.HTTPError); ok { + if err != nil && !config.HandleError { + // this block should not be executed in case of HandleError=true as the global error handler will decide + // the status code. In that case status code could be different from what err contains. + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { v.Status = httpErr.Code } } @@ -310,7 +355,10 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { return errOnLog } - + // in case of HandleError=true we are returning the error that we already have handled with global error handler + // this is deliberate as this error could be useful for upstream middlewares and default global error handler + // will ignore that error when it bubbles up in middleware chain. + // Committed response can be checked in custom error handler with `c.Response().Committed` field return err } }, nil diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index c5ddced7..3049d592 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -103,12 +103,12 @@ func TestRequestLogger_beforeNextFunc(t *testing.T) { func TestRequestLogger_logError(t *testing.T) { e := echo.New() - var expect RequestLoggerValues + var actual RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogError: true, LogStatus: true, LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { - expect = values + actual = values return nil }, })) @@ -123,8 +123,52 @@ func TestRequestLogger_logError(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotAcceptable, rec.Code) - assert.Equal(t, http.StatusNotAcceptable, expect.Status) - assert.EqualError(t, expect.Error, "code=406, message=nope") + assert.Equal(t, http.StatusNotAcceptable, actual.Status) + assert.EqualError(t, actual.Error, "code=406, message=nope") +} + +func TestRequestLogger_HandleError(t *testing.T) { + e := echo.New() + + var actual RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + timeNow: func() time.Time { + return time.Unix(1631045377, 0).UTC() + }, + HandleError: true, + LogError: true, + LogStatus: true, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + actual = values + return nil + }, + })) + + // to see if "HandleError" works we create custom error handler that uses its own status codes + e.HTTPErrorHandler = func(c echo.Context, err error) { + if c.Response().Committed { + return + } + c.JSON(http.StatusTeapot, "custom error handler") + } + + e.GET("/test", func(c echo.Context) error { + return echo.NewHTTPError(http.StatusForbidden, "nope") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + + expect := RequestLoggerValues{ + StartTime: time.Unix(1631045377, 0).UTC(), + Status: http.StatusTeapot, + Error: echo.NewHTTPError(http.StatusForbidden, "nope"), + } + assert.Equal(t, expect, actual) } func TestRequestLogger_LogValuesFuncError(t *testing.T) { diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 1f3419f0..c4044dcc 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -1,7 +1,7 @@ package middleware import ( - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" @@ -195,7 +195,7 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) defer rec.Result().Body.Close() - bodyBytes, _ := ioutil.ReadAll(rec.Result().Body) + bodyBytes, _ := io.ReadAll(rec.Result().Body) assert.Equal(t, "hosts", string(bodyBytes)) } } diff --git a/middleware/slash.go b/middleware/slash.go index 5826a9f0..eab3d820 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -27,7 +27,7 @@ func AddTrailingSlash() echo.MiddlewareFunc { return AddTrailingSlashWithConfig(AddTrailingSlashConfig{}) } -// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config or panics on invalid configuration. +// AddTrailingSlashWithConfig returns an AddTrailingSlash middleware with config or panics on invalid configuration. func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } diff --git a/middleware/static.go b/middleware/static.go index 2ad8dbb4..98e63b80 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -216,8 +216,8 @@ func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return nil } - he, ok := err.(*echo.HTTPError) - if !(ok && config.HTML5 && he.Code == http.StatusNotFound) { + var he *echo.HTTPError + if !(errors.As(err, &he) && config.HTML5 && he.Code == http.StatusNotFound) { return err } // is case HTML5 mode is enabled + echo 404 we serve index to the client diff --git a/middleware/static_test.go b/middleware/static_test.go index 6844d384..44ee74c9 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -257,6 +257,15 @@ func TestStatic_GroupWithStatic(t *testing.T) { expectHeaderLocation: "/group/folder/", expectBodyStartsWith: "", }, + { + name: "Directory redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/group/folder%2f..", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/group/folder/../", + expectBodyStartsWith: "", + }, { name: "Prefixed directory 404 (request URL without slash)", givenGroup: "_fixture", diff --git a/router.go b/router.go index 42d3ec90..c4a0cba1 100644 --- a/router.go +++ b/router.go @@ -10,13 +10,13 @@ import ( // Router is interface for routing request contexts to registered routes. // // Contract between Echo/Context instance and the router: -// * all routes must be added through methods on echo.Echo instance. -// Reason: Echo instance uses RouteInfo.Params() length to allocate slice for paths parameters (see `Echo.contextPathParamAllocSize`). -// * Router must populate Context during Router.Route call with: -// * RoutableContext.SetPath -// * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns) -// * RoutableContext.SetRouteInfo -// And optionally can set additional information to Context with RoutableContext.Set +// - all routes must be added through methods on echo.Echo instance. +// Reason: Echo instance uses RouteInfo.Params() length to allocate slice for paths parameters (see `Echo.contextPathParamAllocSize`). +// - Router must populate Context during Router.Route call with: +// - RoutableContext.SetPath +// - RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns) +// - RoutableContext.SetRouteInfo +// And optionally can set additional information to Context with RoutableContext.Set type Router interface { // Add registers Routable with the Router and returns registered RouteInfo Add(routable Routable) (RouteInfo, error) @@ -344,7 +344,7 @@ func (m *routeMethods) updateAllowHeader() { if m.report != nil { buf.WriteString(", REPORT") } - for method := range m.anyOther { + for method := range m.anyOther { // for simplicity, we use map and therefore order is not deterministic here buf.WriteString(", ") buf.WriteString(method) } diff --git a/router_test.go b/router_test.go index 83d3b1b2..05b53804 100644 --- a/router_test.go +++ b/router_test.go @@ -669,7 +669,7 @@ func checkUnusedParamValues(t *testing.T, c *DefaultContext, expectParam map[str func TestRouterStatic(t *testing.T) { path := "/folders/a/files/echo.gif" - req := httptest.NewRequest("GET", path, nil) + req := httptest.NewRequest(http.MethodGet, path, nil) rec := httptest.NewRecorder() e := New() @@ -711,7 +711,7 @@ func TestRouterParam(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := e.NewContext(nil, nil).(*DefaultContext) - c.SetRequest(httptest.NewRequest("GET", tc.whenURL, nil)) + c.SetRequest(httptest.NewRequest(http.MethodGet, tc.whenURL, nil)) _ = e.router.Route(c) assert.Equal(t, tc.expectRoute, c.Path()) @@ -725,8 +725,11 @@ func TestRouterParam(t *testing.T) { func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) { var testCases = []struct { - name string - whenMethod string + name string + givenNoAddRoute bool + whenMethod string + expectPath string + expectError string }{ {name: "ok, CONNECT", whenMethod: http.MethodConnect}, {name: "ok, DELETE", whenMethod: http.MethodDelete}, @@ -740,6 +743,13 @@ func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) { {name: "ok, TRACE", whenMethod: http.MethodTrace}, {name: "ok, REPORT", whenMethod: REPORT}, {name: "ok, NON_TRADITIONAL_METHOD", whenMethod: "NON_TRADITIONAL_METHOD"}, + { + name: "ok, NOT_EXISTING_METHOD", + whenMethod: "NOT_EXISTING_METHOD", + givenNoAddRoute: true, + expectPath: "/*", + expectError: "code=405, message=Method Not Allowed", + }, } for _, tc := range testCases { @@ -747,7 +757,9 @@ func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) { e := New() e.GET("/*", handlerFunc) - e.Add(tc.whenMethod, "/my/*", handlerFunc) + if !tc.givenNoAddRoute { + e.Add(tc.whenMethod, "/my/*", handlerFunc) + } req := httptest.NewRequest(tc.whenMethod, "/my/some-url", nil) rec := httptest.NewRecorder() @@ -756,12 +768,45 @@ func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) { handler := e.router.Route(c) err := handler(c) - assert.NoError(t, err) - assert.Equal(t, "/my/*", c.Path()) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + expectPath := "/my/*" + if tc.expectPath != "" { + expectPath = tc.expectPath + } + assert.Equal(t, expectPath, c.Path()) }) } } +func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) { + e := New() + r := e.router + + _, err := r.Add(Route{Method: http.MethodGet, Path: "/users", Handler: handlerFunc}) + assert.NoError(t, err) + _, err = r.Add(Route{Method: "COPY", Path: "/users", Handler: handlerFunc}) + assert.NoError(t, err) + _, err = r.Add(Route{Method: "LOCK", Path: "/users", Handler: handlerFunc}) + assert.NoError(t, err) + + req := httptest.NewRequest("TEST", "/users", nil) + rec := httptest.NewRecorder() + + //r.Find("TEST", "/users", c) + c := e.NewContext(req, rec).(*DefaultContext) + + handler := e.router.Route(c) + err = handler(c) + + assert.EqualError(t, err, "code=405, message=Method Not Allowed") + assert.ElementsMatch(t, []string{"COPY", "GET", "LOCK", "OPTIONS"}, strings.Split(c.Response().Header().Get(HeaderAllow), ", ")) +} + func TestMethodNotAllowedAndNotFound(t *testing.T) { e := New() @@ -970,19 +1015,22 @@ func TestRouterParamWithSlash(t *testing.T) { // Searching route for "/a/c/f" should match "/a/*/f" // When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f" // -// +----------+ -// +-----+ "/" root +--------------------+--------------------------+ -// | +----------+ | | -// | | | -// +-------v-------+ +---v---------+ +-------v---+ -// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | -// +-+----------+--+ | +-----------+-+ +-----------+ -// | | | | +// +----------+ +// +-----+ "/" root +--------------------+--------------------------+ +// | +----------+ | | +// | | | +// +-------v-------+ +---v---------+ +-------v---+ +// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | +// +-+----------+--+ | +-----------+-+ +-----------+ +// | | | | +// // +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ // | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | // +---------+------+ +--------+----+ +----------++ +-----------------+ -// | | | -// | | | +// +// | | | +// | | | +// // +---------v----+ +------v--------+ +------v--------+ // | "f" (static) | | "/c" (static) | | "/f" (static) | // +--------------+ +---------------+ +---------------+ @@ -1052,22 +1100,22 @@ func TestRouteMultiLevelBacktracking(t *testing.T) { // // Request for "/a/c/f" should match "/:e/c/f" // -// +-0,7--------+ -// | "/" (root) |----------------------------------+ -// +------------+ | -// | | | -// | | | -// +-1,6-----------+ | | +-8-----------+ +------v----+ -// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | -// +---------------+ +-------------+ +-----------+ -// | | | -// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ -// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | -// +----------------+ +-------------+ +-----------------+ -// | -// +-4--v----------+ -// | "/c" (static) | -// +---------------+ +// +-0,7--------+ +// | "/" (root) |----------------------------------+ +// +------------+ | +// | | | +// | | | +// +-1,6-----------+ | | +-8-----------+ +------v----+ +// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | +// +---------------+ +-------------+ +-----------+ +// | | | +// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ +// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | +// +----------------+ +-------------+ +-----------------+ +// | +// +-4--v----------+ +// | "/c" (static) | +// +---------------+ func TestRouteMultiLevelBacktracking2(t *testing.T) { e := New() @@ -2753,7 +2801,7 @@ func TestRouter_Routes(t *testing.T) { func benchmarkRouterRoutes(b *testing.B, routes []testRoute, routesToFind []testRoute) { e := New() r := e.router - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) b.ReportAllocs() // Add routes diff --git a/server_test.go b/server_test.go index fa6107dd..d5b48579 100644 --- a/server_test.go +++ b/server_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/http2" "io" - "io/ioutil" "log" "net" "net/http" @@ -67,7 +66,7 @@ func doGet(url string) (int, string, error) { } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return resp.StatusCode, "", err } @@ -427,9 +426,9 @@ func TestStartConfig_StartTLSAndStart(t *testing.T) { } func TestFilepathOrContent(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + cert, err := os.ReadFile("_fixture/certs/cert.pem") require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") + key, err := os.ReadFile("_fixture/certs/key.pem") require.NoError(t, err) testCases := []struct { @@ -796,7 +795,7 @@ func TestWithDisableHTTP2(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { log.Fatalf("Failed reading response body: %s", err) }