diff --git a/challenge/http01/domain_matcher.go b/challenge/http01/domain_matcher.go index 5c755c4b..c31aeed6 100644 --- a/challenge/http01/domain_matcher.go +++ b/challenge/http01/domain_matcher.go @@ -3,6 +3,7 @@ package http01 import ( "fmt" "net/http" + "net/netip" "strings" ) @@ -54,7 +55,7 @@ func (m *hostMatcher) name() string { } func (m *hostMatcher) matches(r *http.Request, domain string) bool { - return strings.HasPrefix(r.Host, domain) + return matchDomain(r.Host, domain) } // arbitraryMatcher checks whether the specified (*net/http.Request).Header value starts with a domain name. @@ -65,7 +66,7 @@ func (m arbitraryMatcher) name() string { } func (m arbitraryMatcher) matches(r *http.Request, domain string) bool { - return strings.HasPrefix(r.Header.Get(m.name()), domain) + return matchDomain(r.Header.Get(m.name()), domain) } // forwardedMatcher checks whether the Forwarded header contains a "host" element starting with a domain name. @@ -87,7 +88,7 @@ func (m *forwardedMatcher) matches(r *http.Request, domain string) bool { } host := fwds[0]["host"] - return strings.HasPrefix(host, domain) + return matchDomain(host, domain) } // parsing requires some form of state machine. @@ -133,9 +134,7 @@ func parseForwardedHeader(s string) (elements []map[string]string, err error) { case r == ',': // end of forwarded-element if key != "" { - if val == "" { - val = s[pos:i] - } + val = s[pos:i] cur[key] = val } elements = append(elements, cur) @@ -185,3 +184,12 @@ func skipWS(s string, i int) int { func isWS(r rune) bool { return strings.ContainsRune(" \t\v\r\n", r) } + +func matchDomain(src, domain string) bool { + addr, err := netip.ParseAddr(domain) + if err == nil && addr.Is6() { + domain = "[" + domain + "]" + } + + return strings.HasPrefix(src, domain) +} diff --git a/challenge/http01/domain_matcher_test.go b/challenge/http01/domain_matcher_test.go index 94add14b..efdc4641 100644 --- a/challenge/http01/domain_matcher_test.go +++ b/challenge/http01/domain_matcher_test.go @@ -1,13 +1,15 @@ package http01 import ( + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestParseForwardedHeader(t *testing.T) { +func Test_parseForwardedHeader(t *testing.T) { testCases := []struct { name string input string @@ -83,3 +85,54 @@ func TestParseForwardedHeader(t *testing.T) { }) } } + +func Test_hostMatcher_matches(t *testing.T) { + hm := &hostMatcher{} + + testCases := []struct { + desc string + domain string + req *http.Request + expected assert.BoolAssertionFunc + }{ + { + desc: "exact domain", + domain: "example.com", + req: httptest.NewRequest(http.MethodGet, "http://example.com", nil), + expected: assert.True, + }, + { + desc: "request with path", + domain: "example.com", + req: httptest.NewRequest(http.MethodGet, "http://example.com/foo/bar", nil), + expected: assert.True, + }, + { + desc: "ipv4", + domain: "127.0.0.1", + req: httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil), + expected: assert.True, + }, + { + desc: "ipv6", + domain: "2001:db8::1", + req: httptest.NewRequest(http.MethodGet, "http://[2001:db8::1]", nil), + expected: assert.True, + }, + { + desc: "ipv6 with brackets", + domain: "[2001:db8::1]", + req: httptest.NewRequest(http.MethodGet, "http://[2001:db8::1]", nil), + expected: assert.True, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + hm.matches(test.req, test.domain) + + test.expected(t, hm.matches(test.req, test.domain)) + }) + } +} diff --git a/challenge/http01/http_challenge_server.go b/challenge/http01/http_challenge_server.go index f69f5ac1..009271ce 100644 --- a/challenge/http01/http_challenge_server.go +++ b/challenge/http01/http_challenge_server.go @@ -56,7 +56,9 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error { } s.done = make(chan bool) + go s.serve(domain, token, keyAuth) + return nil } @@ -69,8 +71,11 @@ func (s *ProviderServer) CleanUp(domain, token, keyAuth string) error { if s.listener == nil { return nil } + s.listener.Close() + <-s.done + return nil } @@ -107,19 +112,23 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) { mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet && s.matcher.matches(r, domain) { w.Header().Set("Content-Type", "text/plain") + _, err := w.Write([]byte(keyAuth)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } + log.Infof("[%s] Served key authentication", domain) - } else { - log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure you are passing the %s header properly.", r.Host, r.Method, s.matcher.name()) - _, err := w.Write([]byte("TEST")) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + return + } + + log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure you are passing the %s header properly.", r.Host, r.Method, s.matcher.name()) + + _, err := w.Write([]byte("TEST")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return } }) @@ -133,5 +142,6 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) { if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { log.Println(err) } + s.done <- true }