mirror of
https://github.com/labstack/echo.git
synced 2024-11-28 08:38:39 +02:00
Timeout middleware implementation for go1.13+ (#1743)
Co-authored-by: Ilija Matoski <imatoski@schubergphilis.com>
This commit is contained in:
parent
02ed3f3126
commit
67263b5e45
81
middleware/timeout.go
Normal file
81
middleware/timeout.go
Normal file
@ -0,0 +1,81 @@
|
||||
// +build go1.13
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/labstack/echo/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
// TimeoutConfig defines the config for Timeout middleware.
|
||||
TimeoutConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// ErrorHandler defines a function which is executed for a timeout
|
||||
// It can be used to define a custom timeout error
|
||||
ErrorHandler TimeoutErrorHandlerWithContext
|
||||
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can
|
||||
// handle the error as we see fit
|
||||
TimeoutErrorHandlerWithContext func(error, echo.Context) error
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultTimeoutConfig is the default Timeout middleware config.
|
||||
DefaultTimeoutConfig = TimeoutConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Timeout: 0,
|
||||
ErrorHandler: nil,
|
||||
}
|
||||
)
|
||||
|
||||
// Timeout returns a middleware which recovers from panics anywhere in the chain
|
||||
// and handles the control to the centralized HTTPErrorHandler.
|
||||
func Timeout() echo.MiddlewareFunc {
|
||||
return TimeoutWithConfig(DefaultTimeoutConfig)
|
||||
}
|
||||
|
||||
// TimeoutWithConfig returns a Timeout middleware with config.
|
||||
// See: `Timeout()`.
|
||||
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultTimeoutConfig.Skipper
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) || config.Timeout == 0 {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// this does a deep clone of the context, wondering if there is a better way to do this?
|
||||
c.SetRequest(c.Request().Clone(ctx))
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
// This goroutine will keep running even if this middleware times out and
|
||||
// will be stopped when ctx.Done() is called down the next(c) call chain
|
||||
done <- next(c)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(ctx.Err(), c)
|
||||
}
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
177
middleware/timeout_test.go
Normal file
177
middleware/timeout_test.go
Normal file
@ -0,0 +1,177 @@
|
||||
// +build go1.13
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTimeoutSkipper(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := TimeoutWithConfig(TimeoutConfig{
|
||||
Skipper: func(context echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(func(c echo.Context) error {
|
||||
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
|
||||
return nil
|
||||
})(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTimeoutWithTimeout0(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := TimeoutWithConfig(TimeoutConfig{
|
||||
Timeout: 0,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(func(c echo.Context) error {
|
||||
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
|
||||
return nil
|
||||
})(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTimeoutIsCancelable(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := TimeoutWithConfig(TimeoutConfig{
|
||||
Timeout: time.Minute,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(func(c echo.Context) error {
|
||||
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
|
||||
return nil
|
||||
})(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTimeoutErrorOutInHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := Timeout()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(func(c echo.Context) error {
|
||||
return errors.New("err")
|
||||
})(c)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := TimeoutWithConfig(TimeoutConfig{
|
||||
Timeout: time.Second,
|
||||
ErrorHandler: func(err error, e echo.Context) error {
|
||||
assert.EqualError(t, err, context.DeadlineExceeded.Error())
|
||||
return errors.New("err")
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(func(c echo.Context) error {
|
||||
time.Sleep(time.Minute)
|
||||
return nil
|
||||
})(c)
|
||||
|
||||
assert.EqualError(t, err, errors.New("err").Error())
|
||||
}
|
||||
|
||||
func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := TimeoutWithConfig(TimeoutConfig{
|
||||
Timeout: time.Second,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(func(c echo.Context) error {
|
||||
time.Sleep(time.Minute)
|
||||
return nil
|
||||
})(c)
|
||||
|
||||
assert.EqualError(t, err, context.DeadlineExceeded.Error())
|
||||
}
|
||||
|
||||
func TestTimeoutTestRequestClone(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
|
||||
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
m := TimeoutWithConfig(TimeoutConfig{
|
||||
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
|
||||
Timeout: time.Second,
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(func(c echo.Context) error {
|
||||
// Cookie test
|
||||
cookie, err := c.Request().Cookie("cookie")
|
||||
if assert.NoError(t, err) {
|
||||
assert.EqualValues(t, "cookie", cookie.Name)
|
||||
assert.EqualValues(t, "value", cookie.Value)
|
||||
}
|
||||
|
||||
// Form values
|
||||
if assert.NoError(t, c.Request().ParseForm()) {
|
||||
assert.EqualValues(t, "value", c.Request().FormValue("form"))
|
||||
}
|
||||
|
||||
// Query string
|
||||
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
|
||||
return nil
|
||||
})(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user