From 72c8292d5fd7c5cb3a46bc2a19a93d41c60f87f7 Mon Sep 17 00:00:00 2001 From: Yelshat Duskaliyev Date: Tue, 17 Sep 2024 11:37:24 +0500 Subject: [PATCH] fix: issue new cert correctly in tests, added test for DNS-01 challenge --- app/proxy/acmetest/ca.go | 59 +++++++++----- app/proxy/dns_provider_mock.go | 143 +++++++++++++++++++++++++++++++++ app/proxy/proxy.go | 2 + app/proxy/ssl.go | 5 ++ app/proxy/ssl_test.go | 93 +++++++++++++++++++++ 5 files changed, 283 insertions(+), 19 deletions(-) create mode 100644 app/proxy/dns_provider_mock.go diff --git a/app/proxy/acmetest/ca.go b/app/proxy/acmetest/ca.go index 556ae17..4b139cb 100644 --- a/app/proxy/acmetest/ca.go +++ b/app/proxy/acmetest/ca.go @@ -8,6 +8,7 @@ import ( "crypto/x509/pkix" "encoding/base64" "encoding/json" + "encoding/pem" "fmt" "io" "math/big" @@ -27,7 +28,7 @@ type ACMEServer struct { t *testing.T url string cl *http.Client - checkDNS func(domain, token string) (exists bool, value string, err error) + checkDNS func(domain string) (exists bool, value string, err error) modifyReq func(*http.Request) issuedCerts map[string][]byte @@ -35,14 +36,16 @@ type ACMEServer struct { orders map[string]order // map[orderID]order mu sync.Mutex - privateKey *ecdsa.PrivateKey + rootKey *ecdsa.PrivateKey + rootTemplate *x509.Certificate + rootCert []byte } // Option is a function that configures the ACMEServer. type Option func(*ACMEServer) // CheckDNS is an option to enable DNS check for DNS-01 challenge. -func CheckDNS(fn func(domain, token string) (exists bool, value string, err error)) Option { +func CheckDNS(fn func(domain string) (exists bool, value string, err error)) Option { return func(s *ACMEServer) { s.checkDNS = fn } } @@ -53,17 +56,11 @@ func ModifyRequest(fn func(r *http.Request)) Option { // NewACMEServer creates a new ACMEServer for testing. func NewACMEServer(t *testing.T, opts ...Option) *ACMEServer { - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("[acmetest] failed to generate private key: %v", err) - } - s := &ACMEServer{ - privateKey: privateKey, orders: make(map[string]order), orderByAuthz: make(map[string]string), issuedCerts: make(map[string][]byte), - checkDNS: func(string, string) (bool, string, error) { return false, "", nil }, + checkDNS: func(string) (bool, string, error) { return false, "", nil }, modifyReq: func(*http.Request) {}, cl: &http.Client{ // Prevent HTTP redirects @@ -80,10 +77,31 @@ func NewACMEServer(t *testing.T, opts ...Option) *ACMEServer { t.Cleanup(srv.Close) s.url = srv.URL + s.genRoot() return s } +func (s *ACMEServer) genRoot() { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(s.t, err) + + s.rootTemplate = &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test Reproxy Co Root CA"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + } + der, err := x509.CreateCertificate(rand.Reader, s.rootTemplate, s.rootTemplate, &key.PublicKey, key) + require.NoError(s.t, err) + + s.rootCert = der + s.rootKey = key +} + // URL returns the URL of the server. func (s *ACMEServer) URL() string { return s.url } @@ -359,7 +377,7 @@ func (s *ACMEServer) finalizeCtrl(w http.ResponseWriter, r *http.Request) { leaf.DNSNames = []string{csr.Subject.CommonName} } - cert, err := x509.CreateCertificate(rand.Reader, leaf, leaf, s.privateKey.Public(), s.privateKey) + cert, err := x509.CreateCertificate(rand.Reader, s.rootTemplate, leaf, csr.PublicKey, s.rootKey) if err != nil { s.error(w, 500, "create certificate: %v", err) return @@ -384,8 +402,9 @@ func (s *ACMEServer) certCtrl(w http.ResponseWriter, r *http.Request) { return } - w.Header().Set("Content-Type", "application/pkix-cert") - w.Write(cert) + w.Header().Set("Content-Type", "application/pem-certificate-chain") + pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: cert}) + pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: s.rootCert}) } // POST /challenge - verify a challenge @@ -435,7 +454,7 @@ func (s *ACMEServer) challengeCtrl(w http.ResponseWriter, r *http.Request) { s.verifyHTTP01Challenge(w, token, domain) o.HTTP01Accepted = true case challengeType == "dns-01": - s.verifyDNS01Challenge(w, token, domain) + s.verifyDNS01Challenge(w, domain) o.DNS01Accepted = true default: s.error(w, 400, "invalid challenge type") @@ -473,17 +492,19 @@ func (s *ACMEServer) verifyHTTP01Challenge(w http.ResponseWriter, token, domain } // requires the server to be locked -func (s *ACMEServer) verifyDNS01Challenge(w http.ResponseWriter, token, domain string) { - exists, value, err := s.checkDNS(domain, token) +func (s *ACMEServer) verifyDNS01Challenge(w http.ResponseWriter, domain string) { + exists, value, err := s.checkDNS(domain) if err != nil { s.t.Logf("[acmetest] DNS-01 challenge check failed: %v", err) require.NoError(s.t, rest.EncodeJSON(w, 200, rest.JSON{"status": "invalid"})) return } - expectedValue := base64.RawURLEncoding.EncodeToString(s.privateKey.Public().(*ecdsa.PublicKey).X.Bytes()) - if !exists || value != expectedValue { - s.t.Logf("[acmetest] DNS-01 challenge invalid. Expected: %s, Got: %s", expectedValue, value) + // we don't check the token, as it is derived from account's public key, + // but we check whether the consumer's code assumes that the record exists + // and has a value + if !exists || value == "" { + s.t.Logf("[acmetest] DNS-01 challenge failed: domain %s does not exist or has no value", domain) require.NoError(s.t, rest.EncodeJSON(w, 200, rest.JSON{"status": "invalid"})) return } diff --git a/app/proxy/dns_provider_mock.go b/app/proxy/dns_provider_mock.go new file mode 100644 index 0000000..ef4a9ac --- /dev/null +++ b/app/proxy/dns_provider_mock.go @@ -0,0 +1,143 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package proxy + +import ( + "context" + "sync" + + "github.com/libdns/libdns" +) + +// Ensure, that dnsProviderMock does implement dnsProvider. +// If this is not the case, regenerate this file with moq. +var _ dnsProvider = &dnsProviderMock{} + +// dnsProviderMock is a mock implementation of dnsProvider. +// +// func TestSomethingThatUsesdnsProvider(t *testing.T) { +// +// // make and configure a mocked dnsProvider +// mockeddnsProvider := &dnsProviderMock{ +// AppendRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { +// panic("mock out the AppendRecords method") +// }, +// DeleteRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { +// panic("mock out the DeleteRecords method") +// }, +// } +// +// // use mockeddnsProvider in code that requires dnsProvider +// // and then make assertions. +// +// } +type dnsProviderMock struct { + // AppendRecordsFunc mocks the AppendRecords method. + AppendRecordsFunc func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) + + // DeleteRecordsFunc mocks the DeleteRecords method. + DeleteRecordsFunc func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) + + // calls tracks calls to the methods. + calls struct { + // AppendRecords holds details about calls to the AppendRecords method. + AppendRecords []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Zone is the zone argument value. + Zone string + // Recs is the recs argument value. + Recs []libdns.Record + } + // DeleteRecords holds details about calls to the DeleteRecords method. + DeleteRecords []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Zone is the zone argument value. + Zone string + // Recs is the recs argument value. + Recs []libdns.Record + } + } + lockAppendRecords sync.RWMutex + lockDeleteRecords sync.RWMutex +} + +// AppendRecords calls AppendRecordsFunc. +func (mock *dnsProviderMock) AppendRecords(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { + if mock.AppendRecordsFunc == nil { + panic("dnsProviderMock.AppendRecordsFunc: method is nil but dnsProvider.AppendRecords was just called") + } + callInfo := struct { + Ctx context.Context + Zone string + Recs []libdns.Record + }{ + Ctx: ctx, + Zone: zone, + Recs: recs, + } + mock.lockAppendRecords.Lock() + mock.calls.AppendRecords = append(mock.calls.AppendRecords, callInfo) + mock.lockAppendRecords.Unlock() + return mock.AppendRecordsFunc(ctx, zone, recs) +} + +// AppendRecordsCalls gets all the calls that were made to AppendRecords. +// Check the length with: +// len(mockeddnsProvider.AppendRecordsCalls()) +func (mock *dnsProviderMock) AppendRecordsCalls() []struct { + Ctx context.Context + Zone string + Recs []libdns.Record +} { + var calls []struct { + Ctx context.Context + Zone string + Recs []libdns.Record + } + mock.lockAppendRecords.RLock() + calls = mock.calls.AppendRecords + mock.lockAppendRecords.RUnlock() + return calls +} + +// DeleteRecords calls DeleteRecordsFunc. +func (mock *dnsProviderMock) DeleteRecords(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { + if mock.DeleteRecordsFunc == nil { + panic("dnsProviderMock.DeleteRecordsFunc: method is nil but dnsProvider.DeleteRecords was just called") + } + callInfo := struct { + Ctx context.Context + Zone string + Recs []libdns.Record + }{ + Ctx: ctx, + Zone: zone, + Recs: recs, + } + mock.lockDeleteRecords.Lock() + mock.calls.DeleteRecords = append(mock.calls.DeleteRecords, callInfo) + mock.lockDeleteRecords.Unlock() + return mock.DeleteRecordsFunc(ctx, zone, recs) +} + +// DeleteRecordsCalls gets all the calls that were made to DeleteRecords. +// Check the length with: +// len(mockeddnsProvider.DeleteRecordsCalls()) +func (mock *dnsProviderMock) DeleteRecordsCalls() []struct { + Ctx context.Context + Zone string + Recs []libdns.Record +} { + var calls []struct { + Ctx context.Context + Zone string + Recs []libdns.Record + } + mock.lockDeleteRecords.RLock() + calls = mock.calls.DeleteRecords + mock.lockDeleteRecords.RUnlock() + return calls +} diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index b1618bb..3d36531 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -57,6 +57,8 @@ type Http struct { // nolint golint ThrottleUser int KeepHost bool + + dnsResolvers []string // used to mock DNS resolvers for testing } // Matcher source info (server and route) to the destination url diff --git a/app/proxy/ssl.go b/app/proxy/ssl.go index 85a9d1a..c682d22 100644 --- a/app/proxy/ssl.go +++ b/app/proxy/ssl.go @@ -72,6 +72,10 @@ func (h *Http) redirectHandler() http.Handler { }) } +//go:generate moq -out dns_provider_mock.go -fmt goimports . dnsProvider + +type dnsProvider interface{ certmagic.DNSProvider } + // AutocertManager specifies methods for the automatic ACME certificate manager to implement type AutocertManager interface { GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) @@ -136,6 +140,7 @@ func (h *Http) makeAutocertManager() AutocertManager { DNSProvider: h.SSLConfig.DNSProvider, TTL: h.SSLConfig.TTL, Logger: logger, + Resolvers: h.dnsResolvers, }, } } diff --git a/app/proxy/ssl_test.go b/app/proxy/ssl_test.go index c5d31fe..65c17ad 100644 --- a/app/proxy/ssl_test.go +++ b/app/proxy/ssl_test.go @@ -1,7 +1,9 @@ package proxy import ( + "context" "crypto/tls" + "net" "net/http" "net/http/httptest" "os" @@ -9,6 +11,8 @@ import ( "testing" log "github.com/go-pkgz/lgr" + "github.com/libdns/libdns" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/umputun/reproxy/app/proxy/acmetest" @@ -87,3 +91,92 @@ func TestSSL_ACME_HTTPChallengeRouter(t *testing.T) { require.NoError(t, err) assert.NotNil(t, cert) } + +func TestSSL_ACME_DNSChallenge(t *testing.T) { + log.Setup(log.Debug, log.LevelBraces) + + dir, err := os.MkdirTemp("", "acme") + require.NoError(t, err) + log.Printf("[DEBUG] acme dir: %s", dir) + defer os.RemoveAll(dir) + + var expectedToken string + cas := acmetest.NewACMEServer(t, + acmetest.CheckDNS(func(domain string) (exists bool, value string, err error) { + assert.Equal(t, "example.com", domain) + return true, expectedToken, nil + }), + ) + + dnsListener, err := net.ListenPacket("udp", ":0") + require.NoError(t, err) + dnsPort := strings.TrimPrefix(dnsListener.LocalAddr().String(), "[::]:") + + dnsMock := &dns.Server{ + Addr: "localhost:" + dnsPort, + Net: "udp", + PacketConn: dnsListener, + Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + msg := &dns.Msg{} + msg.SetReply(r) + + switch r.Question[0].Qtype { + case dns.TypeSOA: + msg.Answer = []dns.RR{&dns.SOA{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + Ttl: 0, + }, + Ns: "ns1.example.com.", + Mbox: "hostmaster.example.com.", + }} + case dns.TypeTXT: + assert.Equal(t, "_acme-challenge.example.com.", r.Question[0].Name) + msg.Answer = []dns.RR{&dns.TXT{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}, + Txt: []string{expectedToken}, + }} + default: + msg.SetRcode(r, dns.RcodeNameError) + } + + _ = w.WriteMsg(msg) + }), + } + + go func() { require.NoError(t, dnsMock.ActivateAndServe()) }() + defer dnsMock.Shutdown() + + t.Log("dns server started at", dnsMock.Addr) + + p := Http{ + SSLConfig: SSLConfig{ + ACMELocation: dir, + FQDNs: []string{"example.com", "localhost"}, + ACMEDirectory: cas.URL(), + DNSProvider: &dnsProviderMock{ + AppendRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { + assert.Equal(t, "example.com.", zone) + assert.Equal(t, 1, len(recs)) + assert.Equal(t, "_acme-challenge", recs[0].Name) + assert.Equal(t, "TXT", recs[0].Type) + assert.NotEmpty(t, recs[0].Value) + expectedToken = recs[0].Value + return recs, nil + }, + DeleteRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { + return recs, nil + }, + }, + }, + dnsResolvers: []string{dnsMock.Addr}, + } + + m := p.makeAutocertManager() + + cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}) + require.NoError(t, err) + assert.NotNil(t, cert) +}