package middleware import ( "context" "errors" "github.com/labstack/echo/v5" "time" ) // ContextTimeoutConfig defines the config for ContextTimeout middleware. type ContextTimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // ErrorHandler is a function when error aries in middeware execution. ErrorHandler func(c echo.Context, err error) error // Timeout configures a timeout for the middleware Timeout time.Duration } // ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client // when underlying method returns context.DeadlineExceeded error. func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout}) } // ContextTimeoutWithConfig returns a Timeout middleware with config. func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts Config to middleware. func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Timeout == 0 { return nil, errors.New("timeout must be set") } if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.ErrorHandler == nil { config.ErrorHandler = func(c echo.Context, err error) error { if err != nil && errors.Is(err, context.DeadlineExceeded) { return echo.ErrServiceUnavailable.WithInternal(err) } return err } } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) defer cancel() c.SetRequest(c.Request().WithContext(timeoutContext)) if err := next(c); err != nil { return config.ErrorHandler(c, err) } return nil } }, nil }