1
0
mirror of https://github.com/labstack/echo.git synced 2024-11-24 08:22:21 +02:00

Merge branch 'master' into fix-conflict

This commit is contained in:
pwli 2020-12-16 09:21:26 +08:00 committed by GitHub
commit 89ec0070b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1122 additions and 110 deletions

12
.github/stale.yml vendored
View File

@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale # Number of days of inactivity before an issue becomes stale
daysUntilStale: 60 daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed # Number of days of inactivity before a stale issue is closed
daysUntilClose: 7 daysUntilClose: 30
# Issues with these labels will never be considered stale # Issues with these labels will never be considered stale
exemptLabels: exemptLabels:
- pinned - pinned
- security - security
- bug
- enhancement
# Label to use when marking an issue as stale # Label to use when marking an issue as stale
staleLabel: wontfix staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable # Comment to post when marking an issue as stale. Set to `false` to disable
markComment: > markComment: >
This issue has been automatically marked as stale because it has not had This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you recent activity. It will be closed within a month if no further activity occurs.
for your contributions. Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable # Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false closeComment: false

View File

@ -4,20 +4,28 @@ on:
push: push:
branches: branches:
- master - master
paths:
- '**.go'
- 'go.*'
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
pull_request: pull_request:
branches: branches:
- master - master
paths:
env: - '**.go'
GO111MODULE: on - 'go.*'
GOPROXY: https://proxy.golang.org - '_fixture/**'
- '.github/**'
- 'codecov.yml'
jobs: jobs:
test: test:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
go: [1.12, 1.13, 1.14] go: [1.12, 1.13, 1.14, 1.15]
name: ${{ matrix.os }} @ Go ${{ matrix.go }} name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@ -28,10 +36,15 @@ jobs:
- name: Set GOPATH and PATH - name: Set GOPATH and PATH
run: | run: |
echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)" echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin" echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v1 uses: actions/checkout@v1
with: with:
@ -51,3 +64,55 @@ jobs:
with: with:
token: token:
fail_ci_if_error: false fail_ci_if_error: false
benchmark:
needs: test
strategy:
matrix:
os: [ubuntu-latest]
go: [1.15]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v1
with:
go-version: ${{ matrix.go }}
- name: Set GOPATH and PATH
run: |
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code (Previous)
uses: actions/checkout@v2
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
uses: actions/checkout@v2
with:
path: new
- name: Install Dependencies
run: go get -v golang.org/x/perf/cmd/benchstat
- name: Run Benchmark (Previous)
run: |
cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New)
run: |
cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat
run: |
benchstat previous/benchmark.txt new/benchmark.txt

View File

@ -1,3 +1,7 @@
arch:
- amd64
- ppc64le
language: go language: go
go: go:
- 1.14.x - 1.14.x

View File

@ -42,11 +42,14 @@ For older versions, please use the latest v3 tag.
## Benchmarks ## Benchmarks
Date: 2018/03/15<br> Date: 2020/11/11<br>
Source: https://github.com/vishr/web-framework-benchmark<br> Source: https://github.com/vishr/web-framework-benchmark<br>
Lower is better! Lower is better!
<img src="https://i.imgur.com/I32VdMJ.png"> <img src="https://i.imgur.com/qwPNQbl.png">
<img src="https://i.imgur.com/s8yKQjx.png">
The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
## [Guide](https://echo.labstack.com/guide) ## [Guide](https://echo.labstack.com/guide)

View File

@ -0,0 +1 @@
This directory is used for the static middleware test

11
codecov.yml Normal file
View File

@ -0,0 +1,11 @@
coverage:
status:
project:
default:
threshold: 1%
patch:
default:
threshold: 1%
comment:
require_changes: true

View File

@ -365,7 +365,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close() f.Close()
return fh, nil return fh, nil
} }

27
echo.go
View File

