1
0
mirror of https://github.com/umputun/reproxy.git synced 2025-11-29 22:08:14 +02:00
Files
reproxy/app/proxy/middleware/compress_test.go
2021-04-09 13:50:01 -05:00

219 lines
5.4 KiB
Go

package middleware
import (
"compress/flate"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// based on https://github.com/go-chi/chi/blob/master/middleware/compress_test.go
// removed chi deps from the testing code
func TestCompressor(t *testing.T) {
r := http.NewServeMux()
compressor := NewCompressor(5, "text/html", "text/css")
if len(compressor.encoders) != 0 || len(compressor.pooledEncoders) != 2 {
t.Errorf("gzip and deflate should be pooled")
}
compressor.SetEncoder("nop", func(w io.Writer, _ int) io.Writer {
return w
})
if len(compressor.encoders) != 1 {
t.Errorf("nop encoder should be stored in the encoders map")
}
r.Handle("/gethtml", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("textstring"))
}))
r.Handle("/getcss", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("textstring"))
}))
r.Handle("/getplain", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("textstring"))
}))
ts := httptest.NewServer(compressor.Handler(r))
defer ts.Close()
tests := []struct {
name string
path string
expectedEncoding string
acceptedEncodings []string
}{
{
name: "no expected encodings due to no accepted encodings",
path: "/gethtml",
acceptedEncodings: nil,
expectedEncoding: "",
},
{
name: "no expected encodings due to content type",
path: "/getplain",
acceptedEncodings: nil,
expectedEncoding: "",
},
{
name: "gzip is only encoding",
path: "/gethtml",
acceptedEncodings: []string{"gzip"},
expectedEncoding: "gzip",
},
{
name: "gzip is preferred over deflate",
path: "/getcss",
acceptedEncodings: []string{"gzip", "deflate"},
expectedEncoding: "gzip",
},
{
name: "deflate is used",
path: "/getcss",
acceptedEncodings: []string{"deflate"},
expectedEncoding: "deflate",
},
{
name: "nop is preferred",
path: "/getcss",
acceptedEncodings: []string{"nop, gzip, deflate"},
expectedEncoding: "nop",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...)
if respString != "textstring" {
t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString)
}
if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding {
t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got)
}
})
}
}
func TestCompressorWildcards(t *testing.T) {
tests := []struct {
name string
recover string
types []string
typesCount int
wcCount int
}{
{
name: "defaults",
typesCount: 10,
},
{
name: "no wildcard",
types: []string{"text/plain", "text/html"},
typesCount: 2,
},
{
name: "invalid wildcard #1",
types: []string{"audio/*wav"},
recover: "middleware/compress: Unsupported content-type wildcard pattern 'audio/*wav'. Only '/*' supported",
},
{
name: "invalid wildcard #2",
types: []string{"application*/*"},
recover: "middleware/compress: Unsupported content-type wildcard pattern 'application*/*'. Only '/*' supported",
},
{
name: "valid wildcard",
types: []string{"text/*"},
wcCount: 1,
},
{
name: "mixed",
types: []string{"audio/wav", "text/*"},
typesCount: 1,
wcCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if tt.recover == "" {
tt.recover = "<nil>"
}
if r := recover(); tt.recover != fmt.Sprintf("%v", r) {
t.Errorf("Unexpected value recovered: %v", r)
}
}()
compressor := NewCompressor(5, tt.types...)
if len(compressor.allowedTypes) != tt.typesCount {
t.Errorf("expected %d allowedTypes, got %d", tt.typesCount, len(compressor.allowedTypes))
}
if len(compressor.allowedWildcards) != tt.wcCount {
t.Errorf("expected %d allowedWildcards, got %d", tt.wcCount, len(compressor.allowedWildcards))
}
})
}
}
func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, nil)
if err != nil {
t.Fatal(err)
return nil, ""
}
if len(encodings) > 0 {
encodingsString := strings.Join(encodings, ",")
req.Header.Set("Accept-Encoding", encodingsString)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody := decodeResponseBody(t, resp)
defer resp.Body.Close()
return resp, respBody
}
func decodeResponseBody(t *testing.T, resp *http.Response) string {
var reader io.ReadCloser
switch resp.Header.Get("Content-Encoding") {
case "gzip":
var err error
reader, err = gzip.NewReader(resp.Body)
if err != nil {
t.Fatal(err)
}
case "deflate":
reader = flate.NewReader(resp.Body)
default:
reader = resp.Body
}
respBody, err := ioutil.ReadAll(reader)
if err != nil {
t.Fatal(err)
return ""
}
reader.Close()
return string(respBody)
}