1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Merge branch 'master' into fix_router_find_after_invalid_set_param_values

This commit is contained in:
Pablo Andres Fuente 2020-12-16 01:56:28 +00:00
commit 53653b3df6
29 changed files with 1210 additions and 181 deletions

12
.github/stale.yml vendored
View File

@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
daysUntilClose: 30
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
- bug
- enhancement
# Label to use when marking an issue as stale
staleLabel: wontfix
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
recent activity. It will be closed within a month if no further activity occurs.
Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false
closeComment: false

View File

@ -4,20 +4,28 @@ on:
push:
branches:
- master
paths:
- '**.go'
- 'go.*'
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
pull_request:
branches:
- master
env:
GO111MODULE: on
GOPROXY: https://proxy.golang.org
paths:
- '**.go'
- 'go.*'
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
go: [1.12, 1.13, 1.14]
go: [1.12, 1.13, 1.14, 1.15]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
@ -28,10 +36,15 @@ jobs:
- name: Set GOPATH and PATH
run: |
echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)"
echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin"
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code
uses: actions/checkout@v1
with:
@ -51,3 +64,55 @@ jobs:
with:
token:
fail_ci_if_error: false
benchmark:
needs: test
strategy:
matrix:
os: [ubuntu-latest]
go: [1.15]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v1
with:
go-version: ${{ matrix.go }}
- name: Set GOPATH and PATH
run: |
echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV
echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH
shell: bash
- name: Set build variables
run: |
echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV
echo "GO111MODULE=on" >> $GITHUB_ENV
- name: Checkout Code (Previous)
uses: actions/checkout@v2
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
uses: actions/checkout@v2
with:
path: new
- name: Install Dependencies
run: go get -v golang.org/x/perf/cmd/benchstat
- name: Run Benchmark (Previous)
run: |
cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New)
run: |
cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat
run: |
benchstat previous/benchmark.txt new/benchmark.txt

View File

@ -1,3 +1,7 @@
arch:
- amd64
- ppc64le
language: go
go:
- 1.14.x

View File

@ -42,11 +42,14 @@ For older versions, please use the latest v3 tag.
## Benchmarks
Date: 2018/03/15<br>
Date: 2020/11/11<br>
Source: https://github.com/vishr/web-framework-benchmark<br>
Lower is better!
<img src="https://i.imgur.com/I32VdMJ.png">
<img src="https://i.imgur.com/qwPNQbl.png">
<img src="https://i.imgur.com/s8yKQjx.png">
The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
## [Guide](https://echo.labstack.com/guide)

View File

@ -0,0 +1 @@
This directory is used for the static middleware test

11
codecov.yml Normal file
View File

@ -0,0 +1,11 @@
coverage:
status:
project:
default:
threshold: 1%
patch:
default:
threshold: 1%
comment:
require_changes: true

View File

@ -385,7 +385,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
if err != nil {
return nil, err
}
defer f.Close()
f.Close()
return fh, nil
}

37
echo.go
View File

