diff --git a/middleware/body_dump.go b/middleware/body_dump.go index fa7891b1..946ffc58 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -3,6 +3,7 @@ package middleware import ( "bufio" "bytes" + "errors" "io" "net" "net/http" @@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) { } func (w *bodyDumpResponseWriter) Flush() { - w.ResponseWriter.(http.Flusher).Flush() + err := responseControllerFlush(w.ResponseWriter) + if err != nil && errors.Is(err, http.ErrNotSupported) { + panic(errors.New("response writer flushing is not supported")) + } } func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return responseControllerHijack(w.ResponseWriter) +} + +func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter } diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index de1de335..a68930b4 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) { } }) } + +func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { + bdrw := bodyDumpResponseWriter{ + ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush + } + + assert.PanicsWithError(t, "response writer flushing is not supported", func() { + bdrw.Flush() + }) +} + +func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, + } + + bdrw.Flush() + assert.Equal(t, 1, trwu.unwrapCalled) +} + +func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: trwu, + } + + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} diff --git a/middleware/compress.go b/middleware/compress.go index 3e9bd320..c77062d9 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() { } w.Writer.(*gzip.Writer).Flush() - if flusher, ok := w.ResponseWriter.(http.Flusher); ok { - flusher.Flush() - } + _ = responseControllerFlush(w.ResponseWriter) +} + +func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter } func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return responseControllerHijack(w.ResponseWriter) } func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 0ed16c81..6c5ce412 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) { } } +func TestGzipResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: trwu, + } + + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestGzipResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestGzipResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} + func BenchmarkGzip(b *testing.B) { e := echo.New() diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 44f44142..990568d5 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,7 +1,10 @@ package middleware import ( + "bufio" + "errors" "github.com/stretchr/testify/assert" + "net" "net/http" "net/http/httptest" "regexp" @@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) { }) } } + +type testResponseWriterNoFlushHijack struct { +} + +func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { +} + +func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriterNoFlushHijack) Header() http.Header { + return nil +} + +type testResponseWriterUnwrapper struct { + unwrapCalled int + rw http.ResponseWriter +} + +func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { +} + +func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriterUnwrapper) Header() http.Header { + return nil +} + +func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { + w.unwrapCalled++ + return w.rw +} + +type testResponseWriterUnwrapperHijack struct { + testResponseWriterUnwrapper +} + +func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("can hijack") +} diff --git a/middleware/responsecontroller_1.19.go b/middleware/responsecontroller_1.19.go new file mode 100644 index 00000000..104784fd --- /dev/null +++ b/middleware/responsecontroller_1.19.go @@ -0,0 +1,41 @@ +//go:build !go1.20 + +package middleware + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerFlush(rw http.ResponseWriter) error { + for { + switch t := rw.(type) { + case interface{ FlushError() error }: + return t.FlushError() + case http.Flusher: + t.Flush() + return nil + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return fmt.Errorf("%w", http.ErrNotSupported) + } + } +} + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t.Hijack() + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return nil, nil, fmt.Errorf("%w", http.ErrNotSupported) + } + } +} diff --git a/middleware/responsecontroller_1.20.go b/middleware/responsecontroller_1.20.go new file mode 100644 index 00000000..02a0cb75 --- /dev/null +++ b/middleware/responsecontroller_1.20.go @@ -0,0 +1,17 @@ +//go:build go1.20 + +package middleware + +import ( + "bufio" + "net" + "net/http" +) + +func responseControllerFlush(rw http.ResponseWriter) error { + return http.NewResponseController(rw).Flush() +} + +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(rw).Hijack() +} diff --git a/response.go b/response.go index d9c9aa6e..117881cc 100644 --- a/response.go +++ b/response.go @@ -2,6 +2,7 @@ package echo import ( "bufio" + "errors" "net" "net/http" ) @@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) { // buffered data to the client. // See [http.Flusher](https://golang.org/pkg/net/http/#Flusher) func (r *Response) Flush() { - r.Writer.(http.Flusher).Flush() + err := responseControllerFlush(r.Writer) + if err != nil && errors.Is(err, http.ErrNotSupported) { + panic(errors.New("response writer flushing is not supported")) + } } // Hijack implements the http.Hijacker interface to allow an HTTP handler to // take over the connection. // See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker) func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return r.Writer.(http.Hijacker).Hijack() + return responseControllerHijack(r.Writer) } // Unwrap returns the original http.ResponseWriter. diff --git a/response_test.go b/response_test.go index e4fd636d..e457a019 100644 --- a/response_test.go +++ b/response_test.go @@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) { assert.True(t, rec.Flushed) } +type testResponseWriter struct { +} + +func (w *testResponseWriter) WriteHeader(statusCode int) { +} + +func (w *testResponseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriter) Header() http.Header { + return nil +} + +func TestResponse_FlushPanics(t *testing.T) { + e := New() + rw := new(testResponseWriter) + res := &Response{echo: e, Writer: rw} + + // we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic + assert.PanicsWithError(t, "response writer flushing is not supported", func() { + res.Flush() + }) +} + func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) { e := New() rec := httptest.NewRecorder() diff --git a/responsecontroller_1.19.go b/responsecontroller_1.19.go new file mode 100644 index 00000000..75c6e3e5 --- /dev/null +++ b/responsecontroller_1.19.go @@ -0,0 +1,41 @@ +//go:build !go1.20 + +package echo + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerFlush(rw http.ResponseWriter) error { + for { + switch t := rw.(type) { + case interface{ FlushError() error }: + return t.FlushError() + case http.Flusher: + t.Flush() + return nil + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return fmt.Errorf("%w", http.ErrNotSupported) + } + } +} + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t.Hijack() + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return nil, nil, fmt.Errorf("%w", http.ErrNotSupported) + } + } +} diff --git a/responsecontroller_1.20.go b/responsecontroller_1.20.go new file mode 100644 index 00000000..fa2fe8b3 --- /dev/null +++ b/responsecontroller_1.20.go @@ -0,0 +1,17 @@ +//go:build go1.20 + +package echo + +import ( + "bufio" + "net" + "net/http" +) + +func responseControllerFlush(rw http.ResponseWriter) error { + return http.NewResponseController(rw).Flush() +} + +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(rw).Hijack() +}