mirror of
https://github.com/labstack/echo.git
synced 2025-01-12 01:22:21 +02:00
Enhanced body-limit middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
96898d5b9b
commit
08a173b476
@ -3,7 +3,6 @@ package middleware
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
@ -31,8 +30,8 @@ type (
|
||||
//
|
||||
// BodyLimit middleware sets the maximum allowed size for a request body, if the
|
||||
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
|
||||
// response. The body limit is determined based on the actually read and not `Content-Length`
|
||||
// request header, which makes it super secure.
|
||||
// response. The body limit is determined based on both `Content-Length` request
|
||||
// header and actual content read, which makes it super secure.
|
||||
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
|
||||
// G, T or P.
|
||||
func BodyLimit(limit string) echo.MiddlewareFunc {
|
||||
@ -52,10 +51,18 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
req := c.Request()
|
||||
|
||||
// Based on content length
|
||||
if req.ContentLength() > config.limit {
|
||||
return echo.ErrStatusRequestEntityTooLarge
|
||||
}
|
||||
|
||||
// Based on content read
|
||||
r := pool.Get().(*limitedReader)
|
||||
r.Reset(req.Body(), c)
|
||||
defer pool.Put(r)
|
||||
req.SetBody(r)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
@ -65,8 +72,6 @@ func (r *limitedReader) Read(b []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(b)
|
||||
r.read += int64(n)
|
||||
if r.read > r.limit {
|
||||
s := http.StatusRequestEntityTooLarge
|
||||
r.context.String(s, http.StatusText(s))
|
||||
return n, echo.ErrStatusRequestEntityTooLarge
|
||||
}
|
||||
return
|
||||
|
@ -1,12 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"github.com/labstack/echo/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -14,7 +13,8 @@ import (
|
||||
|
||||
func TestBodyLimit(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!")))
|
||||
hw := []byte("Hello, World!")
|
||||
req := test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
|
||||
rec := test.NewResponseRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := func(c echo.Context) error {
|
||||
@ -25,15 +25,29 @@ func TestBodyLimit(t *testing.T) {
|
||||
return c.String(http.StatusOK, string(body))
|
||||
}
|
||||
|
||||
// Within limit
|
||||
BodyLimit("2M")(h)(c)
|
||||
assert.Equal(t, http.StatusOK, rec.Status())
|
||||
assert.Equal(t, "Hello, World!", rec.Body.String())
|
||||
// Based on content length (within limit)
|
||||
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
|
||||
assert.Equal(t, http.StatusOK, rec.Status())
|
||||
assert.Equal(t, hw, rec.Body.Bytes())
|
||||
}
|
||||
|
||||
// Overlimit
|
||||
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!")))
|
||||
// Based on content read (overlimit)
|
||||
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||
|
||||
// Based on content read (within limit)
|
||||
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
|
||||
rec = test.NewResponseRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
BodyLimit("2B")(h)(c)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Status())
|
||||
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
|
||||
assert.Equal(t, http.StatusOK, rec.Status())
|
||||
assert.Equal(t, "Hello, World!", rec.Body.String())
|
||||
}
|
||||
|
||||
// Based on content read (overlimit)
|
||||
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
|
||||
rec = test.NewResponseRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||
}
|
||||
|
@ -36,9 +36,13 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
|
||||
if qs != "" {
|
||||
uri += "?" + qs
|
||||
}
|
||||
|
||||
// Redirect
|
||||
if config.RedirectCode != 0 {
|
||||
return c.Redirect(config.RedirectCode, uri)
|
||||
}
|
||||
|
||||
// Forward
|
||||
req.SetURI(uri)
|
||||
url.SetPath(path)
|
||||
}
|
||||
@ -71,9 +75,13 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
|
||||
if qs != "" {
|
||||
uri += "?" + qs
|
||||
}
|
||||
|
||||
// Redirect
|
||||
if config.RedirectCode != 0 {
|
||||
return c.Redirect(config.RedirectCode, uri)
|
||||
}
|
||||
|
||||
// Forward
|
||||
req.SetURI(uri)
|
||||
url.SetPath(path)
|
||||
}
|
||||
|
@ -20,7 +20,8 @@ type (
|
||||
// Optional. Default value "index.html".
|
||||
Index string `json:"index"`
|
||||
|
||||
// Enable HTML5 mode by forwarding all not-found routes to root.
|
||||
// Enable HTML5 mode by forwarding all not-found requests to root so that
|
||||
// SPA (single-page application) can handle the routing.
|
||||
HTML5 bool `json:"html5"`
|
||||
|
||||
// Enable directory browsing.
|
||||
|
Loading…
Reference in New Issue
Block a user