@ -49,7 +49,6 @@ import (
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
@ -92,6 +91,7 @@ type (
Renderer Renderer
Logger Logger
IPExtractor IPExtractor
ListenerNetwork string
}
// Route contains a handler and information for matching against requests.
@ -281,6 +281,7 @@ var (
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)
// Error handlers
@ -302,9 +303,10 @@ func New() (e *Echo) {
AutoTLSManager: autocert.Manager{
Prompt: autocert.AcceptTOS,
},
Logger: log.New("echo"),
colorer: color.New(),
maxParam: new(int),
Logger: log.New("echo"),
colorer: color.New(),
maxParam: new(int),
ListenerNetwork: "tcp",
}
e.Server.Handler = e
e.TLSServer.Handler = e
@ -362,10 +364,12 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
// Issue #1426
code := he.Code
message := he.Message
if e.Debug {
message = err.Error()
} else if m, ok := message.(string); ok {
message = Map{"message": m}
if m, ok := he.Message.(string); ok {
if e.Debug {
message = Map{"message": m, "error": err.Error()}
} else {
message = Map{"message": m}
}
}
// Send response
@ -481,7 +485,7 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl
return err
}
name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security
name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
fi, err := os.Stat(name)
if err != nil {
// The access path does not exist
@ -570,7 +574,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
for _, r := range e.router.routes {
if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < ln {
if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
@ -712,7 +716,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
if s.TLSConfig == nil {
if e.Listener == nil {
e.Listener, err = newListener(s.Addr)
e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
@ -723,7 +727,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
return s.Serve(e.Listener)
}
if e.TLSListener == nil {
l, err := newListener(s.Addr)
l, err := newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
@ -752,7 +756,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
}
if e.Listener == nil {
e.Listener, err = newListener(s.Addr)
e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
@ -873,8 +877,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
return
}
func newListener(address string) (*tcpKeepAliveListener, error) {
l, err := net.Listen("tcp", address)
func newListener(address, network string) (*tcpKeepAliveListener, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, ErrInvalidListenerNetwork
}
l, err := net.Listen(network, address)
if err != nil {
return nil, err
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
stdContext "context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -59,45 +60,105 @@ func TestEcho(t *testing.T) {
}
func TestEchoStatic(t *testing.T) {
e := New()
var testCases = []struct {
name string
givenPrefix string
givenRoot string
whenURL string
expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
}{
{
name: "ok",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "No file",
givenPrefix: "/images",
givenRoot: "_fixture/scripts",
whenURL: "/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
assert := assert.New(t)
// OK
e.Static("/images", "_fixture/images")
c, b := request(http.MethodGet, "/images/walle.png", e)
assert.Equal(http.StatusOK, c)
assert.NotEmpty(b)
// No file
e.Static("/images", "_fixture/scripts")
c, _ = request(http.MethodGet, "/images/bolt.png", e)
assert.Equal(http.StatusNotFound, c)
// Directory
e.Static("/images", "_fixture/images")
c, _ = request(http.MethodGet, "/images/", e)
assert.Equal(http.StatusNotFound, c)
// Directory Redirect
e.Static("/", "_fixture")
req := httptest.NewRequest(http.MethodGet, "/folder", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(http.StatusMovedPermanently, rec.Code)
assert.Equal("/folder/", rec.HeaderMap["Location"][0])
// Directory with index.html
e.Static("/", "_fixture")
c, r := request(http.MethodGet, "/", e)
assert.Equal(http.StatusOK, c)
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
// Sub-directory with index.html
c, r = request(http.MethodGet, "/folder/", e)
assert.Equal(http.StatusOK, c)
assert.Equal(true, strings.HasPrefix(r, "<!doctype html>"))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.Static(tc.givenPrefix, tc.givenRoot)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}
func TestEchoFile(t *testing.T) {
@ -277,10 +338,12 @@ func TestEchoURL(t *testing.T) {
e := New()
static := func(Context) error { return nil }
getUser := func(Context) error { return nil }
getAny := func(Context) error { return nil }
getFile := func(Context) error { return nil }
e.GET("/static/file", static)
e.GET("/users/:id", getUser)
e.GET("/documents/*", getAny)
g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile)
@ -289,6 +352,9 @@ func TestEchoURL(t *testing.T) {
assert.Equal("/static/file", e.URL(static))
assert.Equal("/users/:id", e.URL(getUser))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/users/1", e.URL(getUser, "1"))
assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal("/documents/*", e.URL(getAny))
assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
}
@ -568,6 +634,49 @@ func TestHTTPError(t *testing.T) {
})
}
func TestDefaultHTTPErrorHandler(t *testing.T) {
e := New()
e.Debug = true
e.Any("/plain", func(c Context) error {
return errors.New("An error occurred")
})
e.Any("/badrequest", func(c Context) error {
return NewHTTPError(http.StatusBadRequest, "Invalid request")
})
e.Any("/servererror", func(c Context) error {
return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{
"code": 33,
"message": "Something bad happened",
"error": "stackinfo",
})
})
// With Debug=true plain response contains error message
c, b := request(http.MethodGet, "/plain", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b)
// and special handling for HTTPError
c, b = request(http.MethodGet, "/badrequest", e)
assert.Equal(t, http.StatusBadRequest, c)
assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b)
// complex errors are serialized to pretty JSON
c, b = request(http.MethodGet, "/servererror", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b)
e.Debug = false
// With Debug=false the error response is shortened
c, b = request(http.MethodGet, "/plain", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b)
c, b = request(http.MethodGet, "/badrequest", e)
assert.Equal(t, http.StatusBadRequest, c)
assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b)
// No difference for error response with non plain string errors
c, b = request(http.MethodGet, "/servererror", e)
assert.Equal(t, http.StatusInternalServerError, c)
assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b)
}
func TestEchoClose(t *testing.T) {
e := New()
errCh := make(chan error)
@ -609,3 +718,91 @@ func TestEchoShutdown(t *testing.T) {
err := <-errCh
assert.Equal(t, err.Error(), "http: Server closed")
}
var listenerNetworkTests = []struct {
test string
network string
address string
}{
{"tcp ipv4 address", "tcp", "127.0.0.1:1323"},
{"tcp ipv6 address", "tcp", "[::1]:1323"},
{"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"},
{"tcp6 ipv6 address", "tcp6", "[::1]:1323"},
}
func TestEchoListenerNetwork(t *testing.T) {
for _, tt := range listenerNetworkTests {
t.Run(tt.test, func(t *testing.T) {
e := New()
e.ListenerNetwork = tt.network
// HandlerFunc
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
errCh := make(chan error)
go func() {
errCh <- e.Start(tt.address)
}()
time.Sleep(200 * time.Millisecond)
if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
if body, err := ioutil.ReadAll(resp.Body); err == nil {
assert.Equal(t, "OK", string(body))
} else {
assert.Fail(t, err.Error())
}
} else {
assert.Fail(t, err.Error())
}
if err := e.Close(); err != nil {
t.Fatal(err)
}
})
}
}
func TestEchoListenerNetworkInvalid(t *testing.T) {
e := New()
e.ListenerNetwork = "unix"
// HandlerFunc
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
}
func TestEchoReverse(t *testing.T) {
assert := assert.New(t)
e := New()
dummyHandler := func(Context) error { return nil }
e.GET("/static", dummyHandler).Name = "/static"
e.GET("/static/*", dummyHandler).Name = "/static/*"
e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

View File

@ -8,6 +8,7 @@ import (
"net"
"net/http"
"strings"
"sync"
"github.com/labstack/echo/v4"
)
@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level
}
pool := gzipCompressPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
rw := res.Writer
w, err := gzip.NewWriterLevel(rw, config.Level)
if err != nil {
return err
i := pool.Get()
w, ok := i.(*gzip.Writer)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
rw := res.Writer
w.Reset(rw)
defer func() {
if res.Size == 0 {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
w.Reset(ioutil.Discard)
}
w.Close()
pool.Put(w)
}()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
res.Writer = grw
@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
}
return http.ErrNotSupported
}
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}

View File

@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
e := echo.New()
// Invalid level
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
}
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) {
}
}
}
func BenchmarkGzip(b *testing.B) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Gzip
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}

