diff --git a/echo.go b/echo.go index 29e3e139..f229a7ef 100644 --- a/echo.go +++ b/echo.go @@ -193,12 +193,13 @@ var ( // Errors var ( - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - ErrNotFound = NewHTTPError(http.StatusNotFound) - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) + ErrNotFound = NewHTTPError(http.StatusNotFound) + ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) + ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) + ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") ) // Error handlers diff --git a/engine/engine.go b/engine/engine.go index 021f0ed2..10c90e6e 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -69,6 +69,9 @@ type ( // Body returns request's body. Body() io.Reader + // Body sets request's body. + SetBody(io.Reader) + // FormValue returns the form field value for the provided name. FormValue(string) string diff --git a/engine/fasthttp/request.go b/engine/fasthttp/request.go index de84180b..892958ca 100644 --- a/engine/fasthttp/request.go +++ b/engine/fasthttp/request.go @@ -7,6 +7,7 @@ import ( "io" "mime/multipart" + "github.com/labstack/echo" "github.com/labstack/echo/engine" "github.com/labstack/gommon/log" "github.com/valyala/fasthttp" @@ -97,6 +98,11 @@ func (r *Request) Body() io.Reader { return bytes.NewBuffer(r.PostBody()) } +// SetBody implements `engine.Request#SetBody` function. +func (r *Request) SetBody(reader io.Reader) { + r.SetBodyStream(reader, r.header.Get(echo.HeaderContentType)) +} + // FormValue implements `engine.Request#FormValue` function. func (r *Request) FormValue(name string) string { return string(r.RequestCtx.FormValue(name)) diff --git a/engine/standard/request.go b/engine/standard/request.go index 5ac85dbc..9d311a84 100644 --- a/engine/standard/request.go +++ b/engine/standard/request.go @@ -2,6 +2,7 @@ package standard import ( "io" + "io/ioutil" "mime/multipart" "net/http" @@ -109,6 +110,11 @@ func (r *Request) Body() io.Reader { return r.Request.Body } +// SetBody implements `engine.Request#SetBody` function. +func (r *Request) SetBody(reader io.Reader) { + r.Request.Body = ioutil.NopCloser(reader) +} + // FormValue implements `engine.Request#FormValue` function. func (r *Request) FormValue(name string) string { return r.Request.FormValue(name) diff --git a/glide.lock b/glide.lock index a60e5336..6e6f035e 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ hash: 21820434709470e49c64df0f854d3352088ca664d193e29bc6cd434518c27a7c -updated: 2016-04-24T11:03:22.86754619-07:00 +updated: 2016-04-30T19:53:13.742557404-07:00 imports: - name: github.com/dgrijalva/jwt-go version: a2c85815a77d0f951e33ba4db5ae93629a1530af @@ -14,7 +14,7 @@ imports: - name: github.com/klauspost/crc32 version: 19b0b332c9e4516a6370a0456e6182c3b5036720 - name: github.com/labstack/gommon - version: 4fae226dd67b1100622ab213e798e5ee4c5d4230 + version: 2e62be24dbb1ceb226554aaccfe5a89ec71043b3 subpackages: - color - log @@ -27,16 +27,16 @@ imports: subpackages: - assert - name: github.com/valyala/fasthttp - version: 229698876475fef591fb04fabba5479fc3182291 + version: 3509bd8a7d1d9e9a9a4e8594e78af85fc01e2fba - name: github.com/valyala/fasttemplate version: 3b874956e03f1636d171bda64b130f9135f42cff - name: golang.org/x/net - version: b797637b7aeeed133049c7281bfa31dcc9ca42d6 + version: 1aafd77e1e7f6849ad16a7bdeb65e3589a10b2bb subpackages: - context - websocket - name: golang.org/x/sys - version: f64b50fbea64174967a8882830d621a18ee1548e + version: b776ec39b3e54652e09028aaaaac9757f4f8211a subpackages: - unix devImports: [] diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index ac8c07e6..12f102e3 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -21,20 +21,13 @@ const ( basic = "Basic" ) -var ( - // DefaultBasicAuthConfig is the default basic auth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{} -) - // BasicAuth returns an HTTP basic auth middleware. // // For valid credentials it calls the next handler. // For invalid credentials, it sends "401 - Unauthorized" response. // For empty or invalid `Authorization` header, it sends "400 - Bad Request" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.Validator = fn - return BasicAuthWithConfig(c) + return BasicAuthWithConfig(BasicAuthConfig{fn}) } // BasicAuthWithConfig returns an HTTP basic auth middleware from config. diff --git a/middleware/body_limit.go b/middleware/body_limit.go new file mode 100644 index 00000000..8e517b47 --- /dev/null +++ b/middleware/body_limit.go @@ -0,0 +1,84 @@ +package middleware + +import ( + "fmt" + "io" + "net/http" + "sync" + + "github.com/labstack/echo" + "github.com/labstack/gommon/bytes" +) + +type ( + // BodyLimitConfig defines the config for body limit middleware. + BodyLimitConfig struct { + Limit string `json:"limit"` + limit int + } + + limitedReader struct { + BodyLimitConfig + reader io.Reader + read int + context echo.Context + } +) + +// BodyLimit returns a body limit middleware. +// +// BodyLimit middleware sets the maximum allowed size for a request body, if the +// size exceeds the configured limit, it sends "413 - Request Entity Too Large" +// response. The body limit is determined based on the actually read and not `Content-Length` +// request header, which makes it super secure. +// Limit can be specifed as `4x` or `4xB`, where x is one of the multple from K, M, +// G, T or P. +func BodyLimit(limit string) echo.MiddlewareFunc { + return BodyLimitWithConfig(BodyLimitConfig{Limit: limit}) +} + +// BodyLimitWithConfig returns a body limit middleware from config. +// See `BodyLimit()`. +func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { + limit, err := bytes.Parse(config.Limit) + if err != nil { + panic(fmt.Errorf("invalid body-limit=%s", config.Limit)) + } + config.limit = limit + pool := limitedReaderPool(config) + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + req := c.Request() + r := pool.Get().(*limitedReader) + r.Reset(req.Body(), c) + defer pool.Put(r) + req.SetBody(r) + return next(c) + } + } +} + +func (r *limitedReader) Read(b []byte) (n int, err error) { + n, err = r.reader.Read(b) + r.read += n + if r.read > r.limit { + s := http.StatusRequestEntityTooLarge + r.context.String(s, http.StatusText(s)) + return n, echo.ErrStatusRequestEntityTooLarge + } + return +} + +func (r *limitedReader) Reset(reader io.Reader, context echo.Context) { + r.reader = reader + r.context = context +} + +func limitedReaderPool(c BodyLimitConfig) sync.Pool { + return sync.Pool{ + New: func() interface{} { + return &limitedReader{BodyLimitConfig: c} + }, + } +} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go new file mode 100644 index 00000000..0538735a --- /dev/null +++ b/middleware/body_limit_test.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "io/ioutil" + "net/http" + "testing" + + "bytes" + + "github.com/labstack/echo" + "github.com/labstack/echo/test" + "github.com/stretchr/testify/assert" +) + +func TestBodyLimit(t *testing.T) { + e := echo.New() + req := test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("Hello, World!"))) + rec := test.NewResponseRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, _ := ioutil.ReadAll(c.Request().Body()) + return c.String(http.StatusOK, string(body)) + } + + // Within limit + BodyLimit("2M")(h)(c) + assert.Equal(t, http.StatusOK, rec.Status()) + assert.Equal(t, "Hello, World!", rec.Body.String()) + + // Overlimit + // BodyLimit("2B")(h)(c) + // assert.Equal(t, "Hello, World!", rec.Body.String()) +} diff --git a/test/request.go b/test/request.go index e4067388..8240509c 100644 --- a/test/request.go +++ b/test/request.go @@ -2,6 +2,7 @@ package test import ( "io" + "io/ioutil" "mime/multipart" "net/http" @@ -93,6 +94,10 @@ func (r *Request) Body() io.Reader { return r.request.Body } +func (r *Request) SetBody(reader io.Reader) { + r.request.Body = ioutil.NopCloser(reader) +} + func (r *Request) FormValue(name string) string { return r.request.FormValue(name) }