1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-01 22:09:21 +02:00

Allow ResponseWriters to unwrap writers when flushing/hijacking (#2595)

* Allow ResponseWriters to unwrap writers when flushing/hijacking
This commit is contained in:
Martti T 2024-03-09 10:50:47 +02:00 committed by GitHub
parent 3e04e3e2f2
commit bc1e1904f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 289 additions and 8 deletions

View File

@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
}
func (w *bodyDumpResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
err := responseControllerFlush(w.ResponseWriter)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}
func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return responseControllerHijack(w.ResponseWriter)
}
func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) {
}
})
}
func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
}
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
bdrw.Flush()
})
}
func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu,
}
bdrw.Flush()
assert.Equal(t, 1, trwu.unwrapCalled)
}
func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: trwu,
}
result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}
func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}
func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}

View File

@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() {
}
w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
_ = responseControllerFlush(w.ResponseWriter)
}
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return responseControllerHijack(w.ResponseWriter)
}
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {

View File

@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) {
}
}
func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: trwu,
}
result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}
func TestGzipResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}
func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
func BenchmarkGzip(b *testing.B) {
e := echo.New()

View File

@ -1,7 +1,10 @@
package middleware
import (
"bufio"
"errors"
"github.com/stretchr/testify/assert"
"net"
"net/http"
"net/http/httptest"
"regexp"
@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) {
})
}
}
type testResponseWriterNoFlushHijack struct {
}
func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
}
func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriterNoFlushHijack) Header() http.Header {
return nil
}
type testResponseWriterUnwrapper struct {
unwrapCalled int
rw http.ResponseWriter
}
func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
}
func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriterUnwrapper) Header() http.Header {
return nil
}
func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
w.unwrapCalled++
return w.rw
}
type testResponseWriterUnwrapperHijack struct {
testResponseWriterUnwrapper
}
func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("can hijack")
}

View File

@ -0,0 +1,41 @@
//go:build !go1.20
package middleware
import (
"bufio"
"fmt"
"net"
"net/http"
)
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

View File

@ -0,0 +1,17 @@
//go:build go1.20
package middleware
import (
"bufio"
"net"
"net/http"
)
func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}

View File

@ -2,6 +2,7 @@ package echo
import (
"bufio"
"errors"
"net"
"net/http"
)
@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
// buffered data to the client.
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
func (r *Response) Flush() {
r.Writer.(http.Flusher).Flush()
err := responseControllerFlush(r.Writer)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}
// Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection.
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.Writer.(http.Hijacker).Hijack()
return responseControllerHijack(r.Writer)
}
// Unwrap returns the original http.ResponseWriter.

View File

@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) {
assert.True(t, rec.Flushed)
}
type testResponseWriter struct {
}
func (w *testResponseWriter) WriteHeader(statusCode int) {
}
func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriter) Header() http.Header {
return nil
}
func TestResponse_FlushPanics(t *testing.T) {
e := New()
rw := new(testResponseWriter)
res := &Response{echo: e, Writer: rw}
// we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
res.Flush()
})
}
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()

View File

@ -0,0 +1,41 @@
//go:build !go1.20
package echo
import (
"bufio"
"fmt"
"net"
"net/http"
)
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

View File

@ -0,0 +1,17 @@
//go:build go1.20
package echo
import (
"bufio"
"net"
"net/http"
)
func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}