1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-17 01:43:02 +02:00

erge branch 'master' into routing_misses_performance_improvements

This commit is contained in:
Pablo Andres Fuente
2020-12-14 03:36:12 +00:00
19 changed files with 475 additions and 46 deletions

10
.github/stale.yml vendored
View File

@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale # Number of days of inactivity before an issue becomes stale
daysUntilStale: 60 daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed # Number of days of inactivity before a stale issue is closed
daysUntilClose: 7 daysUntilClose: 30
# Issues with these labels will never be considered stale # Issues with these labels will never be considered stale
exemptLabels: exemptLabels:
- pinned - pinned
- security - security
- bug
- enhancement
# Label to use when marking an issue as stale # Label to use when marking an issue as stale
staleLabel: wontfix staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable # Comment to post when marking an issue as stale. Set to `false` to disable
markComment: > markComment: >
This issue has been automatically marked as stale because it has not had This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you recent activity. It will be closed within a month if no further activity occurs.
for your contributions. Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable # Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false closeComment: false

View File

@ -9,6 +9,7 @@ on:
- 'go.*' - 'go.*'
- '_fixture/**' - '_fixture/**'
- '.github/**' - '.github/**'
- 'codecov.yml'
pull_request: pull_request:
branches: branches:
- master - master
@ -17,6 +18,7 @@ on:
- 'go.*' - 'go.*'
- '_fixture/**' - '_fixture/**'
- '.github/**' - '.github/**'
- 'codecov.yml'
jobs: jobs:
test: test:
@ -62,3 +64,55 @@ jobs:
with: with:
token: token:
fail_ci_if_error: false fail_ci_if_error: false
benchmark:
needs: test
strategy:
matrix:
os: [ubuntu-latest]
go: [1.15]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v1
with:
go-version: ${{ matrix.go }}
- name: Set GOPATH and PATH
run: |
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code (Previous)
uses: actions/checkout@v2
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
uses: actions/checkout@v2
with:
path: new
- name: Install Dependencies
run: go get -v golang.org/x/perf/cmd/benchstat
- name: Run Benchmark (Previous)
run: |
cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New)
run: |
cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat
run: |
benchstat previous/benchmark.txt new/benchmark.txt

View File

@ -0,0 +1 @@
This directory is used for the static middleware test

11
codecov.yml Normal file
View File

@ -0,0 +1,11 @@
coverage:
status:
project:
default:
threshold: 1%
patch:
default:
threshold: 1%
comment:
require_changes: true

View File

@ -572,7 +572,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes { for _, r := range e.router.routes {
if r.Name == name { if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ { for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < ln { if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ { for ; i < l && r.Path[i] != '/'; i++ {
} }
uri.WriteString(fmt.Sprintf("%v", params[n])) uri.WriteString(fmt.Sprintf("%v", params[n]))

View File

@ -277,10 +277,12 @@ func TestEchoURL(t *testing.T) {
e := New() e := New()
static := func(Context) error { return nil } static := func(Context) error { return nil }
getUser := func(Context) error { return nil } getUser := func(Context) error { return nil }
getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil } getFile := func(Context) error { return nil }
e.GET("/static/file", static) e.GET("/static/file", static)
e.GET("/users/:id", getUser) e.GET("/users/:id", getUser)
e.GET("/documents/*", getAny)
g := e.Group("/group") g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile) g.GET("/users/:uid/files/:fid", getFile)
@ -289,6 +291,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static)) assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser)) assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1")) assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
} }
@ -652,3 +657,28 @@ func TestEchoShutdown(t *testing.T) {
err := <-errCh err := <-errCh
assert.Equal(t, err.Error(), "http: Server closed") assert.Equal(t, err.Error(), "http: Server closed")
} }
func TestEchoReverse(t *testing.T) {
assert := assert.New(t)
e := New()
dummyHandler := func(Context) error { return nil }
e.GET("/static", dummyHandler).Name = "/static"
e.GET("/static/*", dummyHandler).Name = "/static/*"
e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

View File

@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level config.Level = DefaultGzipConfig.Level
} }
pool := gzipPool(config) pool := gzipCompressPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
return http.ErrNotSupported return http.ErrNotSupported
} }
func gzipPool(config GzipConfig) sync.Pool { func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{ return sync.Pool{
New: func() interface{} { New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)

View File

@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"` AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
@ -113,39 +120,49 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
} }
// Check allowed origins if config.AllowOriginFunc != nil {
for _, o := range config.AllowOrigins { allowed, err := config.AllowOriginFunc(origin)
if o == "*" && config.AllowCredentials { if err != nil {
return err
}
if allowed {
allowOrigin = origin allowOrigin = origin
break
} }
if o == "*" || o == origin { } else {
allowOrigin = o // Check allowed origins
break for _, o := range config.AllowOrigins {
} if o == "*" && config.AllowCredentials {
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin allowOrigin = origin
break break
} }
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
} }
} }

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) {
} }
} }
} }
func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(origin string) (bool, error) {
return true, nil
}
returnFalse := func(origin string) (bool, error) {
return false, nil
}
returnError := func(origin string) (bool, error) {
return true, errors.New("this is a test error")
}
allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
}
const origin = "http://example.com"
e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
if expected {
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}
}
}

View File

