diff --git a/.github/stale.yml b/.github/stale.yml index d9f65632..04dd169c 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -1,17 +1,19 @@ # Number of days of inactivity before an issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 +daysUntilClose: 30 # Issues with these labels will never be considered stale exemptLabels: - pinned - security + - bug + - enhancement # Label to use when marking an issue as stale -staleLabel: wontfix +staleLabel: stale # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. + recent activity. It will be closed within a month if no further activity occurs. + Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable -closeComment: false \ No newline at end of file +closeComment: false diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 1d7508b9..2aec272d 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -4,20 +4,28 @@ on: push: branches: - master + paths: + - '**.go' + - 'go.*' + - '_fixture/**' + - '.github/**' + - 'codecov.yml' pull_request: branches: - master - -env: - GO111MODULE: on - GOPROXY: https://proxy.golang.org + paths: + - '**.go' + - 'go.*' + - '_fixture/**' + - '.github/**' + - 'codecov.yml' jobs: test: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go: [1.12, 1.13, 1.14] + go: [1.12, 1.13, 1.14, 1.15] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -28,10 +36,15 @@ jobs: - name: Set GOPATH and PATH run: | - echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)" - echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin" + echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV + echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH shell: bash + - name: Set build variables + run: | + echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV + echo "GO111MODULE=on" >> $GITHUB_ENV + - name: Checkout Code uses: actions/checkout@v1 with: @@ -51,3 +64,55 @@ jobs: with: token: fail_ci_if_error: false + benchmark: + needs: test + strategy: + matrix: + os: [ubuntu-latest] + go: [1.15] + name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} + runs-on: ${{ matrix.os }} + steps: + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go }} + + - name: Set GOPATH and PATH + run: | + echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV + echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH + shell: bash + + - name: Set build variables + run: | + echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV + echo "GO111MODULE=on" >> $GITHUB_ENV + + - name: Checkout Code (Previous) + uses: actions/checkout@v2 + with: + ref: ${{ github.base_ref }} + path: previous + + - name: Checkout Code (New) + uses: actions/checkout@v2 + with: + path: new + + - name: Install Dependencies + run: go get -v golang.org/x/perf/cmd/benchstat + + - name: Run Benchmark (Previous) + run: | + cd previous + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + + - name: Run Benchmark (New) + run: | + cd new + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + + - name: Run Benchstat + run: | + benchstat previous/benchmark.txt new/benchmark.txt diff --git a/.travis.yml b/.travis.yml index ef826e95..67d45ad7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,7 @@ +arch: + - amd64 + - ppc64le + language: go go: - 1.14.x diff --git a/README.md b/README.md index 03ad4dca..deba54f4 100644 --- a/README.md +++ b/README.md @@ -42,11 +42,14 @@ For older versions, please use the latest v3 tag. ## Benchmarks -Date: 2018/03/15
+Date: 2020/11/11
Source: https://github.com/vishr/web-framework-benchmark
Lower is better! - + + + +The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz ## [Guide](https://echo.labstack.com/guide) diff --git a/_fixture/_fixture/README.md b/_fixture/_fixture/README.md new file mode 100644 index 00000000..21a78585 --- /dev/null +++ b/_fixture/_fixture/README.md @@ -0,0 +1 @@ +This directory is used for the static middleware test \ No newline at end of file diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..0fa3a3f1 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +coverage: + status: + project: + default: + threshold: 1% + patch: + default: + threshold: 1% + +comment: + require_changes: true \ No newline at end of file diff --git a/context.go b/context.go index 12aea7a2..380a52e1 100644 --- a/context.go +++ b/context.go @@ -385,7 +385,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) { if err != nil { return nil, err } - defer f.Close() + f.Close() return fh, nil } diff --git a/echo.go b/echo.go index 128f84fd..d284ff39 100644 --- a/echo.go +++ b/echo.go @@ -49,7 +49,6 @@ import ( "net/http" "net/url" "os" - "path" "path/filepath" "reflect" "runtime" @@ -92,6 +91,7 @@ type ( Renderer Renderer Logger Logger IPExtractor IPExtractor + ListenerNetwork string } // Route contains a handler and information for matching against requests. @@ -281,6 +281,7 @@ var ( ErrInvalidRedirectCode = errors.New("invalid redirect status code") ErrCookieNotFound = errors.New("cookie not found") ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") ) // Error handlers @@ -302,9 +303,10 @@ func New() (e *Echo) { AutoTLSManager: autocert.Manager{ Prompt: autocert.AcceptTOS, }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), + Logger: log.New("echo"), + colorer: color.New(), + maxParam: new(int), + ListenerNetwork: "tcp", } e.Server.Handler = e e.TLSServer.Handler = e @@ -362,10 +364,12 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { // Issue #1426 code := he.Code message := he.Message - if e.Debug { - message = err.Error() - } else if m, ok := message.(string); ok { - message = Map{"message": m} + if m, ok := he.Message.(string); ok { + if e.Debug { + message = Map{"message": m, "error": err.Error()} + } else { + message = Map{"message": m} + } } // Send response @@ -481,7 +485,7 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl return err } - name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security + name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security fi, err := os.Stat(name) if err != nil { // The access path does not exist @@ -570,7 +574,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string { for _, r := range e.router.routes { if r.Name == name { for i, l := 0, len(r.Path); i < l; i++ { - if r.Path[i] == ':' && n < ln { + if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { for ; i < l && r.Path[i] != '/'; i++ { } uri.WriteString(fmt.Sprintf("%v", params[n])) @@ -712,7 +716,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if s.TLSConfig == nil { if e.Listener == nil { - e.Listener, err = newListener(s.Addr) + e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -723,7 +727,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { return s.Serve(e.Listener) } if e.TLSListener == nil { - l, err := newListener(s.Addr) + l, err := newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -752,7 +756,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { } if e.Listener == nil { - e.Listener, err = newListener(s.Addr) + e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -873,8 +877,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { return } -func newListener(address string) (*tcpKeepAliveListener, error) { - l, err := net.Listen("tcp", address) +func newListener(address, network string) (*tcpKeepAliveListener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, ErrInvalidListenerNetwork + } + l, err := net.Listen(network, address) if err != nil { return nil, err } diff --git a/echo_test.go b/echo_test.go index e1706eff..a6071e12 100644 --- a/echo_test.go +++ b/echo_test.go @@ -4,6 +4,7 @@ import ( "bytes" stdContext "context" "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -59,45 +60,105 @@ func TestEcho(t *testing.T) { } func TestEchoStatic(t *testing.T) { - e := New() + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } - assert := assert.New(t) - - // OK - e.Static("/images", "_fixture/images") - c, b := request(http.MethodGet, "/images/walle.png", e) - assert.Equal(http.StatusOK, c) - assert.NotEmpty(b) - - // No file - e.Static("/images", "_fixture/scripts") - c, _ = request(http.MethodGet, "/images/bolt.png", e) - assert.Equal(http.StatusNotFound, c) - - // Directory - e.Static("/images", "_fixture/images") - c, _ = request(http.MethodGet, "/images/", e) - assert.Equal(http.StatusNotFound, c) - - // Directory Redirect - e.Static("/", "_fixture") - req := httptest.NewRequest(http.MethodGet, "/folder", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(http.StatusMovedPermanently, rec.Code) - assert.Equal("/folder/", rec.HeaderMap["Location"][0]) - - // Directory with index.html - e.Static("/", "_fixture") - c, r := request(http.MethodGet, "/", e) - assert.Equal(http.StatusOK, c) - assert.Equal(true, strings.HasPrefix(r, "")) - - // Sub-directory with index.html - c, r = request(http.MethodGet, "/folder/", e) - assert.Equal(http.StatusOK, c) - assert.Equal(true, strings.HasPrefix(r, "")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Static(tc.givenPrefix, tc.givenRoot) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } } func TestEchoFile(t *testing.T) { @@ -277,10 +338,12 @@ func TestEchoURL(t *testing.T) { e := New() static := func(Context) error { return nil } getUser := func(Context) error { return nil } + getAny := func(Context) error { return nil } getFile := func(Context) error { return nil } e.GET("/static/file", static) e.GET("/users/:id", getUser) + e.GET("/documents/*", getAny) g := e.Group("/group") g.GET("/users/:uid/files/:fid", getFile) @@ -289,6 +352,9 @@ func TestEchoURL(t *testing.T) { assert.Equal("/static/file", e.URL(static)) assert.Equal("/users/:id", e.URL(getUser)) assert.Equal("/users/1", e.URL(getUser, "1")) + assert.Equal("/users/1", e.URL(getUser, "1")) + assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) + assert.Equal("/documents/*", e.URL(getAny)) assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) } @@ -568,6 +634,49 @@ func TestHTTPError(t *testing.T) { }) } +func TestDefaultHTTPErrorHandler(t *testing.T) { + e := New() + e.Debug = true + e.Any("/plain", func(c Context) error { + return errors.New("An error occurred") + }) + e.Any("/badrequest", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, "Invalid request") + }) + e.Any("/servererror", func(c Context) error { + return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ + "code": 33, + "message": "Something bad happened", + "error": "stackinfo", + }) + }) + // With Debug=true plain response contains error message + c, b := request(http.MethodGet, "/plain", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) + // and special handling for HTTPError + c, b = request(http.MethodGet, "/badrequest", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) + // complex errors are serialized to pretty JSON + c, b = request(http.MethodGet, "/servererror", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) + + e.Debug = false + // With Debug=false the error response is shortened + c, b = request(http.MethodGet, "/plain", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) + c, b = request(http.MethodGet, "/badrequest", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) + // No difference for error response with non plain string errors + c, b = request(http.MethodGet, "/servererror", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) +} + func TestEchoClose(t *testing.T) { e := New() errCh := make(chan error) @@ -609,3 +718,91 @@ func TestEchoShutdown(t *testing.T) { err := <-errCh assert.Equal(t, err.Error(), "http: Server closed") } + +var listenerNetworkTests = []struct { + test string + network string + address string +}{ + {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, + {"tcp ipv6 address", "tcp", "[::1]:1323"}, + {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, + {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, +} + +func TestEchoListenerNetwork(t *testing.T) { + for _, tt := range listenerNetworkTests { + t.Run(tt.test, func(t *testing.T) { + e := New() + e.ListenerNetwork = tt.network + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errCh := make(chan error) + + go func() { + errCh <- e.Start(tt.address) + }() + + time.Sleep(200 * time.Millisecond) + + if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + if body, err := ioutil.ReadAll(resp.Body); err == nil { + assert.Equal(t, "OK", string(body)) + } else { + assert.Fail(t, err.Error()) + } + + } else { + assert.Fail(t, err.Error()) + } + + if err := e.Close(); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestEchoListenerNetworkInvalid(t *testing.T) { + e := New() + e.ListenerNetwork = "unix" + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) +} + +func TestEchoReverse(t *testing.T) { + assert := assert.New(t) + + e := New() + dummyHandler := func(Context) error { return nil } + + e.GET("/static", dummyHandler).Name = "/static" + e.GET("/static/*", dummyHandler).Name = "/static/*" + e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" + e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" + e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" + + assert.Equal("/static", e.Reverse("/static")) + assert.Equal("/static", e.Reverse("/static", "missing param")) + assert.Equal("/static/*", e.Reverse("/static/*")) + assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) + + assert.Equal("/params/:foo", e.Reverse("/params/:foo")) + assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) + assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) + assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +} diff --git a/middleware/compress.go b/middleware/compress.go index dd97d983..6ae19745 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strings" + "sync" "github.com/labstack/echo/v4" ) @@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { config.Level = DefaultGzipConfig.Level } + pool := gzipCompressPool(config) + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { @@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 - rw := res.Writer - w, err := gzip.NewWriterLevel(rw, config.Level) - if err != nil { - return err + i := pool.Get() + w, ok := i.(*gzip.Writer) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) } + rw := res.Writer + w.Reset(rw) defer func() { if res.Size == 0 { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { @@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { w.Reset(ioutil.Discard) } w.Close() + pool.Put(w) }() grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} res.Writer = grw @@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { } return http.ErrNotSupported } + +func gzipCompressPool(config GzipConfig) sync.Pool { + return sync.Pool{ + New: func() interface{} { + w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) + if err != nil { + return err + } + return w + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index ac5b6c3b..d16ffca4 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } +func TestGzipErrorReturnedInvalidConfig(t *testing.T) { + e := echo.New() + // Invalid level + e.Use(GzipWithConfig(GzipConfig{Level: 12})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "gzip") +} + // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() @@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) { } } } + +func BenchmarkGzip(b *testing.B) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + + h := Gzip()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Gzip + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} diff --git a/middleware/cors.go b/middleware/cors.go index c263f731..d6ef8964 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -19,6 +19,13 @@ type ( // Optional. Default value []string{"*"}. AllowOrigins []string `yaml:"allow_origins"` + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. @@ -102,45 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { origin := req.Header.Get(echo.HeaderOrigin) allowOrigin := "" - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { - allowOrigin = origin - break - } - if o == "*" || o == origin { - allowOrigin = o - break - } - if matchSubdomain(origin, o) { - allowOrigin = origin - break + preflight := req.Method == http.MethodOptions + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + + // No Origin provided + if origin == "" { + if !preflight { + return next(c) } + return c.NoContent(http.StatusNoContent) } - // Check allowed origin patterns - for _, re := range allowOriginPatterns { - if allowOrigin == "" { - didx := strings.Index(origin, "://") - if didx == -1 { - continue - } - domAuth := origin[didx+3:] - // to avoid regex cost by invalid long domain - if len(domAuth) > 253 { + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { + allowOrigin = origin + } + } else { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials { + allowOrigin = origin break } - - if match, _ := regexp.MatchString(re, origin); match { + if o == "*" || o == origin { + allowOrigin = o + break + } + if matchSubdomain(origin, o) { allowOrigin = origin break } } + + // Check allowed origin patterns + for _, re := range allowOriginPatterns { + if allowOrigin == "" { + didx := strings.Index(origin, "://") + if didx == -1 { + continue + } + domAuth := origin[didx+3:] + // to avoid regex cost by invalid long domain + if len(domAuth) > 253 { + break + } + + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } + } + } + + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) } // Simple request - if req.Method != http.MethodOptions { - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + if !preflight { res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") @@ -152,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } // Preflight request - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index ca922321..717abe49 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -17,19 +18,31 @@ func TestCORS(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := CORS()(echo.NotFoundHandler) + req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + // Wildcard AllowedOrigin with no Origin header in request + req = httptest.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h = CORS()(echo.NotFoundHandler) + h(c) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + // Allow origins req = httptest.NewRequest(http.MethodGet, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, })(echo.NotFoundHandler) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) // Preflight request req = httptest.NewRequest(http.MethodOptions, "/", nil) @@ -67,6 +80,22 @@ func TestCORS(t *testing.T) { assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) + // Preflight request with Access-Control-Request-Headers + req = httptest.NewRequest(http.MethodOptions, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, "localhost") + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header") + cors = CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + }) + h = cors(echo.NotFoundHandler) + h(c) + assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) + assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) + // Preflight request with `AllowOrigins` which allow all subdomains with * req = httptest.NewRequest(http.MethodOptions, "/", nil) rec = httptest.NewRecorder() @@ -126,7 +155,7 @@ func Test_allowOriginScheme(t *testing.T) { if tt.expected { assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { - assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) } } } @@ -216,6 +245,163 @@ func Test_allowOriginSubdomain(t *testing.T) { if tt.expected { assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + } + } +} + +func TestCorsHeaders(t *testing.T) { + tests := []struct { + domain, allowedOrigin, method string + expected bool + }{ + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodOptions, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: true, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(tt.method, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tt.domain != "" { + req.Header.Set(echo.HeaderOrigin, tt.domain) + } + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.allowedOrigin}, + //AllowCredentials: true, + //MaxAge: 3600, + }) + h := cors(echo.NotFoundHandler) + h(c) + + assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + + expectedAllowOrigin := "" + if tt.allowedOrigin == "*" { + expectedAllowOrigin = "*" + } else { + expectedAllowOrigin = tt.domain + } + + switch { + case tt.expected && tt.method == http.MethodOptions: + assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) + case tt.expected && tt.method == http.MethodGet: + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + default: + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + } + + if tt.method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, rec.Code) + } + } +} + +func Test_allowOriginFunc(t *testing.T) { + returnTrue := func(origin string) (bool, error) { + return true, nil + } + returnFalse := func(origin string) (bool, error) { + return false, nil + } + returnError := func(origin string) (bool, error) { + return true, errors.New("this is a test error") + } + + allowOriginFuncs := []func(origin string) (bool, error){ + returnTrue, + returnFalse, + returnError, + } + + const origin = "http://example.com" + + e := echo.New() + for _, allowOriginFunc := range allowOriginFuncs { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, origin) + cors := CORSWithConfig(CORSConfig{ + AllowOriginFunc: allowOriginFunc, + }) + h := cors(echo.NotFoundHandler) + err := h(c) + + expected, expectedErr := allowOriginFunc(origin) + if expectedErr != nil { + assert.Equal(t, expectedErr, err) + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + continue + } + + if expected { + assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } diff --git a/middleware/decompress.go b/middleware/decompress.go new file mode 100644 index 00000000..c046359a --- /dev/null +++ b/middleware/decompress.go @@ -0,0 +1,120 @@ +package middleware + +import ( + "bytes" + "compress/gzip" + "io" + "io/ioutil" + "net/http" + "sync" + + "github.com/labstack/echo/v4" +) + +type ( + // DecompressConfig defines the config for Decompress middleware. + DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor + } +) + +//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +const GZIPEncoding string = "gzip" + +// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers +type Decompressor interface { + gzipDecompressPool() sync.Pool +} + +var ( + //DefaultDecompressConfig defines the config for decompress middleware + DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, + } +) + +// DefaultGzipDecompressPool is the default implementation of Decompressor interface +type DefaultGzipDecompressPool struct { +} + +func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + // create with an empty reader (but with GZIP header) + w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) + if err != nil { + return err + } + + b := new(bytes.Buffer) + w.Reset(b) + w.Flush() + w.Close() + + r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) + if err != nil { + return err + } + return r + }, + } +} + +//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +func Decompress() echo.MiddlewareFunc { + return DecompressWithConfig(DefaultDecompressConfig) +} + +//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultGzipConfig.Skipper + } + if config.GzipDecompressPool == nil { + config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + pool := config.GzipDecompressPool.gzipDecompressPool() + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + switch c.Request().Header.Get(echo.HeaderContentEncoding) { + case GZIPEncoding: + b := c.Request().Body + + i := pool.Get() + gr, ok := i.(*gzip.Reader) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + } + + if err := gr.Reset(b); err != nil { + pool.Put(gr) + if err == io.EOF { //ignore if body is empty + return next(c) + } + return err + } + var buf bytes.Buffer + io.Copy(&buf, gr) + + gr.Close() + pool.Put(gr) + + b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here + + r := ioutil.NopCloser(&buf) + c.Request().Body = r + } + return next(c) + } + } +} diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go new file mode 100644 index 00000000..51fa6b0f --- /dev/null +++ b/middleware/decompress_test.go @@ -0,0 +1,209 @@ +package middleware + +import ( + "bytes" + "compress/gzip" + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestDecompress(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Skip if no Content-Encoding header + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + +func TestDecompressDefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + +func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { + e := echo.New() + body := `{"name":"echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(t, err) + assert.NotEqual(t, b, body) + assert.Equal(t, b, gz) +} + +func TestDecompressNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Decompress()(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestDecompressErrorReturned(t *testing.T) { + e := echo.New() + e.Use(Decompress()) + e.GET("/", func(c echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestDecompressSkipper(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: func(c echo.Context) bool { + return c.Request().URL.Path == "/skip" + }, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) +} + +type TestDecompressPoolWithError struct { +} + +func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + return errors.New("pool error") + }, + } +} + +func TestDecompressPoolError(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &TestDecompressPoolWithError{}, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) + assert.Equal(t, rec.Code, http.StatusInternalServerError) +} + +func BenchmarkDecompress(b *testing.B) { + e := echo.New() + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte(body)) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Decompress + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} + +func gzipString(body string) ([]byte, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + + _, err := gz.Write([]byte(body)) + if err != nil { + return nil, err + } + + if err := gz.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/middleware/jwt.go b/middleware/jwt.go index bab00c9f..da00ea56 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -57,6 +57,7 @@ type ( // - "query:" // - "param:" // - "cookie:" + // - "form:" TokenLookup string // AuthScheme to be used in the Authorization header. @@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { extractor = jwtFromParam(parts[1]) case "cookie": extractor = jwtFromCookie(parts[1]) + case "form": + extractor = jwtFromForm(parts[1]) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor { return cookie.Value, nil } } + +// jwtFromForm returns a `jwtExtractor` that extracts token from the form field. +func jwtFromForm(name string) jwtExtractor { + return func(c echo.Context) (string, error) { + field := c.FormValue(name) + if field == "" { + return "", ErrJWTMissing + } + return field, nil + } +} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index ce44f9c9..205721ae 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -3,6 +3,8 @@ package middleware import ( "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "github.com/dgrijalva/jwt-go" @@ -75,6 +77,7 @@ func TestJWT(t *testing.T) { reqURL string // "/" if empty hdrAuth string hdrCookie string // test.Request doesn't provide SetCookie(); use name=val + formValues map[string]string info string }{ { @@ -192,12 +195,48 @@ func TestJWT(t *testing.T) { expErrCode: http.StatusBadRequest, info: "Empty cookie", }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + formValues: map[string]string{"jwt": token}, + info: "Valid form method", + }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + expErrCode: http.StatusUnauthorized, + formValues: map[string]string{"jwt": "invalid"}, + info: "Invalid token with form method", + }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + expErrCode: http.StatusBadRequest, + info: "Empty form field", + }, } { if tc.reqURL == "" { tc.reqURL = "/" } - req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + var req *http.Request + if len(tc.formValues) > 0 { + form := url.Values{} + for k, v := range tc.formValues { + form.Set(k, v) + } + req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) + req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") + req.ParseForm() + } else { + req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + } res := httptest.NewRecorder() req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderCookie, tc.hdrCookie) diff --git a/middleware/middleware.go b/middleware/middleware.go index 12260ddb..60834b50 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,7 +2,6 @@ package middleware import ( "net/http" - "net/url" "regexp" "strconv" "strings" @@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return strings.NewReplacer(replace...) } -//rewritePath sets request url path and raw path -func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error { - replacerRawPath := replacer.Replace(target) - replacerPath, err := url.PathUnescape(replacerRawPath) - if err != nil { - return err +func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { + // Initialize + rulesRegex := map[*regexp.Regexp]string{} + for k, v := range rewrite { + k = regexp.QuoteMeta(k) + k = strings.Replace(k, `\*`, "(.*)", -1) + if strings.HasPrefix(k, `\^`) { + k = strings.Replace(k, `\^`, "^", -1) + } + k = k + "$" + rulesRegex[regexp.MustCompile(k)] = v + } + return rulesRegex +} + +func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { + for k, v := range rewriteRegex { + replacerRawPath := captureTokens(k, req.URL.EscapedPath()) + if replacerRawPath != nil { + replacerPath := captureTokens(k, req.URL.Path) + req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v) + } } - req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath - return nil } // DefaultSkipper returns false which processes the middleware. diff --git a/middleware/proxy.go b/middleware/proxy.go index cd50b76a..1b972eb1 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "regexp" - "strings" "sync" "sync/atomic" "time" @@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { if config.Balancer == nil { panic("echo: proxy middleware requires balancer") } - config.rewriteRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rewrite { - k = strings.Replace(k, "*", "(\\S*)", -1) - config.rewriteRegex[regexp.MustCompile(k)] = v - } + config.rewriteRegex = rewriteRulesRegex(config.Rewrite) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { tgt := config.Balancer.Next(c) c.Set(config.ContextKey, tgt) - // Rewrite - for k, v := range config.rewriteRegex { - //use req.URL.Path here or else we will have double escaping - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - if err := rewritePath(replacer, v, req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid url") - } - } - } + // Set rewrite path and raw path + rewritePath(config.rewriteRegex, req) // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. @@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } } + + diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 4bb74648..534e45f4 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -94,36 +94,35 @@ func TestProxy(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/js/main.js" + req.URL, _ = url.Parse( "/js/main.js") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/old" + req.URL, _ = url.Parse("/old") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jack/orders/1" + req.URL, _ = url.Parse( "/users/jack/orders/1") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jill/orders/%%%%" + req.URL, _ = url.Parse("/api/new users") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusBadRequest, rec.Code) - + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) // ModifyResponse e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{ diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 30eecdef..86eec8c3 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -31,3 +31,20 @@ func TestRequestID(t *testing.T) { h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") } + +func TestRequestID_IDNotAltered(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRequestID, "") + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{}) + h := rid(handler) + _ = h(c) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "") +} diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 855c8633..0965e313 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,11 +1,8 @@ package middleware import ( - "net/http" - "regexp" - "strings" - "github.com/labstack/echo/v4" + "regexp" ) type ( @@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultBodyDumpConfig.Skipper } - config.rulesRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rules { - k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*)", -1) - if strings.HasPrefix(k, `\^`) { - k = strings.Replace(k, `\^`, "^", -1) - } - k = k + "$" - config.rulesRegex[regexp.MustCompile(k)] = v - } + config.rulesRegex = rewriteRulesRegex(config.Rules) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } req := c.Request() - // Rewrite - for k, v := range config.rulesRegex { - //use req.URL.Path here or else we will have double escaping - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - if err := rewritePath(replacer, v, req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid url") - } - break - } - } + // Set rewrite path and raw path + rewritePath(config.rulesRegex, req) return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index a9b3437c..abf11b2f 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/labstack/echo/v4" @@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) { })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") e.ServeHTTP(rec, req) assert.Equal(t, "/users", req.URL.EscapedPath()) - req.URL.Path = "/js/main.js" + req.URL, _ = url.Parse("/js/main.js") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) - req.URL.Path = "/old" + req.URL, _ = url.Parse("/old") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) - req.URL.Path = "/users/jack/orders/1" + req.URL, _ = url.Parse("/users/jack/orders/1") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) - req.URL.Path = "/api/new users" - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new%20users", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/%%%%" - rec = httptest.NewRecorder() + req.URL, _ = url.Parse("/api/new users") e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) } // Issue #1086 @@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) { r := e.Router() // Rewrite old url to new one - e.Pre(RewriteWithConfig(RewriteConfig{ - Rules: map[string]string{ + e.Pre(Rewrite(map[string]string{ "/old": "/new", }, - })) + )) // Route r.Add(http.MethodGet, "/new", func(c echo.Context) error { diff --git a/middleware/static.go b/middleware/static.go index bc2087a7..ae79cb5f 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -36,6 +36,12 @@ type ( // Enable directory browsing. // Optional. Default value false. Browse bool `yaml:"browse"` + + // Enable ignoring of the base of the URL path. + // Example: when assigning a static middleware to a non root path group, + // the filesystem path is not doubled + // Optional. Default value false. + IgnoreBase bool `yaml:"ignoreBase"` } ) @@ -161,7 +167,16 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { if err != nil { return } - name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security + name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security + + if config.IgnoreBase { + routePath := path.Base(strings.TrimRight(c.Path(), "/*")) + baseURLPath := path.Base(p) + if baseURLPath == routePath { + i := strings.LastIndex(name, routePath) + name = name[:i] + strings.Replace(name[i:], routePath, "", 1) + } + } fi, err := os.Stat(name) if err != nil { diff --git a/middleware/static_test.go b/middleware/static_test.go index 0d695d3d..407dd15c 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "path/filepath" "testing" "github.com/labstack/echo/v4" @@ -67,4 +68,27 @@ func TestStatic(t *testing.T) { assert.Equal(http.StatusOK, rec.Code) assert.Contains(rec.Body.String(), "cert.pem") } + + // IgnoreBase + req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) + rec = httptest.NewRecorder() + config.Root = "../_fixture" + config.IgnoreBase = true + static = StaticWithConfig(config) + c.Echo().Group("_fixture", static) + e.ServeHTTP(rec, req) + + assert.Equal(http.StatusOK, rec.Code) + assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122") + + req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) + rec = httptest.NewRecorder() + config.Root = "../_fixture" + config.IgnoreBase = false + static = StaticWithConfig(config) + c.Echo().Group("_fixture", static) + e.ServeHTTP(rec, req) + + assert.Equal(http.StatusOK, rec.Code) + assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture")) } diff --git a/response.go b/response.go index ca7405c5..84f7c9e7 100644 --- a/response.go +++ b/response.go @@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) { r.echo.Logger.Warn("response already committed") return } + r.Status = code for _, fn := range r.beforeFuncs { fn() } - r.Status = code - r.Writer.WriteHeader(code) + r.Writer.WriteHeader(r.Status) r.Committed = true } diff --git a/response_test.go b/response_test.go index 7a9c51c6..d95e079f 100644 --- a/response_test.go +++ b/response_test.go @@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) { res.Flush() assert.True(t, rec.Flushed) } + +func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + res := &Response{echo: e, Writer: rec} + + res.Before(func() { + if 200 < res.Status && res.Status < 300 { + res.Status = 200 + } + }) + + res.WriteHeader(209) + + assert.Equal(t, http.StatusOK, rec.Code) +} diff --git a/router.go b/router.go index ed728d6a..4c3898c4 100644 --- a/router.go +++ b/router.go @@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) { pos := strings.IndexByte(ns, '/') if pos == -1 { // If no slash is remaining in search string set param value - pvalues[len(cn.pnames)-1] = search + if len(cn.pnames) > 0 { + pvalues[len(cn.pnames)-1] = search + } break } else if pos > 0 { // Otherwise continue route processing with restored next node diff --git a/router_test.go b/router_test.go index d0972720..947e0dad 100644 --- a/router_test.go +++ b/router_test.go @@ -1335,6 +1335,40 @@ func TestRouterFindNotPanicOrLoopsWhenContextSetParamValuesIsCalledWithLessValue assert.Equal(t, 1, c.Get("i")) } +// Issue #1653 +func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1)) + r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error { + return nil + }) + r.Add(http.MethodGet, "/users/:id/active", func(c Context) error { + return nil + }) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/alice/edit", c) + assert.Equal(t, "alice", c.Param("id")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/bob/active", c) + assert.Equal(t, "bob", c.Param("id")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/create", c) + c.Handler()(c) + assert.Equal(t, 1, c.Get("create")) + assert.Equal(t, "/users/create", c.Get("path")) + + //This panic before the fix for Issue #1653 + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/createNotFound", c) + he := c.Handler()(c).(*HTTPError) + assert.Equal(t, http.StatusNotFound, he.Code) +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route) { e := New() r := e.router