@ -49,7 +49,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime" "runtime"
@ -92,6 +91,7 @@ type (
Renderer Renderer Renderer Renderer
Logger Logger Logger Logger
IPExtractor IPExtractor IPExtractor IPExtractor
ListenerNetwork string
} }
// Route contains a handler and information for matching against requests. // Route contains a handler and information for matching against requests.
@ -281,6 +281,7 @@ var (
ErrInvalidRedirectCode = errors.New("invalid redirect status code") ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found") ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
) )
// Error handlers // Error handlers
@ -302,9 +303,10 @@ func New() (e *Echo) {
AutoTLSManager: autocert.Manager{ AutoTLSManager: autocert.Manager{
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
}, },
Logger: log.New("echo"), Logger: log.New("echo"),
colorer: color.New(), colorer: color.New(),
maxParam: new(int), maxParam: new(int),
ListenerNetwork: "tcp",
} }
e.Server.Handler = e e.Server.Handler = e
e.TLSServer.Handler = e e.TLSServer.Handler = e
@ -483,7 +485,7 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl
return err return err
} }
name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
fi, err := os.Stat(name) fi, err := os.Stat(name)
if err != nil { if err != nil {
// The access path does not exist // The access path does not exist
@ -573,7 +575,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes { for _, r := range e.router.routes {
if r.Name == name { if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ { for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < ln { if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ { for ; i < l && r.Path[i] != '/'; i++ {
} }
uri.WriteString(fmt.Sprintf("%v", params[n])) uri.WriteString(fmt.Sprintf("%v", params[n]))
@ -715,7 +717,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
if s.TLSConfig == nil { if s.TLSConfig == nil {
if e.Listener == nil { if e.Listener == nil {
e.Listener, err = newListener(s.Addr) e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil { if err != nil {
return err return err
} }
@ -726,7 +728,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
return s.Serve(e.Listener) return s.Serve(e.Listener)
} }
if e.TLSListener == nil { if e.TLSListener == nil {
l, err := newListener(s.Addr) l, err := newListener(s.Addr, e.ListenerNetwork)
if err != nil { if err != nil {
return err return err
} }
@ -755,7 +757,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
} }
if e.Listener == nil { if e.Listener == nil {
e.Listener, err = newListener(s.Addr) e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil { if err != nil {
return err return err
} }
@ -876,8 +878,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
return return
} }
func newListener(address string) (*tcpKeepAliveListener, error) { func newListener(address, network string) (*tcpKeepAliveListener, error) {
l, err := net.Listen("tcp", address) if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, ErrInvalidListenerNetwork
}
l, err := net.Listen(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
stdContext "context" stdContext "context"
"errors" "errors"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -59,53 +60,114 @@ func TestEcho(t *testing.T) {
} }
func TestEchoStatic(t *testing.T) { func TestEchoStatic(t *testing.T) {
e := New() var testCases = []struct {
name string
givenPrefix string
givenRoot string
whenURL string
expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
}{
{
name: "ok",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "No file",
givenPrefix: "/images",
givenRoot: "_fixture/scripts",
whenURL: "/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/static/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenRoot: "_fixture",
whenURL: "/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
assert := assert.New(t) for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// OK e := New()
e.Static("/images", "_fixture/images") e.Static(tc.givenPrefix, tc.givenRoot)
c, b := request(http.MethodGet, "/images/walle.png", e) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
assert.Equal(http.StatusOK, c) rec := httptest.NewRecorder()
assert.NotEmpty(b) e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
// No file body := rec.Body.String()
e.Static("/images", "_fixture/scripts") if tc.expectBodyStartsWith != "" {
c, _ = request(http.MethodGet, "/images/bolt.png", e) assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
assert.Equal(http.StatusNotFound, c) } else {
assert.Equal(t, "", body)
// Directory }
e.Static("/images", "_fixture/images")
c, _ = request(http.MethodGet, "/images/", e)
assert.Equal(http.StatusNotFound, c)
// Directory Redirect
e.Static("/", "_fixture")
req := httptest.NewRequest(http.MethodGet, "/folder", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(http.StatusMovedPermanently, rec.Code)
assert.Equal("/folder/", rec.HeaderMap["Location"][0])
// Directory Redirect with non-root path
e.Static("/static", "_fixture")
req = httptest.NewRequest(http.MethodGet, "/static", nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(http.StatusMovedPermanently, rec.Code)
assert.Equal("/static/", rec.HeaderMap["Location"][0])
// Directory with index.html
e.Static("/", "_fixture")
c, r := request(http.MethodGet, "/", e)
assert.Equal(http.StatusOK, c)
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
// Sub-directory with index.html
c, r = request(http.MethodGet, "/folder/", e)
assert.Equal(http.StatusOK, c)
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
} }
func TestEchoStaticRedirectIndex(t *testing.T) { func TestEchoStaticRedirectIndex(t *testing.T) {
@ -319,10 +381,12 @@ func TestEchoURL(t *testing.T) {
e := New() e := New()
static := func(Context) error { return nil } static := func(Context) error { return nil }
getUser := func(Context) error { return nil } getUser := func(Context) error { return nil }
getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil } getFile := func(Context) error { return nil }
e.GET("/static/file", static) e.GET("/static/file", static)
e.GET("/users/:id", getUser) e.GET("/users/:id", getUser)
e.GET("/documents/*", getAny)
g := e.Group("/group") g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile) g.GET("/users/:uid/files/:fid", getFile)
@ -331,6 +395,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static)) assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser)) assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1")) assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
} }
@ -694,3 +761,91 @@ func TestEchoShutdown(t *testing.T) {
err := <-errCh err := <-errCh
assert.Equal(t, err.Error(), "http: Server closed") assert.Equal(t, err.Error(), "http: Server closed")
} }
var listenerNetworkTests = []struct {
test string
network string
address string
}{
{"tcp ipv4 address", "tcp", "127.0.0.1:1323"},
{"tcp ipv6 address", "tcp", "[::1]:1323"},
{"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"},
{"tcp6 ipv6 address", "tcp6", "[::1]:1323"},
}
func TestEchoListenerNetwork(t *testing.T) {
for _, tt := range listenerNetworkTests {
t.Run(tt.test, func(t *testing.T) {
e := New()
e.ListenerNetwork = tt.network
// HandlerFunc
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
errCh := make(chan error)
go func() {
errCh <- e.Start(tt.address)
}()
time.Sleep(200 * time.Millisecond)
if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
if body, err := ioutil.ReadAll(resp.Body); err == nil {
assert.Equal(t, "OK", string(body))
} else {
assert.Fail(t, err.Error())
}
} else {
assert.Fail(t, err.Error())
}
if err := e.Close(); err != nil {
t.Fatal(err)
}
})
}
}
func TestEchoListenerNetworkInvalid(t *testing.T) {
e := New()
e.ListenerNetwork = "unix"
// HandlerFunc
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
}
func TestEchoReverse(t *testing.T) {
assert := assert.New(t)
e := New()
dummyHandler := func(Context) error { return nil }
e.GET("/static", dummyHandler).Name = "/static"
e.GET("/static/*", dummyHandler).Name = "/static/*"
e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

View File

@ -8,6 +8,7 @@ import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
"sync"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level config.Level = DefaultGzipConfig.Level
} }
pool := gzipCompressPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
rw := res.Writer i := pool.Get()
w, err := gzip.NewWriterLevel(rw, config.Level) w, ok := i.(*gzip.Writer)
if err != nil { if !ok {
return err return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
} }
rw := res.Writer
w.Reset(rw)
defer func() { defer func() {
if res.Size == 0 { if res.Size == 0 {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
w.Reset(ioutil.Discard) w.Reset(ioutil.Discard)
} }
w.Close() w.Close()
pool.Put(w)
}() }()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
res.Writer = grw res.Writer = grw
@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
} }
return http.ErrNotSupported return http.ErrNotSupported
} }
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}

View File

@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
} }
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
e := echo.New()
// Invalid level
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
}
// Issue #806 // Issue #806
func TestGzipWithStatic(t *testing.T) { func TestGzipWithStatic(t *testing.T) {
e := echo.New() e := echo.New()
@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) {
} }
} }
} }
func BenchmarkGzip(b *testing.B) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Gzip
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}

