1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-17 01:43:02 +02:00

Merge branch 'master' into listener_network_configurable

This commit is contained in:
Pablo Andres Fuente
2020-12-10 04:10:13 +00:00
25 changed files with 851 additions and 130 deletions

10
.github/stale.yml vendored
View File

@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale # Number of days of inactivity before an issue becomes stale
daysUntilStale: 60 daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed # Number of days of inactivity before a stale issue is closed
daysUntilClose: 7 daysUntilClose: 30
# Issues with these labels will never be considered stale # Issues with these labels will never be considered stale
exemptLabels: exemptLabels:
- pinned - pinned
- security - security
- bug
- enhancement
# Label to use when marking an issue as stale # Label to use when marking an issue as stale
staleLabel: wontfix staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable # Comment to post when marking an issue as stale. Set to `false` to disable
markComment: > markComment: >
This issue has been automatically marked as stale because it has not had This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you recent activity. It will be closed within a month if no further activity occurs.
for your contributions. Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable # Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false closeComment: false

View File

@ -4,20 +4,28 @@ on:
push: push:
branches: branches:
- master - master
paths:
- '**.go'
- 'go.*'
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
pull_request: pull_request:
branches: branches:
- master - master
paths:
env: - '**.go'
GO111MODULE: on - 'go.*'
GOPROXY: https://proxy.golang.org - '_fixture/**'
- '.github/**'
- 'codecov.yml'
jobs: jobs:
test: test:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
go: [1.12, 1.13, 1.14] go: [1.12, 1.13, 1.14, 1.15]
name: ${{ matrix.os }} @ Go ${{ matrix.go }} name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@ -28,10 +36,15 @@ jobs:
- name: Set GOPATH and PATH - name: Set GOPATH and PATH
run: | run: |
echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)" echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin" echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v1 uses: actions/checkout@v1
with: with:
@ -51,3 +64,55 @@ jobs:
with: with:
token: token:
fail_ci_if_error: false fail_ci_if_error: false
benchmark:
needs: test
strategy:
matrix:
os: [ubuntu-latest]
go: [1.15]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v1
with:
go-version: ${{ matrix.go }}
- name: Set GOPATH and PATH
run: |
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code (Previous)
uses: actions/checkout@v2
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
uses: actions/checkout@v2
with:
path: new
- name: Install Dependencies
run: go get -v golang.org/x/perf/cmd/benchstat
- name: Run Benchmark (Previous)
run: |
cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New)
run: |
cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat
run: |
benchstat previous/benchmark.txt new/benchmark.txt

View File

@ -1,3 +1,7 @@
arch:
- amd64
- ppc64le
language: go language: go
go: go:
- 1.14.x - 1.14.x

View File

@ -0,0 +1 @@
This directory is used for the static middleware test

11
codecov.yml Normal file
View File

@ -0,0 +1,11 @@
coverage:
status:
project:
default:
threshold: 1%
patch:
default:
threshold: 1%
comment:
require_changes: true

View File

@ -365,7 +365,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close() f.Close()
return fh, nil return fh, nil
} }

View File

