1
0
mirror of https://github.com/go-acme/lego.git synced 2025-01-13 18:37:56 +02:00

Add a mechanism to wrap a PreCheckFunc (#783)

This commit is contained in:
Danek Duvall 2019-02-12 08:36:44 -08:00 committed by Ludovic Fernandez
parent 19303d3ac6
commit 1c6f67f47a
3 changed files with 51 additions and 26 deletions

View File

@ -127,7 +127,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers) log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers)
err = wait.For("propagation", timeout, interval, func() (bool, error) { err = wait.For("propagation", timeout, interval, func() (bool, error) {
stop, errP := c.preCheck.call(fqdn, value) stop, errP := c.preCheck.call(domain, fqdn, value)
if !stop || errP != nil { if !stop || errP != nil {
log.Infof("[%s] acme: Waiting for DNS record propagation.", domain) log.Infof("[%s] acme: Waiting for DNS record propagation.", domain)
} }

View File

@ -44,20 +44,20 @@ func TestChallenge_PreSolve(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
validate ValidateFunc validate ValidateFunc
preCheck PreCheckFunc preCheck WrapPreCheckFunc
provider challenge.Provider provider challenge.Provider
expectError bool expectError bool
}{ }{
{ {
desc: "success", desc: "success",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{}, provider: &providerMock{},
}, },
{ {
desc: "validate fail", desc: "validate fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
present: nil, present: nil,
cleanUp: nil, cleanUp: nil,
@ -66,7 +66,7 @@ func TestChallenge_PreSolve(t *testing.T) {
{ {
desc: "preCheck fail", desc: "preCheck fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
provider: &providerTimeoutMock{ provider: &providerTimeoutMock{
timeout: 2 * time.Second, timeout: 2 * time.Second,
interval: 500 * time.Millisecond, interval: 500 * time.Millisecond,
@ -75,7 +75,7 @@ func TestChallenge_PreSolve(t *testing.T) {
{ {
desc: "present fail", desc: "present fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
present: errors.New("OOPS"), present: errors.New("OOPS"),
}, },
@ -84,7 +84,7 @@ func TestChallenge_PreSolve(t *testing.T) {
{ {
desc: "cleanUp fail", desc: "cleanUp fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
cleanUp: errors.New("OOPS"), cleanUp: errors.New("OOPS"),
}, },
@ -94,7 +94,7 @@ func TestChallenge_PreSolve(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck)) chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
authz := acme.Authorization{ authz := acme.Authorization{
Identifier: acme.Identifier{ Identifier: acme.Identifier{
@ -128,20 +128,20 @@ func TestChallenge_Solve(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
validate ValidateFunc validate ValidateFunc
preCheck PreCheckFunc preCheck WrapPreCheckFunc
provider challenge.Provider provider challenge.Provider
expectError bool expectError bool
}{ }{
{ {
desc: "success", desc: "success",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{}, provider: &providerMock{},
}, },
{ {
desc: "validate fail", desc: "validate fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
present: nil, present: nil,
cleanUp: nil, cleanUp: nil,
@ -151,7 +151,7 @@ func TestChallenge_Solve(t *testing.T) {
{ {
desc: "preCheck fail", desc: "preCheck fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
provider: &providerTimeoutMock{ provider: &providerTimeoutMock{
timeout: 2 * time.Second, timeout: 2 * time.Second,
interval: 500 * time.Millisecond, interval: 500 * time.Millisecond,
@ -161,7 +161,7 @@ func TestChallenge_Solve(t *testing.T) {
{ {
desc: "present fail", desc: "present fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
present: errors.New("OOPS"), present: errors.New("OOPS"),
}, },
@ -169,7 +169,7 @@ func TestChallenge_Solve(t *testing.T) {
{ {
desc: "cleanUp fail", desc: "cleanUp fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
cleanUp: errors.New("OOPS"), cleanUp: errors.New("OOPS"),
}, },
@ -179,7 +179,11 @@ func TestChallenge_Solve(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck)) var options []ChallengeOption
if test.preCheck != nil {
options = append(options, WrapPreCheck(test.preCheck))
}
chlg := NewChallenge(core, test.validate, test.provider, options...)
authz := acme.Authorization{ authz := acme.Authorization{
Identifier: acme.Identifier{ Identifier: acme.Identifier{
@ -213,20 +217,20 @@ func TestChallenge_CleanUp(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
validate ValidateFunc validate ValidateFunc
preCheck PreCheckFunc preCheck WrapPreCheckFunc
provider challenge.Provider provider challenge.Provider
expectError bool expectError bool
}{ }{
{ {
desc: "success", desc: "success",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{}, provider: &providerMock{},
}, },
{ {
desc: "validate fail", desc: "validate fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
present: nil, present: nil,
cleanUp: nil, cleanUp: nil,
@ -235,7 +239,7 @@ func TestChallenge_CleanUp(t *testing.T) {
{ {
desc: "preCheck fail", desc: "preCheck fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
provider: &providerTimeoutMock{ provider: &providerTimeoutMock{
timeout: 2 * time.Second, timeout: 2 * time.Second,
interval: 500 * time.Millisecond, interval: 500 * time.Millisecond,
@ -244,7 +248,7 @@ func TestChallenge_CleanUp(t *testing.T) {
{ {
desc: "present fail", desc: "present fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
present: errors.New("OOPS"), present: errors.New("OOPS"),
}, },
@ -252,7 +256,7 @@ func TestChallenge_CleanUp(t *testing.T) {
{ {
desc: "cleanUp fail", desc: "cleanUp fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil }, preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{ provider: &providerMock{
cleanUp: errors.New("OOPS"), cleanUp: errors.New("OOPS"),
}, },
@ -263,7 +267,7 @@ func TestChallenge_CleanUp(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck)) chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
authz := acme.Authorization{ authz := acme.Authorization{
Identifier: acme.Identifier{ Identifier: acme.Identifier{

View File

@ -1,6 +1,7 @@
package dns01 package dns01
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
@ -11,11 +12,30 @@ import (
// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready. // PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
type PreCheckFunc func(fqdn, value string) (bool, error) type PreCheckFunc func(fqdn, value string) (bool, error)
// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after
// the main check, put it in a loop, etc.
type WrapPreCheckFunc func(domain, fqdn, value string, check PreCheckFunc) (bool, error)
// WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.checkFunc = wrap
return nil
}
}
// AddPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
// Deprecated: use WrapPreCheck instead.
func AddPreCheck(preCheck PreCheckFunc) ChallengeOption { func AddPreCheck(preCheck PreCheckFunc) ChallengeOption {
// Prevent race condition // Prevent race condition
check := preCheck check := preCheck
return func(chlg *Challenge) error { return func(chlg *Challenge) error {
chlg.preCheck.checkFunc = check chlg.preCheck.checkFunc = func(_, fqdn, value string, _ PreCheckFunc) (bool, error) {
if check == nil {
return false, errors.New("invalid preCheck: preCheck is nil")
}
return check(fqdn, value)
}
return nil return nil
} }
} }
@ -29,7 +49,7 @@ func DisableCompletePropagationRequirement() ChallengeOption {
type preCheck struct { type preCheck struct {
// checks DNS propagation before notifying ACME that the DNS challenge is ready. // checks DNS propagation before notifying ACME that the DNS challenge is ready.
checkFunc PreCheckFunc checkFunc WrapPreCheckFunc
// require the TXT record to be propagated to all authoritative name servers // require the TXT record to be propagated to all authoritative name servers
requireCompletePropagation bool requireCompletePropagation bool
} }
@ -40,11 +60,12 @@ func newPreCheck() preCheck {
} }
} }
func (p preCheck) call(fqdn, value string) (bool, error) { func (p preCheck) call(domain, fqdn, value string) (bool, error) {
if p.checkFunc == nil { if p.checkFunc == nil {
return p.checkDNSPropagation(fqdn, value) return p.checkDNSPropagation(fqdn, value)
} }
return p.checkFunc(fqdn, value)
return p.checkFunc(domain, fqdn, value, p.checkDNSPropagation)
} }
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers. // checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.