1
0
mirror of https://github.com/umputun/reproxy.git synced 2025-06-30 22:13:42 +02:00

fix: issue new cert correctly in tests, added test for DNS-01 challenge

This commit is contained in:
Yelshat Duskaliyev
2024-09-17 11:37:24 +05:00
committed by Yelshat Duskaliyev
parent 9f1fd3e86a
commit 72c8292d5f
5 changed files with 283 additions and 19 deletions

View File

@ -8,6 +8,7 @@ import (
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
@ -27,7 +28,7 @@ type ACMEServer struct {
t *testing.T
url string
cl *http.Client
checkDNS func(domain, token string) (exists bool, value string, err error)
checkDNS func(domain string) (exists bool, value string, err error)
modifyReq func(*http.Request)
issuedCerts map[string][]byte
@ -35,14 +36,16 @@ type ACMEServer struct {
orders map[string]order // map[orderID]order
mu sync.Mutex
privateKey *ecdsa.PrivateKey
rootKey *ecdsa.PrivateKey
rootTemplate *x509.Certificate
rootCert []byte
}
// Option is a function that configures the ACMEServer.
type Option func(*ACMEServer)
// CheckDNS is an option to enable DNS check for DNS-01 challenge.
func CheckDNS(fn func(domain, token string) (exists bool, value string, err error)) Option {
func CheckDNS(fn func(domain string) (exists bool, value string, err error)) Option {
return func(s *ACMEServer) { s.checkDNS = fn }
}
@ -53,17 +56,11 @@ func ModifyRequest(fn func(r *http.Request)) Option {
// NewACMEServer creates a new ACMEServer for testing.
func NewACMEServer(t *testing.T, opts ...Option) *ACMEServer {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("[acmetest] failed to generate private key: %v", err)
}
s := &ACMEServer{
privateKey: privateKey,
orders: make(map[string]order),
orderByAuthz: make(map[string]string),
issuedCerts: make(map[string][]byte),
checkDNS: func(string, string) (bool, string, error) { return false, "", nil },
checkDNS: func(string) (bool, string, error) { return false, "", nil },
modifyReq: func(*http.Request) {},
cl: &http.Client{
// Prevent HTTP redirects
@ -80,10 +77,31 @@ func NewACMEServer(t *testing.T, opts ...Option) *ACMEServer {
t.Cleanup(srv.Close)
s.url = srv.URL
s.genRoot()
return s
}
func (s *ACMEServer) genRoot() {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(s.t, err)
s.rootTemplate = &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test Reproxy Co Root CA"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, s.rootTemplate, s.rootTemplate, &key.PublicKey, key)
require.NoError(s.t, err)
s.rootCert = der
s.rootKey = key
}
// URL returns the URL of the server.
func (s *ACMEServer) URL() string { return s.url }
@ -359,7 +377,7 @@ func (s *ACMEServer) finalizeCtrl(w http.ResponseWriter, r *http.Request) {
leaf.DNSNames = []string{csr.Subject.CommonName}
}
cert, err := x509.CreateCertificate(rand.Reader, leaf, leaf, s.privateKey.Public(), s.privateKey)
cert, err := x509.CreateCertificate(rand.Reader, s.rootTemplate, leaf, csr.PublicKey, s.rootKey)
if err != nil {
s.error(w, 500, "create certificate: %v", err)
return
@ -384,8 +402,9 @@ func (s *ACMEServer) certCtrl(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Set("Content-Type", "application/pkix-cert")
w.Write(cert)
w.Header().Set("Content-Type", "application/pem-certificate-chain")
pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: cert})
pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: s.rootCert})
}
// POST /challenge - verify a challenge
@ -435,7 +454,7 @@ func (s *ACMEServer) challengeCtrl(w http.ResponseWriter, r *http.Request) {
s.verifyHTTP01Challenge(w, token, domain)
o.HTTP01Accepted = true
case challengeType == "dns-01":
s.verifyDNS01Challenge(w, token, domain)
s.verifyDNS01Challenge(w, domain)
o.DNS01Accepted = true
default:
s.error(w, 400, "invalid challenge type")
@ -473,17 +492,19 @@ func (s *ACMEServer) verifyHTTP01Challenge(w http.ResponseWriter, token, domain
}
// requires the server to be locked
func (s *ACMEServer) verifyDNS01Challenge(w http.ResponseWriter, token, domain string) {
exists, value, err := s.checkDNS(domain, token)
func (s *ACMEServer) verifyDNS01Challenge(w http.ResponseWriter, domain string) {
exists, value, err := s.checkDNS(domain)
if err != nil {
s.t.Logf("[acmetest] DNS-01 challenge check failed: %v", err)
require.NoError(s.t, rest.EncodeJSON(w, 200, rest.JSON{"status": "invalid"}))
return
}
expectedValue := base64.RawURLEncoding.EncodeToString(s.privateKey.Public().(*ecdsa.PublicKey).X.Bytes())
if !exists || value != expectedValue {
s.t.Logf("[acmetest] DNS-01 challenge invalid. Expected: %s, Got: %s", expectedValue, value)
// we don't check the token, as it is derived from account's public key,
// but we check whether the consumer's code assumes that the record exists
// and has a value
if !exists || value == "" {
s.t.Logf("[acmetest] DNS-01 challenge failed: domain %s does not exist or has no value", domain)
require.NoError(s.t, rest.EncodeJSON(w, 200, rest.JSON{"status": "invalid"}))
return
}

View File

@ -0,0 +1,143 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package proxy
import (
"context"
"sync"
"github.com/libdns/libdns"
)
// Ensure, that dnsProviderMock does implement dnsProvider.
// If this is not the case, regenerate this file with moq.
var _ dnsProvider = &dnsProviderMock{}
// dnsProviderMock is a mock implementation of dnsProvider.
//
// func TestSomethingThatUsesdnsProvider(t *testing.T) {
//
// // make and configure a mocked dnsProvider
// mockeddnsProvider := &dnsProviderMock{
// AppendRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
// panic("mock out the AppendRecords method")
// },
// DeleteRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
// panic("mock out the DeleteRecords method")
// },
// }
//
// // use mockeddnsProvider in code that requires dnsProvider
// // and then make assertions.
//
// }
type dnsProviderMock struct {
// AppendRecordsFunc mocks the AppendRecords method.
AppendRecordsFunc func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error)
// DeleteRecordsFunc mocks the DeleteRecords method.
DeleteRecordsFunc func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error)
// calls tracks calls to the methods.
calls struct {
// AppendRecords holds details about calls to the AppendRecords method.
AppendRecords []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Zone is the zone argument value.
Zone string
// Recs is the recs argument value.
Recs []libdns.Record
}
// DeleteRecords holds details about calls to the DeleteRecords method.
DeleteRecords []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Zone is the zone argument value.
Zone string
// Recs is the recs argument value.
Recs []libdns.Record
}
}
lockAppendRecords sync.RWMutex
lockDeleteRecords sync.RWMutex
}
// AppendRecords calls AppendRecordsFunc.
func (mock *dnsProviderMock) AppendRecords(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
if mock.AppendRecordsFunc == nil {
panic("dnsProviderMock.AppendRecordsFunc: method is nil but dnsProvider.AppendRecords was just called")
}
callInfo := struct {
Ctx context.Context
Zone string
Recs []libdns.Record
}{
Ctx: ctx,
Zone: zone,
Recs: recs,
}
mock.lockAppendRecords.Lock()
mock.calls.AppendRecords = append(mock.calls.AppendRecords, callInfo)
mock.lockAppendRecords.Unlock()
return mock.AppendRecordsFunc(ctx, zone, recs)
}
// AppendRecordsCalls gets all the calls that were made to AppendRecords.
// Check the length with:
// len(mockeddnsProvider.AppendRecordsCalls())
func (mock *dnsProviderMock) AppendRecordsCalls() []struct {
Ctx context.Context
Zone string
Recs []libdns.Record
} {
var calls []struct {
Ctx context.Context
Zone string
Recs []libdns.Record
}
mock.lockAppendRecords.RLock()
calls = mock.calls.AppendRecords
mock.lockAppendRecords.RUnlock()
return calls
}
// DeleteRecords calls DeleteRecordsFunc.
func (mock *dnsProviderMock) DeleteRecords(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
if mock.DeleteRecordsFunc == nil {
panic("dnsProviderMock.DeleteRecordsFunc: method is nil but dnsProvider.DeleteRecords was just called")
}
callInfo := struct {
Ctx context.Context
Zone string
Recs []libdns.Record
}{
Ctx: ctx,
Zone: zone,
Recs: recs,
}
mock.lockDeleteRecords.Lock()
mock.calls.DeleteRecords = append(mock.calls.DeleteRecords, callInfo)
mock.lockDeleteRecords.Unlock()
return mock.DeleteRecordsFunc(ctx, zone, recs)
}
// DeleteRecordsCalls gets all the calls that were made to DeleteRecords.
// Check the length with:
// len(mockeddnsProvider.DeleteRecordsCalls())
func (mock *dnsProviderMock) DeleteRecordsCalls() []struct {
Ctx context.Context
Zone string
Recs []libdns.Record
} {
var calls []struct {
Ctx context.Context
Zone string
Recs []libdns.Record
}
mock.lockDeleteRecords.RLock()
calls = mock.calls.DeleteRecords
mock.lockDeleteRecords.RUnlock()
return calls
}

View File

@ -57,6 +57,8 @@ type Http struct { // nolint golint
ThrottleUser int
KeepHost bool
dnsResolvers []string // used to mock DNS resolvers for testing
}
// Matcher source info (server and route) to the destination url

View File

@ -72,6 +72,10 @@ func (h *Http) redirectHandler() http.Handler {
})
}
//go:generate moq -out dns_provider_mock.go -fmt goimports . dnsProvider
type dnsProvider interface{ certmagic.DNSProvider }
// AutocertManager specifies methods for the automatic ACME certificate manager to implement
type AutocertManager interface {
GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
@ -136,6 +140,7 @@ func (h *Http) makeAutocertManager() AutocertManager {
DNSProvider: h.SSLConfig.DNSProvider,
TTL: h.SSLConfig.TTL,
Logger: logger,
Resolvers: h.dnsResolvers,
},
}
}

