diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index cfa44e68..2aec272d 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -4,20 +4,28 @@ on: push: branches: - master + paths: + - '**.go' + - 'go.*' + - '_fixture/**' + - '.github/**' + - 'codecov.yml' pull_request: branches: - master - -env: - GO111MODULE: on - GOPROXY: https://proxy.golang.org + paths: + - '**.go' + - 'go.*' + - '_fixture/**' + - '.github/**' + - 'codecov.yml' jobs: test: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go: [1.11, 1.12, 1.13] + go: [1.12, 1.13, 1.14, 1.15] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -28,10 +36,15 @@ jobs: - name: Set GOPATH and PATH run: | - echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)" - echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin" + echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV + echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH shell: bash + - name: Set build variables + run: | + echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV + echo "GO111MODULE=on" >> $GITHUB_ENV + - name: Checkout Code uses: actions/checkout@v1 with: @@ -51,3 +64,55 @@ jobs: with: token: fail_ci_if_error: false + benchmark: + needs: test + strategy: + matrix: + os: [ubuntu-latest] + go: [1.15] + name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} + runs-on: ${{ matrix.os }} + steps: + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go }} + + - name: Set GOPATH and PATH + run: | + echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV + echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH + shell: bash + + - name: Set build variables + run: | + echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV + echo "GO111MODULE=on" >> $GITHUB_ENV + + - name: Checkout Code (Previous) + uses: actions/checkout@v2 + with: + ref: ${{ github.base_ref }} + path: previous + + - name: Checkout Code (New) + uses: actions/checkout@v2 + with: + path: new + + - name: Install Dependencies + run: go get -v golang.org/x/perf/cmd/benchstat + + - name: Run Benchmark (Previous) + run: | + cd previous + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + + - name: Run Benchmark (New) + run: | + cd new + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + + - name: Run Benchstat + run: | + benchstat previous/benchmark.txt new/benchmark.txt diff --git a/.travis.yml b/.travis.yml index a1fc8768..67d45ad7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,11 @@ +arch: + - amd64 + - ppc64le + language: go go: - - 1.12.x - - 1.13.x + - 1.14.x + - 1.15.x - tip env: - GO111MODULE=on diff --git a/README.md b/README.md index c57d478f..03ad4dca 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) -[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/labstack/echo) +[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) [![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) @@ -17,7 +17,7 @@ Therefore a Go version capable of understanding /vN suffixed imports is required - 1.9.7+ - 1.10.3+ -- 1.11+ +- 1.14+ Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. @@ -52,7 +52,7 @@ Lower is better! ### Installation -```go +```sh // go get github.com/labstack/echo/{version} go get github.com/labstack/echo/v4 ``` diff --git a/_fixture/_fixture/README.md b/_fixture/_fixture/README.md new file mode 100644 index 00000000..21a78585 --- /dev/null +++ b/_fixture/_fixture/README.md @@ -0,0 +1 @@ +This directory is used for the static middleware test \ No newline at end of file diff --git a/bind_test.go b/bind_test.go index 943cfd55..b9fb9de3 100644 --- a/bind_test.go +++ b/bind_test.go @@ -332,7 +332,6 @@ func TestBindbindData(t *testing.T) { func TestBindParam(t *testing.T) { e := New() - *e.maxParam = 2 req := httptest.NewRequest(GET, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -363,7 +362,6 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - *e2.maxParam = 2 req2 := httptest.NewRequest(POST, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..0fa3a3f1 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +coverage: + status: + project: + default: + threshold: 1% + patch: + default: + threshold: 1% + +comment: + require_changes: true \ No newline at end of file diff --git a/context.go b/context.go index dfcbe16c..8ba98477 100644 --- a/context.go +++ b/context.go @@ -276,7 +276,11 @@ func (c *context) RealIP() string { } // Fall back to legacy behavior if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { - return strings.Split(ip, ", ")[0] + i := strings.IndexAny(ip, ", ") + if i > 0 { + return ip[:i] + } + return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { return ip @@ -310,6 +314,7 @@ func (c *context) ParamNames() []string { func (c *context) SetParamNames(names ...string) { c.pnames = names + *c.echo.maxParam = len(names) } func (c *context) ParamValues() []string { @@ -317,10 +322,7 @@ func (c *context) ParamValues() []string { } func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times - for i, val := range values { - c.pvalues[i] = val - } + c.pvalues = values } func (c *context) QueryParam(name string) string { @@ -363,7 +365,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) { if err != nil { return nil, err } - defer f.Close() + f.Close() return fh, nil } diff --git a/context_test.go b/context_test.go index 866d0643..0044bf87 100644 --- a/context_test.go +++ b/context_test.go @@ -72,6 +72,15 @@ func BenchmarkAllocXML(b *testing.B) { } } +func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { + c := context{request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, + }} + for i := 0; i < b.N; i++ { + c.RealIP() + } +} + func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) error { return t.templates.ExecuteTemplate(w, name, data) } @@ -93,7 +102,6 @@ func (responseWriterErr) WriteHeader(statusCode int) { func TestContext(t *testing.T) { e := New() - *e.maxParam = 1 req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -472,7 +480,6 @@ func TestContextPath(t *testing.T) { func TestContextPathParam(t *testing.T) { e := New() - *e.maxParam = 2 req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, nil) @@ -491,7 +498,8 @@ func TestContextPathParam(t *testing.T) { func TestContextGetAndSetParam(t *testing.T) { e := New() - *e.maxParam = 2 + r := e.Router() + r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) c.SetParamNames("foo") @@ -848,6 +856,14 @@ func TestContext_RealIP(t *testing.T) { }, "127.0.0.1", }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, + }, + }, + "127.0.0.1", + }, { &context{ request: &http.Request{ diff --git a/echo.go b/echo.go index fa1c93ec..29b88b70 100644 --- a/echo.go +++ b/echo.go @@ -48,6 +48,7 @@ import ( "net" "net/http" "net/url" + "os" "path" "path/filepath" "reflect" @@ -230,7 +231,7 @@ const ( const ( // Version of Echo - Version = "4.1.15" + Version = "4.1.17" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` @@ -361,10 +362,12 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { // Issue #1426 code := he.Code message := he.Message - if e.Debug { - message = err.Error() - } else if m, ok := message.(string); ok { - message = Map{"message": m} + if m, ok := he.Message.(string); ok { + if e.Debug { + message = Map{"message": m, "error": err.Error()} + } else { + message = Map{"message": m} + } } // Send response @@ -479,7 +482,20 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl if err != nil { return err } + name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security + fi, err := os.Stat(name) + if err != nil { + // The access path does not exist + return NotFoundHandler(c) + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, p+"/") + } return c.File(name) } if prefix == "/" { @@ -504,11 +520,7 @@ func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware .. name := handlerName(handler) router := e.findRouter(host) router.Add(method, path, func(c Context) error { - h := handler - // Chain middleware - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) - } + h := applyMiddleware(handler, middleware...) return h(c) }) r := &Route{ @@ -602,16 +614,15 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Acquire context c := e.pool.Get().(*context) c.Reset(r, w) - h := NotFoundHandler if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, getPath(r), c) + e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c) h = c.Handler() h = applyMiddleware(h, e.middleware...) } else { h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, getPath(r), c) + e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c) h := c.Handler() h = applyMiddleware(h, e.middleware...) return h(c) @@ -783,6 +794,9 @@ func NewHTTPError(code int, message ...interface{}) *HTTPError { // Error makes it compatible with `error` interface. func (he *HTTPError) Error() string { + if he.Internal == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) + } return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) } @@ -792,6 +806,11 @@ func (he *HTTPError) SetInternal(err error) *HTTPError { return he } +// Unwrap satisfies the Go 1.13 error wrapper interface. +func (he *HTTPError) Unwrap() error { + return he.Internal +} + // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. func WrapHandler(h http.Handler) HandlerFunc { return func(c Context) error { @@ -814,14 +833,6 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -func getPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path - } - return path -} - func (e *Echo) findRouter(host string) *Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { diff --git a/echo_go1.13_test.go b/echo_go1.13_test.go new file mode 100644 index 00000000..3c488bc6 --- /dev/null +++ b/echo_go1.13_test.go @@ -0,0 +1,28 @@ +// +build go1.13 + +package echo + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err.SetInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) +} diff --git a/echo_test.go b/echo_test.go index 68c556f4..0368dbd7 100644 --- a/echo_test.go +++ b/echo_test.go @@ -76,9 +76,17 @@ func TestEchoStatic(t *testing.T) { // Directory e.Static("/images", "_fixture/images") - c, _ = request(http.MethodGet, "/images", e) + c, _ = request(http.MethodGet, "/images/", e) assert.Equal(http.StatusNotFound, c) + // Directory Redirect + e.Static("/", "_fixture") + req := httptest.NewRequest(http.MethodGet, "/folder", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(http.StatusMovedPermanently, rec.Code) + assert.Equal("/folder/", rec.HeaderMap["Location"][0]) + // Directory with index.html e.Static("/", "_fixture") c, r := request(http.MethodGet, "/", e) @@ -86,9 +94,10 @@ func TestEchoStatic(t *testing.T) { assert.Equal(true, strings.HasPrefix(r, "")) // Sub-directory with index.html - c, r = request(http.MethodGet, "/folder", e) + c, r = request(http.MethodGet, "/folder/", e) assert.Equal(http.StatusOK, c) assert.Equal(true, strings.HasPrefix(r, "")) + } func TestEchoFile(t *testing.T) { @@ -543,10 +552,63 @@ func request(method, path string, e *Echo) (int, string) { } func TestHTTPError(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Equal(t, "code=400, message=map[code:12]", err.Error()) }) - assert.Equal(t, "code=400, message=map[code:12], internal=", err.Error()) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err.SetInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) +} + +func TestDefaultHTTPErrorHandler(t *testing.T) { + e := New() + e.Debug = true + e.Any("/plain", func(c Context) error { + return errors.New("An error occurred") + }) + e.Any("/badrequest", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, "Invalid request") + }) + e.Any("/servererror", func(c Context) error { + return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ + "code": 33, + "message": "Something bad happened", + "error": "stackinfo", + }) + }) + // With Debug=true plain response contains error message + c, b := request(http.MethodGet, "/plain", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) + // and special handling for HTTPError + c, b = request(http.MethodGet, "/badrequest", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) + // complex errors are serialized to pretty JSON + c, b = request(http.MethodGet, "/servererror", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) + + e.Debug = false + // With Debug=false the error response is shortened + c, b = request(http.MethodGet, "/plain", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) + c, b = request(http.MethodGet, "/badrequest", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) + // No difference for error response with non plain string errors + c, b = request(http.MethodGet, "/servererror", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) } func TestEchoClose(t *testing.T) { diff --git a/go.mod b/go.mod index f981ba48..74c6a9ab 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,15 @@ module github.com/labstack/echo/v4 -go 1.14 +go 1.15 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/labstack/gommon v0.3.0 - github.com/mattn/go-colorable v0.1.6 // indirect + github.com/mattn/go-colorable v0.1.7 // indirect github.com/stretchr/testify v1.4.0 - github.com/valyala/fasttemplate v1.1.0 // indirect - golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d - golang.org/x/net v0.0.0-20200226121028-0de0cce0169b - golang.org/x/text v0.3.2 // indirect + github.com/valyala/fasttemplate v1.2.1 + golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a + golang.org/x/net v0.0.0-20200822124328-c89045814202 + golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect + golang.org/x/text v0.3.3 // indirect ) diff --git a/go.sum b/go.sum index 2f6d74d0..58c80c83 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,13 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= -github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= +github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= @@ -20,14 +22,15 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.0.1 h1:tY9CJiPnMXf1ERmG2EyK7gNUd+c6RKGD0IfU8WdUSz8= github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= -github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4= -github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= +github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= +github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -36,11 +39,15 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 h1:DvY3Zkh7KabQE/kfzMvYvKirSiguP9Q/veMtkYyf0o8= +golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/group.go b/group.go index 5d958253..426bef9e 100644 --- a/group.go +++ b/group.go @@ -109,7 +109,7 @@ func (g *Group) Static(prefix, root string) { // File implements `Echo#File()` for sub-routes within the Group. func (g *Group) File(path, file string) { - g.file(g.prefix+path, file, g.GET) + g.file(path, file, g.GET) } // Add implements `Echo#Add()` for sub-routes within the Group. diff --git a/group_test.go b/group_test.go index 342cd29e..c51fd91e 100644 --- a/group_test.go +++ b/group_test.go @@ -1,7 +1,9 @@ package echo import ( + "io/ioutil" "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -26,6 +28,19 @@ func TestGroup(t *testing.T) { g.File("/walle", "_fixture/images//walle.png") } +func TestGroupFile(t *testing.T) { + e := New() + g := e.Group("/group") + g.File("/walle", "_fixture/images/walle.png") + expectedData, err := ioutil.ReadFile("_fixture/images/walle.png") + assert.Nil(t, err) + req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, expectedData, rec.Body.Bytes()) +} + func TestGroupRouteMiddleware(t *testing.T) { // Ensure middleware slices are not re-used e := New() diff --git a/middleware/compress.go b/middleware/compress.go index 89da16ef..e4f9fc51 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strings" + "sync" "github.com/labstack/echo/v4" ) @@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { config.Level = DefaultGzipConfig.Level } + pool := gzipPool(config) + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { @@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 - rw := res.Writer - w, err := gzip.NewWriterLevel(rw, config.Level) - if err != nil { - return err + i := pool.Get() + w, ok := i.(*gzip.Writer) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) } + rw := res.Writer + w.Reset(rw) defer func() { if res.Size == 0 { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { @@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { w.Reset(ioutil.Discard) } w.Close() + pool.Put(w) }() grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} res.Writer = grw @@ -119,3 +125,22 @@ func (w *gzipResponseWriter) Flush() { func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return w.ResponseWriter.(http.Hijacker).Hijack() } + +func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { + if p, ok := w.ResponseWriter.(http.Pusher); ok { + return p.Push(target, opts) + } + return http.ErrNotSupported +} + +func gzipPool(config GzipConfig) sync.Pool { + return sync.Pool{ + New: func() interface{} { + w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) + if err != nil { + return err + } + return w + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index ac5b6c3b..d16ffca4 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } +func TestGzipErrorReturnedInvalidConfig(t *testing.T) { + e := echo.New() + // Invalid level + e.Use(GzipWithConfig(GzipConfig{Level: 12})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "gzip") +} + // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() @@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) { } } } + +func BenchmarkGzip(b *testing.B) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + + h := Gzip()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Gzip + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} diff --git a/middleware/cors.go b/middleware/cors.go index 5dfe31f9..d6ef8964 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "regexp" "strconv" "strings" @@ -18,6 +19,13 @@ type ( // Optional. Default value []string{"*"}. AllowOrigins []string `yaml:"allow_origins"` + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. @@ -76,6 +84,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { config.AllowMethods = DefaultCORSConfig.AllowMethods } + allowOriginPatterns := []string{} + for _, origin := range config.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.Replace(pattern, "\\*", ".*", -1) + pattern = strings.Replace(pattern, "\\?", ".", -1) + pattern = "^" + pattern + "$" + allowOriginPatterns = append(allowOriginPatterns, pattern) + } + allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") @@ -92,25 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { origin := req.Header.Get(echo.HeaderOrigin) allowOrigin := "" - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { - allowOrigin = origin - break + preflight := req.Method == http.MethodOptions + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + + // No Origin provided + if origin == "" { + if !preflight { + return next(c) } - if o == "*" || o == origin { - allowOrigin = o - break + return c.NoContent(http.StatusNoContent) + } + + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err } - if matchSubdomain(origin, o) { + if allowed { allowOrigin = origin - break + } + } else { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials { + allowOrigin = origin + break + } + if o == "*" || o == origin { + allowOrigin = o + break + } + if matchSubdomain(origin, o) { + allowOrigin = origin + break + } + } + + // Check allowed origin patterns + for _, re := range allowOriginPatterns { + if allowOrigin == "" { + didx := strings.Index(origin, "://") + if didx == -1 { + continue + } + domAuth := origin[didx+3:] + // to avoid regex cost by invalid long domain + if len(domAuth) > 253 { + break + } + + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } } } + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) + } + // Simple request - if req.Method != http.MethodOptions { - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + if !preflight { res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") @@ -122,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } // Preflight request - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 456ec7b3..717abe49 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -17,19 +18,31 @@ func TestCORS(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := CORS()(echo.NotFoundHandler) + req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + // Wildcard AllowedOrigin with no Origin header in request + req = httptest.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h = CORS()(echo.NotFoundHandler) + h(c) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + // Allow origins req = httptest.NewRequest(http.MethodGet, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, })(echo.NotFoundHandler) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) // Preflight request req = httptest.NewRequest(http.MethodOptions, "/", nil) @@ -67,6 +80,22 @@ func TestCORS(t *testing.T) { assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) + // Preflight request with Access-Control-Request-Headers + req = httptest.NewRequest(http.MethodOptions, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, "localhost") + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header") + cors = CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + }) + h = cors(echo.NotFoundHandler) + h(c) + assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) + assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) + // Preflight request with `AllowOrigins` which allow all subdomains with * req = httptest.NewRequest(http.MethodOptions, "/", nil) rec = httptest.NewRecorder() @@ -83,3 +112,298 @@ func TestCORS(t *testing.T) { h(c) assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } + +func Test_allowOriginScheme(t *testing.T) { + tests := []struct { + domain, pattern string + expected bool + }{ + { + domain: "http://example.com", + pattern: "http://example.com", + expected: true, + }, + { + domain: "https://example.com", + pattern: "https://example.com", + expected: true, + }, + { + domain: "http://example.com", + pattern: "https://example.com", + expected: false, + }, + { + domain: "https://example.com", + pattern: "http://example.com", + expected: false, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, tt.domain) + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.pattern}, + }) + h := cors(echo.NotFoundHandler) + h(c) + + if tt.expected { + assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + } + } +} + +func Test_allowOriginSubdomain(t *testing.T) { + tests := []struct { + domain, pattern string + expected bool + }{ + { + domain: "http://aaa.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://bbb.aaa.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://bbb.aaa.example.com", + pattern: "http://*.aaa.example.com", + expected: true, + }, + { + domain: "http://aaa.example.com:8080", + pattern: "http://*.example.com:8080", + expected: true, + }, + + { + domain: "http://fuga.hoge.com", + pattern: "http://*.example.com", + expected: false, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://*.aaa.example.com", + expected: false, + }, + { + domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, + pattern: "http://*.example.com", + expected: false, + }, + { + domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, + pattern: "http://*.example.com", + expected: false, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://example.com", + expected: false, + }, + { + domain: "https://prod-preview--aaa.bbb.com", + pattern: "https://*--aaa.bbb.com", + expected: true, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://foo.[a-z]*.example.com", + expected: false, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, tt.domain) + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.pattern}, + }) + h := cors(echo.NotFoundHandler) + h(c) + + if tt.expected { + assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + } + } +} + +func TestCorsHeaders(t *testing.T) { + tests := []struct { + domain, allowedOrigin, method string + expected bool + }{ + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodOptions, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: true, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(tt.method, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tt.domain != "" { + req.Header.Set(echo.HeaderOrigin, tt.domain) + } + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.allowedOrigin}, + //AllowCredentials: true, + //MaxAge: 3600, + }) + h := cors(echo.NotFoundHandler) + h(c) + + assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + + expectedAllowOrigin := "" + if tt.allowedOrigin == "*" { + expectedAllowOrigin = "*" + } else { + expectedAllowOrigin = tt.domain + } + + switch { + case tt.expected && tt.method == http.MethodOptions: + assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) + case tt.expected && tt.method == http.MethodGet: + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + default: + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + } + + if tt.method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, rec.Code) + } + } +} + +func Test_allowOriginFunc(t *testing.T) { + returnTrue := func(origin string) (bool, error) { + return true, nil + } + returnFalse := func(origin string) (bool, error) { + return false, nil + } + returnError := func(origin string) (bool, error) { + return true, errors.New("this is a test error") + } + + allowOriginFuncs := []func(origin string) (bool, error){ + returnTrue, + returnFalse, + returnError, + } + + const origin = "http://example.com" + + e := echo.New() + for _, allowOriginFunc := range allowOriginFuncs { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, origin) + cors := CORSWithConfig(CORSConfig{ + AllowOriginFunc: allowOriginFunc, + }) + h := cors(echo.NotFoundHandler) + err := h(c) + + expected, expectedErr := allowOriginFunc(origin) + if expectedErr != nil { + assert.Equal(t, expectedErr, err) + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + continue + } + + if expected { + assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } + } +} diff --git a/middleware/decompress.go b/middleware/decompress.go new file mode 100644 index 00000000..99eaf066 --- /dev/null +++ b/middleware/decompress.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "bytes" + "compress/gzip" + "github.com/labstack/echo/v4" + "io" + "io/ioutil" +) + +type ( + // DecompressConfig defines the config for Decompress middleware. + DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + } +) + +//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +const GZIPEncoding string = "gzip" + +var ( + //DefaultDecompressConfig defines the config for decompress middleware + DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper} +) + +//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +func Decompress() echo.MiddlewareFunc { + return DecompressWithConfig(DefaultDecompressConfig) +} + +//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + switch c.Request().Header.Get(echo.HeaderContentEncoding) { + case GZIPEncoding: + gr, err := gzip.NewReader(c.Request().Body) + if err != nil { + if err == io.EOF { //ignore if body is empty + return next(c) + } + return err + } + defer gr.Close() + var buf bytes.Buffer + io.Copy(&buf, gr) + r := ioutil.NopCloser(&buf) + defer r.Close() + c.Request().Body = r + } + return next(c) + } + } +} diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go new file mode 100644 index 00000000..772c14f6 --- /dev/null +++ b/middleware/decompress_test.go @@ -0,0 +1,148 @@ +package middleware + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestDecompress(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Skip if no Content-Encoding header + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + +func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { + e := echo.New() + body := `{"name":"echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(t, err) + assert.NotEqual(t, b, body) + assert.Equal(t, b, gz) +} + +func TestDecompressNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Decompress()(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestDecompressErrorReturned(t *testing.T) { + e := echo.New() + e.Use(Decompress()) + e.GET("/", func(c echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestDecompressSkipper(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: func(c echo.Context) bool { + return c.Request().URL.Path == "/skip" + }, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) +} + +func BenchmarkDecompress(b *testing.B) { + e := echo.New() + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte(body)) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Decompress + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} + +func gzipString(body string) ([]byte, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + + _, err := gz.Write([]byte(body)) + if err != nil { + return nil, err + } + + if err := gz.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/middleware/jwt.go b/middleware/jwt.go index 3c7c4868..bab00c9f 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -86,6 +86,7 @@ const ( // Errors var ( ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") + ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") ) var ( @@ -213,8 +214,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return config.ErrorHandlerWithContext(err, c) } return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "invalid or expired jwt", + Code: ErrJWTInvalid.Code, + Message: ErrJWTInvalid.Message, Internal: err, } } diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 1731d90f..ce44f9c9 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -60,8 +60,6 @@ func TestJWTRace(t *testing.T) { func TestJWT(t *testing.T) { e := echo.New() - r := e.Router() - r.Add("GET", "/:jwt", func(echo.Context) error { return nil }) handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } diff --git a/middleware/middleware.go b/middleware/middleware.go index d0b7153c..60834b50 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,6 +1,7 @@ package middleware import ( + "net/http" "regexp" "strconv" "strings" @@ -32,6 +33,31 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return strings.NewReplacer(replace...) } +func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { + // Initialize + rulesRegex := map[*regexp.Regexp]string{} + for k, v := range rewrite { + k = regexp.QuoteMeta(k) + k = strings.Replace(k, `\*`, "(.*)", -1) + if strings.HasPrefix(k, `\^`) { + k = strings.Replace(k, `\^`, "^", -1) + } + k = k + "$" + rulesRegex[regexp.MustCompile(k)] = v + } + return rulesRegex +} + +func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { + for k, v := range rewriteRegex { + replacerRawPath := captureTokens(k, req.URL.EscapedPath()) + if replacerRawPath != nil { + replacerPath := captureTokens(k, req.URL.Path) + req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v) + } + } +} + // DefaultSkipper returns false which processes the middleware. func DefaultSkipper(echo.Context) bool { return false diff --git a/middleware/proxy.go b/middleware/proxy.go index 1da370db..1b972eb1 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "regexp" - "strings" "sync" "sync/atomic" "time" @@ -45,6 +44,9 @@ type ( // Examples: If custom TLS certificates are required. Transport http.RoundTripper + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error + rewriteRegex map[*regexp.Regexp]string } @@ -203,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { if config.Balancer == nil { panic("echo: proxy middleware requires balancer") } - config.rewriteRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rewrite { - k = strings.Replace(k, "*", "(\\S*)", -1) - config.rewriteRegex[regexp.MustCompile(k)] = v - } + config.rewriteRegex = rewriteRulesRegex(config.Rewrite) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -222,13 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { tgt := config.Balancer.Next(c) c.Set(config.ContextKey, tgt) - // Rewrite - for k, v := range config.rewriteRegex { - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - req.URL.Path = replacer.Replace(v) - } - } + // Set rewrite path and raw path + rewritePath(config.rewriteRegex, req) // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. @@ -259,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } } + + diff --git a/middleware/proxy_1_11.go b/middleware/proxy_1_11.go index 12b7568b..a4392781 100644 --- a/middleware/proxy_1_11.go +++ b/middleware/proxy_1_11.go @@ -20,5 +20,6 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))) } proxy.Transport = config.Transport + proxy.ModifyResponse = config.ModifyResponse return proxy } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 40d150cf..534e45f4 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -1,7 +1,9 @@ package middleware import ( + "bytes" "fmt" + "io/ioutil" "net" "net/http" "net/http/httptest" @@ -12,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" ) +//Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -91,19 +94,49 @@ func TestProxy(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.Path) - req.URL.Path = "/js/main.js" - e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.Path) - req.URL.Path = "/old" - e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) - req.URL.Path = "/users/jack/orders/1" - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.Path) + assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse( "/js/main.js") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/old") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/new", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse( "/users/jack/orders/1") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/api/new users") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) + // ModifyResponse + e = echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + ModifyResponse: func(res *http.Response) error { + res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified"))) + res.Header.Set("X-Modified", "1") + return nil + }, + })) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "modified", rec.Body.String()) + assert.Equal(t, "1", rec.Header().Get("X-Modified")) // ProxyTarget is set in context contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc { diff --git a/middleware/recover.go b/middleware/recover.go index e87aaf32..0dbe740d 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,6 +5,7 @@ import ( "runtime" "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" ) type ( @@ -25,6 +26,10 @@ type ( // DisablePrintStack disables printing stack trace. // Optional. Default value as false. DisablePrintStack bool `yaml:"disable_print_stack"` + + // LogLevel is log level to printing stack trace. + // Optional. Default value 0 (Print). + LogLevel log.Lvl } ) @@ -35,6 +40,7 @@ var ( StackSize: 4 << 10, // 4 KB DisableStackAll: false, DisablePrintStack: false, + LogLevel: 0, } ) @@ -70,7 +76,21 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { stack := make([]byte, config.StackSize) length := runtime.Stack(stack, !config.DisableStackAll) if !config.DisablePrintStack { - c.Logger().Printf("[PANIC RECOVER] %v %s\n", err, stack[:length]) + msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) + switch config.LogLevel { + case log.DEBUG: + c.Logger().Debug(msg) + case log.INFO: + c.Logger().Info(msg) + case log.WARN: + c.Logger().Warn(msg) + case log.ERROR: + c.Logger().Error(msg) + case log.OFF: + // None. + default: + c.Logger().Print(msg) + } } c.Error(err) } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 37707c5c..64433297 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,11 +2,13 @@ package middleware import ( "bytes" + "fmt" "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" "github.com/stretchr/testify/assert" ) @@ -24,3 +26,58 @@ func TestRecover(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, buf.String(), "PANIC RECOVER") } + +func TestRecoverWithConfig_LogLevel(t *testing.T) { + tests := []struct { + logLevel log.Lvl + levelName string + }{{ + logLevel: log.DEBUG, + levelName: "DEBUG", + }, { + logLevel: log.INFO, + levelName: "INFO", + }, { + logLevel: log.WARN, + levelName: "WARN", + }, { + logLevel: log.ERROR, + levelName: "ERROR", + }, { + logLevel: log.OFF, + levelName: "OFF", + }} + + for _, tt := range tests { + tt := tt + t.Run(tt.levelName, func(t *testing.T) { + e := echo.New() + e.Logger.SetLevel(log.DEBUG) + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := DefaultRecoverConfig + config.LogLevel = tt.logLevel + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("test") + })) + + h(c) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + if tt.logLevel == log.OFF { + assert.Empty(t, output) + } else { + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) + } + }) + } +} diff --git a/middleware/rewrite.go b/middleware/rewrite.go index a64e10bb..0965e313 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,10 +1,8 @@ package middleware import ( - "regexp" - "strings" - "github.com/labstack/echo/v4" + "regexp" ) type ( @@ -53,14 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultBodyDumpConfig.Skipper } - config.rulesRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rules { - k = strings.Replace(k, "*", "(.*)", -1) - k = k + "$" - config.rulesRegex[regexp.MustCompile(k)] = v - } + config.rulesRegex = rewriteRulesRegex(config.Rules) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -69,15 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } req := c.Request() - - // Rewrite - for k, v := range config.rulesRegex { - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - req.URL.Path = replacer.Replace(v) - break - } - } + // Set rewrite path and raw path + rewritePath(config.rulesRegex, req) return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index eb5a46d8..abf11b2f 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -4,12 +4,14 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) +//Assert expected with url.EscapedPath method to obtain the path. func TestRewrite(t *testing.T) { e := echo.New() e.Use(RewriteWithConfig(RewriteConfig{ @@ -22,21 +24,28 @@ func TestRewrite(t *testing.T) { })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.Path) - req.URL.Path = "/js/main.js" + assert.Equal(t, "/users", req.URL.EscapedPath()) + req.URL, _ = url.Parse("/js/main.js") + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.Path) - req.URL.Path = "/old" + assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) + req.URL, _ = url.Parse("/old") + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) - req.URL.Path = "/users/jack/orders/1" + assert.Equal(t, "/new", req.URL.EscapedPath()) + req.URL, _ = url.Parse("/users/jack/orders/1") + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.Path) - req.URL.Path = "/api/new users" + assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new users", req.URL.Path) + assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) + req.URL, _ = url.Parse("/api/new users") + e.ServeHTTP(rec, req) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) } // Issue #1086 @@ -45,22 +54,21 @@ func TestEchoRewritePreMiddleware(t *testing.T) { r := e.Router() // Rewrite old url to new one - e.Pre(RewriteWithConfig(RewriteConfig{ - Rules: map[string]string{ + e.Pre(Rewrite(map[string]string{ "/old": "/new", }, - })) + )) // Route r.Add(http.MethodGet, "/new", func(c echo.Context) error { - return c.NoContent(200) + return c.NoContent(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/old", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) - assert.Equal(t, 200, rec.Code) + assert.Equal(t, "/new", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) } // Issue #1143 @@ -76,21 +84,48 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { })) r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { - return c.String(200, "hosts") + return c.String(http.StatusOK, "hosts") }) r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { - return c.String(200, "eng") + return c.String(http.StatusOK, "eng") }) for i := 0; i < 100; i++ { req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/api/v1/hosts/test", req.URL.Path) - assert.Equal(t, 200, rec.Code) + assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) defer rec.Result().Body.Close() bodyBytes, _ := ioutil.ReadAll(rec.Result().Body) assert.Equal(t, "hosts", string(bodyBytes)) } } + +// Issue #1573 +func TestEchoRewriteWithCaret(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/abc/*": "/v1/abc/$1", + }, + })) + + rec := httptest.NewRecorder() + + var req *http.Request + + req = httptest.NewRequest(http.MethodGet, "/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/abc/test", req.URL.Path) + + req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/abc/test", req.URL.Path) + + req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v2/abc/test", req.URL.Path) +} diff --git a/middleware/static.go b/middleware/static.go index bc2087a7..58b7890a 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -36,6 +36,12 @@ type ( // Enable directory browsing. // Optional. Default value false. Browse bool `yaml:"browse"` + + // Enable ignoring of the base of the URL path. + // Example: when assigning a static middleware to a non root path group, + // the filesystem path is not doubled + // Optional. Default value false. + IgnoreBase bool `yaml:"ignoreBase"` } ) @@ -163,6 +169,15 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security + if config.IgnoreBase { + routePath := path.Base(strings.TrimRight(c.Path(), "/*")) + baseURLPath := path.Base(p) + if baseURLPath == routePath { + i := strings.LastIndex(name, routePath) + name = name[:i] + strings.Replace(name[i:], routePath, "", 1) + } + } + fi, err := os.Stat(name) if err != nil { if os.IsNotExist(err) { diff --git a/middleware/static_test.go b/middleware/static_test.go index 0d695d3d..407dd15c 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "path/filepath" "testing" "github.com/labstack/echo/v4" @@ -67,4 +68,27 @@ func TestStatic(t *testing.T) { assert.Equal(http.StatusOK, rec.Code) assert.Contains(rec.Body.String(), "cert.pem") } + + // IgnoreBase + req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) + rec = httptest.NewRecorder() + config.Root = "../_fixture" + config.IgnoreBase = true + static = StaticWithConfig(config) + c.Echo().Group("_fixture", static) + e.ServeHTTP(rec, req) + + assert.Equal(http.StatusOK, rec.Code) + assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122") + + req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) + rec = httptest.NewRecorder() + config.Root = "../_fixture" + config.IgnoreBase = false + static = StaticWithConfig(config) + c.Echo().Group("_fixture", static) + e.ServeHTTP(rec, req) + + assert.Equal(http.StatusOK, rec.Code) + assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture")) } diff --git a/response_test.go b/response_test.go index bc570a50..7a9c51c6 100644 --- a/response_test.go +++ b/response_test.go @@ -19,8 +19,13 @@ func TestResponse(t *testing.T) { res.Before(func() { c.Response().Header().Set(HeaderServer, "echo") }) + // After + res.After(func() { + c.Response().Header().Set(HeaderXFrameOptions, "DENY") + }) res.Write([]byte("test")) assert.Equal(t, "echo", rec.Header().Get(HeaderServer)) + assert.Equal(t, "DENY", rec.Header().Get(HeaderXFrameOptions)) } func TestResponse_Write_FallsBackToDefaultStatus(t *testing.T) { @@ -41,3 +46,13 @@ func TestResponse_Write_UsesSetResponseCode(t *testing.T) { res.Write([]byte("test")) assert.Equal(t, http.StatusBadRequest, rec.Code) } + +func TestResponse_Flush(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + res := &Response{echo: e, Writer: rec} + + res.Write([]byte("test")) + res.Flush() + assert.True(t, rec.Flushed) +} diff --git a/router.go b/router.go index 15a3398f..ed728d6a 100644 --- a/router.go +++ b/router.go @@ -355,6 +355,10 @@ func (r *Router) Find(method, path string, c Context) { // Attempt to go back up the tree on no matching prefix or no remaining search if l != pl || search == "" { + // Handle special case of trailing slash route with existing any route (see #1526) + if path[len(path)-1] == '/' && cn.findChildByKind(akind) != nil { + goto Any + } if nn == nil { // Issue #1348 return // Not found } diff --git a/router_test.go b/router_test.go index 8c27b9f7..0e883233 100644 --- a/router_test.go +++ b/router_test.go @@ -608,7 +608,6 @@ func TestRouterMatchAny(t *testing.T) { return nil }) c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) assert.Equal(t, "", c.Param("*")) @@ -619,6 +618,78 @@ func TestRouterMatchAny(t *testing.T) { assert.Equal(t, "joe", c.Param("*")) } +// TestRouterMatchAnySlash shall verify finding the best route +// for any routes with trailing slash requests +func TestRouterMatchAnySlash(t *testing.T) { + e := New() + r := e.router + + handler := func(c Context) error { + c.Set("path", c.Path()) + return nil + } + + // Routes + r.Add(http.MethodGet, "/users", handler) + r.Add(http.MethodGet, "/users/*", handler) + r.Add(http.MethodGet, "/img/*", handler) + r.Add(http.MethodGet, "/img/load", handler) + r.Add(http.MethodGet, "/img/load/*", handler) + r.Add(http.MethodGet, "/assets/*", handler) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/", c) + assert.Equal(t, "", c.Param("*")) + + // Test trailing slash request for simple any route (see #1526) + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/", c) + c.handler(c) + assert.Equal(t, "/users/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/joe", c) + c.handler(c) + assert.Equal(t, "/users/*", c.Get("path")) + assert.Equal(t, "joe", c.Param("*")) + + // Test trailing slash request for nested any route (see #1526) + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/img/load", c) + c.handler(c) + assert.Equal(t, "/img/load", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/img/load/", c) + c.handler(c) + assert.Equal(t, "/img/load/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/img/load/ben", c) + c.handler(c) + assert.Equal(t, "/img/load/*", c.Get("path")) + assert.Equal(t, "ben", c.Param("*")) + + // Test /assets/* any route + // ... without trailing slash must not match + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/assets", c) + c.handler(c) + assert.Equal(t, nil, c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + // ... with trailing slash must match + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/assets/", c) + c.handler(c) + assert.Equal(t, "/assets/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + +} + func TestRouterMatchAnyMultiLevel(t *testing.T) { e := New() r := e.router