1
0
mirror of https://github.com/labstack/echo.git synced 2024-11-24 08:22:21 +02:00

Timeout middleware write race

This commit is contained in:
toimtoimtoim 2022-03-13 16:00:02 +02:00 committed by Martti T
parent 01d7d01bbc
commit 1919cf4491
2 changed files with 89 additions and 46 deletions

View File

@ -2,10 +2,10 @@ package middleware
import ( import (
"context" "context"
"net/http"
"time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http"
"sync"
"time"
) )
// --------------------------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------------------------
@ -55,29 +55,27 @@ import (
// }) // })
// //
type ( // TimeoutConfig defines the config for Timeout middleware.
// TimeoutConfig defines the config for Timeout middleware. type TimeoutConfig struct {
TimeoutConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
// It can be used to define a custom timeout error message // It can be used to define a custom timeout error message
ErrorMessage string ErrorMessage string
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
// request timeouted and we already had sent the error code (503) and message response to the client. // request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context) OnTimeoutRouteErrorHandler func(err error, c echo.Context)
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout // Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable // difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration Timeout time.Duration
} }
)
var ( var (
// DefaultTimeoutConfig is the default Timeout middleware config. // DefaultTimeoutConfig is the default Timeout middleware config.
@ -94,10 +92,17 @@ func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig) return TimeoutWithConfig(DefaultTimeoutConfig)
} }
// TimeoutWithConfig returns a Timeout middleware with config. // TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
// See: `Timeout()`.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}
// ToMiddleware converts Config to middleware or returns an error for invalid configuration
func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper config.Skipper = DefaultTimeoutConfig.Skipper
} }
@ -108,26 +113,29 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
errChan := make(chan error, 1)
handlerWrapper := echoHandlerFuncWrapper{ handlerWrapper := echoHandlerFuncWrapper{
writer: &ignorableWriter{ResponseWriter: c.Response().Writer},
ctx: c, ctx: c,
handler: next, handler: next,
errChan: make(chan error, 1), errChan: errChan,
errHandler: config.OnTimeoutRouteErrorHandler, errHandler: config.OnTimeoutRouteErrorHandler,
} }
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(c.Response().Writer, c.Request()) handler.ServeHTTP(handlerWrapper.writer, c.Request())
select { select {
case err := <-handlerWrapper.errChan: case err := <-errChan:
return err return err
default: default:
return nil return nil
} }
} }
} }, nil
} }
type echoHandlerFuncWrapper struct { type echoHandlerFuncWrapper struct {
writer *ignorableWriter
ctx echo.Context ctx echo.Context
handler echo.HandlerFunc handler echo.HandlerFunc
errHandler func(err error, c echo.Context) errHandler func(err error, c echo.Context)
@ -160,23 +168,53 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
} }
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
} }
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
// and should not anymore send additional headers/data
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
if err != nil { if err != nil {
// Error must be written into Writer created in `http.TimeoutHandler` so to get Response into `commited` state. // This is needed as `http.TimeoutHandler` will write status code by itself on error and after that our tries to write
// So call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send // status code will not work anymore as Echo.Response thinks it has been already "committed" and further writes
// status code by itself and after that our tries to write status code will not work anymore and/or create errors in // create errors in log about `superfluous response.WriteHeader call from`
// log about `superfluous response.WriteHeader call from` t.writer.Ignore(true)
t.ctx.Error(err) t.ctx.Response().Writer = originalWriter // make sure we restore writer before we signal original coroutine about the error
// we pass error from handler to middlewares up in handler chain to act on it if needed. But this means that // we pass error from handler to middlewares up in handler chain to act on it if needed.
// global error handler is probably be called twice as `t.ctx.Error` already does that.
// NB: later call of the global error handler or middlewares will not take any effect, as echo.Response will be
// already marked as `committed` because we called global error handler above.
t.ctx.Response().Writer = originalWriter // make sure we restore before we signal original coroutine about the error
t.errChan <- err t.errChan <- err
return return
} }
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
// and should not anymore send additional headers/data
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
t.ctx.Response().Writer = originalWriter t.ctx.Response().Writer = originalWriter
} }
// ignorableWriter is ResponseWriter implementations that allows us to mark writer to ignore further write calls. This
// is handy in cases when you do not have direct control of code being executed (3rd party middleware) but want to make
// sure that external code will not be able to write response to the client.
// Writer is coroutine safe for writes.
type ignorableWriter struct {
http.ResponseWriter
lock sync.Mutex
ignoreWrites bool
}
func (w *ignorableWriter) Ignore(ignore bool) {
w.lock.Lock()
w.ignoreWrites = ignore
w.lock.Unlock()
}
func (w *ignorableWriter) WriteHeader(code int) {
w.lock.Lock()
defer w.lock.Unlock()
if w.ignoreWrites {
return
}
w.ResponseWriter.WriteHeader(code)
}
func (w *ignorableWriter) Write(b []byte) (int, error) {
w.lock.Lock()
defer w.lock.Unlock()
if w.ignoreWrites {
return len(b), nil
}
return w.ResponseWriter.Write(b)
}

View File

@ -74,13 +74,18 @@ func TestTimeoutErrorOutInHandler(t *testing.T) {
e := echo.New() e := echo.New()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
rec.Code = 1 // we want to be sure that even 200 will not be sent
err := m(func(c echo.Context) error { err := m(func(c echo.Context) error {
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
// to handle returned error and this can be done only then handler has not yet committed (written status code)
// the response.
return echo.NewHTTPError(http.StatusTeapot, "err") return echo.NewHTTPError(http.StatusTeapot, "err")
})(c) })(c)
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code) assert.EqualError(t, err, "code=418, message=err")
assert.Equal(t, "{\"message\":\"err\"}\n", rec.Body.String()) assert.Equal(t, 1, rec.Code)
assert.Equal(t, "", rec.Body.String())
} }
func TestTimeoutSuccessfulRequest(t *testing.T) { func TestTimeoutSuccessfulRequest(t *testing.T) {