mirror of
https://github.com/labstack/echo.git
synced 2025-01-01 22:09:21 +02:00
Allow ResponseWriters to unwrap writers when flushing/hijacking (#2595)
* Allow ResponseWriters to unwrap writers when flushing/hijacking
This commit is contained in:
parent
3e04e3e2f2
commit
bc1e1904f1
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *bodyDumpResponseWriter) Flush() {
|
||||
w.ResponseWriter.(http.Flusher).Flush()
|
||||
err := responseControllerFlush(w.ResponseWriter)
|
||||
if err != nil && errors.Is(err, http.ErrNotSupported) {
|
||||
panic(errors.New("response writer flushing is not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.ResponseWriter.(http.Hijacker).Hijack()
|
||||
return responseControllerHijack(w.ResponseWriter)
|
||||
}
|
||||
|
||||
func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
|
||||
bdrw := bodyDumpResponseWriter{
|
||||
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
|
||||
}
|
||||
|
||||
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
|
||||
bdrw.Flush()
|
||||
})
|
||||
}
|
||||
|
||||
func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
|
||||
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
|
||||
bdrw := bodyDumpResponseWriter{
|
||||
ResponseWriter: &trwu,
|
||||
}
|
||||
|
||||
bdrw.Flush()
|
||||
assert.Equal(t, 1, trwu.unwrapCalled)
|
||||
}
|
||||
|
||||
func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
|
||||
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
|
||||
bdrw := bodyDumpResponseWriter{
|
||||
ResponseWriter: trwu,
|
||||
}
|
||||
|
||||
result := bdrw.Unwrap()
|
||||
assert.Equal(t, trwu, result)
|
||||
}
|
||||
|
||||
func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
|
||||
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
|
||||
bdrw := bodyDumpResponseWriter{
|
||||
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
|
||||
}
|
||||
|
||||
_, _, err := bdrw.Hijack()
|
||||
assert.EqualError(t, err, "can hijack")
|
||||
}
|
||||
|
||||
func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
|
||||
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
|
||||
bdrw := bodyDumpResponseWriter{
|
||||
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
|
||||
}
|
||||
|
||||
_, _, err := bdrw.Hijack()
|
||||
assert.EqualError(t, err, "feature not supported")
|
||||
}
|
||||
|
@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() {
|
||||
}
|
||||
|
||||
w.Writer.(*gzip.Writer).Flush()
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
_ = responseControllerFlush(w.ResponseWriter)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.ResponseWriter.(http.Hijacker).Hijack()
|
||||
return responseControllerHijack(w.ResponseWriter)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||
|
@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
|
||||
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
|
||||
bdrw := gzipResponseWriter{
|
||||
ResponseWriter: trwu,
|
||||
}
|
||||
|
||||
result := bdrw.Unwrap()
|
||||
assert.Equal(t, trwu, result)
|
||||
}
|
||||
|
||||
func TestGzipResponseWriter_CanHijack(t *testing.T) {
|
||||
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
|
||||
bdrw := gzipResponseWriter{
|
||||
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
|
||||
}
|
||||
|
||||
_, _, err := bdrw.Hijack()
|
||||
assert.EqualError(t, err, "can hijack")
|
||||
}
|
||||
|
||||
func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
|
||||
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
|
||||
bdrw := gzipResponseWriter{
|
||||
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
|
||||
}
|
||||
|
||||
_, _, err := bdrw.Hijack()
|
||||
assert.EqualError(t, err, "feature not supported")
|
||||
}
|
||||
|
||||
func BenchmarkGzip(b *testing.B) {
|
||||
e := echo.New()
|
||||
|
||||
|
@ -1,7 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testResponseWriterNoFlushHijack struct {
|
||||
}
|
||||
|
||||
func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
|
||||
}
|
||||
|
||||
func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (w *testResponseWriterNoFlushHijack) Header() http.Header {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testResponseWriterUnwrapper struct {
|
||||
unwrapCalled int
|
||||
rw http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
|
||||
}
|
||||
|
||||
func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (w *testResponseWriterUnwrapper) Header() http.Header {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
|
||||
w.unwrapCalled++
|
||||
return w.rw
|
||||
}
|
||||
|
||||
type testResponseWriterUnwrapperHijack struct {
|
||||
testResponseWriterUnwrapper
|
||||
}
|
||||
|
||||
func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, errors.New("can hijack")
|
||||
}
|
||||
|
41
middleware/responsecontroller_1.19.go
Normal file
41
middleware/responsecontroller_1.19.go
Normal file
@ -0,0 +1,41 @@
|
||||
//go:build !go1.20
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
|
||||
func responseControllerFlush(rw http.ResponseWriter) error {
|
||||
for {
|
||||
switch t := rw.(type) {
|
||||
case interface{ FlushError() error }:
|
||||
return t.FlushError()
|
||||
case http.Flusher:
|
||||
t.Flush()
|
||||
return nil
|
||||
case interface{ Unwrap() http.ResponseWriter }:
|
||||
rw = t.Unwrap()
|
||||
default:
|
||||
return fmt.Errorf("%w", http.ErrNotSupported)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
|
||||
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
|
||||
for {
|
||||
switch t := rw.(type) {
|
||||
case http.Hijacker:
|
||||
return t.Hijack()
|
||||
case interface{ Unwrap() http.ResponseWriter }:
|
||||
rw = t.Unwrap()
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
|
||||
}
|
||||
}
|
||||
}
|
17
middleware/responsecontroller_1.20.go
Normal file
17
middleware/responsecontroller_1.20.go
Normal file
@ -0,0 +1,17 @@
|
||||
//go:build go1.20
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func responseControllerFlush(rw http.ResponseWriter) error {
|
||||
return http.NewResponseController(rw).Flush()
|
||||
}
|
||||
|
||||
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
|
||||
return http.NewResponseController(rw).Hijack()
|
||||
}
|
@ -2,6 +2,7 @@ package echo
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
|
||||
// buffered data to the client.
|
||||
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
|
||||
func (r *Response) Flush() {
|
||||
r.Writer.(http.Flusher).Flush()
|
||||
err := responseControllerFlush(r.Writer)
|
||||
if err != nil && errors.Is(err, http.ErrNotSupported) {
|
||||
panic(errors.New("response writer flushing is not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface to allow an HTTP handler to
|
||||
// take over the connection.
|
||||
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
|
||||
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return r.Writer.(http.Hijacker).Hijack()
|
||||
return responseControllerHijack(r.Writer)
|
||||
}
|
||||
|
||||
// Unwrap returns the original http.ResponseWriter.
|
||||
|
@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) {
|
||||
assert.True(t, rec.Flushed)
|
||||
}
|
||||
|
||||
type testResponseWriter struct {
|
||||
}
|
||||
|
||||
func (w *testResponseWriter) WriteHeader(statusCode int) {
|
||||
}
|
||||
|
||||
func (w *testResponseWriter) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (w *testResponseWriter) Header() http.Header {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestResponse_FlushPanics(t *testing.T) {
|
||||
e := New()
|
||||
rw := new(testResponseWriter)
|
||||
res := &Response{echo: e, Writer: rw}
|
||||
|
||||
// we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic
|
||||
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
|
||||
res.Flush()
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
|
||||
e := New()
|
||||
rec := httptest.NewRecorder()
|
||||
|
41
responsecontroller_1.19.go
Normal file
41
responsecontroller_1.19.go
Normal file
@ -0,0 +1,41 @@
|
||||
//go:build !go1.20
|
||||
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
|
||||
func responseControllerFlush(rw http.ResponseWriter) error {
|
||||
for {
|
||||
switch t := rw.(type) {
|
||||
case interface{ FlushError() error }:
|
||||
return t.FlushError()
|
||||
case http.Flusher:
|
||||
t.Flush()
|
||||
return nil
|
||||
case interface{ Unwrap() http.ResponseWriter }:
|
||||
rw = t.Unwrap()
|
||||
default:
|
||||
return fmt.Errorf("%w", http.ErrNotSupported)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
|
||||
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
|
||||
for {
|
||||
switch t := rw.(type) {
|
||||
case http.Hijacker:
|
||||
return t.Hijack()
|
||||
case interface{ Unwrap() http.ResponseWriter }:
|
||||
rw = t.Unwrap()
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
|
||||
}
|
||||
}
|
||||
}
|
17
responsecontroller_1.20.go
Normal file
17
responsecontroller_1.20.go
Normal file
@ -0,0 +1,17 @@
|
||||
//go:build go1.20
|
||||
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func responseControllerFlush(rw http.ResponseWriter) error {
|
||||
return http.NewResponseController(rw).Flush()
|
||||
}
|
||||
|
||||
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
|
||||
return http.NewResponseController(rw).Hijack()
|
||||
}
|
Loading…
Reference in New Issue
Block a user