diff --git a/app/main.go b/app/main.go index 4b49612..020f7e0 100644 --- a/app/main.go +++ b/app/main.go @@ -49,6 +49,7 @@ var opts struct { Assets struct { Location string `short:"a" long:"location" env:"LOCATION" default:"" description:"assets location"` WebRoot string `long:"root" env:"ROOT" default:"/" description:"assets web root"` + SPA bool `long:"spa" env:"SPA" description:"spa treatment for assets"` CacheControl []string `long:"cache" env:"CACHE" description:"cache duration for assets" env-delim:","` } `group:"assets" namespace:"assets" env-namespace:"ASSETS"` @@ -222,6 +223,7 @@ func run() error { MaxBodySize: int64(maxBodySize), AssetsLocation: opts.Assets.Location, AssetsWebRoot: opts.Assets.WebRoot, + AssetsSPA: opts.Assets.SPA, CacheControl: cacheControl, GzEnabled: opts.GzipEnabled, SSLConfig: sslConfig, diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index 433e606..6995371 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -317,7 +317,7 @@ func (h *Http) assetsHandler() http.HandlerFunc { if h.AssetsLocation == "" || h.AssetsWebRoot == "" { return func(writer http.ResponseWriter, request *http.Request) {} } - log.Printf("[DEBUG] shared assets server enabled for %s %s", h.AssetsWebRoot, h.AssetsLocation) + log.Printf("[DEBUG] shared assets server enabled for %s %s, spa=%v", h.AssetsWebRoot, h.AssetsLocation, h.AssetsSPA) fs, err := h.fileServer(h.AssetsWebRoot, h.AssetsLocation, h.AssetsSPA) if err != nil { log.Printf("[WARN] can't initialize assets server, %v", err) diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index a1b31d0..1fbbee6 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -171,6 +171,13 @@ func TestHttp_DoWithAssets(t *testing.T) { assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) } + { + resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + } + { resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") require.NoError(t, err) @@ -181,7 +188,101 @@ func TestHttp_DoWithAssets(t *testing.T) { assert.Contains(t, string(body), "Server error") assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) } +} +func TestHttp_DoWithSpaAssets(t *testing.T) { + port := rand.Intn(10000) + 40000 + cc := NewCacheControl(time.Hour * 12) + h := Http{Timeouts: Timeouts{ResponseHeader: 200 * time.Millisecond}, Address: fmt.Sprintf("127.0.0.1:%d", port), + AccessLog: io.Discard, AssetsWebRoot: "/static", AssetsLocation: "testdata", AssetsSPA: true, + CacheControl: cc, Reporter: &ErrorReporter{Nice: false}} + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + ds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("req: %v", r) + w.Header().Add("h1", "v1") + require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP")) + fmt.Fprintf(w, "response %s", r.URL.String()) + })) + + svc := discovery.NewService([]discovery.Provider{ + &provider.Static{Rules: []string{ + "localhost,^/api/(.*)," + ds.URL + "/123/$1,", + "127.0.0.1,^/api/(.*)," + ds.URL + "/567/$1,", + }, + }}, time.Millisecond*10) + + go func() { + _ = svc.Run(context.Background()) + }() + time.Sleep(50 * time.Millisecond) + h.Matcher = svc + h.Metrics = mgmt.NewMetrics() + + go func() { + _ = h.Run(ctx) + }() + time.Sleep(10 * time.Millisecond) + + client := http.Client{} + + { + req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + t.Logf("%+v", resp.Header) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "response /567/something", string(body)) + assert.Equal(t, "", resp.Header.Get("App-Method")) + assert.Equal(t, "v1", resp.Header.Get("h1")) + } + + { + resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + t.Logf("%+v", resp.Header) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "test html", string(body)) + assert.Equal(t, "", resp.Header.Get("App-Method")) + assert.Equal(t, "", resp.Header.Get("h1")) + assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) + } + + { + resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + t.Logf("%+v", resp.Header) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "index html", string(body)) + assert.Equal(t, "", resp.Header.Get("App-Method")) + assert.Equal(t, "", resp.Header.Get("h1")) + assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) + } + + { + resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "Server error") + assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) + } } func TestHttp_DoWithAssetRules(t *testing.T) {