1
0
mirror of https://github.com/labstack/echo.git synced 2025-03-17 21:08:05 +02:00

Merge branch 'master' into listener_network_configurable

This commit is contained in:
Pablo Andres Fuente 2020-12-10 04:10:13 +00:00
commit 78fe2224b6
25 changed files with 851 additions and 130 deletions

12
.github/stale.yml vendored
View File

@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
daysUntilClose: 30
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
- bug
- enhancement
# Label to use when marking an issue as stale
staleLabel: wontfix
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
recent activity. It will be closed within a month if no further activity occurs.
Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false
closeComment: false

View File

@ -4,20 +4,28 @@ on:
push:
branches:
- master
paths:
- '**.go'
- 'go.*'
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
pull_request:
branches:
- master
env:
GO111MODULE: on
GOPROXY: https://proxy.golang.org
paths:
- '**.go'
- 'go.*'
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
go: [1.12, 1.13, 1.14]
go: [1.12, 1.13, 1.14, 1.15]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
@ -28,10 +36,15 @@ jobs:
- name: Set GOPATH and PATH
run: |
echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)"
echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin"
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code
uses: actions/checkout@v1
with:
@ -51,3 +64,55 @@ jobs:
with:
token:
fail_ci_if_error: false
benchmark:
needs: test
strategy:
matrix:
os: [ubuntu-latest]
go: [1.15]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v1
with:
go-version: ${{ matrix.go }}
- name: Set GOPATH and PATH
run: |
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code (Previous)
uses: actions/checkout@v2
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
uses: actions/checkout@v2
with:
path: new
- name: Install Dependencies
run: go get -v golang.org/x/perf/cmd/benchstat
- name: Run Benchmark (Previous)
run: |
cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New)
run: |
cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat
run: |
benchstat previous/benchmark.txt new/benchmark.txt

View File

@ -1,3 +1,7 @@
arch:
- amd64
- ppc64le
language: go
go:
- 1.14.x

View File

@ -0,0 +1 @@
This directory is used for the static middleware test

11
codecov.yml Normal file
View File

@ -0,0 +1,11 @@
coverage:
status:
project:
default:
threshold: 1%
patch:
default:
threshold: 1%
comment:
require_changes: true

View File

@ -365,7 +365,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
if err != nil {
return nil, err
}
defer f.Close()
f.Close()
return fh, nil
}

12
echo.go
View File

@ -365,10 +365,12 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
// Issue #1426
code := he.Code
message := he.Message
if e.Debug {
message = err.Error()
} else if m, ok := message.(string); ok {
message = Map{"message": m}
if m, ok := he.Message.(string); ok {
if e.Debug {
message = Map{"message": m, "error": err.Error()}
} else {
message = Map{"message": m}
}
}
// Send response
@ -573,7 +575,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes {
if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < ln {
if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))

View File

@ -278,10 +278,12 @@ func TestEchoURL(t *testing.T) {
e := New()
static := func(Context) error { return nil }
getUser := func(Context) error { return nil }
getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil }
e.GET("/static/file", static)
e.GET("/users/:id", getUser)
e.GET("/documents/*", getAny)
g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile)
@ -290,6 +292,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
}
@ -569,6 +574,49 @@ func TestHTTPError(t *testing.T) {
})
}
func TestDefaultHTTPErrorHandler(t *testing.T) {
e := New()
e.Debug = true
e.Any("/plain", func(c Context) error {
return errors.New("An error occurred")
})
e.Any("/badrequest", func(c Context) error {
return NewHTTPError(http.StatusBadRequest, "Invalid request")
})
e.Any("/servererror", func(c Context) error {
return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{
"code": 33,
"message": "Something bad happened",
"error": "stackinfo",
})
})
// With Debug=true plain response contains error message
c, b := request(http.MethodGet, "/plain", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b)
// and special handling for HTTPError
c, b = request(http.MethodGet, "/badrequest", e)
assert.Equal(t, http.StatusBadRequest, c)
assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b)
// complex errors are serialized to pretty JSON
c, b = request(http.MethodGet, "/servererror", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b)
e.Debug = false
// With Debug=false the error response is shortened
c, b = request(http.MethodGet, "/plain", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b)
c, b = request(http.MethodGet, "/badrequest", e)
assert.Equal(t, http.StatusBadRequest, c)
assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b)
// No difference for error response with non plain string errors
c, b = request(http.MethodGet, "/servererror", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b)
}
func TestEchoClose(t *testing.T) {
e := New()
errCh := make(chan error)
@ -673,3 +721,28 @@ func TestEchoListenerNetworkInvalid(t *testing.T) {
assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
}
func TestEchoReverse(t *testing.T) {
assert := assert.New(t)
e := New()
dummyHandler := func(Context) error { return nil }
e.GET("/static", dummyHandler).Name = "/static"
e.GET("/static/*", dummyHandler).Name = "/static/*"
e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

View File

@ -8,6 +8,7 @@ import (
"net"
"net/http"
"strings"
"sync"
"github.com/labstack/echo/v4"
)
@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level
}
pool := gzipPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
rw := res.Writer
w, err := gzip.NewWriterLevel(rw, config.Level)
if err != nil {
return err
i := pool.Get()
w, ok := i.(*gzip.Writer)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
rw := res.Writer
w.Reset(rw)
defer func() {
if res.Size == 0 {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
w.Reset(ioutil.Discard)
}
w.Close()
pool.Put(w)
}()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
res.Writer = grw
@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
}
return http.ErrNotSupported
}
func gzipPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}

