1
0
mirror of https://github.com/umputun/reproxy.git synced 2025-06-30 22:13:42 +02:00

fix: add timeouts and address race conditions in DNS challenge tests

- Add proper synchronization for DNS mock server
- Fix race condition with thread-safe token access
- Add timeouts to certificate acquisition to prevent test hanging
- Improve error handling in DNS server
- Normalize comments with unfuck-ai-comments
This commit is contained in:
Umputun
2025-04-19 12:28:07 -05:00
parent 528807d744
commit d563c2a648
9 changed files with 109 additions and 26 deletions

View File

@ -38,7 +38,7 @@ func NewClient(address string, httpClient HTTPClient) ConsulClient {
// Get implements ConsulClient interface and returns consul services list, // Get implements ConsulClient interface and returns consul services list,
// which have any tag with 'reproxy.' prefix // which have any tag with 'reproxy.' prefix
func (cl *consulClient) Get() ([]consulService, error) { func (cl *consulClient) Get() ([]consulService, error) {
var result []consulService //nolint:prealloc // We cannot calc slice size var result []consulService //nolint:prealloc // we cannot calc slice size
serviceNames, err := cl.getServiceNames() serviceNames, err := cl.getServiceNames()
if err != nil { if err != nil {

View File

@ -241,7 +241,7 @@ func (d *Docker) events(ctx context.Context, eventsCh chan<- discovery.ProviderI
ticker := time.NewTicker(d.RefreshInterval) ticker := time.NewTicker(d.RefreshInterval)
defer ticker.Stop() defer ticker.Stop()
// Keep track of running containers // keep track of running containers
saved := make(map[string]containerInfo) saved := make(map[string]containerInfo)
update := func() { update := func() {
@ -276,7 +276,7 @@ func (d *Docker) events(ctx context.Context, eventsCh chan<- discovery.ProviderI
} }
} }
update() // Refresh immediately update() // refresh immediately
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -371,7 +371,7 @@ func NewDockerClient(host, network string) DockerClient {
} }
func (d *dockerClient) ListContainers() ([]containerInfo, error) { func (d *dockerClient) ListContainers() ([]containerInfo, error) {
// Minimum API version that returns attached networks // minimum API version that returns attached networks
// docs.docker.com/engine/api/version-history/#v122-api-changes // docs.docker.com/engine/api/version-history/#v122-api-changes
const APIVersion = "v1.24" const APIVersion = "v1.24"

View File

@ -344,20 +344,20 @@ func TestDocker_refresh(t *testing.T) {
} }
}() }()
// Start some // start some
containers <- []containerInfo{stub("1"), stub("2")} containers <- []containerInfo{stub("1"), stub("2")}
recv() recv()
// Nothing changed // nothing changed
containers <- []containerInfo{stub("1"), stub("2")} containers <- []containerInfo{stub("1"), stub("2")}
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
assert.Empty(t, events, "unexpected refresh notification") assert.Empty(t, events, "unexpected refresh notification")
// Stopped // stopped
containers <- []containerInfo{stub("1")} containers <- []containerInfo{stub("1")}
recv() recv()
// One changed // one changed
containers <- []containerInfo{ containers <- []containerInfo{
{ID: "1", Name: "1", State: "running", IP: "127.42.42.42", Ports: []int{12345}}, {ID: "1", Name: "1", State: "running", IP: "127.42.42.42", Ports: []int{12345}},
} }

View File

@ -63,7 +63,7 @@ func NewACMEServer(t *testing.T, opts ...Option) *ACMEServer {
checkDNS: func(string) (bool, string, error) { return false, "", nil }, checkDNS: func(string) (bool, string, error) { return false, "", nil },
modifyReq: func(*http.Request) {}, modifyReq: func(*http.Request) {},
cl: &http.Client{ cl: &http.Client{
// Prevent HTTP redirects // prevent HTTP redirects
CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }, CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse },
}, },
t: t, t: t,

View File

@ -317,7 +317,7 @@ func TestHeaders_CSPParsing(t *testing.T) {
handler.ServeHTTP(wr, req) handler.ServeHTTP(wr, req)
if len(tt.expected) == 0 { if len(tt.expected) == 0 {
// For malformed headers, check they weren't set // for malformed headers, check they weren't set
assert.Equal(t, 0, len(wr.Header())) assert.Equal(t, 0, len(wr.Header()))
return return
} }

View File

@ -72,10 +72,10 @@ func (o *OnlyFrom) realIP(ipLookups []OFLookup, r *http.Request) string {
} }
if lookup == OFForwarded && forwardedFor != "" { if lookup == OFForwarded && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with "," // x-Forwarded-For is potentially a list of addresses separated with ","
// The left-most being the original client, and each successive proxy that passed the request // the left-most being the original client, and each successive proxy that passed the request
// adding the IP address where it received the request from. // adding the IP address where it received the request from.
// In case if the original IP is a private behind a proxy, we need to get the first public IP from the list // in case if the original IP is a private behind a proxy, we need to get the first public IP from the list
return preferPublicIP(strings.Split(forwardedFor, ",")) return preferPublicIP(strings.Split(forwardedFor, ","))
} }

View File

@ -242,7 +242,7 @@ func (h *Http) proxyHandler() http.HandlerFunc {
IdleConnTimeout: h.Timeouts.IdleConn, IdleConnTimeout: h.Timeouts.IdleConn,
TLSHandshakeTimeout: h.Timeouts.TLSHandshake, TLSHandshakeTimeout: h.Timeouts.TLSHandshake,
ExpectContinueTimeout: h.Timeouts.ExpectContinue, ExpectContinueTimeout: h.Timeouts.ExpectContinue,
TLSClientConfig: &tls.Config{InsecureSkipVerify: h.Insecure}, //nolint:gosec // G402: User defined option to disable verification for self-signed certificates TLSClientConfig: &tls.Config{InsecureSkipVerify: h.Insecure}, //nolint:gosec // g402: User defined option to disable verification for self-signed certificates
}, },
ErrorLog: log.ToStdLogger(log.Default(), "WARN"), ErrorLog: log.ToStdLogger(log.Default(), "WARN"),
} }

