1
0
mirror of https://github.com/go-acme/lego.git synced 2025-01-25 06:43:10 +02:00
2023-05-05 09:49:38 +02:00

293 lines
6.1 KiB
Go

package internal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTest(t *testing.T, handler http.HandlerFunc) *Client {
t.Helper()
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client, _ := NewClient("myKey", "mySecret")
client.baseURL, _ = url.Parse(server.URL + "/")
client.HTTPClient = server.Client()
return client
}
func handlerMock(method string, response *apiResponse, data interface{}) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
if req.Method != method {
content, err := json.Marshal(apiResponse{
Code: 999, // random code only for the test
Message: fmt.Sprintf("invalid method: got %s want %s", req.Method, method),
})
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
http.Error(rw, string(content), http.StatusBadRequest)
return
}
jsonData, err := json.Marshal(data)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
response.Data = jsonData
content, err := json.Marshal(response)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
_, err = rw.Write(content)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
}
}
func TestClient_GetDomainInformation(t *testing.T) {
type result struct {
domain *Data
error bool
}
testCases := []struct {
desc string
fqdn string
response *apiResponse
data []Data
expected result
}{
{
desc: "domain found",
fqdn: "_acme-challenge.foo.com.",
response: &apiResponse{
Code: 1,
},
data: []Data{
{
ID: "1",
Domain: "bar.com.",
},
{
ID: "2",
Domain: "foo.com.",
},
},
expected: result{domain: &Data{
ID: "2",
Domain: "foo.com.",
}},
},
{
desc: "domains not found",
fqdn: "_acme-challenge.huu.com.",
response: &apiResponse{
Code: 1,
},
data: []Data{
{
ID: "5",
Domain: "bar.com.",
},
{
ID: "6",
Domain: "foo.com.",
},
},
expected: result{error: true},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
client := setupTest(t, handlerMock(http.MethodGet, test.response, test.data))
domain, err := client.GetDomainInformation(context.Background(), test.fqdn)
if test.expected.error {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expected.domain, domain)
}
})
}
}
func TestClient_FindTxtRecord(t *testing.T) {
type result struct {
txtRecord *TXTRecord
error bool
}
testCases := []struct {
desc string
fqdn string
zoneID string
txtRecords []TXTRecord
response *apiResponse
expected result
}{
{
desc: "record found",
fqdn: "_acme-challenge.foo.com.",
zoneID: "test-zone",
txtRecords: []TXTRecord{
{
ID: 1,
RecordID: "Record-A",
Host: "_acme-challenge.foo.com",
Value: "txtTXTtxtTXTtxtTXTtxtTXT",
Type: "TXT",
LineID: 6,
TTL: 30,
},
{
ID: 2,
RecordID: "Record-B",
Host: "_acme-challenge.bar.com",
Value: "TXTtxtTXTtxtTXTtxtTXTtxt",
Type: "TXT",
LineID: 6,
TTL: 30,
},
},
response: &apiResponse{
Code: 1,
},
expected: result{
txtRecord: &TXTRecord{
ID: 1,
RecordID: "Record-A",
Host: "_acme-challenge.foo.com",
Value: "txtTXTtxtTXTtxtTXTtxtTXT",
Type: "TXT",
LineID: 6,
TTL: 30,
},
},
},
{
desc: "record not found",
fqdn: "_acme-challenge.huu.com.",
zoneID: "test-zone",
txtRecords: []TXTRecord{
{
ID: 1,
RecordID: "Record-A",
Host: "_acme-challenge.foo.com",
Value: "txtTXTtxtTXTtxtTXTtxtTXT",
Type: "TXT",
LineID: 6,
TTL: 30,
},
{
ID: 2,
RecordID: "Record-B",
Host: "_acme-challenge.bar.com",
Value: "TXTtxtTXTtxtTXTtxtTXTtxt",
Type: "TXT",
LineID: 6,
TTL: 30,
},
},
response: &apiResponse{
Code: 1,
},
expected: result{error: true},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
client := setupTest(t, handlerMock(http.MethodGet, test.response, test.txtRecords))
txtRecord, err := client.FindTxtRecord(context.Background(), test.zoneID, test.fqdn)
if test.expected.error {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expected.txtRecord, txtRecord)
}
})
}
}
func TestClient_AddTxtRecord(t *testing.T) {
testCases := []struct {
desc string
domain *Data
fqdn string
value string
ttl int
expected string
}{
{
desc: "sub-domain",
domain: &Data{
ID: "1",
Domain: "bar.com.",
},
fqdn: "_acme-challenge.foo.bar.com.",
value: "txtTXTtxtTXTtxtTXTtxtTXT",
ttl: 30,
expected: `{"domain_id":1,"host":"_acme-challenge.foo","value":"txtTXTtxtTXTtxtTXTtxtTXT","type":"TXT","line_id":"1","ttl":"30"}`,
},
{
desc: "main domain",
domain: &Data{
ID: "2",
Domain: "bar.com.",
},
fqdn: "_acme-challenge.bar.com.",
value: "TXTtxtTXTtxtTXTtxtTXTtxt",
ttl: 30,
expected: `{"domain_id":2,"host":"_acme-challenge","value":"TXTtxtTXTtxtTXTtxtTXTtxt","type":"TXT","line_id":"1","ttl":"30"}`,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
response := &apiResponse{
Code: 1,
}
client := setupTest(t, func(rw http.ResponseWriter, req *http.Request) {
assert.NotNil(t, req.Body)
content, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, test.expected, string(bytes.TrimSpace(content)))
handlerMock(http.MethodPost, response, nil).ServeHTTP(rw, req)
})
err := client.AddTxtRecord(context.Background(), test.domain, test.fqdn, test.value, test.ttl)
require.NoError(t, err)
})
}
}