From bae168f06acdab5b29d7167d4bdf48f4de02aaae Mon Sep 17 00:00:00 2001 From: tuunit Date: Sun, 6 Oct 2024 21:43:38 +0200 Subject: [PATCH] better handling of default transport modification --- pkg/requests/http.go | 14 +++++++------- pkg/validation/options.go | 8 ++------ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/pkg/requests/http.go b/pkg/requests/http.go index 222b92d3..c0035e0a 100644 --- a/pkg/requests/http.go +++ b/pkg/requests/http.go @@ -7,22 +7,22 @@ import ( ) type userAgentTransport struct { - Next http.RoundTripper + next http.RoundTripper userAgent string } func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { r := req.Clone(req.Context()) setDefaultUserAgent(r.Header, t.userAgent) - return t.Next.RoundTrip(r) + return t.next.RoundTrip(r) } -var DefaultHTTPClient = &http.Client{Transport: &DefaultTransport} - -var DefaultTransport = userAgentTransport{ - Next: http.DefaultTransport, +var DefaultHTTPClient = &http.Client{Transport: &userAgentTransport{ + next: DefaultTransport, userAgent: "oauth2-proxy/" + version.VERSION, -} +}} + +var DefaultTransport = http.DefaultTransport func setDefaultUserAgent(header http.Header, userAgent string) { if header != nil && len(header.Values("User-Agent")) == 0 { diff --git a/pkg/validation/options.go b/pkg/validation/options.go index caf896c5..c720f47e 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -31,20 +31,16 @@ func Validate(o *options.Options) error { msgs = parseSignatureKey(o, msgs) if o.SSLInsecureSkipVerify { - transport := requests.DefaultTransport.Next.(*http.Transport).Clone() + transport := requests.DefaultTransport.(*http.Transport) transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 -- InsecureSkipVerify is a configurable option we allow - - requests.DefaultHTTPClient = &http.Client{Transport: transport} } else if len(o.Providers[0].CAFiles) > 0 { pool, err := util.GetCertPool(o.Providers[0].CAFiles, o.Providers[0].UseSystemTrustStore) if err == nil { - transport := requests.DefaultTransport.Next.(*http.Transport).Clone() + transport := requests.DefaultTransport.(*http.Transport) transport.TLSClientConfig = &tls.Config{ RootCAs: pool, MinVersion: tls.VersionTLS12, } - - requests.DefaultHTTPClient = &http.Client{Transport: transport} } else { msgs = append(msgs, fmt.Sprintf("unable to load provider CA file(s): %v", err)) }