From 2386e17b21ae319c0f8044326b2c65c15f0cccd2 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Sat, 28 Nov 2020 02:03:54 +0000 Subject: [PATCH] Increasing Decompress Middleware coverage --- middleware/decompress.go | 73 +++++++++++++++++++++++------------ middleware/decompress_test.go | 61 +++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 25 deletions(-) diff --git a/middleware/decompress.go b/middleware/decompress.go index 3785ab0f..c046359a 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -16,17 +16,55 @@ type ( DecompressConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor } ) //GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" +// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers +type Decompressor interface { + gzipDecompressPool() sync.Pool +} + var ( //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper} + DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, + } ) +// DefaultGzipDecompressPool is the default implementation of Decompressor interface +type DefaultGzipDecompressPool struct { +} + +func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + // create with an empty reader (but with GZIP header) + w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) + if err != nil { + return err + } + + b := new(bytes.Buffer) + w.Reset(b) + w.Flush() + w.Close() + + r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) + if err != nil { + return err + } + return r + }, + } +} + //Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { return DecompressWithConfig(DefaultDecompressConfig) @@ -34,8 +72,16 @@ func Decompress() echo.MiddlewareFunc { //DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultGzipConfig.Skipper + } + if config.GzipDecompressPool == nil { + config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + } + return func(next echo.HandlerFunc) echo.HandlerFunc { - pool := gzipDecompressPool() + pool := config.GzipDecompressPool.gzipDecompressPool() return func(c echo.Context) error { if config.Skipper(c) { return next(c) @@ -72,26 +118,3 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { } } } - -func gzipDecompressPool() sync.Pool { - return sync.Pool{ - New: func() interface{} { - // create with an empty reader (but with GZIP header) - w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) - if err != nil { - return err - } - - b := new(bytes.Buffer) - w.Reset(b) - w.Flush() - w.Close() - - r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) - if err != nil { - return err - } - return r - }, - } -} diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 772c14f6..51fa6b0f 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -3,10 +3,12 @@ package middleware import ( "bytes" "compress/gzip" + "errors" "io/ioutil" "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/labstack/echo/v4" @@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) { assert.Equal(body, string(b)) } +func TestDecompressDefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e := echo.New() body := `{"name":"echo"}` @@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) { assert.Equal(t, body, string(reqBody)) } +type TestDecompressPoolWithError struct { +} + +func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + return errors.New("pool error") + }, + } +} + +func TestDecompressPoolError(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &TestDecompressPoolWithError{}, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) + assert.Equal(t, rec.Code, http.StatusInternalServerError) +} + func BenchmarkDecompress(b *testing.B) { e := echo.New() body := `{"name": "echo"}`