View File

@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
e := echo.New()
// Invalid level
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
}
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) {
}
}
}
func BenchmarkGzip(b *testing.B) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Gzip
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}

View File

@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
@ -102,45 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided
if origin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if match, _ := regexp.MatchString(re, origin); match {
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Simple request
if req.Method != http.MethodOptions {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
@ -152,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
// Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)

View File

@ -1,6 +1,7 @@
package middleware
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
@ -17,19 +18,31 @@ func TestCORS(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard AllowedOrigin with no Origin header in request
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
// Allow origins
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
// Preflight request
req = httptest.NewRequest(http.MethodOptions, "/", nil)
@ -67,6 +80,22 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
// Preflight request with Access-Control-Request-Headers
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
// Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
@ -126,7 +155,7 @@ func Test_allowOriginScheme(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
@ -216,6 +245,163 @@ func Test_allowOriginSubdomain(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
func TestCorsHeaders(t *testing.T) {
tests := []struct {
domain, allowedOrigin, method string
expected bool
}{
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
},
}
e := echo.New()
for _, tt := range tests {
req := httptest.NewRequest(tt.method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tt.domain != "" {
req.Header.Set(echo.HeaderOrigin, tt.domain)
}
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
})
h := cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
expectedAllowOrigin := ""
if tt.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tt.domain
}
switch {
case tt.expected && tt.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tt.expected && tt.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
}
}
func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(origin string) (bool, error) {
return true, nil
}
returnFalse := func(origin string) (bool, error) {
return false, nil
}
returnError := func(origin string) (bool, error) {
return true, errors.New("this is a test error")
}
allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
}
const origin = "http://example.com"
e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
if expected {
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}

58
middleware/decompress.go Normal file
View File

@ -0,0 +1,58 @@
package middleware
import (
"bytes"
"compress/gzip"
"github.com/labstack/echo/v4"
"io"
"io/ioutil"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
)
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
gr, err := gzip.NewReader(c.Request().Body)
if err != nil {
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
defer gr.Close()
var buf bytes.Buffer
io.Copy(&buf, gr)
r := ioutil.NopCloser(&buf)
defer r.Close()
c.Request().Body = r
}
return next(c)
}
}
}

View File

@ -0,0 +1,148 @@
package middleware
import (
"bytes"
"compress/gzip"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestDecompress(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Skip if no Content-Encoding header
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.NotEqual(t, b, body)
assert.Equal(t, b, gz)
}
func TestDecompressNoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func TestDecompressErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Decompress())
e.GET("/", func(c echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestDecompressSkipper(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: func(c echo.Context) bool {
return c.Request().URL.Path == "/skip"
},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
}
func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte(body)) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Decompress
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}
func gzipString(body string) ([]byte, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte(body))
if err != nil {
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@ -57,6 +57,7 @@ type (
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
TokenLookup string
// AuthScheme to be used in the Authorization header.
@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1])
case "cookie":
extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -3,6 +3,8 @@ package middleware
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/dgrijalva/jwt-go"
@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
}{
{
@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest,
info: "Empty cookie",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
} {
if tc.reqURL == "" {
tc.reqURL = "/"
}
req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)

View File

@ -2,7 +2,6 @@ package middleware
import (
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
return strings.NewReplacer(replace...)
}
//rewritePath sets request url path and raw path
func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error {
replacerRawPath := replacer.Replace(target)
replacerPath, err := url.PathUnescape(replacerRawPath)
if err != nil {
return err
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
// Initialize
rulesRegex := map[*regexp.Regexp]string{}
for k, v := range rewrite {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v
}
return rulesRegex
}
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
for k, v := range rewriteRegex {
replacerRawPath := captureTokens(k, req.URL.EscapedPath())
if replacerRawPath != nil {
replacerPath := captureTokens(k, req.URL.Path)
req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v)
}
}
req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath
return nil
}
// DefaultSkipper returns false which processes the middleware.

