1
0
mirror of https://github.com/go-acme/lego.git synced 2024-12-23 09:15:11 +02:00
lego/providers/dns/checkdomain/internal/client.go
2023-05-05 09:49:38 +02:00

384 lines
9.3 KiB
Go

package internal
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
"golang.org/x/oauth2"
)
const (
ns1 = "ns.checkdomain.de"
ns2 = "ns2.checkdomain.de"
)
// DefaultEndpoint the default API endpoint.
const DefaultEndpoint = "https://api.checkdomain.de"
const domainNotFound = -1
// max page limit that the checkdomain api allows.
const maxLimit = 100
// max integer value.
const maxInt = int((^uint(0)) >> 1)
// Client the Autodns API client.
type Client struct {
domainIDMapping map[string]int
domainIDMu sync.Mutex
BaseURL *url.URL
httpClient *http.Client
}
// NewClient creates a new Client.
func NewClient(hc *http.Client) *Client {
baseURL, _ := url.Parse(DefaultEndpoint)
if hc == nil {
hc = &http.Client{Timeout: 10 * time.Second}
}
return &Client{
BaseURL: baseURL,
httpClient: hc,
domainIDMapping: make(map[string]int),
}
}
func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) {
// Load from cache if exists
c.domainIDMu.Lock()
id, ok := c.domainIDMapping[name]
c.domainIDMu.Unlock()
if ok {
return id, nil
}
// Find out by querying API
domains, err := c.listDomains(ctx)
if err != nil {
return domainNotFound, err
}
// Linear search over all registered domains
for _, domain := range domains {
if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
c.domainIDMu.Lock()
c.domainIDMapping[name] = domain.ID
c.domainIDMu.Unlock()
return domain.ID, nil
}
}
return domainNotFound, errors.New("domain not found")
}
func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains")
// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
// But that functionality is kinda broken,
// so we scan through the whole list of registered domains to later find the one that is of interest to us.
q := endpoint.Query()
q.Set("limit", strconv.Itoa(maxLimit))
currentPage := 1
totalPages := maxInt
var domainList []*Domain
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
var res DomainListingResponse
if err := c.do(req, &res); err != nil {
return nil, fmt.Errorf("failed to send domain listing request: %w", err)
}
// This is the first response,
// so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
domainList = make([]*Domain, 0, res.Total)
}
domainList = append(domainList, res.Embedded.Domains...)
currentPage++
}
return domainList, nil
}
func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers")
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
res := &NameserverResponse{}
if err := c.do(req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) CheckNameservers(ctx context.Context, domainID int) error {
info, err := c.getNameserverInfo(ctx, domainID)
if err != nil {
return err
}
var found1, found2 bool
for _, item := range info.Nameservers {
switch item.Name {
case ns1:
found1 = true
case ns2:
found2 = true
}
}
if !found1 || !found2 {
return errors.New("not using checkdomain nameservers, can not update records")
}
return nil
}
func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
if err != nil {
return err
}
return c.do(req, nil)
}
// DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
// The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
// TODO: Simplify this function once Checkdomain do provide the functionality.
func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error {
domainInfo, err := c.getDomainInfo(ctx, domainID)
if err != nil {
return err
}
nsInfo, err := c.getNameserverInfo(ctx, domainID)
if err != nil {
return err
}
allRecords, err := c.listRecords(ctx, domainID, "")
if err != nil {
return err
}
recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
var recordsToKeep []*Record
// Find and delete matching records
for _, record := range allRecords {
if skipRecord(recordName, recordValue, record, nsInfo) {
continue
}
// Checkdomain API can return records without any TTL set (indicated by the value of 0).
// The API Call to replace the records would fail if we wouldn't specify a value.
// Thus, we use the default TTL queried beforehand
if record.TTL == 0 {
record.TTL = nsInfo.SOA.TTL
}
recordsToKeep = append(recordsToKeep, record)
}
return c.replaceRecords(ctx, domainID, recordsToKeep)
}
func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID))
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
var res DomainResponse
err = c.do(req, &res)
if err != nil {
return nil, err
}
return &res, nil
}
func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
q := endpoint.Query()
q.Set("limit", strconv.Itoa(maxLimit))
if recordType != "" {
q.Set("type", recordType)
}
currentPage := 1
totalPages := maxInt
var recordList []*Record
for currentPage <= totalPages {
q.Set("page", strconv.Itoa(currentPage))
endpoint.RawQuery = q.Encode()
req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
var res RecordListingResponse
if err := c.do(req, &res); err != nil {
return nil, fmt.Errorf("failed to send record listing request: %w", err)
}
// This is the first response, so we update totalPages and allocate the slice memory.
if totalPages == maxInt {
totalPages = res.Pages
recordList = make([]*Record, 0, res.Total)
}
recordList = append(recordList, res.Embedded.Records...)
currentPage++
}
return recordList, nil
}
func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error {
endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records)
if err != nil {
return err
}
return c.do(req, nil)
}
func (c *Client) do(req *http.Request, result any) error {
resp, err := c.httpClient.Do(req)
if err != nil {
return errutils.NewHTTPDoError(req, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
}
if result == nil {
return nil
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
return errutils.NewReadResponseError(req, resp.StatusCode, err)
}
err = json.Unmarshal(raw, result)
if err != nil {
return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
}
return nil
}
func (c *Client) CleanCache(fqdn string) {
c.domainIDMu.Lock()
delete(c.domainIDMapping, fqdn)
c.domainIDMu.Unlock()
}
func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
// Skip empty records
if record.Value == "" {
return true
}
// Skip some special records, otherwise we would get a "Nameserver update failed"
if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
return true
}
nameMatch := recordName == "" || record.Name == recordName
valueMatch := recordValue == "" || record.Value == recordValue
// Skip our matching record
if record.Type == "TXT" && nameMatch && valueMatch {
return true
}
return false
}
func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
buf := new(bytes.Buffer)
if payload != nil {
err := json.NewEncoder(buf).Encode(payload)
if err != nil {
return nil, fmt.Errorf("failed to create request JSON body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
}
func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
if client == nil {
client = &http.Client{Timeout: 5 * time.Second}
}
client.Transport = &oauth2.Transport{
Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
Base: client.Transport,
}
return client
}