1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00
This commit is contained in:
Vadim Sabirov 2020-12-11 12:34:52 +03:00
commit 53b38de143
12 changed files with 278 additions and 18 deletions

12
.github/stale.yml vendored
View File

@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
daysUntilClose: 30
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
- bug
- enhancement
# Label to use when marking an issue as stale
staleLabel: wontfix
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
recent activity. It will be closed within a month if no further activity occurs.
Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false
closeComment: false

View File

@ -572,7 +572,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes {
if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < ln {
if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))

View File

@ -277,10 +277,12 @@ func TestEchoURL(t *testing.T) {
e := New()
static := func(Context) error { return nil }
getUser := func(Context) error { return nil }
getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil }
e.GET("/static/file", static)
e.GET("/users/:id", getUser)
e.GET("/documents/*", getAny)
g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile)
@ -289,6 +291,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
}
@ -652,3 +657,28 @@ func TestEchoShutdown(t *testing.T) {
err := <-errCh
assert.Equal(t, err.Error(), "http: Server closed")
}
func TestEchoReverse(t *testing.T) {
assert := assert.New(t)
e := New()
dummyHandler := func(Context) error { return nil }
e.GET("/static", dummyHandler).Name = "/static"
e.GET("/static/*", dummyHandler).Name = "/static/*"
e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

View File

@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level
}
pool := gzipPool(config)
pool := gzipCompressPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
return http.ErrNotSupported
}
func gzipPool(config GzipConfig) sync.Pool {
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)

View File

@ -3,9 +3,12 @@ package middleware
import (
"bytes"
"compress/gzip"
"github.com/labstack/echo/v4"
"io"
"io/ioutil"
"net/http"
"sync"
"github.com/labstack/echo/v4"
)
type (
@ -13,17 +16,55 @@ type (
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
@ -31,25 +72,46 @@ func Decompress() echo.MiddlewareFunc {
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
gr, err := gzip.NewReader(c.Request().Body)
if err != nil {
b := c.Request().Body
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
defer gr.Close()
var buf bytes.Buffer
io.Copy(&buf, gr)
gr.Close()
pool.Put(gr)
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
r := ioutil.NopCloser(&buf)
defer r.Close()
c.Request().Body = r
}
return next(c)

View File

@ -3,10 +3,12 @@ package middleware
import (
"bytes"
"compress/gzip"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/labstack/echo/v4"
@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) {
assert.Equal(body, string(b))
}
func TestDecompressDefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) {
assert.Equal(t, body, string(reqBody))
}
type TestDecompressPoolWithError struct {
}
func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
return errors.New("pool error")
},
}
}
func TestDecompressPoolError(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &TestDecompressPoolWithError{},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError)
}
func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`

View File

@ -57,6 +57,7 @@ type (
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
TokenLookup string
// AuthScheme to be used in the Authorization header.
@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1])
case "cookie":
extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -3,6 +3,8 @@ package middleware
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/dgrijalva/jwt-go"
@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
}{
{
@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest,
info: "Empty cookie",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
} {
if tc.reqURL == "" {
tc.reqURL = "/"
}
req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)

View File

@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed")
return
}
r.Status = code
for _, fn := range r.beforeFuncs {
fn()
}
r.Status = code
r.Writer.WriteHeader(code)
r.Writer.WriteHeader(r.Status)
r.Committed = true
}

View File

@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush()
assert.True(t, rec.Flushed)
}
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
res.Before(func() {
if 200 < res.Status && res.Status < 300 {
res.Status = 200
}
})
res.WriteHeader(209)
assert.Equal(t, http.StatusOK, rec.Code)
}

View File

@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) {
pos := strings.IndexByte(ns, '/')
if pos == -1 {
// If no slash is remaining in search string set param value
pvalues[len(cn.pnames)-1] = search
if len(cn.pnames) > 0 {
pvalues[len(cn.pnames)-1] = search
}
break
} else if pos > 0 {
// Otherwise continue route processing with restored next node

View File

@ -1298,6 +1298,40 @@ func TestRouterParam1466(t *testing.T) {
assert.Equal(t, 0, c.response.Status)
}
// Issue #1653
func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1))
r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error {
return nil
})
r.Add(http.MethodGet, "/users/:id/active", func(c Context) error {
return nil
})
c := e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/alice/edit", c)
assert.Equal(t, "alice", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/bob/active", c)
assert.Equal(t, "bob", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/create", c)
c.Handler()(c)
assert.Equal(t, 1, c.Get("create"))
assert.Equal(t, "/users/create", c.Get("path"))
//This panic before the fix for Issue #1653
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/createNotFound", c)
he := c.Handler()(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.Code)
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route) {
e := New()
r := e.router