View File

@ -1,7 +1,9 @@
package proxy
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"os"
@ -9,6 +11,8 @@ import (
"testing"
log "github.com/go-pkgz/lgr"
"github.com/libdns/libdns"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/proxy/acmetest"
@ -87,3 +91,92 @@ func TestSSL_ACME_HTTPChallengeRouter(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, cert)
}
func TestSSL_ACME_DNSChallenge(t *testing.T) {
log.Setup(log.Debug, log.LevelBraces)
dir, err := os.MkdirTemp("", "acme")
require.NoError(t, err)
log.Printf("[DEBUG] acme dir: %s", dir)
defer os.RemoveAll(dir)
var expectedToken string
cas := acmetest.NewACMEServer(t,
acmetest.CheckDNS(func(domain string) (exists bool, value string, err error) {
assert.Equal(t, "example.com", domain)
return true, expectedToken, nil
}),
)
dnsListener, err := net.ListenPacket("udp", ":0")
require.NoError(t, err)
dnsPort := strings.TrimPrefix(dnsListener.LocalAddr().String(), "[::]:")
dnsMock := &dns.Server{
Addr: "localhost:" + dnsPort,
Net: "udp",
PacketConn: dnsListener,
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
msg := &dns.Msg{}
msg.SetReply(r)
switch r.Question[0].Qtype {
case dns.TypeSOA:
msg.Answer = []dns.RR{&dns.SOA{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns1.example.com.",
Mbox: "hostmaster.example.com.",
}}
case dns.TypeTXT:
assert.Equal(t, "_acme-challenge.example.com.", r.Question[0].Name)
msg.Answer = []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
Txt: []string{expectedToken},
}}
default:
msg.SetRcode(r, dns.RcodeNameError)
}
_ = w.WriteMsg(msg)
}),
}
go func() { require.NoError(t, dnsMock.ActivateAndServe()) }()
defer dnsMock.Shutdown()
t.Log("dns server started at", dnsMock.Addr)
p := Http{
SSLConfig: SSLConfig{
ACMELocation: dir,
FQDNs: []string{"example.com", "localhost"},
ACMEDirectory: cas.URL(),
DNSProvider: &dnsProviderMock{
AppendRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
assert.Equal(t, "example.com.", zone)
assert.Equal(t, 1, len(recs))
assert.Equal(t, "_acme-challenge", recs[0].Name)
assert.Equal(t, "TXT", recs[0].Type)
assert.NotEmpty(t, recs[0].Value)
expectedToken = recs[0].Value
return recs, nil
},
DeleteRecordsFunc: func(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) {
return recs, nil
},
},
},
dnsResolvers: []string{dnsMock.Addr},
}
m := p.makeAutocertManager()
cert, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"})
require.NoError(t, err)
assert.NotNil(t, cert)
}