mirror of
https://github.com/labstack/echo.git
synced 2025-01-26 03:20:08 +02:00
Enhanced body-limit middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
96898d5b9b
commit
08a173b476
@ -3,7 +3,6 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
@ -31,8 +30,8 @@ type (
|
|||||||
//
|
//
|
||||||
// BodyLimit middleware sets the maximum allowed size for a request body, if the
|
// BodyLimit middleware sets the maximum allowed size for a request body, if the
|
||||||
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
|
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
|
||||||
// response. The body limit is determined based on the actually read and not `Content-Length`
|
// response. The body limit is determined based on both `Content-Length` request
|
||||||
// request header, which makes it super secure.
|
// header and actual content read, which makes it super secure.
|
||||||
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
|
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
|
||||||
// G, T or P.
|
// G, T or P.
|
||||||
func BodyLimit(limit string) echo.MiddlewareFunc {
|
func BodyLimit(limit string) echo.MiddlewareFunc {
|
||||||
@ -52,10 +51,18 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
|
|||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
req := c.Request()
|
req := c.Request()
|
||||||
|
|
||||||
|
// Based on content length
|
||||||
|
if req.ContentLength() > config.limit {
|
||||||
|
return echo.ErrStatusRequestEntityTooLarge
|
||||||
|
}
|
||||||
|
|
||||||
|
// Based on content read
|
||||||
r := pool.Get().(*limitedReader)
|
r := pool.Get().(*limitedReader)
|
||||||
r.Reset(req.Body(), c)
|
r.Reset(req.Body(), c)
|
||||||
defer pool.Put(r)
|
defer pool.Put(r)
|
||||||
req.SetBody(r)
|
req.SetBody(r)
|
||||||
|
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -65,8 +72,6 @@ func (r *limitedReader) Read(b []byte) (n int, err error) {
|
|||||||
n, err = r.reader.Read(b)
|
n, err = r.reader.Read(b)
|
||||||
r.read += int64(n)
|
r.read += int64(n)
|
||||||
if r.read > r.limit {
|
if r.read > r.limit {
|
||||||
s := http.StatusRequestEntityTooLarge
|
|
||||||
r.context.String(s, http.StatusText(s))
|
|
||||||
return n, echo.ErrStatusRequestEntityTooLarge
|
return n, echo.ErrStatusRequestEntityTooLarge
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
"github.com/labstack/echo/test"
|
"github.com/labstack/echo/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -14,7 +13,8 @@ import (
|
|||||||
|
|
||||||
func TestBodyLimit(t *testing.T) {
|
func TestBodyLimit(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
req := test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!")))
|
hw := []byte("Hello, World!")
|
||||||
|
req := test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
|
||||||
rec := test.NewResponseRecorder()
|
rec := test.NewResponseRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
h := func(c echo.Context) error {
|
h := func(c echo.Context) error {
|
||||||
@ -25,15 +25,29 @@ func TestBodyLimit(t *testing.T) {
|
|||||||
return c.String(http.StatusOK, string(body))
|
return c.String(http.StatusOK, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Within limit
|
// Based on content length (within limit)
|
||||||
BodyLimit("2M")(h)(c)
|
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
|
||||||
assert.Equal(t, http.StatusOK, rec.Status())
|
assert.Equal(t, http.StatusOK, rec.Status())
|
||||||
assert.Equal(t, "Hello, World!", rec.Body.String())
|
assert.Equal(t, hw, rec.Body.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
// Overlimit
|
// Based on content read (overlimit)
|
||||||
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!")))
|
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||||
|
|
||||||
|
// Based on content read (within limit)
|
||||||
|
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
|
||||||
rec = test.NewResponseRecorder()
|
rec = test.NewResponseRecorder()
|
||||||
c = e.NewContext(req, rec)
|
c = e.NewContext(req, rec)
|
||||||
BodyLimit("2B")(h)(c)
|
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
|
||||||
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Status())
|
assert.Equal(t, http.StatusOK, rec.Status())
|
||||||
|
assert.Equal(t, "Hello, World!", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Based on content read (overlimit)
|
||||||
|
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
|
||||||
|
rec = test.NewResponseRecorder()
|
||||||
|
c = e.NewContext(req, rec)
|
||||||
|
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||||
}
|
}
|
||||||
|
@ -36,9 +36,13 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
|
|||||||
if qs != "" {
|
if qs != "" {
|
||||||
uri += "?" + qs
|
uri += "?" + qs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect
|
||||||
if config.RedirectCode != 0 {
|
if config.RedirectCode != 0 {
|
||||||
return c.Redirect(config.RedirectCode, uri)
|
return c.Redirect(config.RedirectCode, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Forward
|
||||||
req.SetURI(uri)
|
req.SetURI(uri)
|
||||||
url.SetPath(path)
|
url.SetPath(path)
|
||||||
}
|
}
|
||||||
@ -71,9 +75,13 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
|
|||||||
if qs != "" {
|
if qs != "" {
|
||||||
uri += "?" + qs
|
uri += "?" + qs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect
|
||||||
if config.RedirectCode != 0 {
|
if config.RedirectCode != 0 {
|
||||||
return c.Redirect(config.RedirectCode, uri)
|
return c.Redirect(config.RedirectCode, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Forward
|
||||||
req.SetURI(uri)
|
req.SetURI(uri)
|
||||||
url.SetPath(path)
|
url.SetPath(path)
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,8 @@ type (
|
|||||||
// Optional. Default value "index.html".
|
// Optional. Default value "index.html".
|
||||||
Index string `json:"index"`
|
Index string `json:"index"`
|
||||||
|
|
||||||
// Enable HTML5 mode by forwarding all not-found routes to root.
|
// Enable HTML5 mode by forwarding all not-found requests to root so that
|
||||||
|
// SPA (single-page application) can handle the routing.
|
||||||
HTML5 bool `json:"html5"`
|
HTML5 bool `json:"html5"`
|
||||||
|
|
||||||
// Enable directory browsing.
|
// Enable directory browsing.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user