From de09aed7a5c2650ae2303cdba75b2df5bf1cc7ea Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Wed, 11 Nov 2020 23:02:00 +0000 Subject: [PATCH] Introduce Duration so that marshalling works for duration strings --- pkg/apis/options/common.go | 32 +++++++++++++++++++++++++ pkg/apis/options/legacy_options.go | 3 ++- pkg/apis/options/legacy_options_test.go | 8 +++---- pkg/apis/options/upstreams.go | 4 +--- pkg/upstream/http.go | 2 +- pkg/upstream/http_test.go | 14 +++++------ pkg/validation/upstreams.go | 2 +- pkg/validation/upstreams_test.go | 2 +- 8 files changed, 49 insertions(+), 18 deletions(-) diff --git a/pkg/apis/options/common.go b/pkg/apis/options/common.go index 60d352a5..5c8bd80e 100644 --- a/pkg/apis/options/common.go +++ b/pkg/apis/options/common.go @@ -1,5 +1,11 @@ package options +import ( + "fmt" + "strings" + "time" +) + // SecretSource references an individual secret value. // Only one source within the struct should be defined at any time. type SecretSource struct { @@ -12,3 +18,29 @@ type SecretSource struct { // FromFile expects a path to a file containing the secret value. FromFile string } + +type Duration time.Duration + +func (d *Duration) UnmarshalJSON(data []byte) error { + input := string(data) + input = strings.TrimPrefix(input, "\"") + input = strings.TrimSuffix(input, "\"") + du, err := time.ParseDuration(input) + if err != nil { + return err + } + *d = Duration(du) + return nil +} + +func (d *Duration) MarshalJSON() ([]byte, error) { + dStr := fmt.Sprintf("%q", d.Duration().String()) + return []byte(dStr), nil +} + +func (d *Duration) Duration() time.Duration { + if d == nil { + return time.Duration(0) + } + return time.Duration(*d) +} diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index 88784e9b..030dbe83 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -84,6 +84,7 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) { u.Path = "/" } + flushInterval := Duration(l.FlushInterval) upstream := Upstream{ ID: u.Path, Path: u.Path, @@ -91,7 +92,7 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) { InsecureSkipTLSVerify: l.SSLUpstreamInsecureSkipVerify, PassHostHeader: &l.PassHostHeader, ProxyWebSockets: &l.ProxyWebSockets, - FlushInterval: &l.FlushInterval, + FlushInterval: &flushInterval, } switch u.Scheme { diff --git a/pkg/apis/options/legacy_options_test.go b/pkg/apis/options/legacy_options_test.go index 2e50edcc..44c8c728 100644 --- a/pkg/apis/options/legacy_options_test.go +++ b/pkg/apis/options/legacy_options_test.go @@ -17,8 +17,8 @@ var _ = Describe("Legacy Options", func() { legacyOpts := NewLegacyOptions() // Set upstreams and related options to test their conversion - flushInterval := 5 * time.Second - legacyOpts.LegacyUpstreams.FlushInterval = flushInterval + flushInterval := Duration(5 * time.Second) + legacyOpts.LegacyUpstreams.FlushInterval = time.Duration(flushInterval) legacyOpts.LegacyUpstreams.PassHostHeader = true legacyOpts.LegacyUpstreams.ProxyWebSockets = true legacyOpts.LegacyUpstreams.SSLUpstreamInsecureSkipVerify = true @@ -124,7 +124,7 @@ var _ = Describe("Legacy Options", func() { skipVerify := true passHostHeader := false proxyWebSockets := true - flushInterval := 5 * time.Second + flushInterval := Duration(5 * time.Second) // Test cases and expected outcomes validHTTP := "http://foo.bar/baz" @@ -199,7 +199,7 @@ var _ = Describe("Legacy Options", func() { SSLUpstreamInsecureSkipVerify: skipVerify, PassHostHeader: passHostHeader, ProxyWebSockets: proxyWebSockets, - FlushInterval: flushInterval, + FlushInterval: time.Duration(flushInterval), } upstreams, err := legacyUpstreams.convert() diff --git a/pkg/apis/options/upstreams.go b/pkg/apis/options/upstreams.go index e879a107..ab6543c6 100644 --- a/pkg/apis/options/upstreams.go +++ b/pkg/apis/options/upstreams.go @@ -1,7 +1,5 @@ package options -import "time" - // Upstreams is a collection of definitions for upstream servers. type Upstreams []Upstream @@ -47,7 +45,7 @@ type Upstream struct { // FlushInterval is the period between flushing the response buffer when // streaming response from the upstream. // Defaults to 1 second. - FlushInterval *time.Duration `json:"flushInterval,omitempty"` + FlushInterval *Duration `json:"flushInterval,omitempty"` // PassHostHeader determines whether the request host header should be proxied // to the upstream server. diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index 88c0afcd..741e9a97 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -98,7 +98,7 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr // Configure options on the SingleHostReverseProxy if upstream.FlushInterval != nil { - proxy.FlushInterval = *upstream.FlushInterval + proxy.FlushInterval = upstream.FlushInterval.Duration() } else { proxy.FlushInterval = 1 * time.Second } diff --git a/pkg/upstream/http_test.go b/pkg/upstream/http_test.go index 8bfe9087..3ce5bd19 100644 --- a/pkg/upstream/http_test.go +++ b/pkg/upstream/http_test.go @@ -22,8 +22,8 @@ import ( var _ = Describe("HTTP Upstream Suite", func() { - const flushInterval5s = 5 * time.Second - const flushInterval1s = 1 * time.Second + const flushInterval5s = options.Duration(5 * time.Second) + const flushInterval1s = options.Duration(1 * time.Second) truth := true falsum := false @@ -52,7 +52,7 @@ var _ = Describe("HTTP Upstream Suite", func() { rw := httptest.NewRecorder() - flush := 1 * time.Second + flush := options.Duration(1 * time.Second) upstream := options.Upstream{ ID: in.id, @@ -258,7 +258,7 @@ var _ = Describe("HTTP Upstream Suite", func() { req := httptest.NewRequest("", "http://example.localhost/foo", nil) rw := httptest.NewRecorder() - flush := 1 * time.Second + flush := options.Duration(1 * time.Second) upstream := options.Upstream{ ID: "noPassHost", PassHostHeader: &falsum, @@ -290,7 +290,7 @@ var _ = Describe("HTTP Upstream Suite", func() { type newUpstreamTableInput struct { proxyWebSockets bool - flushInterval time.Duration + flushInterval options.Duration skipVerify bool sigData *options.SignatureData errorHandler func(http.ResponseWriter, *http.Request, error) @@ -319,7 +319,7 @@ var _ = Describe("HTTP Upstream Suite", func() { proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy) Expect(ok).To(BeTrue()) - Expect(proxy.FlushInterval).To(Equal(in.flushInterval)) + Expect(proxy.FlushInterval).To(Equal(in.flushInterval.Duration())) Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil)) if in.skipVerify { Expect(proxy.Transport).To(Equal(&http.Transport{ @@ -370,7 +370,7 @@ var _ = Describe("HTTP Upstream Suite", func() { var proxyServer *httptest.Server BeforeEach(func() { - flush := 1 * time.Second + flush := options.Duration(1 * time.Second) upstream := options.Upstream{ ID: "websocketProxy", PassHostHeader: &truth, diff --git a/pkg/validation/upstreams.go b/pkg/validation/upstreams.go index 5cfe0b1e..fbff122c 100644 --- a/pkg/validation/upstreams.go +++ b/pkg/validation/upstreams.go @@ -70,7 +70,7 @@ func validateStaticUpstream(upstream options.Upstream) []string { if upstream.InsecureSkipTLSVerify { msgs = append(msgs, fmt.Sprintf("upstream %q has insecureSkipTLSVerify, but is a static upstream, this will have no effect.", upstream.ID)) } - if upstream.FlushInterval != nil && *upstream.FlushInterval != time.Second { + if upstream.FlushInterval != nil && upstream.FlushInterval.Duration() != time.Second { msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID)) } if upstream.PassHostHeader != nil { diff --git a/pkg/validation/upstreams_test.go b/pkg/validation/upstreams_test.go index 6b8f9829..122286ad 100644 --- a/pkg/validation/upstreams_test.go +++ b/pkg/validation/upstreams_test.go @@ -15,7 +15,7 @@ var _ = Describe("Upstreams", func() { errStrings []string } - flushInterval := 5 * time.Second + flushInterval := options.Duration(5 * time.Second) staticCode200 := 200 truth := true