diff --git a/.github/stale.yml b/.github/stale.yml
index d9f65632..04dd169c 100644
--- a/.github/stale.yml
+++ b/.github/stale.yml
@@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
-daysUntilClose: 7
+daysUntilClose: 30
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
+ - bug
+ - enhancement
# Label to use when marking an issue as stale
-staleLabel: wontfix
+staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
- recent activity. It will be closed if no further activity occurs. Thank you
- for your contributions.
+ recent activity. It will be closed within a month if no further activity occurs.
+ Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
-closeComment: false
\ No newline at end of file
+closeComment: false
diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml
index 1d7508b9..2aec272d 100644
--- a/.github/workflows/echo.yml
+++ b/.github/workflows/echo.yml
@@ -4,20 +4,28 @@ on:
push:
branches:
- master
+ paths:
+ - '**.go'
+ - 'go.*'
+ - '_fixture/**'
+ - '.github/**'
+ - 'codecov.yml'
pull_request:
branches:
- master
-
-env:
- GO111MODULE: on
- GOPROXY: https://proxy.golang.org
+ paths:
+ - '**.go'
+ - 'go.*'
+ - '_fixture/**'
+ - '.github/**'
+ - 'codecov.yml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
- go: [1.12, 1.13, 1.14]
+ go: [1.12, 1.13, 1.14, 1.15]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
@@ -28,10 +36,15 @@ jobs:
- name: Set GOPATH and PATH
run: |
- echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)"
- echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin"
+ echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
+ echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
+ - name: Set build variables
+ run: |
+ echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
+ echo "GO111MODULE=on" >> $GITHUB_ENV
+
- name: Checkout Code
uses: actions/checkout@v1
with:
@@ -51,3 +64,55 @@ jobs:
with:
token:
fail_ci_if_error: false
+ benchmark:
+ needs: test
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ go: [1.15]
+ name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
+ runs-on: ${{ matrix.os }}
+ steps:
+ - name: Set up Go ${{ matrix.go }}
+ uses: actions/setup-go@v1
+ with:
+ go-version: ${{ matrix.go }}
+
+ - name: Set GOPATH and PATH
+ run: |
+ echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
+ echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
+ shell: bash
+
+ - name: Set build variables
+ run: |
+ echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
+ echo "GO111MODULE=on" >> $GITHUB_ENV
+
+ - name: Checkout Code (Previous)
+ uses: actions/checkout@v2
+ with:
+ ref: ${{ github.base_ref }}
+ path: previous
+
+ - name: Checkout Code (New)
+ uses: actions/checkout@v2
+ with:
+ path: new
+
+ - name: Install Dependencies
+ run: go get -v golang.org/x/perf/cmd/benchstat
+
+ - name: Run Benchmark (Previous)
+ run: |
+ cd previous
+ go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
+
+ - name: Run Benchmark (New)
+ run: |
+ cd new
+ go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
+
+ - name: Run Benchstat
+ run: |
+ benchstat previous/benchmark.txt new/benchmark.txt
diff --git a/.travis.yml b/.travis.yml
index ef826e95..67d45ad7 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,3 +1,7 @@
+arch:
+ - amd64
+ - ppc64le
+
language: go
go:
- 1.14.x
diff --git a/README.md b/README.md
index 03ad4dca..deba54f4 100644
--- a/README.md
+++ b/README.md
@@ -42,11 +42,14 @@ For older versions, please use the latest v3 tag.
## Benchmarks
-Date: 2018/03/15
+Date: 2020/11/11
Source: https://github.com/vishr/web-framework-benchmark
Lower is better!
-
+
+
+
+The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
## [Guide](https://echo.labstack.com/guide)
diff --git a/_fixture/_fixture/README.md b/_fixture/_fixture/README.md
new file mode 100644
index 00000000..21a78585
--- /dev/null
+++ b/_fixture/_fixture/README.md
@@ -0,0 +1 @@
+This directory is used for the static middleware test
\ No newline at end of file
diff --git a/codecov.yml b/codecov.yml
new file mode 100644
index 00000000..0fa3a3f1
--- /dev/null
+++ b/codecov.yml
@@ -0,0 +1,11 @@
+coverage:
+ status:
+ project:
+ default:
+ threshold: 1%
+ patch:
+ default:
+ threshold: 1%
+
+comment:
+ require_changes: true
\ No newline at end of file
diff --git a/context.go b/context.go
index 0507f139..8ba98477 100644
--- a/context.go
+++ b/context.go
@@ -365,7 +365,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
if err != nil {
return nil, err
}
- defer f.Close()
+ f.Close()
return fh, nil
}
diff --git a/echo.go b/echo.go
index db64e1c0..4b0c785a 100644
--- a/echo.go
+++ b/echo.go
@@ -49,7 +49,6 @@ import (
"net/http"
"net/url"
"os"
- "path"
"path/filepath"
"reflect"
"runtime"
@@ -92,6 +91,7 @@ type (
Renderer Renderer
Logger Logger
IPExtractor IPExtractor
+ ListenerNetwork string
}
// Route contains a handler and information for matching against requests.
@@ -281,6 +281,7 @@ var (
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
+ ErrInvalidListenerNetwork = errors.New("invalid listener network")
)
// Error handlers
@@ -302,9 +303,10 @@ func New() (e *Echo) {
AutoTLSManager: autocert.Manager{
Prompt: autocert.AcceptTOS,
},
- Logger: log.New("echo"),
- colorer: color.New(),
- maxParam: new(int),
+ Logger: log.New("echo"),
+ colorer: color.New(),
+ maxParam: new(int),
+ ListenerNetwork: "tcp",
}
e.Server.Handler = e
e.TLSServer.Handler = e
@@ -483,7 +485,7 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl
return err
}
- name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security
+ name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
fi, err := os.Stat(name)
if err != nil {
// The access path does not exist
@@ -573,7 +575,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes {
if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ {
- if r.Path[i] == ':' && n < ln {
+ if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
@@ -715,7 +717,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
if s.TLSConfig == nil {
if e.Listener == nil {
- e.Listener, err = newListener(s.Addr)
+ e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
@@ -726,7 +728,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
return s.Serve(e.Listener)
}
if e.TLSListener == nil {
- l, err := newListener(s.Addr)
+ l, err := newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
@@ -755,7 +757,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
}
if e.Listener == nil {
- e.Listener, err = newListener(s.Addr)
+ e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
@@ -876,8 +878,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
return
}
-func newListener(address string) (*tcpKeepAliveListener, error) {
- l, err := net.Listen("tcp", address)
+func newListener(address, network string) (*tcpKeepAliveListener, error) {
+ if network != "tcp" && network != "tcp4" && network != "tcp6" {
+ return nil, ErrInvalidListenerNetwork
+ }
+ l, err := net.Listen(network, address)
if err != nil {
return nil, err
}
diff --git a/echo_test.go b/echo_test.go
index ac400122..82ccad0c 100644
--- a/echo_test.go
+++ b/echo_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
stdContext "context"
"errors"
+ "fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@@ -59,53 +60,114 @@ func TestEcho(t *testing.T) {
}
func TestEchoStatic(t *testing.T) {
- e := New()
+ var testCases = []struct {
+ name string
+ givenPrefix string
+ givenRoot string
+ whenURL string
+ expectStatus int
+ expectHeaderLocation string
+ expectBodyStartsWith string
+ }{
+ {
+ name: "ok",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "No file",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/scripts",
+ whenURL: "/images/bolt.png",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/images/",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory Redirect",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/static/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory Redirect with non-root path",
+ givenPrefix: "/static",
+ givenRoot: "_fixture",
+ whenURL: "/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Sub-directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/folder/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "do not allow directory traversal (backslash - windows separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/..\\middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "do not allow directory traversal (slash - unix separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/../middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ }
- assert := assert.New(t)
-
- // OK
- e.Static("/images", "_fixture/images")
- c, b := request(http.MethodGet, "/images/walle.png", e)
- assert.Equal(http.StatusOK, c)
- assert.NotEmpty(b)
-
- // No file
- e.Static("/images", "_fixture/scripts")
- c, _ = request(http.MethodGet, "/images/bolt.png", e)
- assert.Equal(http.StatusNotFound, c)
-
- // Directory
- e.Static("/images", "_fixture/images")
- c, _ = request(http.MethodGet, "/images/", e)
- assert.Equal(http.StatusNotFound, c)
-
- // Directory Redirect
- e.Static("/", "_fixture")
- req := httptest.NewRequest(http.MethodGet, "/folder", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- assert.Equal(http.StatusMovedPermanently, rec.Code)
- assert.Equal("/folder/", rec.HeaderMap["Location"][0])
-
- // Directory Redirect with non-root path
- e.Static("/static", "_fixture")
- req = httptest.NewRequest(http.MethodGet, "/static", nil)
- rec = httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- assert.Equal(http.StatusMovedPermanently, rec.Code)
- assert.Equal("/static/", rec.HeaderMap["Location"][0])
-
- // Directory with index.html
- e.Static("/", "_fixture")
- c, r := request(http.MethodGet, "/", e)
- assert.Equal(http.StatusOK, c)
- assert.Equal(true, strings.HasPrefix(r, ""))
-
- // Sub-directory with index.html
- c, r = request(http.MethodGet, "/folder/", e)
- assert.Equal(http.StatusOK, c)
- assert.Equal(true, strings.HasPrefix(r, ""))
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Static(tc.givenPrefix, tc.givenRoot)
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ body := rec.Body.String()
+ if tc.expectBodyStartsWith != "" {
+ assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
+ } else {
+ assert.Equal(t, "", body)
+ }
+ if tc.expectHeaderLocation != "" {
+ assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
+ } else {
+ _, ok := rec.Result().Header["Location"]
+ assert.False(t, ok)
+ }
+ })
+ }
}
func TestEchoStaticRedirectIndex(t *testing.T) {
@@ -319,10 +381,12 @@ func TestEchoURL(t *testing.T) {
e := New()
static := func(Context) error { return nil }
getUser := func(Context) error { return nil }
+ getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil }
e.GET("/static/file", static)
e.GET("/users/:id", getUser)
+ e.GET("/documents/*", getAny)
g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile)
@@ -331,6 +395,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1"))
+ assert.Equal("/users/1", e.URL(getUser, "1"))
+ assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
+ assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
}
@@ -694,3 +761,91 @@ func TestEchoShutdown(t *testing.T) {
err := <-errCh
assert.Equal(t, err.Error(), "http: Server closed")
}
+
+var listenerNetworkTests = []struct {
+ test string
+ network string
+ address string
+}{
+ {"tcp ipv4 address", "tcp", "127.0.0.1:1323"},
+ {"tcp ipv6 address", "tcp", "[::1]:1323"},
+ {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"},
+ {"tcp6 ipv6 address", "tcp6", "[::1]:1323"},
+}
+
+func TestEchoListenerNetwork(t *testing.T) {
+ for _, tt := range listenerNetworkTests {
+ t.Run(tt.test, func(t *testing.T) {
+ e := New()
+ e.ListenerNetwork = tt.network
+
+ // HandlerFunc
+ e.GET("/ok", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+
+ errCh := make(chan error)
+
+ go func() {
+ errCh <- e.Start(tt.address)
+ }()
+
+ time.Sleep(200 * time.Millisecond)
+
+ if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
+ defer resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ if body, err := ioutil.ReadAll(resp.Body); err == nil {
+ assert.Equal(t, "OK", string(body))
+ } else {
+ assert.Fail(t, err.Error())
+ }
+
+ } else {
+ assert.Fail(t, err.Error())
+ }
+
+ if err := e.Close(); err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}
+
+func TestEchoListenerNetworkInvalid(t *testing.T) {
+ e := New()
+ e.ListenerNetwork = "unix"
+
+ // HandlerFunc
+ e.GET("/ok", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+
+ assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
+}
+
+func TestEchoReverse(t *testing.T) {
+ assert := assert.New(t)
+
+ e := New()
+ dummyHandler := func(Context) error { return nil }
+
+ e.GET("/static", dummyHandler).Name = "/static"
+ e.GET("/static/*", dummyHandler).Name = "/static/*"
+ e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
+ e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
+ e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
+
+ assert.Equal("/static", e.Reverse("/static"))
+ assert.Equal("/static", e.Reverse("/static", "missing param"))
+ assert.Equal("/static/*", e.Reverse("/static/*"))
+ assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
+
+ assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
+ assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
+ assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
+ assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
+ assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
+ assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
+}
diff --git a/middleware/compress.go b/middleware/compress.go
index dd97d983..6ae19745 100644
--- a/middleware/compress.go
+++ b/middleware/compress.go
@@ -8,6 +8,7 @@ import (
"net"
"net/http"
"strings"
+ "sync"
"github.com/labstack/echo/v4"
)
@@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level
}
+ pool := gzipCompressPool(config)
+
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
@@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
- rw := res.Writer
- w, err := gzip.NewWriterLevel(rw, config.Level)
- if err != nil {
- return err
+ i := pool.Get()
+ w, ok := i.(*gzip.Writer)
+ if !ok {
+ return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
+ rw := res.Writer
+ w.Reset(rw)
defer func() {
if res.Size == 0 {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
@@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
w.Reset(ioutil.Discard)
}
w.Close()
+ pool.Put(w)
}()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
res.Writer = grw
@@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
}
return http.ErrNotSupported
}
+
+func gzipCompressPool(config GzipConfig) sync.Pool {
+ return sync.Pool{
+ New: func() interface{} {
+ w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
+ if err != nil {
+ return err
+ }
+ return w
+ },
+ }
+}
diff --git a/middleware/compress_test.go b/middleware/compress_test.go
index ac5b6c3b..d16ffca4 100644
--- a/middleware/compress_test.go
+++ b/middleware/compress_test.go
@@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
+func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
+ e := echo.New()
+ // Invalid level
+ e.Use(GzipWithConfig(GzipConfig{Level: 12}))
+ e.GET("/", func(c echo.Context) error {
+ c.Response().Write([]byte("test"))
+ return nil
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.Contains(t, rec.Body.String(), "gzip")
+}
+
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
@@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) {
}
}
}
+
+func BenchmarkGzip(b *testing.B) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+
+ h := Gzip()(func(c echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ // Gzip
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h(c)
+ }
+}
diff --git a/middleware/cors.go b/middleware/cors.go
index c263f731..d6ef8964 100644
--- a/middleware/cors.go
+++ b/middleware/cors.go
@@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`
+ // AllowOriginFunc is a custom function to validate the origin. It takes the
+ // origin as an argument and returns true if allowed or false otherwise. If
+ // an error is returned, it is returned by the handler. If this option is
+ // set, AllowOrigins is ignored.
+ // Optional.
+ AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
+
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
@@ -102,45 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""
- // Check allowed origins
- for _, o := range config.AllowOrigins {
- if o == "*" && config.AllowCredentials {
- allowOrigin = origin
- break
- }
- if o == "*" || o == origin {
- allowOrigin = o
- break
- }
- if matchSubdomain(origin, o) {
- allowOrigin = origin
- break
+ preflight := req.Method == http.MethodOptions
+ res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
+
+ // No Origin provided
+ if origin == "" {
+ if !preflight {
+ return next(c)
}
+ return c.NoContent(http.StatusNoContent)
}
- // Check allowed origin patterns
- for _, re := range allowOriginPatterns {
- if allowOrigin == "" {
- didx := strings.Index(origin, "://")
- if didx == -1 {
- continue
- }
- domAuth := origin[didx+3:]
- // to avoid regex cost by invalid long domain
- if len(domAuth) > 253 {
+ if config.AllowOriginFunc != nil {
+ allowed, err := config.AllowOriginFunc(origin)
+ if err != nil {
+ return err
+ }
+ if allowed {
+ allowOrigin = origin
+ }
+ } else {
+ // Check allowed origins
+ for _, o := range config.AllowOrigins {
+ if o == "*" && config.AllowCredentials {
+ allowOrigin = origin
break
}
-
- if match, _ := regexp.MatchString(re, origin); match {
+ if o == "*" || o == origin {
+ allowOrigin = o
+ break
+ }
+ if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
+
+ // Check allowed origin patterns
+ for _, re := range allowOriginPatterns {
+ if allowOrigin == "" {
+ didx := strings.Index(origin, "://")
+ if didx == -1 {
+ continue
+ }
+ domAuth := origin[didx+3:]
+ // to avoid regex cost by invalid long domain
+ if len(domAuth) > 253 {
+ break
+ }
+
+ if match, _ := regexp.MatchString(re, origin); match {
+ allowOrigin = origin
+ break
+ }
+ }
+ }
+ }
+
+ // Origin not allowed
+ if allowOrigin == "" {
+ if !preflight {
+ return next(c)
+ }
+ return c.NoContent(http.StatusNoContent)
}
// Simple request
- if req.Method != http.MethodOptions {
- res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
+ if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
@@ -152,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
// Preflight request
- res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
diff --git a/middleware/cors_test.go b/middleware/cors_test.go
index ca922321..717abe49 100644
--- a/middleware/cors_test.go
+++ b/middleware/cors_test.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "errors"
"net/http"
"net/http/httptest"
"testing"
@@ -17,19 +18,31 @@ func TestCORS(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler)
+ req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ // Wildcard AllowedOrigin with no Origin header in request
+ req = httptest.NewRequest(http.MethodGet, "/", nil)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ h = CORS()(echo.NotFoundHandler)
+ h(c)
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+
// Allow origins
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORSWithConfig(CORSConfig{
- AllowOrigins: []string{"localhost"},
+ AllowOrigins: []string{"localhost"},
+ AllowCredentials: true,
+ MaxAge: 3600,
})(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
// Preflight request
req = httptest.NewRequest(http.MethodOptions, "/", nil)
@@ -67,6 +80,22 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
+ // Preflight request with Access-Control-Request-Headers
+ req = httptest.NewRequest(http.MethodOptions, "/", nil)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ req.Header.Set(echo.HeaderOrigin, "localhost")
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+ req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
+ cors = CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"*"},
+ })
+ h = cors(echo.NotFoundHandler)
+ h(c)
+ assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
+ assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
+
// Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
@@ -126,7 +155,7 @@ func Test_allowOriginScheme(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
- assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
@@ -216,6 +245,163 @@ func Test_allowOriginSubdomain(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ } else {
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+ }
+ }
+}
+
+func TestCorsHeaders(t *testing.T) {
+ tests := []struct {
+ domain, allowedOrigin, method string
+ expected bool
+ }{
+ {
+ domain: "", // Request does not have Origin header
+ allowedOrigin: "*",
+ method: http.MethodGet,
+ expected: false,
+ },
+ {
+ domain: "http://example.com",
+ allowedOrigin: "*",
+ method: http.MethodGet,
+ expected: true,
+ },
+ {
+ domain: "", // Request does not have Origin header
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: false,
+ },
+ {
+ domain: "http://bar.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: false,
+ },
+ {
+ domain: "http://example.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: true,
+ },
+ {
+ domain: "", // Request does not have Origin header
+ allowedOrigin: "*",
+ method: http.MethodOptions,
+ expected: false,
+ },
+ {
+ domain: "http://example.com",
+ allowedOrigin: "*",
+ method: http.MethodOptions,
+ expected: true,
+ },
+ {
+ domain: "", // Request does not have Origin header
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: false,
+ },
+ {
+ domain: "http://bar.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: false,
+ },
+ {
+ domain: "http://example.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: true,
+ },
+ }
+
+ e := echo.New()
+ for _, tt := range tests {
+ req := httptest.NewRequest(tt.method, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ if tt.domain != "" {
+ req.Header.Set(echo.HeaderOrigin, tt.domain)
+ }
+ cors := CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{tt.allowedOrigin},
+ //AllowCredentials: true,
+ //MaxAge: 3600,
+ })
+ h := cors(echo.NotFoundHandler)
+ h(c)
+
+ assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
+
+ expectedAllowOrigin := ""
+ if tt.allowedOrigin == "*" {
+ expectedAllowOrigin = "*"
+ } else {
+ expectedAllowOrigin = tt.domain
+ }
+
+ switch {
+ case tt.expected && tt.method == http.MethodOptions:
+ assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
+ assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
+ case tt.expected && tt.method == http.MethodGet:
+ assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
+ default:
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+ assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
+ }
+
+ if tt.method == http.MethodOptions {
+ assert.Equal(t, http.StatusNoContent, rec.Code)
+ }
+ }
+}
+
+func Test_allowOriginFunc(t *testing.T) {
+ returnTrue := func(origin string) (bool, error) {
+ return true, nil
+ }
+ returnFalse := func(origin string) (bool, error) {
+ return false, nil
+ }
+ returnError := func(origin string) (bool, error) {
+ return true, errors.New("this is a test error")
+ }
+
+ allowOriginFuncs := []func(origin string) (bool, error){
+ returnTrue,
+ returnFalse,
+ returnError,
+ }
+
+ const origin = "http://example.com"
+
+ e := echo.New()
+ for _, allowOriginFunc := range allowOriginFuncs {
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(echo.HeaderOrigin, origin)
+ cors := CORSWithConfig(CORSConfig{
+ AllowOriginFunc: allowOriginFunc,
+ })
+ h := cors(echo.NotFoundHandler)
+ err := h(c)
+
+ expected, expectedErr := allowOriginFunc(origin)
+ if expectedErr != nil {
+ assert.Equal(t, expectedErr, err)
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ continue
+ }
+
+ if expected {
+ assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}
diff --git a/middleware/decompress.go b/middleware/decompress.go
new file mode 100644
index 00000000..c046359a
--- /dev/null
+++ b/middleware/decompress.go
@@ -0,0 +1,120 @@
+package middleware
+
+import (
+ "bytes"
+ "compress/gzip"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "sync"
+
+ "github.com/labstack/echo/v4"
+)
+
+type (
+ // DecompressConfig defines the config for Decompress middleware.
+ DecompressConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
+ GzipDecompressPool Decompressor
+ }
+)
+
+//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
+const GZIPEncoding string = "gzip"
+
+// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
+type Decompressor interface {
+ gzipDecompressPool() sync.Pool
+}
+
+var (
+ //DefaultDecompressConfig defines the config for decompress middleware
+ DefaultDecompressConfig = DecompressConfig{
+ Skipper: DefaultSkipper,
+ GzipDecompressPool: &DefaultGzipDecompressPool{},
+ }
+)
+
+// DefaultGzipDecompressPool is the default implementation of Decompressor interface
+type DefaultGzipDecompressPool struct {
+}
+
+func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
+ return sync.Pool{
+ New: func() interface{} {
+ // create with an empty reader (but with GZIP header)
+ w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
+ if err != nil {
+ return err
+ }
+
+ b := new(bytes.Buffer)
+ w.Reset(b)
+ w.Flush()
+ w.Close()
+
+ r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
+ if err != nil {
+ return err
+ }
+ return r
+ },
+ }
+}
+
+//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
+func Decompress() echo.MiddlewareFunc {
+ return DecompressWithConfig(DefaultDecompressConfig)
+}
+
+//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
+func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
+ // Defaults
+ if config.Skipper == nil {
+ config.Skipper = DefaultGzipConfig.Skipper
+ }
+ if config.GzipDecompressPool == nil {
+ config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ pool := config.GzipDecompressPool.gzipDecompressPool()
+ return func(c echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+ switch c.Request().Header.Get(echo.HeaderContentEncoding) {
+ case GZIPEncoding:
+ b := c.Request().Body
+
+ i := pool.Get()
+ gr, ok := i.(*gzip.Reader)
+ if !ok {
+ return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
+ }
+
+ if err := gr.Reset(b); err != nil {
+ pool.Put(gr)
+ if err == io.EOF { //ignore if body is empty
+ return next(c)
+ }
+ return err
+ }
+ var buf bytes.Buffer
+ io.Copy(&buf, gr)
+
+ gr.Close()
+ pool.Put(gr)
+
+ b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
+
+ r := ioutil.NopCloser(&buf)
+ c.Request().Body = r
+ }
+ return next(c)
+ }
+ }
+}
diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go
new file mode 100644
index 00000000..51fa6b0f
--- /dev/null
+++ b/middleware/decompress_test.go
@@ -0,0 +1,209 @@
+package middleware
+
+import (
+ "bytes"
+ "compress/gzip"
+ "errors"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/labstack/echo/v4"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDecompress(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Skip if no Content-Encoding header
+ h := Decompress()(func(c echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+ h(c)
+
+ assert := assert.New(t)
+ assert.Equal("test", rec.Body.String())
+
+ // Decompress
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ h(c)
+ assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := ioutil.ReadAll(req.Body)
+ assert.NoError(err)
+ assert.Equal(body, string(b))
+}
+
+func TestDecompressDefaultConfig(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+ h(c)
+
+ assert := assert.New(t)
+ assert.Equal("test", rec.Body.String())
+
+ // Decompress
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ h(c)
+ assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := ioutil.ReadAll(req.Body)
+ assert.NoError(err)
+ assert.Equal(body, string(b))
+}
+
+func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
+ e := echo.New()
+ body := `{"name":"echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ e.NewContext(req, rec)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := ioutil.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.NotEqual(t, b, body)
+ assert.Equal(t, b, gz)
+}
+
+func TestDecompressNoContent(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Decompress()(func(c echo.Context) error {
+ return c.NoContent(http.StatusNoContent)
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
+ assert.Equal(t, 0, len(rec.Body.Bytes()))
+ }
+}
+
+func TestDecompressErrorReturned(t *testing.T) {
+ e := echo.New()
+ e.Use(Decompress())
+ e.GET("/", func(c echo.Context) error {
+ return echo.ErrNotFound
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+}
+
+func TestDecompressSkipper(t *testing.T) {
+ e := echo.New()
+ e.Use(DecompressWithConfig(DecompressConfig{
+ Skipper: func(c echo.Context) bool {
+ return c.Request().URL.Path == "/skip"
+ },
+ }))
+ body := `{"name": "echo"}`
+ req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
+ reqBody, err := ioutil.ReadAll(c.Request().Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(reqBody))
+}
+
+type TestDecompressPoolWithError struct {
+}
+
+func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
+ return sync.Pool{
+ New: func() interface{} {
+ return errors.New("pool error")
+ },
+ }
+}
+
+func TestDecompressPoolError(t *testing.T) {
+ e := echo.New()
+ e.Use(DecompressWithConfig(DecompressConfig{
+ Skipper: DefaultSkipper,
+ GzipDecompressPool: &TestDecompressPoolWithError{},
+ }))
+ body := `{"name": "echo"}`
+ req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ reqBody, err := ioutil.ReadAll(c.Request().Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(reqBody))
+ assert.Equal(t, rec.Code, http.StatusInternalServerError)
+}
+
+func BenchmarkDecompress(b *testing.B) {
+ e := echo.New()
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+
+ h := Decompress()(func(c echo.Context) error {
+ c.Response().Write([]byte(body)) // For Content-Type sniffing
+ return nil
+ })
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ // Decompress
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h(c)
+ }
+}
+
+func gzipString(body string) ([]byte, error) {
+ var buf bytes.Buffer
+ gz := gzip.NewWriter(&buf)
+
+ _, err := gz.Write([]byte(body))
+ if err != nil {
+ return nil, err
+ }
+
+ if err := gz.Close(); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
diff --git a/middleware/jwt.go b/middleware/jwt.go
index bab00c9f..da00ea56 100644
--- a/middleware/jwt.go
+++ b/middleware/jwt.go
@@ -57,6 +57,7 @@ type (
// - "query:"
// - "param:"
// - "cookie:"
+ // - "form:"
TokenLookup string
// AuthScheme to be used in the Authorization header.
@@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1])
case "cookie":
extractor = jwtFromCookie(parts[1])
+ case "form":
+ extractor = jwtFromForm(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil
}
}
+
+// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
+func jwtFromForm(name string) jwtExtractor {
+ return func(c echo.Context) (string, error) {
+ field := c.FormValue(name)
+ if field == "" {
+ return "", ErrJWTMissing
+ }
+ return field, nil
+ }
+}
diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go
index ce44f9c9..205721ae 100644
--- a/middleware/jwt_test.go
+++ b/middleware/jwt_test.go
@@ -3,6 +3,8 @@ package middleware
import (
"net/http"
"net/http/httptest"
+ "net/url"
+ "strings"
"testing"
"github.com/dgrijalva/jwt-go"
@@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
+ formValues map[string]string
info string
}{
{
@@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest,
info: "Empty cookie",
},
+ {
+ config: JWTConfig{
+ SigningKey: validKey,
+ TokenLookup: "form:jwt",
+ },
+ formValues: map[string]string{"jwt": token},
+ info: "Valid form method",
+ },
+ {
+ config: JWTConfig{
+ SigningKey: validKey,
+ TokenLookup: "form:jwt",
+ },
+ expErrCode: http.StatusUnauthorized,
+ formValues: map[string]string{"jwt": "invalid"},
+ info: "Invalid token with form method",
+ },
+ {
+ config: JWTConfig{
+ SigningKey: validKey,
+ TokenLookup: "form:jwt",
+ },
+ expErrCode: http.StatusBadRequest,
+ info: "Empty form field",
+ },
} {
if tc.reqURL == "" {
tc.reqURL = "/"
}
- req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
+ var req *http.Request
+ if len(tc.formValues) > 0 {
+ form := url.Values{}
+ for k, v := range tc.formValues {
+ form.Set(k, v)
+ }
+ req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
+ req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
+ req.ParseForm()
+ } else {
+ req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
+ }
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go
index 30eecdef..86eec8c3 100644
--- a/middleware/request_id_test.go
+++ b/middleware/request_id_test.go
@@ -31,3 +31,20 @@ func TestRequestID(t *testing.T) {
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
+
+func TestRequestID_IDNotAltered(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRequestID, "")
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestIDWithConfig(RequestIDConfig{})
+ h := rid(handler)
+ _ = h(c)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "")
+}
diff --git a/middleware/static.go b/middleware/static.go
index bc2087a7..ae79cb5f 100644
--- a/middleware/static.go
+++ b/middleware/static.go
@@ -36,6 +36,12 @@ type (
// Enable directory browsing.
// Optional. Default value false.
Browse bool `yaml:"browse"`
+
+ // Enable ignoring of the base of the URL path.
+ // Example: when assigning a static middleware to a non root path group,
+ // the filesystem path is not doubled
+ // Optional. Default value false.
+ IgnoreBase bool `yaml:"ignoreBase"`
}
)
@@ -161,7 +167,16 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if err != nil {
return
}
- name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
+ name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
+
+ if config.IgnoreBase {
+ routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
+ baseURLPath := path.Base(p)
+ if baseURLPath == routePath {
+ i := strings.LastIndex(name, routePath)
+ name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
+ }
+ }
fi, err := os.Stat(name)
if err != nil {
diff --git a/middleware/static_test.go b/middleware/static_test.go
index 0d695d3d..407dd15c 100644
--- a/middleware/static_test.go
+++ b/middleware/static_test.go
@@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
+ "path/filepath"
"testing"
"github.com/labstack/echo/v4"
@@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), "cert.pem")
}
+
+ // IgnoreBase
+ req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
+ rec = httptest.NewRecorder()
+ config.Root = "../_fixture"
+ config.IgnoreBase = true
+ static = StaticWithConfig(config)
+ c.Echo().Group("_fixture", static)
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(http.StatusOK, rec.Code)
+ assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
+
+ req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
+ rec = httptest.NewRecorder()
+ config.Root = "../_fixture"
+ config.IgnoreBase = false
+ static = StaticWithConfig(config)
+ c.Echo().Group("_fixture", static)
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(http.StatusOK, rec.Code)
+ assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
}
diff --git a/response.go b/response.go
index ca7405c5..84f7c9e7 100644
--- a/response.go
+++ b/response.go
@@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed")
return
}
+ r.Status = code
for _, fn := range r.beforeFuncs {
fn()
}
- r.Status = code
- r.Writer.WriteHeader(code)
+ r.Writer.WriteHeader(r.Status)
r.Committed = true
}
diff --git a/response_test.go b/response_test.go
index 7a9c51c6..d95e079f 100644
--- a/response_test.go
+++ b/response_test.go
@@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush()
assert.True(t, rec.Flushed)
}
+
+func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ res := &Response{echo: e, Writer: rec}
+
+ res.Before(func() {
+ if 200 < res.Status && res.Status < 300 {
+ res.Status = 200
+ }
+ })
+
+ res.WriteHeader(209)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
diff --git a/router.go b/router.go
index ed728d6a..4c3898c4 100644
--- a/router.go
+++ b/router.go
@@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) {
pos := strings.IndexByte(ns, '/')
if pos == -1 {
// If no slash is remaining in search string set param value
- pvalues[len(cn.pnames)-1] = search
+ if len(cn.pnames) > 0 {
+ pvalues[len(cn.pnames)-1] = search
+ }
break
} else if pos > 0 {
// Otherwise continue route processing with restored next node
diff --git a/router_test.go b/router_test.go
index 0e883233..fca3a79b 100644
--- a/router_test.go
+++ b/router_test.go
@@ -1298,6 +1298,40 @@ func TestRouterParam1466(t *testing.T) {
assert.Equal(t, 0, c.response.Status)
}
+// Issue #1653
+func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
+ e := New()
+ r := e.router
+
+ r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1))
+ r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error {
+ return nil
+ })
+ r.Add(http.MethodGet, "/users/:id/active", func(c Context) error {
+ return nil
+ })
+
+ c := e.NewContext(nil, nil).(*context)
+ r.Find(http.MethodGet, "/users/alice/edit", c)
+ assert.Equal(t, "alice", c.Param("id"))
+
+ c = e.NewContext(nil, nil).(*context)
+ r.Find(http.MethodGet, "/users/bob/active", c)
+ assert.Equal(t, "bob", c.Param("id"))
+
+ c = e.NewContext(nil, nil).(*context)
+ r.Find(http.MethodGet, "/users/create", c)
+ c.Handler()(c)
+ assert.Equal(t, 1, c.Get("create"))
+ assert.Equal(t, "/users/create", c.Get("path"))
+
+ //This panic before the fix for Issue #1653
+ c = e.NewContext(nil, nil).(*context)
+ r.Find(http.MethodGet, "/users/createNotFound", c)
+ he := c.Handler()(c).(*HTTPError)
+ assert.Equal(t, http.StatusNotFound, he.Code)
+}
+
func benchmarkRouterRoutes(b *testing.B, routes []*Route) {
e := New()
r := e.router