mirror of
https://github.com/labstack/echo.git
synced 2025-03-31 22:05:06 +02:00
Merge branch 'master' of https://github.com/labstack/echo
This commit is contained in:
commit
cb15226984
.github/workflows
.travis.ymlREADME.md_fixture/_fixture
bind_test.gocodecov.ymlcontext.gocontext_test.goecho.goecho_go1.13_test.goecho_test.gogo.modgo.sumgroup.gogroup_test.gomiddleware
compress.gocompress_test.gocors.gocors_test.godecompress.godecompress_test.gojwt.gojwt_test.gomiddleware.goproxy.goproxy_1_11.goproxy_test.gorecover.gorecover_test.gorewrite.gorewrite_test.gostatic.gostatic_test.go
response_test.gorouter.gorouter_test.go
79
.github/workflows/echo.yml
vendored
79
.github/workflows/echo.yml
vendored
@ -4,20 +4,28 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
paths:
|
||||||
|
- '**.go'
|
||||||
|
- 'go.*'
|
||||||
|
- '_fixture/**'
|
||||||
|
- '.github/**'
|
||||||
|
- 'codecov.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
paths:
|
||||||
env:
|
- '**.go'
|
||||||
GO111MODULE: on
|
- 'go.*'
|
||||||
GOPROXY: https://proxy.golang.org
|
- '_fixture/**'
|
||||||
|
- '.github/**'
|
||||||
|
- 'codecov.yml'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||||
go: [1.11, 1.12, 1.13]
|
go: [1.12, 1.13, 1.14, 1.15]
|
||||||
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
|
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
@ -28,10 +36,15 @@ jobs:
|
|||||||
|
|
||||||
- name: Set GOPATH and PATH
|
- name: Set GOPATH and PATH
|
||||||
run: |
|
run: |
|
||||||
echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)"
|
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
|
||||||
echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin"
|
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
|
- name: Set build variables
|
||||||
|
run: |
|
||||||
|
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
|
||||||
|
echo "GO111MODULE=on" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v1
|
uses: actions/checkout@v1
|
||||||
with:
|
with:
|
||||||
@ -51,3 +64,55 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
token:
|
token:
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
benchmark:
|
||||||
|
needs: test
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
go: [1.15]
|
||||||
|
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- name: Set up Go ${{ matrix.go }}
|
||||||
|
uses: actions/setup-go@v1
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
|
- name: Set GOPATH and PATH
|
||||||
|
run: |
|
||||||
|
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
|
||||||
|
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Set build variables
|
||||||
|
run: |
|
||||||
|
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
|
||||||
|
echo "GO111MODULE=on" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Checkout Code (Previous)
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
ref: ${{ github.base_ref }}
|
||||||
|
path: previous
|
||||||
|
|
||||||
|
- name: Checkout Code (New)
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
path: new
|
||||||
|
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: go get -v golang.org/x/perf/cmd/benchstat
|
||||||
|
|
||||||
|
- name: Run Benchmark (Previous)
|
||||||
|
run: |
|
||||||
|
cd previous
|
||||||
|
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
|
||||||
|
|
||||||
|
- name: Run Benchmark (New)
|
||||||
|
run: |
|
||||||
|
cd new
|
||||||
|
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
|
||||||
|
|
||||||
|
- name: Run Benchstat
|
||||||
|
run: |
|
||||||
|
benchstat previous/benchmark.txt new/benchmark.txt
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
|
arch:
|
||||||
|
- amd64
|
||||||
|
- ppc64le
|
||||||
|
|
||||||
language: go
|
language: go
|
||||||
go:
|
go:
|
||||||
- 1.12.x
|
- 1.14.x
|
||||||
- 1.13.x
|
- 1.15.x
|
||||||
- tip
|
- tip
|
||||||
env:
|
env:
|
||||||
- GO111MODULE=on
|
- GO111MODULE=on
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
<a href="https://echo.labstack.com"><img height="80" src="https://cdn.labstack.com/images/echo-logo.svg"></a>
|
<a href="https://echo.labstack.com"><img height="80" src="https://cdn.labstack.com/images/echo-logo.svg"></a>
|
||||||
|
|
||||||
[](https://sourcegraph.com/github.com/labstack/echo?badge)
|
[](https://sourcegraph.com/github.com/labstack/echo?badge)
|
||||||
[](http://godoc.org/github.com/labstack/echo)
|
[](https://pkg.go.dev/github.com/labstack/echo/v4)
|
||||||
[](https://goreportcard.com/report/github.com/labstack/echo)
|
[](https://goreportcard.com/report/github.com/labstack/echo)
|
||||||
[](https://travis-ci.org/labstack/echo)
|
[](https://travis-ci.org/labstack/echo)
|
||||||
[](https://codecov.io/gh/labstack/echo)
|
[](https://codecov.io/gh/labstack/echo)
|
||||||
@ -17,7 +17,7 @@ Therefore a Go version capable of understanding /vN suffixed imports is required
|
|||||||
|
|
||||||
- 1.9.7+
|
- 1.9.7+
|
||||||
- 1.10.3+
|
- 1.10.3+
|
||||||
- 1.11+
|
- 1.14+
|
||||||
|
|
||||||
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
|
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
|
||||||
way of using Echo going forward.
|
way of using Echo going forward.
|
||||||
@ -52,7 +52,7 @@ Lower is better!
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```go
|
```sh
|
||||||
// go get github.com/labstack/echo/{version}
|
// go get github.com/labstack/echo/{version}
|
||||||
go get github.com/labstack/echo/v4
|
go get github.com/labstack/echo/v4
|
||||||
```
|
```
|
||||||
|
1
_fixture/_fixture/README.md
Normal file
1
_fixture/_fixture/README.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
This directory is used for the static middleware test
|
@ -332,7 +332,6 @@ func TestBindbindData(t *testing.T) {
|
|||||||
|
|
||||||
func TestBindParam(t *testing.T) {
|
func TestBindParam(t *testing.T) {
|
||||||
e := New()
|
e := New()
|
||||||
*e.maxParam = 2
|
|
||||||
req := httptest.NewRequest(GET, "/", nil)
|
req := httptest.NewRequest(GET, "/", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
@ -363,7 +362,6 @@ func TestBindParam(t *testing.T) {
|
|||||||
// Bind something with param and post data payload
|
// Bind something with param and post data payload
|
||||||
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
|
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
|
||||||
e2 := New()
|
e2 := New()
|
||||||
*e2.maxParam = 2
|
|
||||||
req2 := httptest.NewRequest(POST, "/", body)
|
req2 := httptest.NewRequest(POST, "/", body)
|
||||||
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
|
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
|
||||||
|
|
||||||
|
11
codecov.yml
Normal file
11
codecov.yml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
coverage:
|
||||||
|
status:
|
||||||
|
project:
|
||||||
|
default:
|
||||||
|
threshold: 1%
|
||||||
|
patch:
|
||||||
|
default:
|
||||||
|
threshold: 1%
|
||||||
|
|
||||||
|
comment:
|
||||||
|
require_changes: true
|
14
context.go
14
context.go
@ -276,7 +276,11 @@ func (c *context) RealIP() string {
|
|||||||
}
|
}
|
||||||
// Fall back to legacy behavior
|
// Fall back to legacy behavior
|
||||||
if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
|
if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
|
||||||
return strings.Split(ip, ", ")[0]
|
i := strings.IndexAny(ip, ", ")
|
||||||
|
if i > 0 {
|
||||||
|
return ip[:i]
|
||||||
|
}
|
||||||
|
return ip
|
||||||
}
|
}
|
||||||
if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
|
if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
|
||||||
return ip
|
return ip
|
||||||
@ -310,6 +314,7 @@ func (c *context) ParamNames() []string {
|
|||||||
|
|
||||||
func (c *context) SetParamNames(names ...string) {
|
func (c *context) SetParamNames(names ...string) {
|
||||||
c.pnames = names
|
c.pnames = names
|
||||||
|
*c.echo.maxParam = len(names)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *context) ParamValues() []string {
|
func (c *context) ParamValues() []string {
|
||||||
@ -317,10 +322,7 @@ func (c *context) ParamValues() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *context) SetParamValues(values ...string) {
|
func (c *context) SetParamValues(values ...string) {
|
||||||
// NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times
|
c.pvalues = values
|
||||||
for i, val := range values {
|
|
||||||
c.pvalues[i] = val
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *context) QueryParam(name string) string {
|
func (c *context) QueryParam(name string) string {
|
||||||
@ -363,7 +365,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
f.Close()
|
||||||
return fh, nil
|
return fh, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,6 +72,15 @@ func BenchmarkAllocXML(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) {
|
||||||
|
c := context{request: &http.Request{
|
||||||
|
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
|
||||||
|
}}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
c.RealIP()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) error {
|
func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) error {
|
||||||
return t.templates.ExecuteTemplate(w, name, data)
|
return t.templates.ExecuteTemplate(w, name, data)
|
||||||
}
|
}
|
||||||
@ -93,7 +102,6 @@ func (responseWriterErr) WriteHeader(statusCode int) {
|
|||||||
|
|
||||||
func TestContext(t *testing.T) {
|
func TestContext(t *testing.T) {
|
||||||
e := New()
|
e := New()
|
||||||
*e.maxParam = 1
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c := e.NewContext(req, rec).(*context)
|
c := e.NewContext(req, rec).(*context)
|
||||||
@ -472,7 +480,6 @@ func TestContextPath(t *testing.T) {
|
|||||||
|
|
||||||
func TestContextPathParam(t *testing.T) {
|
func TestContextPathParam(t *testing.T) {
|
||||||
e := New()
|
e := New()
|
||||||
*e.maxParam = 2
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
c := e.NewContext(req, nil)
|
c := e.NewContext(req, nil)
|
||||||
|
|
||||||
@ -491,7 +498,8 @@ func TestContextPathParam(t *testing.T) {
|
|||||||
|
|
||||||
func TestContextGetAndSetParam(t *testing.T) {
|
func TestContextGetAndSetParam(t *testing.T) {
|
||||||
e := New()
|
e := New()
|
||||||
*e.maxParam = 2
|
r := e.Router()
|
||||||
|
r.Add(http.MethodGet, "/:foo", func(Context) error { return nil })
|
||||||
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
|
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
|
||||||
c := e.NewContext(req, nil)
|
c := e.NewContext(req, nil)
|
||||||
c.SetParamNames("foo")
|
c.SetParamNames("foo")
|
||||||
@ -848,6 +856,14 @@ func TestContext_RealIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"127.0.0.1",
|
"127.0.0.1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
&context{
|
||||||
|
request: &http.Request{
|
||||||
|
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"127.0.0.1",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
&context{
|
&context{
|
||||||
request: &http.Request{
|
request: &http.Request{
|
||||||
|
53
echo.go
53
echo.go
@ -48,6 +48,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -230,7 +231,7 @@ const (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// Version of Echo
|
// Version of Echo
|
||||||
Version = "4.1.15"
|
Version = "4.1.17"
|
||||||
website = "https://echo.labstack.com"
|
website = "https://echo.labstack.com"
|
||||||
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
|
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
|
||||||
banner = `
|
banner = `
|
||||||
@ -361,10 +362,12 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
|
|||||||
// Issue #1426
|
// Issue #1426
|
||||||
code := he.Code
|
code := he.Code
|
||||||
message := he.Message
|
message := he.Message
|
||||||
if e.Debug {
|
if m, ok := he.Message.(string); ok {
|
||||||
message = err.Error()
|
if e.Debug {
|
||||||
} else if m, ok := message.(string); ok {
|
message = Map{"message": m, "error": err.Error()}
|
||||||
message = Map{"message": m}
|
} else {
|
||||||
|
message = Map{"message": m}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
@ -479,7 +482,20 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security
|
name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security
|
||||||
|
fi, err := os.Stat(name)
|
||||||
|
if err != nil {
|
||||||
|
// The access path does not exist
|
||||||
|
return NotFoundHandler(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the request is for a directory and does not end with "/"
|
||||||
|
p = c.Request().URL.Path // path must not be empty.
|
||||||
|
if fi.IsDir() && p[len(p)-1] != '/' {
|
||||||
|
// Redirect to ends with "/"
|
||||||
|
return c.Redirect(http.StatusMovedPermanently, p+"/")
|
||||||
|
}
|
||||||
return c.File(name)
|
return c.File(name)
|
||||||
}
|
}
|
||||||
if prefix == "/" {
|
if prefix == "/" {
|
||||||
@ -504,11 +520,7 @@ func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ..
|
|||||||
name := handlerName(handler)
|
name := handlerName(handler)
|
||||||
router := e.findRouter(host)
|
router := e.findRouter(host)
|
||||||
router.Add(method, path, func(c Context) error {
|
router.Add(method, path, func(c Context) error {
|
||||||
h := handler
|
h := applyMiddleware(handler, middleware...)
|
||||||
// Chain middleware
|
|
||||||
for i := len(middleware) - 1; i >= 0; i-- {
|
|
||||||
h = middleware[i](h)
|
|
||||||
}
|
|
||||||
return h(c)
|
return h(c)
|
||||||
})
|
})
|
||||||
r := &Route{
|
r := &Route{
|
||||||
@ -602,16 +614,15 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Acquire context
|
// Acquire context
|
||||||
c := e.pool.Get().(*context)
|
c := e.pool.Get().(*context)
|
||||||
c.Reset(r, w)
|
c.Reset(r, w)
|
||||||
|
|
||||||
h := NotFoundHandler
|
h := NotFoundHandler
|
||||||
|
|
||||||
if e.premiddleware == nil {
|
if e.premiddleware == nil {
|
||||||
e.findRouter(r.Host).Find(r.Method, getPath(r), c)
|
e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c)
|
||||||
h = c.Handler()
|
h = c.Handler()
|
||||||
h = applyMiddleware(h, e.middleware...)
|
h = applyMiddleware(h, e.middleware...)
|
||||||
} else {
|
} else {
|
||||||
h = func(c Context) error {
|
h = func(c Context) error {
|
||||||
e.findRouter(r.Host).Find(r.Method, getPath(r), c)
|
e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c)
|
||||||
h := c.Handler()
|
h := c.Handler()
|
||||||
h = applyMiddleware(h, e.middleware...)
|
h = applyMiddleware(h, e.middleware...)
|
||||||
return h(c)
|
return h(c)
|
||||||
@ -783,6 +794,9 @@ func NewHTTPError(code int, message ...interface{}) *HTTPError {
|
|||||||
|
|
||||||
// Error makes it compatible with `error` interface.
|
// Error makes it compatible with `error` interface.
|
||||||
func (he *HTTPError) Error() string {
|
func (he *HTTPError) Error() string {
|
||||||
|
if he.Internal == nil {
|
||||||
|
return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
|
||||||
|
}
|
||||||
return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
|
return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -792,6 +806,11 @@ func (he *HTTPError) SetInternal(err error) *HTTPError {
|
|||||||
return he
|
return he
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unwrap satisfies the Go 1.13 error wrapper interface.
|
||||||
|
func (he *HTTPError) Unwrap() error {
|
||||||
|
return he.Internal
|
||||||
|
}
|
||||||
|
|
||||||
// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
|
// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
|
||||||
func WrapHandler(h http.Handler) HandlerFunc {
|
func WrapHandler(h http.Handler) HandlerFunc {
|
||||||
return func(c Context) error {
|
return func(c Context) error {
|
||||||
@ -814,14 +833,6 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPath(r *http.Request) string {
|
|
||||||
path := r.URL.RawPath
|
|
||||||
if path == "" {
|
|
||||||
path = r.URL.Path
|
|
||||||
}
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Echo) findRouter(host string) *Router {
|
func (e *Echo) findRouter(host string) *Router {
|
||||||
if len(e.routers) > 0 {
|
if len(e.routers) > 0 {
|
||||||
if r, ok := e.routers[host]; ok {
|
if r, ok := e.routers[host]; ok {
|
||||||
|
28
echo_go1.13_test.go
Normal file
28
echo_go1.13_test.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// +build go1.13
|
||||||
|
|
||||||
|
package echo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHTTPError_Unwrap(t *testing.T) {
|
||||||
|
t.Run("non-internal", func(t *testing.T) {
|
||||||
|
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||||
|
"code": 12,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Nil(t, errors.Unwrap(err))
|
||||||
|
})
|
||||||
|
t.Run("internal", func(t *testing.T) {
|
||||||
|
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||||
|
"code": 12,
|
||||||
|
})
|
||||||
|
err.SetInternal(errors.New("internal error"))
|
||||||
|
assert.Equal(t, "internal error", errors.Unwrap(err).Error())
|
||||||
|
})
|
||||||
|
}
|
72
echo_test.go
72
echo_test.go
@ -76,9 +76,17 @@ func TestEchoStatic(t *testing.T) {
|
|||||||
|
|
||||||
// Directory
|
// Directory
|
||||||
e.Static("/images", "_fixture/images")
|
e.Static("/images", "_fixture/images")
|
||||||
c, _ = request(http.MethodGet, "/images", e)
|
c, _ = request(http.MethodGet, "/images/", e)
|
||||||
assert.Equal(http.StatusNotFound, c)
|
assert.Equal(http.StatusNotFound, c)
|
||||||
|
|
||||||
|
// Directory Redirect
|
||||||
|
e.Static("/", "_fixture")
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/folder", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(http.StatusMovedPermanently, rec.Code)
|
||||||
|
assert.Equal("/folder/", rec.HeaderMap["Location"][0])
|
||||||
|
|
||||||
// Directory with index.html
|
// Directory with index.html
|
||||||
e.Static("/", "_fixture")
|
e.Static("/", "_fixture")
|
||||||
c, r := request(http.MethodGet, "/", e)
|
c, r := request(http.MethodGet, "/", e)
|
||||||
@ -86,9 +94,10 @@ func TestEchoStatic(t *testing.T) {
|
|||||||
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
|
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
|
||||||
|
|
||||||
// Sub-directory with index.html
|
// Sub-directory with index.html
|
||||||
c, r = request(http.MethodGet, "/folder", e)
|
c, r = request(http.MethodGet, "/folder/", e)
|
||||||
assert.Equal(http.StatusOK, c)
|
assert.Equal(http.StatusOK, c)
|
||||||
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
|
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEchoFile(t *testing.T) {
|
func TestEchoFile(t *testing.T) {
|
||||||
@ -543,10 +552,63 @@ func request(method, path string, e *Echo) (int, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPError(t *testing.T) {
|
func TestHTTPError(t *testing.T) {
|
||||||
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
t.Run("non-internal", func(t *testing.T) {
|
||||||
"code": 12,
|
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||||
|
"code": 12,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, "code=400, message=map[code:12]", err.Error())
|
||||||
})
|
})
|
||||||
assert.Equal(t, "code=400, message=map[code:12], internal=<nil>", err.Error())
|
t.Run("internal", func(t *testing.T) {
|
||||||
|
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||||
|
"code": 12,
|
||||||
|
})
|
||||||
|
err.SetInternal(errors.New("internal error"))
|
||||||
|
assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultHTTPErrorHandler(t *testing.T) {
|
||||||
|
e := New()
|
||||||
|
e.Debug = true
|
||||||
|
e.Any("/plain", func(c Context) error {
|
||||||
|
return errors.New("An error occurred")
|
||||||
|
})
|
||||||
|
e.Any("/badrequest", func(c Context) error {
|
||||||
|
return NewHTTPError(http.StatusBadRequest, "Invalid request")
|
||||||
|
})
|
||||||
|
e.Any("/servererror", func(c Context) error {
|
||||||
|
return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{
|
||||||
|
"code": 33,
|
||||||
|
"message": "Something bad happened",
|
||||||
|
"error": "stackinfo",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
// With Debug=true plain response contains error message
|
||||||
|
c, b := request(http.MethodGet, "/plain", e)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, c)
|
||||||
|
assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b)
|
||||||
|
// and special handling for HTTPError
|
||||||
|
c, b = request(http.MethodGet, "/badrequest", e)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, c)
|
||||||
|
assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b)
|
||||||
|
// complex errors are serialized to pretty JSON
|
||||||
|
c, b = request(http.MethodGet, "/servererror", e)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, c)
|
||||||
|
assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b)
|
||||||
|
|
||||||
|
e.Debug = false
|
||||||
|
// With Debug=false the error response is shortened
|
||||||
|
c, b = request(http.MethodGet, "/plain", e)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, c)
|
||||||
|
assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b)
|
||||||
|
c, b = request(http.MethodGet, "/badrequest", e)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, c)
|
||||||
|
assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b)
|
||||||
|
// No difference for error response with non plain string errors
|
||||||
|
c, b = request(http.MethodGet, "/servererror", e)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, c)
|
||||||
|
assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEchoClose(t *testing.T) {
|
func TestEchoClose(t *testing.T) {
|
||||||
|
14
go.mod
14
go.mod
@ -1,13 +1,15 @@
|
|||||||
module github.com/labstack/echo/v4
|
module github.com/labstack/echo/v4
|
||||||
|
|
||||||
go 1.14
|
go 1.15
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||||
github.com/labstack/gommon v0.3.0
|
github.com/labstack/gommon v0.3.0
|
||||||
github.com/mattn/go-colorable v0.1.6 // indirect
|
github.com/mattn/go-colorable v0.1.7 // indirect
|
||||||
github.com/stretchr/testify v1.4.0
|
github.com/stretchr/testify v1.4.0
|
||||||
github.com/valyala/fasttemplate v1.1.0 // indirect
|
github.com/valyala/fasttemplate v1.2.1
|
||||||
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d
|
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a
|
||||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b
|
golang.org/x/net v0.0.0-20200822124328-c89045814202
|
||||||
golang.org/x/text v0.3.2 // indirect
|
golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect
|
||||||
|
golang.org/x/text v0.3.3 // indirect
|
||||||
)
|
)
|
||||||
|
27
go.sum
27
go.sum
@ -1,11 +1,13 @@
|
|||||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||||
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||||
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
|
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
|
||||||
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
|
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
|
||||||
github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU=
|
github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU=
|
||||||
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||||
github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE=
|
github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw=
|
||||||
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||||
github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg=
|
github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg=
|
||||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||||
@ -20,14 +22,15 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
|
|||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasttemplate v1.0.1 h1:tY9CJiPnMXf1ERmG2EyK7gNUd+c6RKGD0IfU8WdUSz8=
|
github.com/valyala/fasttemplate v1.0.1 h1:tY9CJiPnMXf1ERmG2EyK7gNUd+c6RKGD0IfU8WdUSz8=
|
||||||
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
||||||
github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4=
|
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
|
||||||
github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM=
|
||||||
|
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8=
|
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
|
||||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@ -36,11 +39,15 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8=
|
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8=
|
||||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 h1:DvY3Zkh7KabQE/kfzMvYvKirSiguP9Q/veMtkYyf0o8=
|
||||||
|
golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
2
group.go
2
group.go
@ -109,7 +109,7 @@ func (g *Group) Static(prefix, root string) {
|
|||||||
|
|
||||||
// File implements `Echo#File()` for sub-routes within the Group.
|
// File implements `Echo#File()` for sub-routes within the Group.
|
||||||
func (g *Group) File(path, file string) {
|
func (g *Group) File(path, file string) {
|
||||||
g.file(g.prefix+path, file, g.GET)
|
g.file(path, file, g.GET)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add implements `Echo#Add()` for sub-routes within the Group.
|
// Add implements `Echo#Add()` for sub-routes within the Group.
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package echo
|
package echo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -26,6 +28,19 @@ func TestGroup(t *testing.T) {
|
|||||||
g.File("/walle", "_fixture/images//walle.png")
|
g.File("/walle", "_fixture/images//walle.png")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGroupFile(t *testing.T) {
|
||||||
|
e := New()
|
||||||
|
g := e.Group("/group")
|
||||||
|
g.File("/walle", "_fixture/images/walle.png")
|
||||||
|
expectedData, err := ioutil.ReadFile("_fixture/images/walle.png")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/group/walle", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, expectedData, rec.Body.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
func TestGroupRouteMiddleware(t *testing.T) {
|
func TestGroupRouteMiddleware(t *testing.T) {
|
||||||
// Ensure middleware slices are not re-used
|
// Ensure middleware slices are not re-used
|
||||||
e := New()
|
e := New()
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
)
|
)
|
||||||
@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
|||||||
config.Level = DefaultGzipConfig.Level
|
config.Level = DefaultGzipConfig.Level
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pool := gzipPool(config)
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
if config.Skipper(c) {
|
if config.Skipper(c) {
|
||||||
@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
|||||||
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
|
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
|
||||||
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
|
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
|
||||||
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
||||||
rw := res.Writer
|
i := pool.Get()
|
||||||
w, err := gzip.NewWriterLevel(rw, config.Level)
|
w, ok := i.(*gzip.Writer)
|
||||||
if err != nil {
|
if !ok {
|
||||||
return err
|
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
|
||||||
}
|
}
|
||||||
|
rw := res.Writer
|
||||||
|
w.Reset(rw)
|
||||||
defer func() {
|
defer func() {
|
||||||
if res.Size == 0 {
|
if res.Size == 0 {
|
||||||
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
|
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
|
||||||
@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
|||||||
w.Reset(ioutil.Discard)
|
w.Reset(ioutil.Discard)
|
||||||
}
|
}
|
||||||
w.Close()
|
w.Close()
|
||||||
|
pool.Put(w)
|
||||||
}()
|
}()
|
||||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
|
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
|
||||||
res.Writer = grw
|
res.Writer = grw
|
||||||
@ -119,3 +125,22 @@ func (w *gzipResponseWriter) Flush() {
|
|||||||
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
return w.ResponseWriter.(http.Hijacker).Hijack()
|
return w.ResponseWriter.(http.Hijacker).Hijack()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||||
|
if p, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||||
|
return p.Push(target, opts)
|
||||||
|
}
|
||||||
|
return http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func gzipPool(config GzipConfig) sync.Pool {
|
||||||
|
return sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) {
|
|||||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
// Invalid level
|
||||||
|
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
c.Response().Write([]byte("test"))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Contains(t, rec.Body.String(), "gzip")
|
||||||
|
}
|
||||||
|
|
||||||
// Issue #806
|
// Issue #806
|
||||||
func TestGzipWithStatic(t *testing.T) {
|
func TestGzipWithStatic(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkGzip(b *testing.B) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||||
|
|
||||||
|
h := Gzip()(func(c echo.Context) error {
|
||||||
|
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Gzip
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
h(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -18,6 +19,13 @@ type (
|
|||||||
// Optional. Default value []string{"*"}.
|
// Optional. Default value []string{"*"}.
|
||||||
AllowOrigins []string `yaml:"allow_origins"`
|
AllowOrigins []string `yaml:"allow_origins"`
|
||||||
|
|
||||||
|
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||||
|
// origin as an argument and returns true if allowed or false otherwise. If
|
||||||
|
// an error is returned, it is returned by the handler. If this option is
|
||||||
|
// set, AllowOrigins is ignored.
|
||||||
|
// Optional.
|
||||||
|
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
|
||||||
|
|
||||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||||
// This is used in response to a preflight request.
|
// This is used in response to a preflight request.
|
||||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||||
@ -76,6 +84,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|||||||
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowOriginPatterns := []string{}
|
||||||
|
for _, origin := range config.AllowOrigins {
|
||||||
|
pattern := regexp.QuoteMeta(origin)
|
||||||
|
pattern = strings.Replace(pattern, "\\*", ".*", -1)
|
||||||
|
pattern = strings.Replace(pattern, "\\?", ".", -1)
|
||||||
|
pattern = "^" + pattern + "$"
|
||||||
|
allowOriginPatterns = append(allowOriginPatterns, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
allowMethods := strings.Join(config.AllowMethods, ",")
|
allowMethods := strings.Join(config.AllowMethods, ",")
|
||||||
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
||||||
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
||||||
@ -92,25 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|||||||
origin := req.Header.Get(echo.HeaderOrigin)
|
origin := req.Header.Get(echo.HeaderOrigin)
|
||||||
allowOrigin := ""
|
allowOrigin := ""
|
||||||
|
|
||||||
// Check allowed origins
|
preflight := req.Method == http.MethodOptions
|
||||||
for _, o := range config.AllowOrigins {
|
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||||
if o == "*" && config.AllowCredentials {
|
|
||||||
allowOrigin = origin
|
// No Origin provided
|
||||||
break
|
if origin == "" {
|
||||||
|
if !preflight {
|
||||||
|
return next(c)
|
||||||
}
|
}
|
||||||
if o == "*" || o == origin {
|
return c.NoContent(http.StatusNoContent)
|
||||||
allowOrigin = o
|
}
|
||||||
break
|
|
||||||
|
if config.AllowOriginFunc != nil {
|
||||||
|
allowed, err := config.AllowOriginFunc(origin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if matchSubdomain(origin, o) {
|
if allowed {
|
||||||
allowOrigin = origin
|
allowOrigin = origin
|
||||||
break
|
}
|
||||||
|
} else {
|
||||||
|
// Check allowed origins
|
||||||
|
for _, o := range config.AllowOrigins {
|
||||||
|
if o == "*" && config.AllowCredentials {
|
||||||
|
allowOrigin = origin
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if o == "*" || o == origin {
|
||||||
|
allowOrigin = o
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if matchSubdomain(origin, o) {
|
||||||
|
allowOrigin = origin
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check allowed origin patterns
|
||||||
|
for _, re := range allowOriginPatterns {
|
||||||
|
if allowOrigin == "" {
|
||||||
|
didx := strings.Index(origin, "://")
|
||||||
|
if didx == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domAuth := origin[didx+3:]
|
||||||
|
// to avoid regex cost by invalid long domain
|
||||||
|
if len(domAuth) > 253 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if match, _ := regexp.MatchString(re, origin); match {
|
||||||
|
allowOrigin = origin
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Origin not allowed
|
||||||
|
if allowOrigin == "" {
|
||||||
|
if !preflight {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
return c.NoContent(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
// Simple request
|
// Simple request
|
||||||
if req.Method != http.MethodOptions {
|
if !preflight {
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
|
||||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
if config.AllowCredentials {
|
if config.AllowCredentials {
|
||||||
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||||
@ -122,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Preflight request
|
// Preflight request
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
||||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -17,19 +18,31 @@ func TestCORS(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
h := CORS()(echo.NotFoundHandler)
|
h := CORS()(echo.NotFoundHandler)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||||
h(c)
|
h(c)
|
||||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
|
||||||
|
// Wildcard AllowedOrigin with no Origin header in request
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
c = e.NewContext(req, rec)
|
||||||
|
h = CORS()(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
|
|
||||||
// Allow origins
|
// Allow origins
|
||||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
c = e.NewContext(req, rec)
|
c = e.NewContext(req, rec)
|
||||||
h = CORSWithConfig(CORSConfig{
|
h = CORSWithConfig(CORSConfig{
|
||||||
AllowOrigins: []string{"localhost"},
|
AllowOrigins: []string{"localhost"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAge: 3600,
|
||||||
})(echo.NotFoundHandler)
|
})(echo.NotFoundHandler)
|
||||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||||
h(c)
|
h(c)
|
||||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||||
|
|
||||||
// Preflight request
|
// Preflight request
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
@ -67,6 +80,22 @@ func TestCORS(t *testing.T) {
|
|||||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||||
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
||||||
|
|
||||||
|
// Preflight request with Access-Control-Request-Headers
|
||||||
|
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
c = e.NewContext(req, rec)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||||
|
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||||
|
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
|
||||||
|
cors = CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"*"},
|
||||||
|
})
|
||||||
|
h = cors(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
|
||||||
|
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||||
|
|
||||||
// Preflight request with `AllowOrigins` which allow all subdomains with *
|
// Preflight request with `AllowOrigins` which allow all subdomains with *
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
@ -83,3 +112,298 @@ func TestCORS(t *testing.T) {
|
|||||||
h(c)
|
h(c)
|
||||||
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_allowOriginScheme(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
domain, pattern string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
pattern: "http://example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "https://example.com",
|
||||||
|
pattern: "https://example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
pattern: "https://example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "https://example.com",
|
||||||
|
pattern: "http://example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||||
|
cors := CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{tt.pattern},
|
||||||
|
})
|
||||||
|
h := cors(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
|
||||||
|
if tt.expected {
|
||||||
|
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
} else {
|
||||||
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_allowOriginSubdomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
domain, pattern string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
domain: "http://aaa.example.com",
|
||||||
|
pattern: "http://*.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://bbb.aaa.example.com",
|
||||||
|
pattern: "http://*.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://bbb.aaa.example.com",
|
||||||
|
pattern: "http://*.aaa.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://aaa.example.com:8080",
|
||||||
|
pattern: "http://*.example.com:8080",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
domain: "http://fuga.hoge.com",
|
||||||
|
pattern: "http://*.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://ccc.bbb.example.com",
|
||||||
|
pattern: "http://*.aaa.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||||
|
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||||
|
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||||
|
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
|
||||||
|
pattern: "http://*.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
|
||||||
|
pattern: "http://*.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://ccc.bbb.example.com",
|
||||||
|
pattern: "http://example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "https://prod-preview--aaa.bbb.com",
|
||||||
|
pattern: "https://*--aaa.bbb.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://ccc.bbb.example.com",
|
||||||
|
pattern: "http://*.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://ccc.bbb.example.com",
|
||||||
|
pattern: "http://foo.[a-z]*.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||||
|
cors := CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{tt.pattern},
|
||||||
|
})
|
||||||
|
h := cors(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
|
||||||
|
if tt.expected {
|
||||||
|
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
} else {
|
||||||
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorsHeaders(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
domain, allowedOrigin, method string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://bar.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://bar.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest(tt.method, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
if tt.domain != "" {
|
||||||
|
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||||
|
}
|
||||||
|
cors := CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{tt.allowedOrigin},
|
||||||
|
//AllowCredentials: true,
|
||||||
|
//MaxAge: 3600,
|
||||||
|
})
|
||||||
|
h := cors(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
|
||||||
|
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||||
|
|
||||||
|
expectedAllowOrigin := ""
|
||||||
|
if tt.allowedOrigin == "*" {
|
||||||
|
expectedAllowOrigin = "*"
|
||||||
|
} else {
|
||||||
|
expectedAllowOrigin = tt.domain
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case tt.expected && tt.method == http.MethodOptions:
|
||||||
|
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
|
||||||
|
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
|
||||||
|
case tt.expected && tt.method == http.MethodGet:
|
||||||
|
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||||
|
default:
|
||||||
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
|
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.method == http.MethodOptions {
|
||||||
|
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_allowOriginFunc(t *testing.T) {
|
||||||
|
returnTrue := func(origin string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
returnFalse := func(origin string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
returnError := func(origin string) (bool, error) {
|
||||||
|
return true, errors.New("this is a test error")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowOriginFuncs := []func(origin string) (bool, error){
|
||||||
|
returnTrue,
|
||||||
|
returnFalse,
|
||||||
|
returnError,
|
||||||
|
}
|
||||||
|
|
||||||
|
const origin = "http://example.com"
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
for _, allowOriginFunc := range allowOriginFuncs {
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, origin)
|
||||||
|
cors := CORSWithConfig(CORSConfig{
|
||||||
|
AllowOriginFunc: allowOriginFunc,
|
||||||
|
})
|
||||||
|
h := cors(echo.NotFoundHandler)
|
||||||
|
err := h(c)
|
||||||
|
|
||||||
|
expected, expectedErr := allowOriginFunc(origin)
|
||||||
|
if expectedErr != nil {
|
||||||
|
assert.Equal(t, expectedErr, err)
|
||||||
|
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected {
|
||||||
|
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
58
middleware/decompress.go
Normal file
58
middleware/decompress.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// DecompressConfig defines the config for Decompress middleware.
|
||||||
|
DecompressConfig struct {
|
||||||
|
// Skipper defines a function to skip middleware.
|
||||||
|
Skipper Skipper
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
|
||||||
|
const GZIPEncoding string = "gzip"
|
||||||
|
|
||||||
|
var (
|
||||||
|
//DefaultDecompressConfig defines the config for decompress middleware
|
||||||
|
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
|
||||||
|
)
|
||||||
|
|
||||||
|
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
|
||||||
|
func Decompress() echo.MiddlewareFunc {
|
||||||
|
return DecompressWithConfig(DefaultDecompressConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
|
||||||
|
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
|
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
if config.Skipper(c) {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
|
||||||
|
case GZIPEncoding:
|
||||||
|
gr, err := gzip.NewReader(c.Request().Body)
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF { //ignore if body is empty
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer gr.Close()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
io.Copy(&buf, gr)
|
||||||
|
r := ioutil.NopCloser(&buf)
|
||||||
|
defer r.Close()
|
||||||
|
c.Request().Body = r
|
||||||
|
}
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
148
middleware/decompress_test.go
Normal file
148
middleware/decompress_test.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDecompress(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
// Skip if no Content-Encoding header
|
||||||
|
h := Decompress()(func(c echo.Context) error {
|
||||||
|
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
h(c)
|
||||||
|
|
||||||
|
assert := assert.New(t)
|
||||||
|
assert.Equal("test", rec.Body.String())
|
||||||
|
|
||||||
|
// Decompress
|
||||||
|
body := `{"name": "echo"}`
|
||||||
|
gz, _ := gzipString(body)
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||||
|
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
c = e.NewContext(req, rec)
|
||||||
|
h(c)
|
||||||
|
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||||
|
b, err := ioutil.ReadAll(req.Body)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal(body, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
body := `{"name":"echo"}`
|
||||||
|
gz, _ := gzipString(body)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||||
|
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.NewContext(req, rec)
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||||
|
b, err := ioutil.ReadAll(req.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, b, body)
|
||||||
|
assert.Equal(t, b, gz)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecompressNoContent(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
h := Decompress()(func(c echo.Context) error {
|
||||||
|
return c.NoContent(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
if assert.NoError(t, h(c)) {
|
||||||
|
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||||
|
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
|
||||||
|
assert.Equal(t, 0, len(rec.Body.Bytes()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecompressErrorReturned(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
e.Use(Decompress())
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
return echo.ErrNotFound
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||||
|
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecompressSkipper(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
e.Use(DecompressWithConfig(DecompressConfig{
|
||||||
|
Skipper: func(c echo.Context) bool {
|
||||||
|
return c.Request().URL.Path == "/skip"
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
body := `{"name": "echo"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
|
||||||
|
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
|
||||||
|
reqBody, err := ioutil.ReadAll(c.Request().Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, body, string(reqBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecompress(b *testing.B) {
|
||||||
|
e := echo.New()
|
||||||
|
body := `{"name": "echo"}`
|
||||||
|
gz, _ := gzipString(body)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||||
|
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||||
|
|
||||||
|
h := Decompress()(func(c echo.Context) error {
|
||||||
|
c.Response().Write([]byte(body)) // For Content-Type sniffing
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Decompress
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
h(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func gzipString(body string) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
|
||||||
|
_, err := gz.Write([]byte(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gz.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
@ -86,6 +86,7 @@ const (
|
|||||||
// Errors
|
// Errors
|
||||||
var (
|
var (
|
||||||
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
|
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
|
||||||
|
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -213,8 +214,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
|||||||
return config.ErrorHandlerWithContext(err, c)
|
return config.ErrorHandlerWithContext(err, c)
|
||||||
}
|
}
|
||||||
return &echo.HTTPError{
|
return &echo.HTTPError{
|
||||||
Code: http.StatusUnauthorized,
|
Code: ErrJWTInvalid.Code,
|
||||||
Message: "invalid or expired jwt",
|
Message: ErrJWTInvalid.Message,
|
||||||
Internal: err,
|
Internal: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,8 +60,6 @@ func TestJWTRace(t *testing.T) {
|
|||||||
|
|
||||||
func TestJWT(t *testing.T) {
|
func TestJWT(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
r := e.Router()
|
|
||||||
r.Add("GET", "/:jwt", func(echo.Context) error { return nil })
|
|
||||||
handler := func(c echo.Context) error {
|
handler := func(c echo.Context) error {
|
||||||
return c.String(http.StatusOK, "test")
|
return c.String(http.StatusOK, "test")
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -32,6 +33,31 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
|||||||
return strings.NewReplacer(replace...)
|
return strings.NewReplacer(replace...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
|
||||||
|
// Initialize
|
||||||
|
rulesRegex := map[*regexp.Regexp]string{}
|
||||||
|
for k, v := range rewrite {
|
||||||
|
k = regexp.QuoteMeta(k)
|
||||||
|
k = strings.Replace(k, `\*`, "(.*)", -1)
|
||||||
|
if strings.HasPrefix(k, `\^`) {
|
||||||
|
k = strings.Replace(k, `\^`, "^", -1)
|
||||||
|
}
|
||||||
|
k = k + "$"
|
||||||
|
rulesRegex[regexp.MustCompile(k)] = v
|
||||||
|
}
|
||||||
|
return rulesRegex
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
|
||||||
|
for k, v := range rewriteRegex {
|
||||||
|
replacerRawPath := captureTokens(k, req.URL.EscapedPath())
|
||||||
|
if replacerRawPath != nil {
|
||||||
|
replacerPath := captureTokens(k, req.URL.Path)
|
||||||
|
req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultSkipper returns false which processes the middleware.
|
// DefaultSkipper returns false which processes the middleware.
|
||||||
func DefaultSkipper(echo.Context) bool {
|
func DefaultSkipper(echo.Context) bool {
|
||||||
return false
|
return false
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -45,6 +44,9 @@ type (
|
|||||||
// Examples: If custom TLS certificates are required.
|
// Examples: If custom TLS certificates are required.
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
|
|
||||||
|
// ModifyResponse defines function to modify response from ProxyTarget.
|
||||||
|
ModifyResponse func(*http.Response) error
|
||||||
|
|
||||||
rewriteRegex map[*regexp.Regexp]string
|
rewriteRegex map[*regexp.Regexp]string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
|||||||
if config.Balancer == nil {
|
if config.Balancer == nil {
|
||||||
panic("echo: proxy middleware requires balancer")
|
panic("echo: proxy middleware requires balancer")
|
||||||
}
|
}
|
||||||
config.rewriteRegex = map[*regexp.Regexp]string{}
|
|
||||||
|
|
||||||
// Initialize
|
config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
|
||||||
for k, v := range config.Rewrite {
|
|
||||||
k = strings.Replace(k, "*", "(\\S*)", -1)
|
|
||||||
config.rewriteRegex[regexp.MustCompile(k)] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) (err error) {
|
return func(c echo.Context) (err error) {
|
||||||
@ -222,13 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
|||||||
tgt := config.Balancer.Next(c)
|
tgt := config.Balancer.Next(c)
|
||||||
c.Set(config.ContextKey, tgt)
|
c.Set(config.ContextKey, tgt)
|
||||||
|
|
||||||
// Rewrite
|
// Set rewrite path and raw path
|
||||||
for k, v := range config.rewriteRegex {
|
rewritePath(config.rewriteRegex, req)
|
||||||
replacer := captureTokens(k, req.URL.Path)
|
|
||||||
if replacer != nil {
|
|
||||||
req.URL.Path = replacer.Replace(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fix header
|
// Fix header
|
||||||
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
|
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
|
||||||
@ -259,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,5 +20,6 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle
|
|||||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)))
|
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)))
|
||||||
}
|
}
|
||||||
proxy.Transport = config.Transport
|
proxy.Transport = config.Transport
|
||||||
|
proxy.ModifyResponse = config.ModifyResponse
|
||||||
return proxy
|
return proxy
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -12,6 +14,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//Assert expected with url.EscapedPath method to obtain the path.
|
||||||
func TestProxy(t *testing.T) {
|
func TestProxy(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -91,19 +94,49 @@ func TestProxy(t *testing.T) {
|
|||||||
"/users/*/orders/*": "/user/$1/order/$2",
|
"/users/*/orders/*": "/user/$1/order/$2",
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
req.URL.Path = "/api/users"
|
req.URL, _ = url.Parse("/api/users")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/users", req.URL.Path)
|
assert.Equal(t, "/users", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/js/main.js"
|
|
||||||
e.ServeHTTP(rec, req)
|
|
||||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.Path)
|
|
||||||
req.URL.Path = "/old"
|
|
||||||
e.ServeHTTP(rec, req)
|
|
||||||
assert.Equal(t, "/new", req.URL.Path)
|
|
||||||
req.URL.Path = "/users/jack/orders/1"
|
|
||||||
e.ServeHTTP(rec, req)
|
|
||||||
assert.Equal(t, "/user/jack/order/1", req.URL.Path)
|
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
req.URL, _ = url.Parse( "/js/main.js")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
req.URL, _ = url.Parse("/old")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
req.URL, _ = url.Parse( "/users/jack/orders/1")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
req.URL, _ = url.Parse("/api/new users")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
|
||||||
|
// ModifyResponse
|
||||||
|
e = echo.New()
|
||||||
|
e.Use(ProxyWithConfig(ProxyConfig{
|
||||||
|
Balancer: rrb,
|
||||||
|
ModifyResponse: func(res *http.Response) error {
|
||||||
|
res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified")))
|
||||||
|
res.Header.Set("X-Modified", "1")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "modified", rec.Body.String())
|
||||||
|
assert.Equal(t, "1", rec.Header().Get("X-Modified"))
|
||||||
|
|
||||||
// ProxyTarget is set in context
|
// ProxyTarget is set in context
|
||||||
contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
|
contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/labstack/gommon/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -25,6 +26,10 @@ type (
|
|||||||
// DisablePrintStack disables printing stack trace.
|
// DisablePrintStack disables printing stack trace.
|
||||||
// Optional. Default value as false.
|
// Optional. Default value as false.
|
||||||
DisablePrintStack bool `yaml:"disable_print_stack"`
|
DisablePrintStack bool `yaml:"disable_print_stack"`
|
||||||
|
|
||||||
|
// LogLevel is log level to printing stack trace.
|
||||||
|
// Optional. Default value 0 (Print).
|
||||||
|
LogLevel log.Lvl
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,6 +40,7 @@ var (
|
|||||||
StackSize: 4 << 10, // 4 KB
|
StackSize: 4 << 10, // 4 KB
|
||||||
DisableStackAll: false,
|
DisableStackAll: false,
|
||||||
DisablePrintStack: false,
|
DisablePrintStack: false,
|
||||||
|
LogLevel: 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,7 +76,21 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
|
|||||||
stack := make([]byte, config.StackSize)
|
stack := make([]byte, config.StackSize)
|
||||||
length := runtime.Stack(stack, !config.DisableStackAll)
|
length := runtime.Stack(stack, !config.DisableStackAll)
|
||||||
if !config.DisablePrintStack {
|
if !config.DisablePrintStack {
|
||||||
c.Logger().Printf("[PANIC RECOVER] %v %s\n", err, stack[:length])
|
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
|
||||||
|
switch config.LogLevel {
|
||||||
|
case log.DEBUG:
|
||||||
|
c.Logger().Debug(msg)
|
||||||
|
case log.INFO:
|
||||||
|
c.Logger().Info(msg)
|
||||||
|
case log.WARN:
|
||||||
|
c.Logger().Warn(msg)
|
||||||
|
case log.ERROR:
|
||||||
|
c.Logger().Error(msg)
|
||||||
|
case log.OFF:
|
||||||
|
// None.
|
||||||
|
default:
|
||||||
|
c.Logger().Print(msg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.Error(err)
|
c.Error(err)
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,13 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/labstack/gommon/log"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -24,3 +26,58 @@ func TestRecover(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
assert.Contains(t, buf.String(), "PANIC RECOVER")
|
assert.Contains(t, buf.String(), "PANIC RECOVER")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRecoverWithConfig_LogLevel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
logLevel log.Lvl
|
||||||
|
levelName string
|
||||||
|
}{{
|
||||||
|
logLevel: log.DEBUG,
|
||||||
|
levelName: "DEBUG",
|
||||||
|
}, {
|
||||||
|
logLevel: log.INFO,
|
||||||
|
levelName: "INFO",
|
||||||
|
}, {
|
||||||
|
logLevel: log.WARN,
|
||||||
|
levelName: "WARN",
|
||||||
|
}, {
|
||||||
|
logLevel: log.ERROR,
|
||||||
|
levelName: "ERROR",
|
||||||
|
}, {
|
||||||
|
logLevel: log.OFF,
|
||||||
|
levelName: "OFF",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.levelName, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
e.Logger.SetLevel(log.DEBUG)
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
e.Logger.SetOutput(buf)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
config := DefaultRecoverConfig
|
||||||
|
config.LogLevel = tt.logLevel
|
||||||
|
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
|
||||||
|
panic("test")
|
||||||
|
}))
|
||||||
|
|
||||||
|
h(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if tt.logLevel == log.OFF {
|
||||||
|
assert.Empty(t, output)
|
||||||
|
} else {
|
||||||
|
assert.Contains(t, output, "PANIC RECOVER")
|
||||||
|
assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -53,14 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
|||||||
if config.Skipper == nil {
|
if config.Skipper == nil {
|
||||||
config.Skipper = DefaultBodyDumpConfig.Skipper
|
config.Skipper = DefaultBodyDumpConfig.Skipper
|
||||||
}
|
}
|
||||||
config.rulesRegex = map[*regexp.Regexp]string{}
|
|
||||||
|
|
||||||
// Initialize
|
config.rulesRegex = rewriteRulesRegex(config.Rules)
|
||||||
for k, v := range config.Rules {
|
|
||||||
k = strings.Replace(k, "*", "(.*)", -1)
|
|
||||||
k = k + "$"
|
|
||||||
config.rulesRegex[regexp.MustCompile(k)] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) (err error) {
|
return func(c echo.Context) (err error) {
|
||||||
@ -69,15 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := c.Request()
|
req := c.Request()
|
||||||
|
// Set rewrite path and raw path
|
||||||
// Rewrite
|
rewritePath(config.rulesRegex, req)
|
||||||
for k, v := range config.rulesRegex {
|
|
||||||
replacer := captureTokens(k, req.URL.Path)
|
|
||||||
if replacer != nil {
|
|
||||||
req.URL.Path = replacer.Replace(v)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,12 +4,14 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//Assert expected with url.EscapedPath method to obtain the path.
|
||||||
func TestRewrite(t *testing.T) {
|
func TestRewrite(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
e.Use(RewriteWithConfig(RewriteConfig{
|
e.Use(RewriteWithConfig(RewriteConfig{
|
||||||
@ -22,21 +24,28 @@ func TestRewrite(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req.URL.Path = "/api/users"
|
req.URL, _ = url.Parse("/api/users")
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/users", req.URL.Path)
|
assert.Equal(t, "/users", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/js/main.js"
|
req.URL, _ = url.Parse("/js/main.js")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.Path)
|
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/old"
|
req.URL, _ = url.Parse("/old")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/new", req.URL.Path)
|
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/users/jack/orders/1"
|
req.URL, _ = url.Parse("/users/jack/orders/1")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/user/jack/order/1", req.URL.Path)
|
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/api/new users"
|
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/new users", req.URL.Path)
|
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
||||||
|
req.URL, _ = url.Parse("/api/new users")
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issue #1086
|
// Issue #1086
|
||||||
@ -45,22 +54,21 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
|
|||||||
r := e.Router()
|
r := e.Router()
|
||||||
|
|
||||||
// Rewrite old url to new one
|
// Rewrite old url to new one
|
||||||
e.Pre(RewriteWithConfig(RewriteConfig{
|
e.Pre(Rewrite(map[string]string{
|
||||||
Rules: map[string]string{
|
|
||||||
"/old": "/new",
|
"/old": "/new",
|
||||||
},
|
},
|
||||||
}))
|
))
|
||||||
|
|
||||||
// Route
|
// Route
|
||||||
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
|
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
|
||||||
return c.NoContent(200)
|
return c.NoContent(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/old", nil)
|
req := httptest.NewRequest(http.MethodGet, "/old", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/new", req.URL.Path)
|
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||||
assert.Equal(t, 200, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issue #1143
|
// Issue #1143
|
||||||
@ -76,21 +84,48 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
|
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
|
||||||
return c.String(200, "hosts")
|
return c.String(http.StatusOK, "hosts")
|
||||||
})
|
})
|
||||||
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
|
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
|
||||||
return c.String(200, "eng")
|
return c.String(http.StatusOK, "eng")
|
||||||
})
|
})
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/api/v1/hosts/test", req.URL.Path)
|
assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath())
|
||||||
assert.Equal(t, 200, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
defer rec.Result().Body.Close()
|
defer rec.Result().Body.Close()
|
||||||
bodyBytes, _ := ioutil.ReadAll(rec.Result().Body)
|
bodyBytes, _ := ioutil.ReadAll(rec.Result().Body)
|
||||||
assert.Equal(t, "hosts", string(bodyBytes))
|
assert.Equal(t, "hosts", string(bodyBytes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Issue #1573
|
||||||
|
func TestEchoRewriteWithCaret(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
e.Pre(RewriteWithConfig(RewriteConfig{
|
||||||
|
Rules: map[string]string{
|
||||||
|
"^/abc/*": "/v1/abc/$1",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
var req *http.Request
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/abc/test", nil)
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/v1/abc/test", req.URL.Path)
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil)
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/v1/abc/test", req.URL.Path)
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil)
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, "/v2/abc/test", req.URL.Path)
|
||||||
|
}
|
||||||
|
@ -36,6 +36,12 @@ type (
|
|||||||
// Enable directory browsing.
|
// Enable directory browsing.
|
||||||
// Optional. Default value false.
|
// Optional. Default value false.
|
||||||
Browse bool `yaml:"browse"`
|
Browse bool `yaml:"browse"`
|
||||||
|
|
||||||
|
// Enable ignoring of the base of the URL path.
|
||||||
|
// Example: when assigning a static middleware to a non root path group,
|
||||||
|
// the filesystem path is not doubled
|
||||||
|
// Optional. Default value false.
|
||||||
|
IgnoreBase bool `yaml:"ignoreBase"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -163,6 +169,15 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
|
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
|
||||||
|
|
||||||
|
if config.IgnoreBase {
|
||||||
|
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
|
||||||
|
baseURLPath := path.Base(p)
|
||||||
|
if baseURLPath == routePath {
|
||||||
|
i := strings.LastIndex(name, routePath)
|
||||||
|
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(name)
|
fi, err := os.Stat(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
|
@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
|
|||||||
assert.Equal(http.StatusOK, rec.Code)
|
assert.Equal(http.StatusOK, rec.Code)
|
||||||
assert.Contains(rec.Body.String(), "cert.pem")
|
assert.Contains(rec.Body.String(), "cert.pem")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IgnoreBase
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
config.Root = "../_fixture"
|
||||||
|
config.IgnoreBase = true
|
||||||
|
static = StaticWithConfig(config)
|
||||||
|
c.Echo().Group("_fixture", static)
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
config.Root = "../_fixture"
|
||||||
|
config.IgnoreBase = false
|
||||||
|
static = StaticWithConfig(config)
|
||||||
|
c.Echo().Group("_fixture", static)
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(http.StatusOK, rec.Code)
|
||||||
|
assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
|
||||||
}
|
}
|
||||||
|
@ -19,8 +19,13 @@ func TestResponse(t *testing.T) {
|
|||||||
res.Before(func() {
|
res.Before(func() {
|
||||||
c.Response().Header().Set(HeaderServer, "echo")
|
c.Response().Header().Set(HeaderServer, "echo")
|
||||||
})
|
})
|
||||||
|
// After
|
||||||
|
res.After(func() {
|
||||||
|
c.Response().Header().Set(HeaderXFrameOptions, "DENY")
|
||||||
|
})
|
||||||
res.Write([]byte("test"))
|
res.Write([]byte("test"))
|
||||||
assert.Equal(t, "echo", rec.Header().Get(HeaderServer))
|
assert.Equal(t, "echo", rec.Header().Get(HeaderServer))
|
||||||
|
assert.Equal(t, "DENY", rec.Header().Get(HeaderXFrameOptions))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResponse_Write_FallsBackToDefaultStatus(t *testing.T) {
|
func TestResponse_Write_FallsBackToDefaultStatus(t *testing.T) {
|
||||||
@ -41,3 +46,13 @@ func TestResponse_Write_UsesSetResponseCode(t *testing.T) {
|
|||||||
res.Write([]byte("test"))
|
res.Write([]byte("test"))
|
||||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponse_Flush(t *testing.T) {
|
||||||
|
e := New()
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
res := &Response{echo: e, Writer: rec}
|
||||||
|
|
||||||
|
res.Write([]byte("test"))
|
||||||
|
res.Flush()
|
||||||
|
assert.True(t, rec.Flushed)
|
||||||
|
}
|
||||||
|
@ -355,6 +355,10 @@ func (r *Router) Find(method, path string, c Context) {
|
|||||||
|
|
||||||
// Attempt to go back up the tree on no matching prefix or no remaining search
|
// Attempt to go back up the tree on no matching prefix or no remaining search
|
||||||
if l != pl || search == "" {
|
if l != pl || search == "" {
|
||||||
|
// Handle special case of trailing slash route with existing any route (see #1526)
|
||||||
|
if path[len(path)-1] == '/' && cn.findChildByKind(akind) != nil {
|
||||||
|
goto Any
|
||||||
|
}
|
||||||
if nn == nil { // Issue #1348
|
if nn == nil { // Issue #1348
|
||||||
return // Not found
|
return // Not found
|
||||||
}
|
}
|
||||||
|
@ -608,7 +608,6 @@ func TestRouterMatchAny(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
c := e.NewContext(nil, nil).(*context)
|
c := e.NewContext(nil, nil).(*context)
|
||||||
|
|
||||||
r.Find(http.MethodGet, "/", c)
|
r.Find(http.MethodGet, "/", c)
|
||||||
assert.Equal(t, "", c.Param("*"))
|
assert.Equal(t, "", c.Param("*"))
|
||||||
|
|
||||||
@ -619,6 +618,78 @@ func TestRouterMatchAny(t *testing.T) {
|
|||||||
assert.Equal(t, "joe", c.Param("*"))
|
assert.Equal(t, "joe", c.Param("*"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRouterMatchAnySlash shall verify finding the best route
|
||||||
|
// for any routes with trailing slash requests
|
||||||
|
func TestRouterMatchAnySlash(t *testing.T) {
|
||||||
|
e := New()
|
||||||
|
r := e.router
|
||||||
|
|
||||||
|
handler := func(c Context) error {
|
||||||
|
c.Set("path", c.Path())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes
|
||||||
|
r.Add(http.MethodGet, "/users", handler)
|
||||||
|
r.Add(http.MethodGet, "/users/*", handler)
|
||||||
|
r.Add(http.MethodGet, "/img/*", handler)
|
||||||
|
r.Add(http.MethodGet, "/img/load", handler)
|
||||||
|
r.Add(http.MethodGet, "/img/load/*", handler)
|
||||||
|
r.Add(http.MethodGet, "/assets/*", handler)
|
||||||
|
|
||||||
|
c := e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/", c)
|
||||||
|
assert.Equal(t, "", c.Param("*"))
|
||||||
|
|
||||||
|
// Test trailing slash request for simple any route (see #1526)
|
||||||
|
c = e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/users/", c)
|
||||||
|
c.handler(c)
|
||||||
|
assert.Equal(t, "/users/*", c.Get("path"))
|
||||||
|
assert.Equal(t, "", c.Param("*"))
|
||||||
|
|
||||||
|
c = e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/users/joe", c)
|
||||||
|
c.handler(c)
|
||||||
|
assert.Equal(t, "/users/*", c.Get("path"))
|
||||||
|
assert.Equal(t, "joe", c.Param("*"))
|
||||||
|
|
||||||
|
// Test trailing slash request for nested any route (see #1526)
|
||||||
|
c = e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/img/load", c)
|
||||||
|
c.handler(c)
|
||||||
|
assert.Equal(t, "/img/load", c.Get("path"))
|
||||||
|
assert.Equal(t, "", c.Param("*"))
|
||||||
|
|
||||||
|
c = e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/img/load/", c)
|
||||||
|
c.handler(c)
|
||||||
|
assert.Equal(t, "/img/load/*", c.Get("path"))
|
||||||
|
assert.Equal(t, "", c.Param("*"))
|
||||||
|
|
||||||
|
c = e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/img/load/ben", c)
|
||||||
|
c.handler(c)
|
||||||
|
assert.Equal(t, "/img/load/*", c.Get("path"))
|
||||||
|
assert.Equal(t, "ben", c.Param("*"))
|
||||||
|
|
||||||
|
// Test /assets/* any route
|
||||||
|
// ... without trailing slash must not match
|
||||||
|
c = e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/assets", c)
|
||||||
|
c.handler(c)
|
||||||
|
assert.Equal(t, nil, c.Get("path"))
|
||||||
|
assert.Equal(t, "", c.Param("*"))
|
||||||
|
|
||||||
|
// ... with trailing slash must match
|
||||||
|
c = e.NewContext(nil, nil).(*context)
|
||||||
|
r.Find(http.MethodGet, "/assets/", c)
|
||||||
|
c.handler(c)
|
||||||
|
assert.Equal(t, "/assets/*", c.Get("path"))
|
||||||
|
assert.Equal(t, "", c.Param("*"))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestRouterMatchAnyMultiLevel(t *testing.T) {
|
func TestRouterMatchAnyMultiLevel(t *testing.T) {
|
||||||
e := New()
|
e := New()
|
||||||
r := e.router
|
r := e.router
|
||||||
|
Loading…
x
Reference in New Issue
Block a user