diff --git a/.github/stale.yml b/.github/stale.yml index d9f65632..04dd169c 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -1,17 +1,19 @@ # Number of days of inactivity before an issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 +daysUntilClose: 30 # Issues with these labels will never be considered stale exemptLabels: - pinned - security + - bug + - enhancement # Label to use when marking an issue as stale -staleLabel: wontfix +staleLabel: stale # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. + recent activity. It will be closed within a month if no further activity occurs. + Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable -closeComment: false \ No newline at end of file +closeComment: false diff --git a/echo.go b/echo.go index 29b88b70..c4399801 100644 --- a/echo.go +++ b/echo.go @@ -572,7 +572,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string { for _, r := range e.router.routes { if r.Name == name { for i, l := 0, len(r.Path); i < l; i++ { - if r.Path[i] == ':' && n < ln { + if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { for ; i < l && r.Path[i] != '/'; i++ { } uri.WriteString(fmt.Sprintf("%v", params[n])) diff --git a/echo_test.go b/echo_test.go index 0368dbd7..71dc1ac0 100644 --- a/echo_test.go +++ b/echo_test.go @@ -277,10 +277,12 @@ func TestEchoURL(t *testing.T) { e := New() static := func(Context) error { return nil } getUser := func(Context) error { return nil } + getAny := func(Context) error { return nil } getFile := func(Context) error { return nil } e.GET("/static/file", static) e.GET("/users/:id", getUser) + e.GET("/documents/*", getAny) g := e.Group("/group") g.GET("/users/:uid/files/:fid", getFile) @@ -289,6 +291,9 @@ func TestEchoURL(t *testing.T) { assert.Equal("/static/file", e.URL(static)) assert.Equal("/users/:id", e.URL(getUser)) assert.Equal("/users/1", e.URL(getUser, "1")) + assert.Equal("/users/1", e.URL(getUser, "1")) + assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) + assert.Equal("/documents/*", e.URL(getAny)) assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) } @@ -652,3 +657,28 @@ func TestEchoShutdown(t *testing.T) { err := <-errCh assert.Equal(t, err.Error(), "http: Server closed") } + +func TestEchoReverse(t *testing.T) { + assert := assert.New(t) + + e := New() + dummyHandler := func(Context) error { return nil } + + e.GET("/static", dummyHandler).Name = "/static" + e.GET("/static/*", dummyHandler).Name = "/static/*" + e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" + e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" + e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" + + assert.Equal("/static", e.Reverse("/static")) + assert.Equal("/static", e.Reverse("/static", "missing param")) + assert.Equal("/static/*", e.Reverse("/static/*")) + assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) + + assert.Equal("/params/:foo", e.Reverse("/params/:foo")) + assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) + assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) + assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +} diff --git a/middleware/compress.go b/middleware/compress.go index e4f9fc51..6ae19745 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { config.Level = DefaultGzipConfig.Level } - pool := gzipPool(config) + pool := gzipCompressPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { return http.ErrNotSupported } -func gzipPool(config GzipConfig) sync.Pool { +func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ New: func() interface{} { w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) diff --git a/middleware/decompress.go b/middleware/decompress.go index 99eaf066..c046359a 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -3,9 +3,12 @@ package middleware import ( "bytes" "compress/gzip" - "github.com/labstack/echo/v4" "io" "io/ioutil" + "net/http" + "sync" + + "github.com/labstack/echo/v4" ) type ( @@ -13,17 +16,55 @@ type ( DecompressConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor } ) //GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" +// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers +type Decompressor interface { + gzipDecompressPool() sync.Pool +} + var ( //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper} + DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, + } ) +// DefaultGzipDecompressPool is the default implementation of Decompressor interface +type DefaultGzipDecompressPool struct { +} + +func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + // create with an empty reader (but with GZIP header) + w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) + if err != nil { + return err + } + + b := new(bytes.Buffer) + w.Reset(b) + w.Flush() + w.Close() + + r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) + if err != nil { + return err + } + return r + }, + } +} + //Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { return DecompressWithConfig(DefaultDecompressConfig) @@ -31,25 +72,46 @@ func Decompress() echo.MiddlewareFunc { //DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultGzipConfig.Skipper + } + if config.GzipDecompressPool == nil { + config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + pool := config.GzipDecompressPool.gzipDecompressPool() return func(c echo.Context) error { if config.Skipper(c) { return next(c) } switch c.Request().Header.Get(echo.HeaderContentEncoding) { case GZIPEncoding: - gr, err := gzip.NewReader(c.Request().Body) - if err != nil { + b := c.Request().Body + + i := pool.Get() + gr, ok := i.(*gzip.Reader) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + } + + if err := gr.Reset(b); err != nil { + pool.Put(gr) if err == io.EOF { //ignore if body is empty return next(c) } return err } - defer gr.Close() var buf bytes.Buffer io.Copy(&buf, gr) + + gr.Close() + pool.Put(gr) + + b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here + r := ioutil.NopCloser(&buf) - defer r.Close() c.Request().Body = r } return next(c) diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 772c14f6..51fa6b0f 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -3,10 +3,12 @@ package middleware import ( "bytes" "compress/gzip" + "errors" "io/ioutil" "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/labstack/echo/v4" @@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) { assert.Equal(body, string(b)) } +func TestDecompressDefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e := echo.New() body := `{"name":"echo"}` @@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) { assert.Equal(t, body, string(reqBody)) } +type TestDecompressPoolWithError struct { +} + +func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + return errors.New("pool error") + }, + } +} + +func TestDecompressPoolError(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &TestDecompressPoolWithError{}, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) + assert.Equal(t, rec.Code, http.StatusInternalServerError) +} + func BenchmarkDecompress(b *testing.B) { e := echo.New() body := `{"name": "echo"}` diff --git a/middleware/jwt.go b/middleware/jwt.go index bab00c9f..da00ea56 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -57,6 +57,7 @@ type ( // - "query:" // - "param:" // - "cookie:" + // - "form:" TokenLookup string // AuthScheme to be used in the Authorization header. @@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { extractor = jwtFromParam(parts[1]) case "cookie": extractor = jwtFromCookie(parts[1]) + case "form": + extractor = jwtFromForm(parts[1]) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor { return cookie.Value, nil } } + +// jwtFromForm returns a `jwtExtractor` that extracts token from the form field. +func jwtFromForm(name string) jwtExtractor { + return func(c echo.Context) (string, error) { + field := c.FormValue(name) + if field == "" { + return "", ErrJWTMissing + } + return field, nil + } +} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index ce44f9c9..205721ae 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -3,6 +3,8 @@ package middleware import ( "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "github.com/dgrijalva/jwt-go" @@ -75,6 +77,7 @@ func TestJWT(t *testing.T) { reqURL string // "/" if empty hdrAuth string hdrCookie string // test.Request doesn't provide SetCookie(); use name=val + formValues map[string]string info string }{ { @@ -192,12 +195,48 @@ func TestJWT(t *testing.T) { expErrCode: http.StatusBadRequest, info: "Empty cookie", }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + formValues: map[string]string{"jwt": token}, + info: "Valid form method", + }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + expErrCode: http.StatusUnauthorized, + formValues: map[string]string{"jwt": "invalid"}, + info: "Invalid token with form method", + }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + expErrCode: http.StatusBadRequest, + info: "Empty form field", + }, } { if tc.reqURL == "" { tc.reqURL = "/" } - req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + var req *http.Request + if len(tc.formValues) > 0 { + form := url.Values{} + for k, v := range tc.formValues { + form.Set(k, v) + } + req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) + req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") + req.ParseForm() + } else { + req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + } res := httptest.NewRecorder() req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderCookie, tc.hdrCookie) diff --git a/response.go b/response.go index ca7405c5..84f7c9e7 100644 --- a/response.go +++ b/response.go @@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) { r.echo.Logger.Warn("response already committed") return } + r.Status = code for _, fn := range r.beforeFuncs { fn() } - r.Status = code - r.Writer.WriteHeader(code) + r.Writer.WriteHeader(r.Status) r.Committed = true } diff --git a/response_test.go b/response_test.go index 7a9c51c6..d95e079f 100644 --- a/response_test.go +++ b/response_test.go @@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) { res.Flush() assert.True(t, rec.Flushed) } + +func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + res := &Response{echo: e, Writer: rec} + + res.Before(func() { + if 200 < res.Status && res.Status < 300 { + res.Status = 200 + } + }) + + res.WriteHeader(209) + + assert.Equal(t, http.StatusOK, rec.Code) +} diff --git a/router.go b/router.go index ed728d6a..4c3898c4 100644 --- a/router.go +++ b/router.go @@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) { pos := strings.IndexByte(ns, '/') if pos == -1 { // If no slash is remaining in search string set param value - pvalues[len(cn.pnames)-1] = search + if len(cn.pnames) > 0 { + pvalues[len(cn.pnames)-1] = search + } break } else if pos > 0 { // Otherwise continue route processing with restored next node diff --git a/router_test.go b/router_test.go index 0e883233..fca3a79b 100644 --- a/router_test.go +++ b/router_test.go @@ -1298,6 +1298,40 @@ func TestRouterParam1466(t *testing.T) { assert.Equal(t, 0, c.response.Status) } +// Issue #1653 +func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1)) + r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error { + return nil + }) + r.Add(http.MethodGet, "/users/:id/active", func(c Context) error { + return nil + }) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/alice/edit", c) + assert.Equal(t, "alice", c.Param("id")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/bob/active", c) + assert.Equal(t, "bob", c.Param("id")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/create", c) + c.Handler()(c) + assert.Equal(t, 1, c.Get("create")) + assert.Equal(t, "/users/create", c.Get("path")) + + //This panic before the fix for Issue #1653 + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/createNotFound", c) + he := c.Handler()(c).(*HTTPError) + assert.Equal(t, http.StatusNotFound, he.Code) +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route) { e := New() r := e.router