View File

@ -8,7 +8,9 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"strings" "strings"
"sync"
"testing" "testing"
"time"
log "github.com/go-pkgz/lgr" log "github.com/go-pkgz/lgr"
"github.com/libdns/libdns" "github.com/libdns/libdns"
@ -86,10 +88,31 @@ func TestSSL_ACME_HTTPChallengeRouter(t *testing.T) {
assert.Equal(t, 307, resp.StatusCode) assert.Equal(t, 307, resp.StatusCode)
assert.Equal(t, "https://localhost:443/blah?param=1", resp.Header.Get("Location")) assert.Equal(t, "https://localhost:443/blah?param=1", resp.Header.Get("Location"))
// acquire new cert from CA and check it // acquire new cert from CA and check it with timeout
cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
require.NoError(t, err) defer cancel()
assert.NotNil(t, cert)
certCh := make(chan struct {
cert *tls.Certificate
err error
})
go func() {
cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"})
certCh <- struct {
cert *tls.Certificate
err error
}{cert, err}
}()
// wait for certificate or timeout
select {
case <-ctx.Done():
t.Fatalf("Certificate acquisition timed out: %v", ctx.Err())
case result := <-certCh:
require.NoError(t, result.err)
assert.NotNil(t, result.cert)
}
} }
func TestSSL_ACME_DNSChallenge(t *testing.T) { func TestSSL_ACME_DNSChallenge(t *testing.T) {
@ -100,11 +123,26 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
log.Printf("[DEBUG] acme dir: %s", dir) log.Printf("[DEBUG] acme dir: %s", dir)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
// use mutex to protect expectedToken from race conditions
var expectedToken string var expectedToken string
var tokenMutex sync.Mutex
getToken := func() string {
tokenMutex.Lock()
defer tokenMutex.Unlock()
return expectedToken
}
setToken := func(token string) {
tokenMutex.Lock()
defer tokenMutex.Unlock()
expectedToken = token
}
cas := acmetest.NewACMEServer(t, cas := acmetest.NewACMEServer(t,
acmetest.CheckDNS(func(domain string) (exists bool, value string, err error) { acmetest.CheckDNS(func(domain string) (exists bool, value string, err error) {
assert.Equal(t, "example.com", domain) assert.Equal(t, "example.com", domain)
return true, expectedToken, nil return true, getToken(), nil
}), }),
) )
@ -120,6 +158,12 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
msg := &dns.Msg{} msg := &dns.Msg{}
msg.SetReply(r) msg.SetReply(r)
if len(r.Question) == 0 {
msg.SetRcode(r, dns.RcodeNameError)
_ = w.WriteMsg(msg)
return
}
switch r.Question[0].Qtype { switch r.Question[0].Qtype {
case dns.TypeSOA: case dns.TypeSOA:
msg.Answer = []dns.RR{&dns.SOA{ msg.Answer = []dns.RR{&dns.SOA{
@ -136,7 +180,7 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
assert.Equal(t, "_acme-challenge.example.com.", r.Question[0].Name) assert.Equal(t, "_acme-challenge.example.com.", r.Question[0].Name)
msg.Answer = []dns.RR{&dns.TXT{ msg.Answer = []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}, Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
Txt: []string{expectedToken}, Txt: []string{getToken()},
}} }}
default: default:
msg.SetRcode(r, dns.RcodeNameError) msg.SetRcode(r, dns.RcodeNameError)
@ -146,7 +190,23 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
}), }),
} }
go func() { require.NoError(t, dnsMock.ActivateAndServe()) }() // create a channel to ensure DNS server is ready
dnsReady := make(chan struct{})
go func() {
// signal that the DNS server is ready to accept connections
close(dnsReady)
// set a timeout for the DNS server
err := dnsMock.ActivateAndServe()
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
t.Logf("DNS server error: %v", err)
}
}()
// wait for DNS server to be ready
<-dnsReady
// ensure the server is shut down at the end of the test
defer dnsMock.Shutdown() defer dnsMock.Shutdown()
t.Log("dns server started at", dnsMock.Addr) t.Log("dns server started at", dnsMock.Addr)
@ -163,7 +223,7 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
assert.Equal(t, "_acme-challenge", recs[0].Name) assert.Equal(t, "_acme-challenge", recs[0].Name)
assert.Equal(t, "TXT", recs[0].Type) assert.Equal(t, "TXT", recs[0].Type)
assert.NotEmpty(t, recs[0].Value) assert.NotEmpty(t, recs[0].Value)
expectedToken = recs[0].Value setToken(recs[0].Value)
return recs, nil return recs, nil
}, },
DeleteRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { DeleteRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
@ -176,7 +236,30 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
m := p.makeAutocertManager() m := p.makeAutocertManager()
cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}) // create context with timeout to prevent test from hanging
require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
assert.NotNil(t, cert) defer cancel()
// use a channel to handle the certificate acquisition with timeout
certCh := make(chan struct {
cert *tls.Certificate
err error
})
go func() {
cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"})
certCh <- struct {
cert *tls.Certificate
err error
}{cert, err}
}()
// wait for certificate or timeout
select {
case <-ctx.Done():
t.Fatalf("Certificate acquisition timed out: %v", ctx.Err())
case result := <-certCh:
require.NoError(t, result.err)
assert.NotNil(t, result.cert)
}
} }

View File

@ -17,7 +17,7 @@ func main() {
Methods: []string{"HeaderThing", "ErrorThing"}, Methods: []string{"HeaderThing", "ErrorThing"},
} }
log.Printf("start demo plugin") log.Printf("start demo plugin")
// Do start the plugin listener and register with reproxy plugin conductor // do start the plugin listener and register with reproxy plugin conductor
if err := plugin.Do(context.TODO(), "http://reproxy:8081", new(Handler)); err != nil { if err := plugin.Do(context.TODO(), "http://reproxy:8081", new(Handler)); err != nil {
log.Fatal(err) log.Fatal(err)
} }