mirror of
https://github.com/labstack/echo.git
synced 2025-07-03 00:56:59 +02:00
Merge branch 'master' of https://github.com/labstack/echo
This commit is contained in:
@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
||||
config.Level = DefaultGzipConfig.Level
|
||||
}
|
||||
|
||||
pool := gzipPool(config)
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
|
||||
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
|
||||
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
|
||||
rw := res.Writer
|
||||
w, err := gzip.NewWriterLevel(rw, config.Level)
|
||||
if err != nil {
|
||||
return err
|
||||
i := pool.Get()
|
||||
w, ok := i.(*gzip.Writer)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
|
||||
}
|
||||
rw := res.Writer
|
||||
w.Reset(rw)
|
||||
defer func() {
|
||||
if res.Size == 0 {
|
||||
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
|
||||
@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
||||
w.Reset(ioutil.Discard)
|
||||
}
|
||||
w.Close()
|
||||
pool.Put(w)
|
||||
}()
|
||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
|
||||
res.Writer = grw
|
||||
@ -119,3 +125,22 @@ func (w *gzipResponseWriter) Flush() {
|
||||
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.ResponseWriter.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||
if p, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||
return p.Push(target, opts)
|
||||
}
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
|
||||
func gzipPool(config GzipConfig) sync.Pool {
|
||||
return sync.Pool{
|
||||
New: func() interface{} {
|
||||
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) {
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
}
|
||||
|
||||
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
// Invalid level
|
||||
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test"))
|
||||
return nil
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "gzip")
|
||||
}
|
||||
|
||||
// Issue #806
|
||||
func TestGzipWithStatic(t *testing.T) {
|
||||
e := echo.New()
|
||||
@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGzip(b *testing.B) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
|
||||
h := Gzip()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Gzip
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h(c)
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -18,6 +19,13 @@ type (
|
||||
// Optional. Default value []string{"*"}.
|
||||
AllowOrigins []string `yaml:"allow_origins"`
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||
// origin as an argument and returns true if allowed or false otherwise. If
|
||||
// an error is returned, it is returned by the handler. If this option is
|
||||
// set, AllowOrigins is ignored.
|
||||
// Optional.
|
||||
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
|
||||
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
@ -76,6 +84,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
||||
}
|
||||
|
||||
allowOriginPatterns := []string{}
|
||||
for _, origin := range config.AllowOrigins {
|
||||
pattern := regexp.QuoteMeta(origin)
|
||||
pattern = strings.Replace(pattern, "\\*", ".*", -1)
|
||||
pattern = strings.Replace(pattern, "\\?", ".", -1)
|
||||
pattern = "^" + pattern + "$"
|
||||
allowOriginPatterns = append(allowOriginPatterns, pattern)
|
||||
}
|
||||
|
||||
allowMethods := strings.Join(config.AllowMethods, ",")
|
||||
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
||||
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
||||
@ -92,25 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
origin := req.Header.Get(echo.HeaderOrigin)
|
||||
allowOrigin := ""
|
||||
|
||||
// Check allowed origins
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" && config.AllowCredentials {
|
||||
allowOrigin = origin
|
||||
break
|
||||
preflight := req.Method == http.MethodOptions
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||
|
||||
// No Origin provided
|
||||
if origin == "" {
|
||||
if !preflight {
|
||||
return next(c)
|
||||
}
|
||||
if o == "*" || o == origin {
|
||||
allowOrigin = o
|
||||
break
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
if config.AllowOriginFunc != nil {
|
||||
allowed, err := config.AllowOriginFunc(origin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
if allowed {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// Check allowed origins
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" && config.AllowCredentials {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
if o == "*" || o == origin {
|
||||
allowOrigin = o
|
||||
break
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed origin patterns
|
||||
for _, re := range allowOriginPatterns {
|
||||
if allowOrigin == "" {
|
||||
didx := strings.Index(origin, "://")
|
||||
if didx == -1 {
|
||||
continue
|
||||
}
|
||||
domAuth := origin[didx+3:]
|
||||
// to avoid regex cost by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
break
|
||||
}
|
||||
|
||||
if match, _ := regexp.MatchString(re, origin); match {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Origin not allowed
|
||||
if allowOrigin == "" {
|
||||
if !preflight {
|
||||
return next(c)
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Simple request
|
||||
if req.Method != http.MethodOptions {
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||
if !preflight {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
if config.AllowCredentials {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||
@ -122,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
// Preflight request
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@ -17,19 +18,31 @@ func TestCORS(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := CORS()(echo.NotFoundHandler)
|
||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||
h(c)
|
||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
// Wildcard AllowedOrigin with no Origin header in request
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
h = CORS()(echo.NotFoundHandler)
|
||||
h(c)
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
|
||||
// Allow origins
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
h = CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{"localhost"},
|
||||
AllowOrigins: []string{"localhost"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
})(echo.NotFoundHandler)
|
||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||
h(c)
|
||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||
|
||||
// Preflight request
|
||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
@ -67,6 +80,22 @@ func TestCORS(t *testing.T) {
|
||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
||||
|
||||
// Preflight request with Access-Control-Request-Headers
|
||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
|
||||
cors = CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
})
|
||||
h = cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
|
||||
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||
|
||||
// Preflight request with `AllowOrigins` which allow all subdomains with *
|
||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
@ -83,3 +112,298 @@ func TestCORS(t *testing.T) {
|
||||
h(c)
|
||||
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
}
|
||||
|
||||
func Test_allowOriginScheme(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, pattern string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
domain: "http://example.com",
|
||||
pattern: "http://example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "https://example.com",
|
||||
pattern: "https://example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
pattern: "https://example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "https://example.com",
|
||||
pattern: "http://example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tt.pattern},
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
|
||||
if tt.expected {
|
||||
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
} else {
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_allowOriginSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, pattern string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
domain: "http://aaa.example.com",
|
||||
pattern: "http://*.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://bbb.aaa.example.com",
|
||||
pattern: "http://*.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://bbb.aaa.example.com",
|
||||
pattern: "http://*.aaa.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://aaa.example.com:8080",
|
||||
pattern: "http://*.example.com:8080",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
{
|
||||
domain: "http://fuga.hoge.com",
|
||||
pattern: "http://*.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://ccc.bbb.example.com",
|
||||
pattern: "http://*.aaa.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
|
||||
pattern: "http://*.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
|
||||
pattern: "http://*.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://ccc.bbb.example.com",
|
||||
pattern: "http://example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "https://prod-preview--aaa.bbb.com",
|
||||
pattern: "https://*--aaa.bbb.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://ccc.bbb.example.com",
|
||||
pattern: "http://*.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://ccc.bbb.example.com",
|
||||
pattern: "http://foo.[a-z]*.example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tt.pattern},
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
|
||||
if tt.expected {
|
||||
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
} else {
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, allowedOrigin, method string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(tt.method, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
if tt.domain != "" {
|
||||
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||
}
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tt.allowedOrigin},
|
||||
//AllowCredentials: true,
|
||||
//MaxAge: 3600,
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
|
||||
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||
|
||||
expectedAllowOrigin := ""
|
||||
if tt.allowedOrigin == "*" {
|
||||
expectedAllowOrigin = "*"
|
||||
} else {
|
||||
expectedAllowOrigin = tt.domain
|
||||
}
|
||||
|
||||
switch {
|
||||
case tt.expected && tt.method == http.MethodOptions:
|
||||
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
|
||||
case tt.expected && tt.method == http.MethodGet:
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
default:
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
}
|
||||
|
||||
if tt.method == http.MethodOptions {
|
||||
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_allowOriginFunc(t *testing.T) {
|
||||
returnTrue := func(origin string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
returnFalse := func(origin string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
returnError := func(origin string) (bool, error) {
|
||||
return true, errors.New("this is a test error")
|
||||
}
|
||||
|
||||
allowOriginFuncs := []func(origin string) (bool, error){
|
||||
returnTrue,
|
||||
returnFalse,
|
||||
returnError,
|
||||
}
|
||||
|
||||
const origin = "http://example.com"
|
||||
|
||||
e := echo.New()
|
||||
for _, allowOriginFunc := range allowOriginFuncs {
|
||||
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
req.Header.Set(echo.HeaderOrigin, origin)
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOriginFunc: allowOriginFunc,
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
err := h(c)
|
||||
|
||||
expected, expectedErr := allowOriginFunc(origin)
|
||||
if expectedErr != nil {
|
||||
assert.Equal(t, expectedErr, err)
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
continue
|
||||
}
|
||||
|
||||
if expected {
|
||||
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
} else {
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
58
middleware/decompress.go
Normal file
58
middleware/decompress.go
Normal file
@ -0,0 +1,58 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"github.com/labstack/echo/v4"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
type (
|
||||
// DecompressConfig defines the config for Decompress middleware.
|
||||
DecompressConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
}
|
||||
)
|
||||
|
||||
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
|
||||
const GZIPEncoding string = "gzip"
|
||||
|
||||
var (
|
||||
//DefaultDecompressConfig defines the config for decompress middleware
|
||||
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
|
||||
)
|
||||
|
||||
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
|
||||
func Decompress() echo.MiddlewareFunc {
|
||||
return DecompressWithConfig(DefaultDecompressConfig)
|
||||
}
|
||||
|
||||
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
|
||||
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
|
||||
case GZIPEncoding:
|
||||
gr, err := gzip.NewReader(c.Request().Body)
|
||||
if err != nil {
|
||||
if err == io.EOF { //ignore if body is empty
|
||||
return next(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer gr.Close()
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, gr)
|
||||
r := ioutil.NopCloser(&buf)
|
||||
defer r.Close()
|
||||
c.Request().Body = r
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
148
middleware/decompress_test.go
Normal file
148
middleware/decompress_test.go
Normal file
@ -0,0 +1,148 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecompress(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Skip if no Content-Encoding header
|
||||
h := Decompress()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
h(c)
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.Equal("test", rec.Body.String())
|
||||
|
||||
// Decompress
|
||||
body := `{"name": "echo"}`
|
||||
gz, _ := gzipString(body)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
h(c)
|
||||
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
b, err := ioutil.ReadAll(req.Body)
|
||||
assert.NoError(err)
|
||||
assert.Equal(body, string(b))
|
||||
}
|
||||
|
||||
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
|
||||
e := echo.New()
|
||||
body := `{"name":"echo"}`
|
||||
gz, _ := gzipString(body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
e.NewContext(req, rec)
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
b, err := ioutil.ReadAll(req.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, b, body)
|
||||
assert.Equal(t, b, gz)
|
||||
}
|
||||
|
||||
func TestDecompressNoContent(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := Decompress()(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
})
|
||||
if assert.NoError(t, h(c)) {
|
||||
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
|
||||
assert.Equal(t, 0, len(rec.Body.Bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecompressErrorReturned(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Use(Decompress())
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return echo.ErrNotFound
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
}
|
||||
|
||||
func TestDecompressSkipper(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Use(DecompressWithConfig(DecompressConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return c.Request().URL.Path == "/skip"
|
||||
},
|
||||
}))
|
||||
body := `{"name": "echo"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
|
||||
reqBody, err := ioutil.ReadAll(c.Request().Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, string(reqBody))
|
||||
}
|
||||
|
||||
func BenchmarkDecompress(b *testing.B) {
|
||||
e := echo.New()
|
||||
body := `{"name": "echo"}`
|
||||
gz, _ := gzipString(body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
|
||||
h := Decompress()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte(body)) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Decompress
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h(c)
|
||||
}
|
||||
}
|
||||
|
||||
func gzipString(body string) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
|
||||
_, err := gz.Write([]byte(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := gz.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
@ -86,6 +86,7 @@ const (
|
||||
// Errors
|
||||
var (
|
||||
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
|
||||
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
|
||||
)
|
||||
|
||||
var (
|
||||
@ -213,8 +214,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
return config.ErrorHandlerWithContext(err, c)
|
||||
}
|
||||
return &echo.HTTPError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "invalid or expired jwt",
|
||||
Code: ErrJWTInvalid.Code,
|
||||
Message: ErrJWTInvalid.Message,
|
||||
Internal: err,
|
||||
}
|
||||
}
|
||||
|
@ -60,8 +60,6 @@ func TestJWTRace(t *testing.T) {
|
||||
|
||||
func TestJWT(t *testing.T) {
|
||||
e := echo.New()
|
||||
r := e.Router()
|
||||
r.Add("GET", "/:jwt", func(echo.Context) error { return nil })
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -32,6 +33,31 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
||||
return strings.NewReplacer(replace...)
|
||||
}
|
||||
|
||||
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
|
||||
// Initialize
|
||||
rulesRegex := map[*regexp.Regexp]string{}
|
||||
for k, v := range rewrite {
|
||||
k = regexp.QuoteMeta(k)
|
||||
k = strings.Replace(k, `\*`, "(.*)", -1)
|
||||
if strings.HasPrefix(k, `\^`) {
|
||||
k = strings.Replace(k, `\^`, "^", -1)
|
||||
}
|
||||
k = k + "$"
|
||||
rulesRegex[regexp.MustCompile(k)] = v
|
||||
}
|
||||
return rulesRegex
|
||||
}
|
||||
|
||||
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
|
||||
for k, v := range rewriteRegex {
|
||||
replacerRawPath := captureTokens(k, req.URL.EscapedPath())
|
||||
if replacerRawPath != nil {
|
||||
replacerPath := captureTokens(k, req.URL.Path)
|
||||
req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSkipper returns false which processes the middleware.
|
||||
func DefaultSkipper(echo.Context) bool {
|
||||
return false
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -45,6 +44,9 @@ type (
|
||||
// Examples: If custom TLS certificates are required.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// ModifyResponse defines function to modify response from ProxyTarget.
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
||||
rewriteRegex map[*regexp.Regexp]string
|
||||
}
|
||||
|
||||
@ -203,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
if config.Balancer == nil {
|
||||
panic("echo: proxy middleware requires balancer")
|
||||
}
|
||||
config.rewriteRegex = map[*regexp.Regexp]string{}
|
||||
|
||||
// Initialize
|
||||
for k, v := range config.Rewrite {
|
||||
k = strings.Replace(k, "*", "(\\S*)", -1)
|
||||
config.rewriteRegex[regexp.MustCompile(k)] = v
|
||||
}
|
||||
config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
@ -222,13 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
tgt := config.Balancer.Next(c)
|
||||
c.Set(config.ContextKey, tgt)
|
||||
|
||||
// Rewrite
|
||||
for k, v := range config.rewriteRegex {
|
||||
replacer := captureTokens(k, req.URL.Path)
|
||||
if replacer != nil {
|
||||
req.URL.Path = replacer.Replace(v)
|
||||
}
|
||||
}
|
||||
// Set rewrite path and raw path
|
||||
rewritePath(config.rewriteRegex, req)
|
||||
|
||||
// Fix header
|
||||
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
|
||||
@ -259,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -20,5 +20,6 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)))
|
||||
}
|
||||
proxy.Transport = config.Transport
|
||||
proxy.ModifyResponse = config.ModifyResponse
|
||||
return proxy
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -12,6 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
//Assert expected with url.EscapedPath method to obtain the path.
|
||||
func TestProxy(t *testing.T) {
|
||||
// Setup
|
||||
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -91,19 +94,49 @@ func TestProxy(t *testing.T) {
|
||||
"/users/*/orders/*": "/user/$1/order/$2",
|
||||
},
|
||||
}))
|
||||
req.URL.Path = "/api/users"
|
||||
req.URL, _ = url.Parse("/api/users")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/users", req.URL.Path)
|
||||
req.URL.Path = "/js/main.js"
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.Path)
|
||||
req.URL.Path = "/old"
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/new", req.URL.Path)
|
||||
req.URL.Path = "/users/jack/orders/1"
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/user/jack/order/1", req.URL.Path)
|
||||
assert.Equal(t, "/users", req.URL.EscapedPath())
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
req.URL, _ = url.Parse( "/js/main.js")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
req.URL, _ = url.Parse("/old")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
req.URL, _ = url.Parse( "/users/jack/orders/1")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
req.URL, _ = url.Parse("/api/new users")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
|
||||
// ModifyResponse
|
||||
e = echo.New()
|
||||
e.Use(ProxyWithConfig(ProxyConfig{
|
||||
Balancer: rrb,
|
||||
ModifyResponse: func(res *http.Response) error {
|
||||
res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified")))
|
||||
res.Header.Set("X-Modified", "1")
|
||||
return nil
|
||||
},
|
||||
}))
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "modified", rec.Body.String())
|
||||
assert.Equal(t, "1", rec.Header().Get("X-Modified"))
|
||||
|
||||
// ProxyTarget is set in context
|
||||
contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"runtime"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/log"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -25,6 +26,10 @@ type (
|
||||
// DisablePrintStack disables printing stack trace.
|
||||
// Optional. Default value as false.
|
||||
DisablePrintStack bool `yaml:"disable_print_stack"`
|
||||
|
||||
// LogLevel is log level to printing stack trace.
|
||||
// Optional. Default value 0 (Print).
|
||||
LogLevel log.Lvl
|
||||
}
|
||||
)
|
||||
|
||||
@ -35,6 +40,7 @@ var (
|
||||
StackSize: 4 << 10, // 4 KB
|
||||
DisableStackAll: false,
|
||||
DisablePrintStack: false,
|
||||
LogLevel: 0,
|
||||
}
|
||||
)
|
||||
|
||||
@ -70,7 +76,21 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
|
||||
stack := make([]byte, config.StackSize)
|
||||
length := runtime.Stack(stack, !config.DisableStackAll)
|
||||
if !config.DisablePrintStack {
|
||||
c.Logger().Printf("[PANIC RECOVER] %v %s\n", err, stack[:length])
|
||||
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
|
||||
switch config.LogLevel {
|
||||
case log.DEBUG:
|
||||
c.Logger().Debug(msg)
|
||||
case log.INFO:
|
||||
c.Logger().Info(msg)
|
||||
case log.WARN:
|
||||
c.Logger().Warn(msg)
|
||||
case log.ERROR:
|
||||
c.Logger().Error(msg)
|
||||
case log.OFF:
|
||||
// None.
|
||||
default:
|
||||
c.Logger().Print(msg)
|
||||
}
|
||||
}
|
||||
c.Error(err)
|
||||
}
|
||||
|
@ -2,11 +2,13 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -24,3 +26,58 @@ func TestRecover(t *testing.T) {
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, buf.String(), "PANIC RECOVER")
|
||||
}
|
||||
|
||||
func TestRecoverWithConfig_LogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
logLevel log.Lvl
|
||||
levelName string
|
||||
}{{
|
||||
logLevel: log.DEBUG,
|
||||
levelName: "DEBUG",
|
||||
}, {
|
||||
logLevel: log.INFO,
|
||||
levelName: "INFO",
|
||||
}, {
|
||||
logLevel: log.WARN,
|
||||
levelName: "WARN",
|
||||
}, {
|
||||
logLevel: log.ERROR,
|
||||
levelName: "ERROR",
|
||||
}, {
|
||||
logLevel: log.OFF,
|
||||
levelName: "OFF",
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.levelName, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Logger.SetLevel(log.DEBUG)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
e.Logger.SetOutput(buf)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
config := DefaultRecoverConfig
|
||||
config.LogLevel = tt.logLevel
|
||||
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
|
||||
panic("test")
|
||||
}))
|
||||
|
||||
h(c)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
|
||||
output := buf.String()
|
||||
if tt.logLevel == log.OFF {
|
||||
assert.Empty(t, output)
|
||||
} else {
|
||||
assert.Contains(t, output, "PANIC RECOVER")
|
||||
assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -53,14 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultBodyDumpConfig.Skipper
|
||||
}
|
||||
config.rulesRegex = map[*regexp.Regexp]string{}
|
||||
|
||||
// Initialize
|
||||
for k, v := range config.Rules {
|
||||
k = strings.Replace(k, "*", "(.*)", -1)
|
||||
k = k + "$"
|
||||
config.rulesRegex[regexp.MustCompile(k)] = v
|
||||
}
|
||||
config.rulesRegex = rewriteRulesRegex(config.Rules)
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
@ -69,15 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
req := c.Request()
|
||||
|
||||
// Rewrite
|
||||
for k, v := range config.rulesRegex {
|
||||
replacer := captureTokens(k, req.URL.Path)
|
||||
if replacer != nil {
|
||||
req.URL.Path = replacer.Replace(v)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Set rewrite path and raw path
|
||||
rewritePath(config.rulesRegex, req)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
@ -4,12 +4,14 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
//Assert expected with url.EscapedPath method to obtain the path.
|
||||
func TestRewrite(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Use(RewriteWithConfig(RewriteConfig{
|
||||
@ -22,21 +24,28 @@ func TestRewrite(t *testing.T) {
|
||||
}))
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
req.URL.Path = "/api/users"
|
||||
req.URL, _ = url.Parse("/api/users")
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/users", req.URL.Path)
|
||||
req.URL.Path = "/js/main.js"
|
||||
assert.Equal(t, "/users", req.URL.EscapedPath())
|
||||
req.URL, _ = url.Parse("/js/main.js")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.Path)
|
||||
req.URL.Path = "/old"
|
||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
||||
req.URL, _ = url.Parse("/old")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/new", req.URL.Path)
|
||||
req.URL.Path = "/users/jack/orders/1"
|
||||
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||
req.URL, _ = url.Parse("/users/jack/orders/1")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/user/jack/order/1", req.URL.Path)
|
||||
req.URL.Path = "/api/new users"
|
||||
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
||||
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/new users", req.URL.Path)
|
||||
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
||||
req.URL, _ = url.Parse("/api/new users")
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
|
||||
}
|
||||
|
||||
// Issue #1086
|
||||
@ -45,22 +54,21 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
|
||||
r := e.Router()
|
||||
|
||||
// Rewrite old url to new one
|
||||
e.Pre(RewriteWithConfig(RewriteConfig{
|
||||
Rules: map[string]string{
|
||||
e.Pre(Rewrite(map[string]string{
|
||||
"/old": "/new",
|
||||
},
|
||||
}))
|
||||
))
|
||||
|
||||
// Route
|
||||
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
|
||||
return c.NoContent(200)
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/old", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/new", req.URL.Path)
|
||||
assert.Equal(t, 200, rec.Code)
|
||||
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// Issue #1143
|
||||
@ -76,21 +84,48 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
|
||||
}))
|
||||
|
||||
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
|
||||
return c.String(200, "hosts")
|
||||
return c.String(http.StatusOK, "hosts")
|
||||
})
|
||||
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
|
||||
return c.String(200, "eng")
|
||||
return c.String(http.StatusOK, "eng")
|
||||
})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/api/v1/hosts/test", req.URL.Path)
|
||||
assert.Equal(t, 200, rec.Code)
|
||||
assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath())
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
defer rec.Result().Body.Close()
|
||||
bodyBytes, _ := ioutil.ReadAll(rec.Result().Body)
|
||||
assert.Equal(t, "hosts", string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
// Issue #1573
|
||||
func TestEchoRewriteWithCaret(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.Pre(RewriteWithConfig(RewriteConfig{
|
||||
Rules: map[string]string{
|
||||
"^/abc/*": "/v1/abc/$1",
|
||||
},
|
||||
}))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
var req *http.Request
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/abc/test", nil)
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/v1/abc/test", req.URL.Path)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil)
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/v1/abc/test", req.URL.Path)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil)
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/v2/abc/test", req.URL.Path)
|
||||
}
|
||||
|
@ -36,6 +36,12 @@ type (
|
||||
// Enable directory browsing.
|
||||
// Optional. Default value false.
|
||||
Browse bool `yaml:"browse"`
|
||||
|
||||
// Enable ignoring of the base of the URL path.
|
||||
// Example: when assigning a static middleware to a non root path group,
|
||||
// the filesystem path is not doubled
|
||||
// Optional. Default value false.
|
||||
IgnoreBase bool `yaml:"ignoreBase"`
|
||||
}
|
||||
)
|
||||
|
||||
@ -163,6 +169,15 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
|
||||
|
||||
if config.IgnoreBase {
|
||||
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
|
||||
baseURLPath := path.Base(p)
|
||||
if baseURLPath == routePath {
|
||||
i := strings.LastIndex(name, routePath)
|
||||
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
|
||||
}
|
||||
}
|
||||
|
||||
fi, err := os.Stat(name)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
@ -67,4 +68,27 @@ func TestStatic(t *testing.T) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Contains(rec.Body.String(), "cert.pem")
|
||||
}
|
||||
|
||||
// IgnoreBase
|
||||
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
config.Root = "../_fixture"
|
||||
config.IgnoreBase = true
|
||||
static = StaticWithConfig(config)
|
||||
c.Echo().Group("_fixture", static)
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122")
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/_fixture", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
config.Root = "../_fixture"
|
||||
config.IgnoreBase = false
|
||||
static = StaticWithConfig(config)
|
||||
c.Echo().Group("_fixture", static)
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture"))
|
||||
}
|
||||
|
Reference in New Issue
Block a user