@ -3,9 +3,12 @@ package middleware
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"github.com/labstack/echo/v4"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"sync"
"github.com/labstack/echo/v4"
) )
type ( type (
@ -13,17 +16,55 @@ type (
DecompressConfig struct { DecompressConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
} }
) )
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. //GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip" const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var ( var (
//DefaultDecompressConfig defines the config for decompress middleware //DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper} DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
) )
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config //Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc { func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig) return DecompressWithConfig(DefaultDecompressConfig)
@ -31,25 +72,46 @@ func Decompress() echo.MiddlewareFunc {
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config //DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
switch c.Request().Header.Get(echo.HeaderContentEncoding) { switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding: case GZIPEncoding:
gr, err := gzip.NewReader(c.Request().Body) b := c.Request().Body
if err != nil {
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty if err == io.EOF { //ignore if body is empty
return next(c) return next(c)
} }
return err return err
} }
defer gr.Close()
var buf bytes.Buffer var buf bytes.Buffer
io.Copy(&buf, gr) io.Copy(&buf, gr)
gr.Close()
pool.Put(gr)
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
r := ioutil.NopCloser(&buf) r := ioutil.NopCloser(&buf)
defer r.Close()
c.Request().Body = r c.Request().Body = r
} }
return next(c) return next(c)

View File

@ -3,10 +3,12 @@ package middleware
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"errors"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) {
assert.Equal(body, string(b)) assert.Equal(body, string(b))
} }
func TestDecompressDefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New() e := echo.New()
body := `{"name":"echo"}` body := `{"name":"echo"}`
@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) {
assert.Equal(t, body, string(reqBody)) assert.Equal(t, body, string(reqBody))
} }
type TestDecompressPoolWithError struct {
}
func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
return errors.New("pool error")
},
}
}
func TestDecompressPoolError(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &TestDecompressPoolWithError{},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError)
}
func BenchmarkDecompress(b *testing.B) { func BenchmarkDecompress(b *testing.B) {
e := echo.New() e := echo.New()
body := `{"name": "echo"}` body := `{"name": "echo"}`

View File

@ -57,6 +57,7 @@ type (
// - "query:<name>" // - "query:<name>"
// - "param:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// - "form:<name>"
TokenLookup string TokenLookup string
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1]) extractor = jwtFromParam(parts[1])
case "cookie": case "cookie":
extractor = jwtFromCookie(parts[1]) extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil return cookie.Value, nil
} }
} }
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -3,6 +3,8 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings"
"testing" "testing"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty reqURL string // "/" if empty
hdrAuth string hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string info string
}{ }{
{ {
@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Empty cookie", info: "Empty cookie",
}, },
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
} { } {
if tc.reqURL == "" { if tc.reqURL == "" {
tc.reqURL = "/" tc.reqURL = "/"
} }
req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil) var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder() res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie) req.Header.Set(echo.HeaderCookie, tc.hdrCookie)

View File

@ -36,6 +36,12 @@ type (
// Enable directory browsing. // Enable directory browsing.
// Optional. Default value false. // Optional. Default value false.
Browse bool `yaml:"browse"` Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
} }
) )
@ -163,6 +169,15 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
} }
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
if config.IgnoreBase {
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
baseURLPath := path.Base(p)
if baseURLPath == routePath {
i := strings.LastIndex(name, routePath)
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
}
}
fi, err := os.Stat(name) fi, err := os.Stat(name)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), "cert.pem") assert.Contains(rec.Body.String(), "cert.pem")
} }
// IgnoreBase
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = true
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = false
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
} }

View File

@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed") r.echo.Logger.Warn("response already committed")
return return
} }
r.Status = code
for _, fn := range r.beforeFuncs { for _, fn := range r.beforeFuncs {
fn() fn()
} }
r.Status = code r.Writer.WriteHeader(r.Status)
r.Writer.WriteHeader(code)
r.Committed = true r.Committed = true
} }

View File

@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush() res.Flush()
assert.True(t, rec.Flushed) assert.True(t, rec.Flushed)
} }
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
res.Before(func() {
if 200 < res.Status && res.Status < 300 {
res.Status = 200
}
})
res.WriteHeader(209)
assert.Equal(t, http.StatusOK, rec.Code)
}

View File

@ -449,7 +449,9 @@ func (r *Router) Find(method, path string, c Context) {
pos := strings.IndexByte(ns, '/') pos := strings.IndexByte(ns, '/')
if pos == -1 { if pos == -1 {
// If no slash is remaining in search string set param value // If no slash is remaining in search string set param value
pvalues[len(cn.pnames)-1] = search if len(cn.pnames) > 0 {
pvalues[len(cn.pnames)-1] = search
}
break break
} else if pos > 0 { } else if pos > 0 {
// Otherwise continue route processing with restored next node // Otherwise continue route processing with restored next node

View File

@ -1430,6 +1430,40 @@ func TestRouterParam1466(t *testing.T) {
assert.Equal(t, 0, c.response.Status) assert.Equal(t, 0, c.response.Status)
} }
// Issue #1653
func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1))
r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error {
return nil
})
r.Add(http.MethodGet, "/users/:id/active", func(c Context) error {
return nil
})
c := e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/alice/edit", c)
assert.Equal(t, "alice", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/bob/active", c)
assert.Equal(t, "bob", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/create", c)
c.Handler()(c)
assert.Equal(t, 1, c.Get("create"))
assert.Equal(t, "/users/create", c.Get("path"))
//This panic before the fix for Issue #1653
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/createNotFound", c)
he := c.Handler()(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.Code)
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) {
e := New() e := New()
r := e.router r := e.router