View File

@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"` AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
@ -102,45 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin) origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := "" allowOrigin := ""
// Check allowed origins preflight := req.Method == http.MethodOptions
for _, o := range config.AllowOrigins { res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if o == "*" && config.AllowCredentials {
allowOrigin = origin // No Origin provided
break if origin == "" {
} if !preflight {
if o == "*" || o == origin { return next(c)
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
} }
return c.NoContent(http.StatusNoContent)
} }
// Check allowed origin patterns if config.AllowOriginFunc != nil {
for _, re := range allowOriginPatterns { allowed, err := config.AllowOriginFunc(origin)
if allowOrigin == "" { if err != nil {
didx := strings.Index(origin, "://") return err
if didx == -1 { }
continue if allowed {
} allowOrigin = origin
domAuth := origin[didx+3:] }
// to avoid regex cost by invalid long domain } else {
if len(domAuth) > 253 { // Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break break
} }
if o == "*" || o == origin {
if match, _ := regexp.MatchString(re, origin); match { allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin allowOrigin = origin
break break
} }
} }
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
} }
// Simple request // Simple request
if req.Method != http.MethodOptions { if !preflight {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials { if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
@ -152,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
// Preflight request // Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -17,19 +18,31 @@ func TestCORS(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler) h := CORS()(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard AllowedOrigin with no Origin header in request
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
// Allow origins // Allow origins
req = httptest.NewRequest(http.MethodGet, "/", nil) req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = CORSWithConfig(CORSConfig{ h = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"}, AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler) })(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
// Preflight request // Preflight request
req = httptest.NewRequest(http.MethodOptions, "/", nil) req = httptest.NewRequest(http.MethodOptions, "/", nil)
@ -67,6 +80,22 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
// Preflight request with Access-Control-Request-Headers
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
// Preflight request with `AllowOrigins` which allow all subdomains with * // Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil) req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
@ -126,7 +155,7 @@ func Test_allowOriginScheme(t *testing.T) {
if tt.expected { if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else { } else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
} }
} }
} }
@ -216,6 +245,163 @@ func Test_allowOriginSubdomain(t *testing.T) {
if tt.expected { if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
func TestCorsHeaders(t *testing.T) {
tests := []struct {
domain, allowedOrigin, method string
expected bool
}{
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
},
}
e := echo.New()
for _, tt := range tests {
req := httptest.NewRequest(tt.method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tt.domain != "" {
req.Header.Set(echo.HeaderOrigin, tt.domain)
}
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
})
h := cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
expectedAllowOrigin := ""
if tt.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tt.domain
}
switch {
case tt.expected && tt.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tt.expected && tt.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
}
}
func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(origin string) (bool, error) {
return true, nil
}
returnFalse := func(origin string) (bool, error) {
return false, nil
}
returnError := func(origin string) (bool, error) {
return true, errors.New("this is a test error")
}
allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
}
const origin = "http://example.com"
e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
if expected {
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else { } else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} }

