diff --git a/certcrypto/crypto.go b/certcrypto/crypto.go index 519c7a45..31e31f1d 100644 --- a/certcrypto/crypto.go +++ b/certcrypto/crypto.go @@ -85,6 +85,9 @@ func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) { // https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238) func ParsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) { keyBlockDER, _ := pem.Decode(key) + if keyBlockDER == nil { + return nil, fmt.Errorf("invalid PEM block") + } if keyBlockDER.Type != "PRIVATE KEY" && !strings.HasSuffix(keyBlockDER.Type, " PRIVATE KEY") { return nil, fmt.Errorf("unknown PEM header %q", keyBlockDER.Type) diff --git a/certcrypto/crypto_test.go b/certcrypto/crypto_test.go index 27e3b412..c56441b5 100644 --- a/certcrypto/crypto_test.go +++ b/certcrypto/crypto_test.go @@ -5,6 +5,7 @@ import ( "crypto" "crypto/rand" "crypto/rsa" + "encoding/pem" "testing" "time" @@ -140,6 +141,30 @@ func TestParsePEMCertificate(t *testing.T) { assert.Equal(t, expiration.UTC(), cert.NotAfter) } +func TestParsePEMPrivateKey(t *testing.T) { + privateKey, err := GeneratePrivateKey(RSA2048) + require.NoError(t, err, "Error generating private key") + + pemPrivateKey := PEMEncode(privateKey) + + // Decoding a key should work and create an identical key to the original + decoded, err := ParsePEMPrivateKey(pemPrivateKey) + require.NoError(t, err) + assert.Equal(t, decoded, privateKey) + + // Decoding a PEM block that doesn't contain a private key should error + _, err = ParsePEMPrivateKey(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE"})) + require.Errorf(t, err, "Expected to return an error for non-private key input") + + // Decoding a PEM block that doesn't actually contain a key should error + _, err = ParsePEMPrivateKey(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY"})) + require.Errorf(t, err, "Expected to return an error for empty input") + + // Decoding non-PEM input should return an error + _, err = ParsePEMPrivateKey([]byte("This is not PEM")) + require.Errorf(t, err, "Expected to return an error for non-PEM input") +} + type MockRandReader struct { b *bytes.Buffer }