View File

@ -8,7 +8,6 @@ import (
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
}
config.rewriteRegex = map[*regexp.Regexp]string{}
// Initialize
for k, v := range config.Rewrite {
k = strings.Replace(k, "*", "(\\S*)", -1)
config.rewriteRegex[regexp.MustCompile(k)] = v
}
config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
tgt := config.Balancer.Next(c)
c.Set(config.ContextKey, tgt)
// Rewrite
for k, v := range config.rewriteRegex {
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
}
}
// Set rewrite path and raw path
rewritePath(config.rewriteRegex, req)
// Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
}
}
}

View File

@ -94,36 +94,35 @@ func TestProxy(t *testing.T) {
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
req.URL.Path = "/api/users"
req.URL, _ = url.Parse("/api/users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/js/main.js"
req.URL, _ = url.Parse( "/js/main.js")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/old"
req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jack/orders/1"
req.URL, _ = url.Parse( "/users/jack/orders/1")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/%%%%"
req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
// ModifyResponse
e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{

View File

@ -1,11 +1,8 @@
package middleware
import (
"net/http"
"regexp"
"strings"
"github.com/labstack/echo/v4"
"regexp"
)
type (
@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
}
config.rulesRegex = map[*regexp.Regexp]string{}
// Initialize
for k, v := range config.Rules {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
config.rulesRegex[regexp.MustCompile(k)] = v
}
config.rulesRegex = rewriteRulesRegex(config.Rules)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
}
req := c.Request()
// Rewrite
for k, v := range config.rulesRegex {
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
break
}
}
// Set rewrite path and raw path
rewritePath(config.rulesRegex, req)
return next(c)
}
}

View File

@ -4,6 +4,7 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/labstack/echo/v4"
@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) {
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
req.URL.Path = "/api/users"
req.URL, _ = url.Parse("/api/users")
e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath())
req.URL.Path = "/js/main.js"
req.URL, _ = url.Parse("/js/main.js")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
req.URL.Path = "/old"
req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath())
req.URL.Path = "/users/jack/orders/1"
req.URL, _ = url.Parse("/users/jack/orders/1")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
req.URL.Path = "/api/new users"
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/%%%%"
rec = httptest.NewRecorder()
req.URL, _ = url.Parse("/api/new users")
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
}
// Issue #1086
@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
r := e.Router()
// Rewrite old url to new one
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
e.Pre(Rewrite(map[string]string{
"/old": "/new",
},
}))
))
// Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error {

View File

@ -36,6 +36,12 @@ type (
// Enable directory browsing.
// Optional. Default value false.
Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
}
)
@ -163,6 +169,15 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
}
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
if config.IgnoreBase {
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
baseURLPath := path.Base(p)
if baseURLPath == routePath {
i := strings.LastIndex(name, routePath)
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
}
}
fi, err := os.Stat(name)
if err != nil {
if os.IsNotExist(err) {

View File

@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"github.com/labstack/echo/v4"
@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), "cert.pem")
}
// IgnoreBase
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = true
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = false
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
}

View File

@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed")
return
}
r.Status = code
for _, fn := range r.beforeFuncs {
fn()
}
r.Status = code
r.Writer.WriteHeader(code)
r.Writer.WriteHeader(r.Status)
r.Committed = true
}

View File

@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush()
assert.True(t, rec.Flushed)
}
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
res.Before(func() {
if 200 < res.Status && res.Status < 300 {
res.Status = 200
}
})
res.WriteHeader(209)
assert.Equal(t, http.StatusOK, rec.Code)
}