120
middleware/decompress.go Normal file
View File

@ -0,0 +1,120 @@
package middleware
import (
"bytes"
"compress/gzip"
"io"
"io/ioutil"
"net/http"
"sync"
"github.com/labstack/echo/v4"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
b := c.Request().Body
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
var buf bytes.Buffer
io.Copy(&buf, gr)
gr.Close()
pool.Put(gr)
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
r := ioutil.NopCloser(&buf)
c.Request().Body = r
}
return next(c)
}
}
}

View File

@ -0,0 +1,209 @@
package middleware
import (
"bytes"
"compress/gzip"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestDecompress(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Skip if no Content-Encoding header
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestDecompressDefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.NotEqual(t, b, body)
assert.Equal(t, b, gz)
}
func TestDecompressNoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func TestDecompressErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Decompress())
e.GET("/", func(c echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestDecompressSkipper(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: func(c echo.Context) bool {
return c.Request().URL.Path == "/skip"
},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
}
type TestDecompressPoolWithError struct {
}
func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
return errors.New("pool error")
},
}
}
func TestDecompressPoolError(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &TestDecompressPoolWithError{},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError)
}
func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte(body)) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Decompress
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}
func gzipString(body string) ([]byte, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte(body))
if err != nil {
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@ -57,6 +57,7 @@ type (
// - "query:<name>" // - "query:<name>"
// - "param:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// - "form:<name>"
TokenLookup string TokenLookup string
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1]) extractor = jwtFromParam(parts[1])
case "cookie": case "cookie":
extractor = jwtFromCookie(parts[1]) extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil return cookie.Value, nil
} }
} }
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -3,6 +3,8 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings"
"testing" "testing"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty reqURL string // "/" if empty
hdrAuth string hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string info string
}{ }{
{ {
@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Empty cookie", info: "Empty cookie",
}, },
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
} { } {
if tc.reqURL == "" { if tc.reqURL == "" {
tc.reqURL = "/" tc.reqURL = "/"
} }
req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil) var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder() res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie) req.Header.Set(echo.HeaderCookie, tc.hdrCookie)

