1
0
mirror of https://github.com/umputun/reproxy.git synced 2025-06-30 22:13:42 +02:00

fix: add timeouts and address race conditions in DNS challenge tests

- Add proper synchronization for DNS mock server
- Fix race condition with thread-safe token access
- Add timeouts to certificate acquisition to prevent test hanging
- Improve error handling in DNS server
- Normalize comments with unfuck-ai-comments
This commit is contained in:
Umputun
2025-04-19 12:28:07 -05:00
parent 528807d744
commit d563c2a648
9 changed files with 109 additions and 26 deletions

View File

@ -38,7 +38,7 @@ func NewClient(address string, httpClient HTTPClient) ConsulClient {
// Get implements ConsulClient interface and returns consul services list,
// which have any tag with 'reproxy.' prefix
func (cl *consulClient) Get() ([]consulService, error) {
var result []consulService //nolint:prealloc // We cannot calc slice size
var result []consulService //nolint:prealloc // we cannot calc slice size
serviceNames, err := cl.getServiceNames()
if err != nil {

View File

@ -241,7 +241,7 @@ func (d *Docker) events(ctx context.Context, eventsCh chan<- discovery.ProviderI
ticker := time.NewTicker(d.RefreshInterval)
defer ticker.Stop()
// Keep track of running containers
// keep track of running containers
saved := make(map[string]containerInfo)
update := func() {
@ -276,7 +276,7 @@ func (d *Docker) events(ctx context.Context, eventsCh chan<- discovery.ProviderI
}
}
update() // Refresh immediately
update() // refresh immediately
for {
select {
case <-ctx.Done():
@ -371,7 +371,7 @@ func NewDockerClient(host, network string) DockerClient {
}
func (d *dockerClient) ListContainers() ([]containerInfo, error) {
// Minimum API version that returns attached networks
// minimum API version that returns attached networks
// docs.docker.com/engine/api/version-history/#v122-api-changes
const APIVersion = "v1.24"

View File

@ -344,20 +344,20 @@ func TestDocker_refresh(t *testing.T) {
}
}()
// Start some
// start some
containers <- []containerInfo{stub("1"), stub("2")}
recv()
// Nothing changed
// nothing changed
containers <- []containerInfo{stub("1"), stub("2")}
time.Sleep(time.Millisecond)
assert.Empty(t, events, "unexpected refresh notification")
// Stopped
// stopped
containers <- []containerInfo{stub("1")}
recv()
// One changed
// one changed
containers <- []containerInfo{
{ID: "1", Name: "1", State: "running", IP: "127.42.42.42", Ports: []int{12345}},
}

View File

@ -63,7 +63,7 @@ func NewACMEServer(t *testing.T, opts ...Option) *ACMEServer {
checkDNS: func(string) (bool, string, error) { return false, "", nil },
modifyReq: func(*http.Request) {},
cl: &http.Client{
// Prevent HTTP redirects
// prevent HTTP redirects
CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse },
},
t: t,

View File

@ -317,7 +317,7 @@ func TestHeaders_CSPParsing(t *testing.T) {
handler.ServeHTTP(wr, req)
if len(tt.expected) == 0 {
// For malformed headers, check they weren't set
// for malformed headers, check they weren't set
assert.Equal(t, 0, len(wr.Header()))
return
}

View File

@ -72,10 +72,10 @@ func (o *OnlyFrom) realIP(ipLookups []OFLookup, r *http.Request) string {
}
if lookup == OFForwarded && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
// The left-most being the original client, and each successive proxy that passed the request
// x-Forwarded-For is potentially a list of addresses separated with ","
// the left-most being the original client, and each successive proxy that passed the request
// adding the IP address where it received the request from.
// In case if the original IP is a private behind a proxy, we need to get the first public IP from the list
// in case if the original IP is a private behind a proxy, we need to get the first public IP from the list
return preferPublicIP(strings.Split(forwardedFor, ","))
}

View File

@ -242,7 +242,7 @@ func (h *Http) proxyHandler() http.HandlerFunc {
IdleConnTimeout: h.Timeouts.IdleConn,
TLSHandshakeTimeout: h.Timeouts.TLSHandshake,
ExpectContinueTimeout: h.Timeouts.ExpectContinue,
TLSClientConfig: &tls.Config{InsecureSkipVerify: h.Insecure}, //nolint:gosec // G402: User defined option to disable verification for self-signed certificates
TLSClientConfig: &tls.Config{InsecureSkipVerify: h.Insecure}, //nolint:gosec // g402: User defined option to disable verification for self-signed certificates
},
ErrorLog: log.ToStdLogger(log.Default(), "WARN"),
}

View File

@ -8,7 +8,9 @@ import (
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"
log "github.com/go-pkgz/lgr"
"github.com/libdns/libdns"
@ -86,10 +88,31 @@ func TestSSL_ACME_HTTPChallengeRouter(t *testing.T) {
assert.Equal(t, 307, resp.StatusCode)
assert.Equal(t, "https://localhost:443/blah?param=1", resp.Header.Get("Location"))
// acquire new cert from CA and check it
// acquire new cert from CA and check it with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
certCh := make(chan struct {
cert *tls.Certificate
err error
})
go func() {
cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"})
require.NoError(t, err)
assert.NotNil(t, cert)
certCh <- struct {
cert *tls.Certificate
err error
}{cert, err}
}()
// wait for certificate or timeout
select {
case <-ctx.Done():
t.Fatalf("Certificate acquisition timed out: %v", ctx.Err())
case result := <-certCh:
require.NoError(t, result.err)
assert.NotNil(t, result.cert)
}
}
func TestSSL_ACME_DNSChallenge(t *testing.T) {
@ -100,11 +123,26 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
log.Printf("[DEBUG] acme dir: %s", dir)
defer os.RemoveAll(dir)
// use mutex to protect expectedToken from race conditions
var expectedToken string
var tokenMutex sync.Mutex
getToken := func() string {
tokenMutex.Lock()
defer tokenMutex.Unlock()
return expectedToken
}
setToken := func(token string) {
tokenMutex.Lock()
defer tokenMutex.Unlock()
expectedToken = token
}
cas := acmetest.NewACMEServer(t,
acmetest.CheckDNS(func(domain string) (exists bool, value string, err error) {
assert.Equal(t, "example.com", domain)
return true, expectedToken, nil
return true, getToken(), nil
}),
)
@ -120,6 +158,12 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
msg := &dns.Msg{}
msg.SetReply(r)
if len(r.Question) == 0 {
msg.SetRcode(r, dns.RcodeNameError)
_ = w.WriteMsg(msg)
return
}
switch r.Question[0].Qtype {
case dns.TypeSOA:
msg.Answer = []dns.RR{&dns.SOA{
@ -136,7 +180,7 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
assert.Equal(t, "_acme-challenge.example.com.", r.Question[0].Name)
msg.Answer = []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
Txt: []string{expectedToken},
Txt: []string{getToken()},
}}
default:
msg.SetRcode(r, dns.RcodeNameError)
@ -146,7 +190,23 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
}),
}
go func() { require.NoError(t, dnsMock.ActivateAndServe()) }()
// create a channel to ensure DNS server is ready
dnsReady := make(chan struct{})
go func() {
// signal that the DNS server is ready to accept connections
close(dnsReady)
// set a timeout for the DNS server
err := dnsMock.ActivateAndServe()
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
t.Logf("DNS server error: %v", err)
}
}()
// wait for DNS server to be ready
<-dnsReady
// ensure the server is shut down at the end of the test
defer dnsMock.Shutdown()
t.Log("dns server started at", dnsMock.Addr)
@ -163,7 +223,7 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
assert.Equal(t, "_acme-challenge", recs[0].Name)
assert.Equal(t, "TXT", recs[0].Type)
assert.NotEmpty(t, recs[0].Value)
expectedToken = recs[0].Value
setToken(recs[0].Value)
return recs, nil
},
DeleteRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
@ -176,7 +236,30 @@ func TestSSL_ACME_DNSChallenge(t *testing.T) {
m := p.makeAutocertManager()
// create context with timeout to prevent test from hanging
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// use a channel to handle the certificate acquisition with timeout
certCh := make(chan struct {
cert *tls.Certificate
err error
})
go func() {
cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"})
require.NoError(t, err)
assert.NotNil(t, cert)
certCh <- struct {
cert *tls.Certificate
err error
}{cert, err}
}()
// wait for certificate or timeout
select {
case <-ctx.Done():
t.Fatalf("Certificate acquisition timed out: %v", ctx.Err())
case result := <-certCh:
require.NoError(t, result.err)
assert.NotNil(t, result.cert)
}
}

View File

@ -17,7 +17,7 @@ func main() {
Methods: []string{"HeaderThing", "ErrorThing"},
}
log.Printf("start demo plugin")
// Do start the plugin listener and register with reproxy plugin conductor
// do start the plugin listener and register with reproxy plugin conductor
if err := plugin.Do(context.TODO(), "http://reproxy:8081", new(Handler)); err != nil {
log.Fatal(err)
}