1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-26 03:20:08 +02:00

Enhanced body-limit middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-05-23 11:23:15 -07:00
parent 96898d5b9b
commit 08a173b476
4 changed files with 45 additions and 17 deletions

View File

@ -3,7 +3,6 @@ package middleware
import ( import (
"fmt" "fmt"
"io" "io"
"net/http"
"sync" "sync"
"github.com/labstack/echo" "github.com/labstack/echo"
@ -31,8 +30,8 @@ type (
// //
// BodyLimit middleware sets the maximum allowed size for a request body, if the // BodyLimit middleware sets the maximum allowed size for a request body, if the
// size exceeds the configured limit, it sends "413 - Request Entity Too Large" // size exceeds the configured limit, it sends "413 - Request Entity Too Large"
// response. The body limit is determined based on the actually read and not `Content-Length` // response. The body limit is determined based on both `Content-Length` request
// request header, which makes it super secure. // header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, // Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
// G, T or P. // G, T or P.
func BodyLimit(limit string) echo.MiddlewareFunc { func BodyLimit(limit string) echo.MiddlewareFunc {
@ -52,10 +51,18 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
req := c.Request() req := c.Request()
// Based on content length
if req.ContentLength() > config.limit {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
r := pool.Get().(*limitedReader) r := pool.Get().(*limitedReader)
r.Reset(req.Body(), c) r.Reset(req.Body(), c)
defer pool.Put(r) defer pool.Put(r)
req.SetBody(r) req.SetBody(r)
return next(c) return next(c)
} }
} }
@ -65,8 +72,6 @@ func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b) n, err = r.reader.Read(b)
r.read += int64(n) r.read += int64(n)
if r.read > r.limit { if r.read > r.limit {
s := http.StatusRequestEntityTooLarge
r.context.String(s, http.StatusText(s))
return n, echo.ErrStatusRequestEntityTooLarge return n, echo.ErrStatusRequestEntityTooLarge
} }
return return

View File

@ -1,12 +1,11 @@
package middleware package middleware
import ( import (
"bytes"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"testing" "testing"
"bytes"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test" "github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -14,7 +13,8 @@ import (
func TestBodyLimit(t *testing.T) { func TestBodyLimit(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!"))) hw := []byte("Hello, World!")
req := test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := func(c echo.Context) error { h := func(c echo.Context) error {
@ -25,15 +25,29 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body)) return c.String(http.StatusOK, string(body))
} }
// Within limit // Based on content length (within limit)
BodyLimit("2M")(h)(c) if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, "Hello, World!", rec.Body.String()) assert.Equal(t, hw, rec.Body.Bytes())
}
// Overlimit // Based on content read (overlimit)
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!"))) he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
BodyLimit("2B")(h)(c) if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Status()) assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, "Hello, World!", rec.Body.String())
}
// Based on content read (overlimit)
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
} }

View File

@ -36,9 +36,13 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
if qs != "" { if qs != "" {
uri += "?" + qs uri += "?" + qs
} }
// Redirect
if config.RedirectCode != 0 { if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri) return c.Redirect(config.RedirectCode, uri)
} }
// Forward
req.SetURI(uri) req.SetURI(uri)
url.SetPath(path) url.SetPath(path)
} }
@ -71,9 +75,13 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
if qs != "" { if qs != "" {
uri += "?" + qs uri += "?" + qs
} }
// Redirect
if config.RedirectCode != 0 { if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri) return c.Redirect(config.RedirectCode, uri)
} }
// Forward
req.SetURI(uri) req.SetURI(uri)
url.SetPath(path) url.SetPath(path)
} }

View File

@ -20,7 +20,8 @@ type (
// Optional. Default value "index.html". // Optional. Default value "index.html".
Index string `json:"index"` Index string `json:"index"`
// Enable HTML5 mode by forwarding all not-found routes to root. // Enable HTML5 mode by forwarding all not-found requests to root so that
// SPA (single-page application) can handle the routing.
HTML5 bool `json:"html5"` HTML5 bool `json:"html5"`
// Enable directory browsing. // Enable directory browsing.