1
0
mirror of https://github.com/go-acme/lego.git synced 2024-11-21 13:25:48 +02:00

fix: HTTP server IPv6 matching (#2345)

This commit is contained in:
Ludovic Fernandez 2024-11-11 18:45:08 +01:00 committed by GitHub
parent e0207678be
commit 11929c9c78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 14 deletions

View File

@ -3,6 +3,7 @@ package http01
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/netip"
"strings" "strings"
) )
@ -54,7 +55,7 @@ func (m *hostMatcher) name() string {
} }
func (m *hostMatcher) matches(r *http.Request, domain string) bool { func (m *hostMatcher) matches(r *http.Request, domain string) bool {
return strings.HasPrefix(r.Host, domain) return matchDomain(r.Host, domain)
} }
// arbitraryMatcher checks whether the specified (*net/http.Request).Header value starts with a domain name. // arbitraryMatcher checks whether the specified (*net/http.Request).Header value starts with a domain name.
@ -65,7 +66,7 @@ func (m arbitraryMatcher) name() string {
} }
func (m arbitraryMatcher) matches(r *http.Request, domain string) bool { func (m arbitraryMatcher) matches(r *http.Request, domain string) bool {
return strings.HasPrefix(r.Header.Get(m.name()), domain) return matchDomain(r.Header.Get(m.name()), domain)
} }
// forwardedMatcher checks whether the Forwarded header contains a "host" element starting with a domain name. // forwardedMatcher checks whether the Forwarded header contains a "host" element starting with a domain name.
@ -87,7 +88,7 @@ func (m *forwardedMatcher) matches(r *http.Request, domain string) bool {
} }
host := fwds[0]["host"] host := fwds[0]["host"]
return strings.HasPrefix(host, domain) return matchDomain(host, domain)
} }
// parsing requires some form of state machine. // parsing requires some form of state machine.
@ -133,9 +134,7 @@ func parseForwardedHeader(s string) (elements []map[string]string, err error) {
case r == ',': // end of forwarded-element case r == ',': // end of forwarded-element
if key != "" { if key != "" {
if val == "" { val = s[pos:i]
val = s[pos:i]
}
cur[key] = val cur[key] = val
} }
elements = append(elements, cur) elements = append(elements, cur)
@ -185,3 +184,12 @@ func skipWS(s string, i int) int {
func isWS(r rune) bool { func isWS(r rune) bool {
return strings.ContainsRune(" \t\v\r\n", r) return strings.ContainsRune(" \t\v\r\n", r)
} }
func matchDomain(src, domain string) bool {
addr, err := netip.ParseAddr(domain)
if err == nil && addr.Is6() {
domain = "[" + domain + "]"
}
return strings.HasPrefix(src, domain)
}

View File

@ -1,13 +1,15 @@
package http01 package http01
import ( import (
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestParseForwardedHeader(t *testing.T) { func Test_parseForwardedHeader(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
input string input string
@ -83,3 +85,54 @@ func TestParseForwardedHeader(t *testing.T) {
}) })
} }
} }
func Test_hostMatcher_matches(t *testing.T) {
hm := &hostMatcher{}
testCases := []struct {
desc string
domain string
req *http.Request
expected assert.BoolAssertionFunc
}{
{
desc: "exact domain",
domain: "example.com",
req: httptest.NewRequest(http.MethodGet, "http://example.com", nil),
expected: assert.True,
},
{
desc: "request with path",
domain: "example.com",
req: httptest.NewRequest(http.MethodGet, "http://example.com/foo/bar", nil),
expected: assert.True,
},
{
desc: "ipv4",
domain: "127.0.0.1",
req: httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil),
expected: assert.True,
},
{
desc: "ipv6",
domain: "2001:db8::1",
req: httptest.NewRequest(http.MethodGet, "http://[2001:db8::1]", nil),
expected: assert.True,
},
{
desc: "ipv6 with brackets",
domain: "[2001:db8::1]",
req: httptest.NewRequest(http.MethodGet, "http://[2001:db8::1]", nil),
expected: assert.True,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
hm.matches(test.req, test.domain)
test.expected(t, hm.matches(test.req, test.domain))
})
}
}

View File

@ -56,7 +56,9 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
} }
s.done = make(chan bool) s.done = make(chan bool)
go s.serve(domain, token, keyAuth) go s.serve(domain, token, keyAuth)
return nil return nil
} }
@ -69,8 +71,11 @@ func (s *ProviderServer) CleanUp(domain, token, keyAuth string) error {
if s.listener == nil { if s.listener == nil {
return nil return nil
} }
s.listener.Close() s.listener.Close()
<-s.done <-s.done
return nil return nil
} }
@ -107,19 +112,23 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) {
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && s.matcher.matches(r, domain) { if r.Method == http.MethodGet && s.matcher.matches(r, domain) {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
_, err := w.Write([]byte(keyAuth)) _, err := w.Write([]byte(keyAuth))
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
log.Infof("[%s] Served key authentication", domain) log.Infof("[%s] Served key authentication", domain)
} else { return
log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure you are passing the %s header properly.", r.Host, r.Method, s.matcher.name()) }
_, err := w.Write([]byte("TEST"))
if err != nil { log.Warnf("Received request for domain %s with method %s but the domain did not match any challenge. Please ensure you are passing the %s header properly.", r.Host, r.Method, s.matcher.name())
http.Error(w, err.Error(), http.StatusInternalServerError)
return _, err := w.Write([]byte("TEST"))
} if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
} }
}) })
@ -133,5 +142,6 @@ func (s *ProviderServer) serve(domain, token, keyAuth string) {
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
log.Println(err) log.Println(err)
} }
s.done <- true s.done <- true
} }