View File

@ -31,3 +31,20 @@ func TestRequestID(t *testing.T) {
h(c) h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
} }
func TestRequestID_IDNotAltered(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRequestID, "<sample-request-id>")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{})
h := rid(handler)
_ = h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "<sample-request-id>")
}

View File

@ -36,6 +36,12 @@ type (
// Enable directory browsing. // Enable directory browsing.
// Optional. Default value false. // Optional. Default value false.
Browse bool `yaml:"browse"` Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
} }
) )
@ -161,7 +167,16 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if err != nil { if err != nil {
return return
} }
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
if config.IgnoreBase {
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
baseURLPath := path.Base(p)
if baseURLPath == routePath {
i := strings.LastIndex(name, routePath)
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
}
}
fi, err := os.Stat(name) fi, err := os.Stat(name)
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), "cert.pem") assert.Contains(rec.Body.String(), "cert.pem")
} }
// IgnoreBase
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = true
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = false
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
} }

View File

@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed") r.echo.Logger.Warn("response already committed")
return return
} }
r.Status = code
for _, fn := range r.beforeFuncs { for _, fn := range r.beforeFuncs {
fn() fn()
} }
r.Status = code r.Writer.WriteHeader(r.Status)
r.Writer.WriteHeader(code)
r.Committed = true r.Committed = true
} }

View File

@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush() res.Flush()
assert.True(t, rec.Flushed) assert.True(t, rec.Flushed)
} }
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
res.Before(func() {
if 200 < res.Status && res.Status < 300 {
res.Status = 200
}
})
res.WriteHeader(209)
assert.Equal(t, http.StatusOK, rec.Code)
}

View File

@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) {
pos := strings.IndexByte(ns, '/') pos := strings.IndexByte(ns, '/')
if pos == -1 { if pos == -1 {
// If no slash is remaining in search string set param value // If no slash is remaining in search string set param value
pvalues[len(cn.pnames)-1] = search if len(cn.pnames) > 0 {
pvalues[len(cn.pnames)-1] = search
}
break break
} else if pos > 0 { } else if pos > 0 {
// Otherwise continue route processing with restored next node // Otherwise continue route processing with restored next node

View File

@ -1298,6 +1298,40 @@ func TestRouterParam1466(t *testing.T) {
assert.Equal(t, 0, c.response.Status) assert.Equal(t, 0, c.response.Status)
} }
// Issue #1653
func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1))
r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error {
return nil
})
r.Add(http.MethodGet, "/users/:id/active", func(c Context) error {
return nil
})
c := e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/alice/edit", c)
assert.Equal(t, "alice", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/bob/active", c)
assert.Equal(t, "bob", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/create", c)
c.Handler()(c)
assert.Equal(t, 1, c.Get("create"))
assert.Equal(t, "/users/create", c.Get("path"))
//This panic before the fix for Issue #1653
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/createNotFound", c)
he := c.Handler()(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.Code)
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route) { func benchmarkRouterRoutes(b *testing.B, routes []*Route) {
e := New() e := New()
r := e.router r := e.router