@ -365,11 +365,13 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
// Issue #1426 // Issue #1426
code := he.Code code := he.Code
message := he.Message message := he.Message
if m, ok := he.Message.(string); ok {
if e.Debug { if e.Debug {
message = err.Error() message = Map{"message": m, "error": err.Error()}
} else if m, ok := message.(string); ok { } else {
message = Map{"message": m} message = Map{"message": m}
} }
}
// Send response // Send response
if !c.Response().Committed { if !c.Response().Committed {
@ -573,7 +575,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes { for _, r := range e.router.routes {
if r.Name == name { if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ { for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < ln { if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ { for ; i < l && r.Path[i] != '/'; i++ {
} }
uri.WriteString(fmt.Sprintf("%v", params[n])) uri.WriteString(fmt.Sprintf("%v", params[n]))

View File

@ -278,10 +278,12 @@ func TestEchoURL(t *testing.T) {
e := New() e := New()
static := func(Context) error { return nil } static := func(Context) error { return nil }
getUser := func(Context) error { return nil } getUser := func(Context) error { return nil }
getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil } getFile := func(Context) error { return nil }
e.GET("/static/file", static) e.GET("/static/file", static)
e.GET("/users/:id", getUser) e.GET("/users/:id", getUser)
e.GET("/documents/*", getAny)
g := e.Group("/group") g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile) g.GET("/users/:uid/files/:fid", getFile)
@ -290,6 +292,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static)) assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser)) assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1")) assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
} }
@ -569,6 +574,49 @@ func TestHTTPError(t *testing.T) {
}) })
} }
func TestDefaultHTTPErrorHandler(t *testing.T) {
e := New()
e.Debug = true
e.Any("/plain", func(c Context) error {
return errors.New("An error occurred")
})
e.Any("/badrequest", func(c Context) error {
return NewHTTPError(http.StatusBadRequest, "Invalid request")
})
e.Any("/servererror", func(c Context) error {
return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{
"code": 33,
"message": "Something bad happened",
"error": "stackinfo",
})
})
// With Debug=true plain response contains error message
c, b := request(http.MethodGet, "/plain", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b)
// and special handling for HTTPError
c, b = request(http.MethodGet, "/badrequest", e)
assert.Equal(t, http.StatusBadRequest, c)
assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b)
// complex errors are serialized to pretty JSON
c, b = request(http.MethodGet, "/servererror", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b)
e.Debug = false
// With Debug=false the error response is shortened
c, b = request(http.MethodGet, "/plain", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b)
c, b = request(http.MethodGet, "/badrequest", e)
assert.Equal(t, http.StatusBadRequest, c)
assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b)
// No difference for error response with non plain string errors
c, b = request(http.MethodGet, "/servererror", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b)
}
func TestEchoClose(t *testing.T) { func TestEchoClose(t *testing.T) {
e := New() e := New()
errCh := make(chan error) errCh := make(chan error)
@ -673,3 +721,28 @@ func TestEchoListenerNetworkInvalid(t *testing.T) {
assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
} }
func TestEchoReverse(t *testing.T) {
assert := assert.New(t)
e := New()
dummyHandler := func(Context) error { return nil }
e.GET("/static", dummyHandler).Name = "/static"
e.GET("/static/*", dummyHandler).Name = "/static/*"
e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

View File

@ -8,6 +8,7 @@ import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
"sync"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level config.Level = DefaultGzipConfig.Level
} }
pool := gzipPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
rw := res.Writer i := pool.Get()
w, err := gzip.NewWriterLevel(rw, config.Level) w, ok := i.(*gzip.Writer)
if err != nil { if !ok {
return err return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
} }
rw := res.Writer
w.Reset(rw)
defer func() { defer func() {
if res.Size == 0 { if res.Size == 0 {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
w.Reset(ioutil.Discard) w.Reset(ioutil.Discard)
} }
w.Close() w.Close()
pool.Put(w)
}() }()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
res.Writer = grw res.Writer = grw
@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
} }
return http.ErrNotSupported return http.ErrNotSupported
} }
func gzipPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}

View File

@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
} }
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
e := echo.New()
// Invalid level
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
}
// Issue #806 // Issue #806
func TestGzipWithStatic(t *testing.T) { func TestGzipWithStatic(t *testing.T) {
e := echo.New() e := echo.New()
@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) {
} }
} }
} }
func BenchmarkGzip(b *testing.B) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Gzip
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}

View File