View File

@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
@ -102,45 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided
if origin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if match, _ := regexp.MatchString(re, origin); match {
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Simple request
if req.Method != http.MethodOptions {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
@ -152,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
// Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)

View File

@ -1,6 +1,7 @@
package middleware
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
@ -17,19 +18,31 @@ func TestCORS(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard AllowedOrigin with no Origin header in request
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
// Allow origins
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
// Preflight request
req = httptest.NewRequest(http.MethodOptions, "/", nil)
@ -67,6 +80,22 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
// Preflight request with Access-Control-Request-Headers
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
// Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
@ -126,7 +155,7 @@ func Test_allowOriginScheme(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
@ -216,6 +245,163 @@ func Test_allowOriginSubdomain(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
func TestCorsHeaders(t *testing.T) {
tests := []struct {
domain, allowedOrigin, method string
expected bool
}{
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
},
}
e := echo.New()
for _, tt := range tests {
req := httptest.NewRequest(tt.method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tt.domain != "" {
req.Header.Set(echo.HeaderOrigin, tt.domain)
}
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
})
h := cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
expectedAllowOrigin := ""
if tt.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tt.domain
}
switch {
case tt.expected && tt.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tt.expected && tt.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
}
}
func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(origin string) (bool, error) {
return true, nil
}
returnFalse := func(origin string) (bool, error) {
return false, nil
}
returnError := func(origin string) (bool, error) {
return true, errors.New("this is a test error")
}
allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
}
const origin = "http://example.com"
e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
if expected {
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}

120
middleware/decompress.go Normal file
View File

@ -0,0 +1,120 @@
package middleware
import (
"bytes"
"compress/gzip"
"io"
"io/ioutil"
"net/http"
"sync"
"github.com/labstack/echo/v4"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
b := c.Request().Body
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
var buf bytes.Buffer
io.Copy(&buf, gr)
gr.Close()
pool.Put(gr)
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
r := ioutil.NopCloser(&buf)
c.Request().Body = r
}
return next(c)
}
}
}

