1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00

Enhanced body-limit middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-05-23 11:23:15 -07:00
parent 96898d5b9b
commit 08a173b476
4 changed files with 45 additions and 17 deletions

View File

@ -3,7 +3,6 @@ package middleware
import (
"fmt"
"io"
"net/http"
"sync"
"github.com/labstack/echo"
@ -31,8 +30,8 @@ type (
//
// BodyLimit middleware sets the maximum allowed size for a request body, if the
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
// response. The body limit is determined based on the actually read and not `Content-Length`
// request header, which makes it super secure.
// response. The body limit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
// G, T or P.
func BodyLimit(limit string) echo.MiddlewareFunc {
@ -52,10 +51,18 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
// Based on content length
if req.ContentLength() > config.limit {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
r := pool.Get().(*limitedReader)
r.Reset(req.Body(), c)
defer pool.Put(r)
req.SetBody(r)
return next(c)
}
}
@ -65,8 +72,6 @@ func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
if r.read > r.limit {
s := http.StatusRequestEntityTooLarge
r.context.String(s, http.StatusText(s))
return n, echo.ErrStatusRequestEntityTooLarge
}
return

View File

@ -1,12 +1,11 @@
package middleware
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"bytes"
"github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert"
@ -14,7 +13,8 @@ import (
func TestBodyLimit(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!")))
hw := []byte("Hello, World!")
req := test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec := test.NewResponseRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
@ -25,15 +25,29 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body))
}
// Within limit
BodyLimit("2M")(h)(c)
// Based on content length (within limit)
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, "Hello, World!", rec.Body.String())
assert.Equal(t, hw, rec.Body.Bytes())
}
// Overlimit
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!")))
// Based on content read (overlimit)
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec)
BodyLimit("2B")(h)(c)
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Status())
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, "Hello, World!", rec.Body.String())
}
// Based on content read (overlimit)
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}

View File

@ -36,9 +36,13 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
if qs != "" {
uri += "?" + qs
}
// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
}
// Forward
req.SetURI(uri)
url.SetPath(path)
}
@ -71,9 +75,13 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
if qs != "" {
uri += "?" + qs
}
// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
}
// Forward
req.SetURI(uri)
url.SetPath(path)
}

View File

@ -20,7 +20,8 @@ type (
// Optional. Default value "index.html".
Index string `json:"index"`
// Enable HTML5 mode by forwarding all not-found routes to root.
// Enable HTML5 mode by forwarding all not-found requests to root so that
// SPA (single-page application) can handle the routing.
HTML5 bool `json:"html5"`
// Enable directory browsing.