@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"` AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
@ -102,6 +109,26 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin) origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := "" allowOrigin := ""
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided
if origin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins // Check allowed origins
for _, o := range config.AllowOrigins { for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials { if o == "*" && config.AllowCredentials {
@ -137,10 +164,18 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
} }
} }
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Simple request // Simple request
if req.Method != http.MethodOptions { if !preflight {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials { if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
@ -152,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
// Preflight request // Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -17,19 +18,31 @@ func TestCORS(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler) h := CORS()(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard AllowedOrigin with no Origin header in request
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
// Allow origins // Allow origins
req = httptest.NewRequest(http.MethodGet, "/", nil) req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = CORSWithConfig(CORSConfig{ h = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"}, AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler) })(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
// Preflight request // Preflight request
req = httptest.NewRequest(http.MethodOptions, "/", nil) req = httptest.NewRequest(http.MethodOptions, "/", nil)
@ -67,6 +80,22 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
// Preflight request with Access-Control-Request-Headers
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
// Preflight request with `AllowOrigins` which allow all subdomains with * // Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil) req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
@ -126,7 +155,7 @@ func Test_allowOriginScheme(t *testing.T) {
if tt.expected { if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else { } else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
} }
} }
} }
@ -216,6 +245,163 @@ func Test_allowOriginSubdomain(t *testing.T) {
if tt.expected { if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
func TestCorsHeaders(t *testing.T) {
tests := []struct {
domain, allowedOrigin, method string
expected bool
}{
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
},
}
e := echo.New()
for _, tt := range tests {
req := httptest.NewRequest(tt.method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tt.domain != "" {
req.Header.Set(echo.HeaderOrigin, tt.domain)
}
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
})
h := cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
expectedAllowOrigin := ""
if tt.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tt.domain
}
switch {
case tt.expected && tt.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tt.expected && tt.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
}
}
func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(origin string) (bool, error) {
return true, nil
}
returnFalse := func(origin string) (bool, error) {
return false, nil
}
returnError := func(origin string) (bool, error) {
return true, errors.New("this is a test error")
}
allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
}
const origin = "http://example.com"
e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
if expected {
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else { } else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} }

58
middleware/decompress.go Normal file
View File

@ -0,0 +1,58 @@
package middleware
import (
"bytes"
"compress/gzip"
"github.com/labstack/echo/v4"
"io"
"io/ioutil"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
)
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
gr, err := gzip.NewReader(c.Request().Body)
if err != nil {
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
defer gr.Close()
var buf bytes.Buffer
io.Copy(&buf, gr)
r := ioutil.NopCloser(&buf)
defer r.Close()
c.Request().Body = r
}
return next(c)
}
}
}

View File

@ -0,0 +1,148 @@
package middleware
import (
"bytes"
"compress/gzip"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestDecompress(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Skip if no Content-Encoding header
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.NotEqual(t, b, body)
assert.Equal(t, b, gz)
}
func TestDecompressNoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func TestDecompressErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Decompress())
e.GET("/", func(c echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestDecompressSkipper(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: func(c echo.Context) bool {
return c.Request().URL.Path == "/skip"
},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
}
func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte(body)) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Decompress
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}
func gzipString(body string) ([]byte, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte(body))
if err != nil {
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@ -57,6 +57,7 @@ type (
// - "query:<name>" // - "query:<name>"
// - "param:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// - "form:<name>"
TokenLookup string TokenLookup string
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1]) extractor = jwtFromParam(parts[1])
case "cookie": case "cookie":
extractor = jwtFromCookie(parts[1]) extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil return cookie.Value, nil
} }
} }
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -3,6 +3,8 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings"
"testing" "testing"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty reqURL string // "/" if empty
hdrAuth string hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string info string
}{ }{
{ {
@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Empty cookie", info: "Empty cookie",
}, },
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
} { } {
if tc.reqURL == "" { if tc.reqURL == "" {
tc.reqURL = "/" tc.reqURL = "/"
} }
req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil) var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder() res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie) req.Header.Set(echo.HeaderCookie, tc.hdrCookie)

View File

@ -2,7 +2,6 @@ package middleware
import ( import (
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
return strings.NewReplacer(replace...) return strings.NewReplacer(replace...)
} }
//rewritePath sets request url path and raw path func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error { // Initialize
replacerRawPath := replacer.Replace(target) rulesRegex := map[*regexp.Regexp]string{}
replacerPath, err := url.PathUnescape(replacerRawPath) for k, v := range rewrite {
if err != nil { k = regexp.QuoteMeta(k)
return err k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v
}
return rulesRegex
}
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
for k, v := range rewriteRegex {
replacerRawPath := captureTokens(k, req.URL.EscapedPath())
if replacerRawPath != nil {
replacerPath := captureTokens(k, req.URL.Path)
req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v)
}
} }
req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath
return nil
} }
// DefaultSkipper returns false which processes the middleware. // DefaultSkipper returns false which processes the middleware.

View File

