From 8f7209ba1a836b7e26ecdf18f1d23b289447df28 Mon Sep 17 00:00:00 2001 From: Josef Johansson Date: Sun, 3 Mar 2024 20:17:36 +0100 Subject: [PATCH] pkg/http: Fix leaking goroutines in tests By using the context created by the test, the goroutines produced in http.Client is actually closed when cancelled and such, not leaked. Signed-off-by: Josef Johansson --- CHANGELOG.md | 2 ++ pkg/http/http_suite_test.go | 20 +++++++++----- pkg/http/server_test.go | 52 ++++++++++++++++++------------------- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94625408..ba4eb149 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ ## Changes since v7.6.0 +- [#2539](https://github.com/oauth2-proxy/oauth2-proxy/pull/2539) pkg/http: Fix leaky test (@isodude) + # V7.6.0 ## Release Highlights diff --git a/pkg/http/http_suite_test.go b/pkg/http/http_suite_test.go index 79aa19a8..edd440c7 100644 --- a/pkg/http/http_suite_test.go +++ b/pkg/http/http_suite_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/pem" @@ -18,7 +19,7 @@ import ( var ipv4CertData, ipv6CertData []byte var ipv4CertDataSource, ipv4KeyDataSource options.SecretSource var ipv6CertDataSource, ipv6KeyDataSource options.SecretSource -var client *http.Client +var transport *http.Transport func TestHTTPSuite(t *testing.T) { logger.SetOutput(GinkgoWriter) @@ -28,6 +29,17 @@ func TestHTTPSuite(t *testing.T) { RunSpecs(t, "HTTP") } +func httpGet(ctx context.Context, url string) (*http.Response, error) { + c := &http.Client{ + Transport: transport.Clone(), + } + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + var _ = BeforeSuite(func() { By("Generating a ipv4 self-signed cert for TLS tests", func() { certBytes, keyBytes, err := util.GenerateCert("127.0.0.1") @@ -70,11 +82,7 @@ var _ = BeforeSuite(func() { certpool.AddCert(ipv4certificate) certpool.AddCert(ipv6certificate) - transport := http.DefaultTransport.(*http.Transport).Clone() + transport = http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig.RootCAs = certpool - - client = &http.Client{ - Transport: transport, - } }) }) diff --git a/pkg/http/server_test.go b/pkg/http/server_test.go index e10979bf..aae7458e 100644 --- a/pkg/http/server_test.go +++ b/pkg/http/server_test.go @@ -587,7 +587,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -602,13 +602,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) }) @@ -641,7 +641,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -656,13 +656,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) }) @@ -673,7 +673,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -712,7 +712,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -727,7 +727,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -742,19 +742,19 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) - _, err = client.Get(secureListenAddr) + _, err = httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) }) @@ -784,7 +784,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -799,13 +799,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) }) @@ -839,7 +839,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -854,13 +854,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) }) @@ -871,7 +871,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -911,7 +911,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -926,7 +926,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -941,19 +941,19 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) - _, err = client.Get(secureListenAddr) + _, err = httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) })