package fshttp

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"math/big"
	"net/http"
	"net/http/httptest"
	"os"
	"testing"
	"time"

	"github.com/rclone/rclone/fs"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestCleanAuth(t *testing.T) {
	for _, test := range []struct {
		in   string
		want string
	}{
		{"", ""},
		{"floo", "floo"},
		{"Authorization: ", "Authorization: "},
		{"Authorization: \n", "Authorization: \n"},
		{"Authorization: A", "Authorization: X"},
		{"Authorization: A\n", "Authorization: X\n"},
		{"Authorization: AAAA", "Authorization: XXXX"},
		{"Authorization: AAAA\n", "Authorization: XXXX\n"},
		{"Authorization: AAAAA", "Authorization: XXXX"},
		{"Authorization: AAAAA\n", "Authorization: XXXX\n"},
		{"Authorization: AAAA\n", "Authorization: XXXX\n"},
		{"Authorization: AAAAAAAAA\nPotato: Help\n", "Authorization: XXXX\nPotato: Help\n"},
		{"Sausage: 1\nAuthorization: AAAAAAAAA\nPotato: Help\n", "Sausage: 1\nAuthorization: XXXX\nPotato: Help\n"},
	} {
		got := string(cleanAuth([]byte(test.in), authBufs[0]))
		assert.Equal(t, test.want, got, test.in)
	}
}

func TestCleanAuths(t *testing.T) {
	for _, test := range []struct {
		in   string
		want string
	}{
		{"", ""},
		{"floo", "floo"},
		{"Authorization: AAAAAAAAA\nPotato: Help\n", "Authorization: XXXX\nPotato: Help\n"},
		{"X-Auth-Token: AAAAAAAAA\nPotato: Help\n", "X-Auth-Token: XXXX\nPotato: Help\n"},
		{"X-Auth-Token: AAAAAAAAA\nAuthorization: AAAAAAAAA\nPotato: Help\n", "X-Auth-Token: XXXX\nAuthorization: XXXX\nPotato: Help\n"},
	} {
		got := string(cleanAuths([]byte(test.in)))
		assert.Equal(t, test.want, got, test.in)
	}
}

var certSerial = int64(0)

// Create a test certificate and key pair that is valid for a specific
// duration
func createTestCert(validity time.Duration) (keyPEM []byte, certPEM []byte, err error) {
	key, err := rsa.GenerateKey(rand.Reader, 1024)
	if err != nil {
		return
	}
	keyBytes := x509.MarshalPKCS1PrivateKey(key)
	// PEM encoding of private key
	keyPEM = pem.EncodeToMemory(
		&pem.Block{
			Type:  "RSA PRIVATE KEY",
			Bytes: keyBytes,
		},
	)

	// Now create the certificate
	notBefore := time.Now()
	notAfter := notBefore.Add(validity).Add(expireWindow)

	certSerial += 1
	template := x509.Certificate{
		SerialNumber:          big.NewInt(certSerial),
		Subject:               pkix.Name{CommonName: "localhost"},
		SignatureAlgorithm:    x509.SHA256WithRSA,
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		BasicConstraintsValid: true,
		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
	}
	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
	if err != nil {
		return
	}

	certPEM = pem.EncodeToMemory(
		&pem.Block{
			Type:  "CERTIFICATE",
			Bytes: derBytes,
		},
	)
	return
}

func writeTestCert(t *testing.T, ci *fs.ConfigInfo, validity time.Duration) {
	keyPEM, certPEM, err := createTestCert(1 * time.Second)
	assert.NoError(t, err, "Cannot create test cert")
	err = os.WriteFile(ci.ClientCert, certPEM, 0666)
	assert.NoError(t, err, "Failed to write cert")
	err = os.WriteFile(ci.ClientKey, keyPEM, 0666)
	assert.NoError(t, err, "Failed to write key")
}

func TestCertificates(t *testing.T) {
	startTime := time.Now()
	// Starting a TLS server
	expectedSerial := int64(0)
	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		cert := r.TLS.PeerCertificates
		require.Greater(t, len(cert), 0, "No certificates received")
		expectedSerial += 1
		assert.Equal(t, expectedSerial, cert[0].SerialNumber.Int64(), "Did not get the correct serial number in certificate")
		// Check that the certificate hasn't expired. We cannot use cert validation
		// functions because those check for signature as well and our certificates
		// are not properly signed
		if time.Now().After(cert[0].NotAfter) {
			assert.Fail(t, "Certificate expired", "Certificate expires at %s, current time is %s", cert[0].NotAfter.Sub(startTime), time.Since(startTime))
		}

		// Write some test data to fullfil the request
		w.Header().Set("Content-Type", "text/plain")
		_, _ = fmt.Fprintln(w, "test data")
	}))
	defer ts.Close()
	// Modify servers config to request a client certificate
	// we cannot validate the certificate since we are not properly signing it
	ts.TLS.ClientAuth = tls.RequestClientCert

	// Set --client-cert and --client-key in config to
	// a pair of temp files
	// create a test cert/key pair and write it to the files
	ctx := context.TODO()
	ci := fs.GetConfig(ctx)
	// Create a test certificate and write it to a temp file
	ci.ClientCert = t.TempDir() + "client.cert"
	ci.ClientKey = t.TempDir() + "client.key"
	validity := 1 * time.Second
	writeTestCert(t, ci, validity)

	// Now create the client with the above settings
	// we need to disable TLS verification since we don't
	// care about server certificate
	client := NewClient(ctx)
	tt := client.Transport.(*Transport)
	tt.TLSClientConfig.InsecureSkipVerify = true

	// Now make requests, the first request should be within
	// the valid window
	_, err := client.Get(ts.URL)
	assert.NoError(t, err)

	// Wait for the 2* valid duration of the certificate so that has definitely expired
	time.Sleep(2 * validity)

	// Create a new cert and write it to files
	writeTestCert(t, ci, validity)

	// The new cert should be auto-loaded before we make this request
	_, err = client.Get(ts.URL)
	assert.NoError(t, err)
}