1
0
mirror of https://github.com/labstack/echo.git synced 2025-03-29 21:56:53 +02:00
This commit is contained in:
Vadim Sabirov 2020-12-11 12:34:52 +03:00
commit 53b38de143
12 changed files with 278 additions and 18 deletions

12
.github/stale.yml vendored

@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
daysUntilClose: 30
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
- bug
- enhancement
# Label to use when marking an issue as stale
staleLabel: wontfix
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
recent activity. It will be closed within a month if no further activity occurs.
Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false
closeComment: false

@ -572,7 +572,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes {
if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < ln {
if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))

@ -277,10 +277,12 @@ func TestEchoURL(t *testing.T) {
e := New()
static := func(Context) error { return nil }
getUser := func(Context) error { return nil }
getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil }
e.GET("/static/file", static)
e.GET("/users/:id", getUser)
e.GET("/documents/*", getAny)
g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile)
@ -289,6 +291,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
}
@ -652,3 +657,28 @@ func TestEchoShutdown(t *testing.T) {
err := <-errCh
assert.Equal(t, err.Error(), "http: Server closed")
}
func TestEchoReverse(t *testing.T) {
assert := assert.New(t)
e := New()
dummyHandler := func(Context) error { return nil }
e.GET("/static", dummyHandler).Name = "/static"
e.GET("/static/*", dummyHandler).Name = "/static/*"
e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level
}
pool := gzipPool(config)
pool := gzipCompressPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
return http.ErrNotSupported
}
func gzipPool(config GzipConfig) sync.Pool {
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)

@ -3,9 +3,12 @@ package middleware
import (
"bytes"
"compress/gzip"
"github.com/labstack/echo/v4"
"io"
"io/ioutil"
"net/http"
"sync"
"github.com/labstack/echo/v4"
)
type (
@ -13,17 +16,55 @@ type (
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
@ -31,25 +72,46 @@ func Decompress() echo.MiddlewareFunc {
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
gr, err := gzip.NewReader(c.Request().Body)
if err != nil {
b := c.Request().Body
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
defer gr.Close()
var buf bytes.Buffer
io.Copy(&buf, gr)
gr.Close()
pool.Put(gr)
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
r := ioutil.NopCloser(&buf)
defer r.Close()
c.Request().Body = r
}
return next(c)

@ -3,10 +3,12 @@ package middleware
import (
"bytes"
"compress/gzip"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/labstack/echo/v4"
@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) {
assert.Equal(body, string(b))
}
func TestDecompressDefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) {
assert.Equal(t, body, string(reqBody))
}
type TestDecompressPoolWithError struct {
}
func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
return errors.New("pool error")
},
}
}
func TestDecompressPoolError(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &TestDecompressPoolWithError{},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError)
}
func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`

@ -57,6 +57,7 @@ type (
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
TokenLookup string
// AuthScheme to be used in the Authorization header.
@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1])
case "cookie":
extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

@ -3,6 +3,8 @@ package middleware
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/dgrijalva/jwt-go"
@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
}{
{
@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest,
info: "Empty cookie",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
} {
if tc.reqURL == "" {
tc.reqURL = "/"
}
req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)

@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed")
return
}
r.Status = code
for _, fn := range r.beforeFuncs {
fn()
}
r.Status = code
r.Writer.WriteHeader(code)
r.Writer.WriteHeader(r.Status)
r.Committed = true
}

@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush()
assert.True(t, rec.Flushed)
}
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
res.Before(func() {
if 200 < res.Status && res.Status < 300 {
res.Status = 200
}
})
res.WriteHeader(209)
assert.Equal(t, http.StatusOK, rec.Code)
}

@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) {
pos := strings.IndexByte(ns, '/')
if pos == -1 {
// If no slash is remaining in search string set param value
pvalues[len(cn.pnames)-1] = search
if len(cn.pnames) > 0 {
pvalues[len(cn.pnames)-1] = search
}
break
} else if pos > 0 {
// Otherwise continue route processing with restored next node

@ -1298,6 +1298,40 @@ func TestRouterParam1466(t *testing.T) {
assert.Equal(t, 0, c.response.Status)
}
// Issue #1653
func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1))
r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error {
return nil
})
r.Add(http.MethodGet, "/users/:id/active", func(c Context) error {
return nil
})
c := e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/alice/edit", c)
assert.Equal(t, "alice", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/bob/active", c)
assert.Equal(t, "bob", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/create", c)
c.Handler()(c)
assert.Equal(t, 1, c.Get("create"))
assert.Equal(t, "/users/create", c.Get("path"))
//This panic before the fix for Issue #1653
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/createNotFound", c)
he := c.Handler()(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.Code)
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route) {
e := New()
r := e.router