@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"regexp" "regexp"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
if config.Balancer == nil { if config.Balancer == nil {
panic("echo: proxy middleware requires balancer") panic("echo: proxy middleware requires balancer")
} }
config.rewriteRegex = map[*regexp.Regexp]string{}
// Initialize config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
for k, v := range config.Rewrite {
k = strings.Replace(k, "*", "(\\S*)", -1)
config.rewriteRegex[regexp.MustCompile(k)] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) (err error) {
@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
tgt := config.Balancer.Next(c) tgt := config.Balancer.Next(c)
c.Set(config.ContextKey, tgt) c.Set(config.ContextKey, tgt)
// Rewrite // Set rewrite path and raw path
for k, v := range config.rewriteRegex { rewritePath(config.rewriteRegex, req)
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
}
}
// Fix header // Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
} }
} }
} }

View File

@ -94,36 +94,35 @@ func TestProxy(t *testing.T) {
"/users/*/orders/*": "/user/$1/order/$2", "/users/*/orders/*": "/user/$1/order/$2",
}, },
})) }))
req.URL.Path = "/api/users" req.URL, _ = url.Parse("/api/users")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, "/users", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/js/main.js" req.URL, _ = url.Parse( "/js/main.js")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/old" req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, "/new", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jack/orders/1" req.URL, _ = url.Parse( "/users/jack/orders/1")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/%%%%" req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, "/new%20users", req.URL.EscapedPath())
// ModifyResponse // ModifyResponse
e = echo.New() e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{ e.Use(ProxyWithConfig(ProxyConfig{

View File

@ -1,11 +1,8 @@
package middleware package middleware
import ( import (
"net/http"
"regexp"
"strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"regexp"
) )
type ( type (
@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper config.Skipper = DefaultBodyDumpConfig.Skipper
} }
config.rulesRegex = map[*regexp.Regexp]string{}
// Initialize config.rulesRegex = rewriteRulesRegex(config.Rules)
for k, v := range config.Rules {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
config.rulesRegex[regexp.MustCompile(k)] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) (err error) {
@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
} }
req := c.Request() req := c.Request()
// Rewrite // Set rewrite path and raw path
for k, v := range config.rulesRegex { rewritePath(config.rulesRegex, req)
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
break
}
}
return next(c) return next(c)
} }
} }

View File

@ -4,6 +4,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) {
})) }))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req.URL.Path = "/api/users" req.URL, _ = url.Parse("/api/users")
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, "/users", req.URL.EscapedPath())
req.URL.Path = "/js/main.js" req.URL, _ = url.Parse("/js/main.js")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
req.URL.Path = "/old" req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, "/new", req.URL.EscapedPath())
req.URL.Path = "/users/jack/orders/1" req.URL, _ = url.Parse("/users/jack/orders/1")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
req.URL.Path = "/api/new users" req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/%%%%" req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, "/new%20users", req.URL.EscapedPath())
} }
// Issue #1086 // Issue #1086
@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
r := e.Router() r := e.Router()
// Rewrite old url to new one // Rewrite old url to new one
e.Pre(RewriteWithConfig(RewriteConfig{ e.Pre(Rewrite(map[string]string{
Rules: map[string]string{
"/old": "/new", "/old": "/new",
}, },
})) ))
// Route // Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error { r.Add(http.MethodGet, "/new", func(c echo.Context) error {

View File

@ -36,6 +36,12 @@ type (
// Enable directory browsing. // Enable directory browsing.
// Optional. Default value false. // Optional. Default value false.
Browse bool `yaml:"browse"` Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
} }
) )
@ -163,6 +169,15 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
} }
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
if config.IgnoreBase {
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
baseURLPath := path.Base(p)
if baseURLPath == routePath {
i := strings.LastIndex(name, routePath)
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
}
}
fi, err := os.Stat(name) fi, err := os.Stat(name)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), "cert.pem") assert.Contains(rec.Body.String(), "cert.pem")
} }
// IgnoreBase
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = true
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = false
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
} }

View File

@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed") r.echo.Logger.Warn("response already committed")
return return
} }
r.Status = code
for _, fn := range r.beforeFuncs { for _, fn := range r.beforeFuncs {
fn() fn()
} }
r.Status = code r.Writer.WriteHeader(r.Status)
r.Writer.WriteHeader(code)
r.Committed = true r.Committed = true
} }

View File

@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush() res.Flush()
assert.True(t, rec.Flushed) assert.True(t, rec.Flushed)
} }
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
res.Before(func() {
if 200 < res.Status && res.Status < 300 {
res.Status = 200
}
})
res.WriteHeader(209)
assert.Equal(t, http.StatusOK, rec.Code)
}