View File

@ -0,0 +1,209 @@
package middleware
import (
"bytes"
"compress/gzip"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestDecompress(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Skip if no Content-Encoding header
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestDecompressDefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.NotEqual(t, b, body)
assert.Equal(t, b, gz)
}
func TestDecompressNoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func TestDecompressErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Decompress())
e.GET("/", func(c echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestDecompressSkipper(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: func(c echo.Context) bool {
return c.Request().URL.Path == "/skip"
},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
}
type TestDecompressPoolWithError struct {
}
func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
return errors.New("pool error")
},
}
}
func TestDecompressPoolError(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &TestDecompressPoolWithError{},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError)
}
func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte(body)) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Decompress
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}
func gzipString(body string) ([]byte, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte(body))
if err != nil {
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@ -57,6 +57,7 @@ type (
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
TokenLookup string
// AuthScheme to be used in the Authorization header.
@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1])
case "cookie":
extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -3,6 +3,8 @@ package middleware
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/dgrijalva/jwt-go"
@ -75,6 +77,7 @@ func TestJWT(t *testing.T) {
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
}{
{
@ -192,12 +195,48 @@ func TestJWT(t *testing.T) {
expErrCode: http.StatusBadRequest,
info: "Empty cookie",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
} {
if tc.reqURL == "" {
tc.reqURL = "/"
}
req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)

View File

@ -2,7 +2,6 @@ package middleware
import (
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
return strings.NewReplacer(replace...)
}
//rewritePath sets request url path and raw path
func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error {
replacerRawPath := replacer.Replace(target)
replacerPath, err := url.PathUnescape(replacerRawPath)
if err != nil {
return err
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
// Initialize
rulesRegex := map[*regexp.Regexp]string{}
for k, v := range rewrite {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v
}
return rulesRegex
}
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
for k, v := range rewriteRegex {
replacerRawPath := captureTokens(k, req.URL.EscapedPath())
if replacerRawPath != nil {
replacerPath := captureTokens(k, req.URL.Path)
req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v)
}
}
req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath
return nil
}
// DefaultSkipper returns false which processes the middleware.

View File

@ -8,7 +8,6 @@ import (
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
}
config.rewriteRegex = map[*regexp.Regexp]string{}
// Initialize
for k, v := range config.Rewrite {
k = strings.Replace(k, "*", "(\\S*)", -1)
config.rewriteRegex[regexp.MustCompile(k)] = v
}
config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
tgt := config.Balancer.Next(c)
c.Set(config.ContextKey, tgt)
// Rewrite
for k, v := range config.rewriteRegex {
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
}
}
// Set rewrite path and raw path
rewritePath(config.rewriteRegex, req)
// Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
}
}
}

View File

