2021-06-01 02:56:39 -05:00
|
|
|
package proxy
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
2021-07-03 01:23:50 -05:00
|
|
|
"context"
|
2021-06-01 02:56:39 -05:00
|
|
|
"net/http"
|
|
|
|
|
"net/http/httptest"
|
2021-07-03 01:23:50 -05:00
|
|
|
"strconv"
|
|
|
|
|
"sync"
|
|
|
|
|
"sync/atomic"
|
2021-06-01 02:56:39 -05:00
|
|
|
"testing"
|
|
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/require"
|
2021-07-03 01:23:50 -05:00
|
|
|
|
|
|
|
|
"github.com/umputun/reproxy/app/discovery"
|
2021-06-01 02:56:39 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func Test_headersHandler(t *testing.T) {
|
|
|
|
|
wr := httptest.NewRecorder()
|
2021-09-11 14:38:56 -05:00
|
|
|
handler := headersHandler([]string{"k1:v1", "k2:v2"}, []string{"r1", "r2"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
2021-06-01 02:56:39 -05:00
|
|
|
t.Logf("req: %v", r)
|
2021-09-11 14:38:56 -05:00
|
|
|
assert.Equal(t, "", r.Header.Get("r1"), "r1 header dropped")
|
|
|
|
|
assert.Equal(t, "", r.Header.Get("r2"), "r2 header dropped")
|
|
|
|
|
assert.Equal(t, "rv3", r.Header.Get("r3"), "r3 kept")
|
2021-06-01 02:56:39 -05:00
|
|
|
}))
|
2021-11-09 12:18:26 -06:00
|
|
|
req, err := http.NewRequest("GET", "http://example.com", http.NoBody)
|
2021-06-01 02:56:39 -05:00
|
|
|
require.NoError(t, err)
|
2021-09-11 14:38:56 -05:00
|
|
|
req.Header.Set("r1", "rv1")
|
|
|
|
|
req.Header.Set("r2", "rv2")
|
|
|
|
|
req.Header.Set("r3", "rv3")
|
2021-06-01 02:56:39 -05:00
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, "v1", wr.Result().Header.Get("k1"))
|
|
|
|
|
assert.Equal(t, "v2", wr.Result().Header.Get("k2"))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_maxReqSizeHandler(t *testing.T) {
|
2024-11-26 13:33:02 -06:00
|
|
|
t.Run("good size", func(t *testing.T) {
|
2021-06-01 02:56:39 -05:00
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := maxReqSizeHandler(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
req, err := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("123456"))
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, http.StatusOK, wr.Result().StatusCode, "good size, full response")
|
2024-11-26 13:33:02 -06:00
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("too large size", func(t *testing.T) {
|
2021-06-01 02:56:39 -05:00
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := maxReqSizeHandler(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
req, err := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("123456789012345"))
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, http.StatusRequestEntityTooLarge, wr.Result().StatusCode)
|
2024-11-26 13:33:02 -06:00
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("zero max size", func(t *testing.T) {
|
2021-06-01 02:56:39 -05:00
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := maxReqSizeHandler(0)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
req, err := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("123456"))
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, http.StatusOK, wr.Result().StatusCode, "good size, full response")
|
2024-11-26 13:33:02 -06:00
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("too large request size", func(t *testing.T) {
|
|
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := maxReqSizeHandler(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
req, err := http.NewRequest("GET", "http://example.com?q=123456789012345", http.NoBody)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, http.StatusRequestURITooLong, wr.Result().StatusCode)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("good request size", func(t *testing.T) {
|
|
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := maxReqSizeHandler(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
req, err := http.NewRequest("GET", "http://example.com?q=12345678", http.NoBody)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, http.StatusOK, wr.Result().StatusCode)
|
|
|
|
|
})
|
2021-06-01 02:56:39 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_signatureHandler(t *testing.T) {
|
2024-11-27 03:36:27 -06:00
|
|
|
t.Run("with signature", func(t *testing.T) {
|
2021-06-01 02:56:39 -05:00
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := signatureHandler(true, "v0.0.1")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
req, err := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("123456"))
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, http.StatusOK, wr.Result().StatusCode)
|
|
|
|
|
assert.Equal(t, "reproxy", wr.Result().Header.Get("App-Name"), wr.Result().Header)
|
|
|
|
|
assert.Equal(t, "umputun", wr.Result().Header.Get("Author"), wr.Result().Header)
|
|
|
|
|
assert.Equal(t, "v0.0.1", wr.Result().Header.Get("App-Version"), wr.Result().Header)
|
2024-11-27 03:36:27 -06:00
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("without signature", func(t *testing.T) {
|
2021-06-01 02:56:39 -05:00
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := signatureHandler(false, "v0.0.1")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
req, err := http.NewRequest("POST", "http://example.com", bytes.NewBufferString("123456"))
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
assert.Equal(t, http.StatusOK, wr.Result().StatusCode)
|
|
|
|
|
assert.Equal(t, "", wr.Result().Header.Get("App-Name"), wr.Result().Header)
|
|
|
|
|
assert.Equal(t, "", wr.Result().Header.Get("Author"), wr.Result().Header)
|
|
|
|
|
assert.Equal(t, "", wr.Result().Header.Get("App-Version"), wr.Result().Header)
|
2024-11-27 03:36:27 -06:00
|
|
|
})
|
2021-06-01 02:56:39 -05:00
|
|
|
}
|
2021-07-03 01:23:50 -05:00
|
|
|
|
|
|
|
|
func Test_limiterSystemHandler(t *testing.T) {
|
|
|
|
|
|
|
|
|
|
var passed int32
|
|
|
|
|
handler := limiterSystemHandler(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
atomic.AddInt32(&passed, 1)
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
ts := httptest.NewServer(handler)
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
|
wg.Add(100)
|
|
|
|
|
for i := 0; i < 100; i++ {
|
|
|
|
|
go func() {
|
|
|
|
|
defer wg.Done()
|
2021-11-09 12:18:26 -06:00
|
|
|
req, err := http.NewRequest("GET", ts.URL, http.NoBody)
|
2021-07-03 01:23:50 -05:00
|
|
|
require.NoError(t, err)
|
|
|
|
|
client := http.Client{}
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
wg.Wait()
|
|
|
|
|
assert.Equal(t, int32(10), atomic.LoadInt32(&passed))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_limiterClientHandlerNoMatches(t *testing.T) {
|
|
|
|
|
|
|
|
|
|
var passed int32
|
|
|
|
|
handler := limiterUserHandler(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
atomic.AddInt32(&passed, 1)
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
ts := httptest.NewServer(handler)
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
|
wg.Add(100)
|
|
|
|
|
for i := 0; i < 100; i++ {
|
|
|
|
|
go func() {
|
|
|
|
|
defer wg.Done()
|
2021-11-09 12:18:26 -06:00
|
|
|
req, err := http.NewRequest("GET", ts.URL, http.NoBody)
|
2021-07-03 01:23:50 -05:00
|
|
|
require.NoError(t, err)
|
|
|
|
|
client := http.Client{}
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
wg.Wait()
|
|
|
|
|
assert.Equal(t, int32(10), atomic.LoadInt32(&passed))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_limiterClientHandlerWithMatches(t *testing.T) {
|
|
|
|
|
var passed int32
|
|
|
|
|
handler := limiterUserHandler(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
atomic.AddInt32(&passed, 1)
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
wrapWithContext := func(next http.Handler) http.Handler {
|
|
|
|
|
var id int32
|
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
n := int(atomic.AddInt32(&id, 1))
|
|
|
|
|
m := discovery.MatchedRoute{Mapper: discovery.URLMapper{Dst: strconv.Itoa(n % 2)}}
|
|
|
|
|
ctx := context.WithValue(context.Background(), ctxMatchType, discovery.MTProxy)
|
|
|
|
|
ctx = context.WithValue(ctx, ctxMatch, m)
|
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ts := httptest.NewServer(wrapWithContext(handler))
|
|
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
|
wg.Add(100)
|
|
|
|
|
for i := 0; i < 100; i++ {
|
|
|
|
|
go func(id int) {
|
|
|
|
|
defer wg.Done()
|
|
|
|
|
req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("123456"))
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
m := discovery.MatchedRoute{Mapper: discovery.URLMapper{Dst: strconv.Itoa(id % 2)}}
|
|
|
|
|
ctx := context.WithValue(context.Background(), ctxMatchType, discovery.MTProxy)
|
|
|
|
|
ctx = context.WithValue(ctx, ctxMatch, m)
|
|
|
|
|
req = req.WithContext(ctx)
|
|
|
|
|
|
|
|
|
|
client := http.Client{}
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
}(i)
|
|
|
|
|
}
|
|
|
|
|
wg.Wait()
|
|
|
|
|
assert.Equal(t, int32(20), atomic.LoadInt32(&passed))
|
|
|
|
|
}
|
2021-11-07 15:31:33 -06:00
|
|
|
|
|
|
|
|
func TestHttp_basicAuthHandler(t *testing.T) {
|
|
|
|
|
allowed := []string{
|
|
|
|
|
"test:$2y$05$zMxDmK65SjcH2vJQNopVSO/nE8ngVLx65RoETyHpez7yTS/8CLEiW",
|
|
|
|
|
"test2:$2y$05$TLQqHh6VT4JxysdKGPOlJeSkkMsv.Ku/G45i7ssIm80XuouCrES12 ",
|
|
|
|
|
"bad bad",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
handler := basicAuthHandler(true, allowed)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
ts := httptest.NewServer(handler)
|
|
|
|
|
|
|
|
|
|
client := http.Client{}
|
|
|
|
|
|
|
|
|
|
tbl := []struct {
|
|
|
|
|
reqFn func(r *http.Request)
|
|
|
|
|
ok bool
|
|
|
|
|
}{
|
|
|
|
|
{func(r *http.Request) {}, false},
|
|
|
|
|
{func(r *http.Request) { r.SetBasicAuth("test", "passwd") }, true},
|
|
|
|
|
{func(r *http.Request) { r.SetBasicAuth("test", "passwdbad") }, false},
|
|
|
|
|
{func(r *http.Request) { r.SetBasicAuth("test2", "passwd2") }, true},
|
|
|
|
|
{func(r *http.Request) { r.SetBasicAuth("test2", "passwbad") }, false},
|
|
|
|
|
{func(r *http.Request) { r.SetBasicAuth("testbad", "passwbad") }, false},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for i, tt := range tbl {
|
|
|
|
|
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
2021-11-09 12:18:26 -06:00
|
|
|
req, err := http.NewRequest("GET", ts.URL, http.NoBody)
|
2021-11-07 15:31:33 -06:00
|
|
|
require.NoError(t, err)
|
|
|
|
|
tt.reqFn(req)
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
if tt.ok {
|
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
handler = basicAuthHandler(false, allowed)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
t.Logf("req: %v", r)
|
|
|
|
|
}))
|
|
|
|
|
ts2 := httptest.NewServer(handler)
|
|
|
|
|
for i, tt := range tbl {
|
|
|
|
|
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
2021-11-09 12:18:26 -06:00
|
|
|
req, err := http.NewRequest("GET", ts2.URL, http.NoBody)
|
2021-11-07 15:31:33 -06:00
|
|
|
require.NoError(t, err)
|
|
|
|
|
tt.reqFn(req)
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
2025-02-20 12:00:40 -06:00
|
|
|
|
|
|
|
|
func TestHeaders_CSPParsing(t *testing.T) {
|
|
|
|
|
tbl := []struct {
|
|
|
|
|
name string
|
|
|
|
|
headers []string
|
|
|
|
|
expected map[string]string
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "simple headers",
|
|
|
|
|
headers: []string{"X-Frame-Options:SAMEORIGIN", "X-XSS-Protection:1; mode=block"},
|
|
|
|
|
expected: map[string]string{"X-Frame-Options": "SAMEORIGIN", "X-XSS-Protection": "1; mode=block"},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "CSP header with multiple directives",
|
|
|
|
|
headers: []string{"Content-Security-Policy:default-src 'self'; style-src 'self' 'unsafe-inline': something"},
|
|
|
|
|
expected: map[string]string{"Content-Security-Policy": "default-src 'self'; style-src 'self' 'unsafe-inline': something"},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "CSP header with quotes and colons",
|
|
|
|
|
headers: []string{"Content-Security-Policy:script-src 'unsafe-inline' 'unsafe-eval' 'self' https://example.com:443"},
|
|
|
|
|
expected: map[string]string{"Content-Security-Policy": "script-src 'unsafe-inline' 'unsafe-eval' 'self' https://example.com:443"},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "multiple colons in value",
|
|
|
|
|
headers: []string{"Custom-Header:value:with:colons"},
|
|
|
|
|
expected: map[string]string{"Custom-Header": "value:with:colons"},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "empty value after colon",
|
|
|
|
|
headers: []string{"Empty-Header:"},
|
|
|
|
|
expected: map[string]string{"Empty-Header": ""},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "malformed no colon",
|
|
|
|
|
headers: []string{"Bad-Header-No-Colon"},
|
|
|
|
|
expected: map[string]string{},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tbl {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
wr := httptest.NewRecorder()
|
|
|
|
|
handler := headersHandler(tt.headers, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", http.NoBody)
|
|
|
|
|
handler.ServeHTTP(wr, req)
|
|
|
|
|
|
|
|
|
|
if len(tt.expected) == 0 {
|
|
|
|
|
// For malformed headers, check they weren't set
|
|
|
|
|
assert.Equal(t, 0, len(wr.Header()))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for k, v := range tt.expected {
|
|
|
|
|
assert.Equal(t, v, wr.Header().Get(k), "Header %s value mismatch", k)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|