@ -94,36 +94,35 @@ func TestProxy(t *testing.T) {
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
req.URL.Path = "/api/users"
req.URL, _ = url.Parse("/api/users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/js/main.js"
req.URL, _ = url.Parse( "/js/main.js")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/old"
req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jack/orders/1"
req.URL, _ = url.Parse( "/users/jack/orders/1")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/%%%%"
req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
// ModifyResponse
e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{

View File

@ -31,3 +31,20 @@ func TestRequestID(t *testing.T) {
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
func TestRequestID_IDNotAltered(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRequestID, "<sample-request-id>")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{})
h := rid(handler)
_ = h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "<sample-request-id>")
}

View File

@ -1,11 +1,8 @@
package middleware
import (
"net/http"
"regexp"
"strings"
"github.com/labstack/echo/v4"
"regexp"
)
type (
@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
}
config.rulesRegex = map[*regexp.Regexp]string{}
// Initialize
for k, v := range config.Rules {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
config.rulesRegex[regexp.MustCompile(k)] = v
}
config.rulesRegex = rewriteRulesRegex(config.Rules)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
}
req := c.Request()
// Rewrite
for k, v := range config.rulesRegex {
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
break
}
}
// Set rewrite path and raw path
rewritePath(config.rulesRegex, req)
return next(c)
}
}

View File

@ -4,6 +4,7 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/labstack/echo/v4"
@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) {
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
req.URL.Path = "/api/users"
req.URL, _ = url.Parse("/api/users")
e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath())
req.URL.Path = "/js/main.js"
req.URL, _ = url.Parse("/js/main.js")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
req.URL.Path = "/old"
req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath())
req.URL.Path = "/users/jack/orders/1"
req.URL, _ = url.Parse("/users/jack/orders/1")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
req.URL.Path = "/api/new users"
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/%%%%"
rec = httptest.NewRecorder()
req.URL, _ = url.Parse("/api/new users")
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
}
// Issue #1086
@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
r := e.Router()
// Rewrite old url to new one
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
e.Pre(Rewrite(map[string]string{
"/old": "/new",
},
}))
))
// Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error {

View File

@ -36,6 +36,12 @@ type (
// Enable directory browsing.
// Optional. Default value false.
Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
}
)
@ -161,7 +167,16 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if err != nil {
return
}
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
if config.IgnoreBase {
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
baseURLPath := path.Base(p)
if baseURLPath == routePath {
i := strings.LastIndex(name, routePath)
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
}
}
fi, err := os.Stat(name)
if err != nil {

View File

@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"github.com/labstack/echo/v4"
@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), "cert.pem")
}
// IgnoreBase
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = true
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
rec = httptest.NewRecorder()
config.Root = "../_fixture"
config.IgnoreBase = false
static = StaticWithConfig(config)
c.Echo().Group("_fixture", static)
e.ServeHTTP(rec, req)
assert.Equal(http.StatusOK, rec.Code)
assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
}

View File

@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) {
r.echo.Logger.Warn("response already committed")
return
}
r.Status = code
for _, fn := range r.beforeFuncs {
fn()
}
r.Status = code
r.Writer.WriteHeader(code)
r.Writer.WriteHeader(r.Status)
r.Committed = true
}

View File

@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) {
res.Flush()
assert.True(t, rec.Flushed)
}
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
res := &Response{echo: e, Writer: rec}
res.Before(func() {
if 200 < res.Status && res.Status < 300 {
res.Status = 200
}
})
res.WriteHeader(209)
assert.Equal(t, http.StatusOK, rec.Code)
}

View File

@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) {
pos := strings.IndexByte(ns, '/')
if pos == -1 {
// If no slash is remaining in search string set param value
pvalues[len(cn.pnames)-1] = search
if len(cn.pnames) > 0 {
pvalues[len(cn.pnames)-1] = search
}
break
} else if pos > 0 {
// Otherwise continue route processing with restored next node

View File

@ -1335,6 +1335,40 @@ func TestRouterFindNotPanicOrLoopsWhenContextSetParamValuesIsCalledWithLessValue
assert.Equal(t, 1, c.Get("i"))
}
// Issue #1653
func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1))
r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error {
return nil
})
r.Add(http.MethodGet, "/users/:id/active", func(c Context) error {
return nil
})
c := e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/alice/edit", c)
assert.Equal(t, "alice", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/bob/active", c)
assert.Equal(t, "bob", c.Param("id"))
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/create", c)
c.Handler()(c)
assert.Equal(t, 1, c.Get("create"))
assert.Equal(t, "/users/create", c.Get("path"))
//This panic before the fix for Issue #1653
c = e.NewContext(nil, nil).(*context)
r.Find(http.MethodGet, "/users/createNotFound", c)
he := c.Handler()(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.Code)
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route) {
e := New()
r := e.router