feat: break up validate_tls_keys and add unit tests

I've split the existing code in several new functions:
 - load_certificates_from_pem (takes pem data as bytes)
 - build_certificate_chain (takes a list of Certificates)
 - validate_certificate_chain (takes a list of Certificates)
 - validate_key (takes pem data as a string)
 - validate_key_matches_cert (now takes a pem key string and a Certificate)
 - extract_sans (now takes a Certificate)
 - validate_end_entity_expired (now takes a Certificate)
 - validate_end_entity_not_yet_valid (now takes a Certificate)

When a relevant exception arises, these functions raise a type of TLSValidationError,
these are appended to the list of errors when validating a cert.
This commit is contained in:
Ana Custura 2024-12-14 14:26:10 +00:00 committed by irl
parent 5275a2a882
commit d5fa521fa1
10 changed files with 1091 additions and 120 deletions

542
tests/utils/test_x509.py Normal file
View file

@ -0,0 +1,542 @@
import pytest
from cryptography.x509 import Certificate
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from datetime import datetime, timedelta
from app.util import x509
from app.util.x509 import TLSValidationError
from tests.api.test_onion import generate_self_signed_tls_certificate
from datetime import datetime, timedelta
import shutil
from itertools import permutations
@pytest.fixture
def self_signed_rsa_cert(tmp_path):
directory = (
tmp_path / "cert"
) # tmp_path is a pytest fixture implemented as a pathlib.Path object
if directory.exists():
shutil.rmtree(directory)
directory.mkdir()
onion_address = "test.onion"
valid_from = datetime.now()
valid_to = valid_from + timedelta(days=30)
dns_names = ["test.onion", "www.test.onion"]
# Generate certificate
private_key_path, cert_path = generate_self_signed_tls_certificate(
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
)
return private_key_path, cert_path
@pytest.fixture
def expired_cert(tmp_path):
# Fixture for generating a certificate that is expired
directory = tmp_path / "cert_expired"
if directory.exists():
shutil.rmtree(directory)
directory.mkdir()
onion_address = "test.onion"
valid_from = datetime.now() - timedelta(days=60)
valid_to = datetime.now() - timedelta(days=30)
dns_names = ["test.onion", "www.test.onion"]
private_key_path, cert_path = generate_self_signed_tls_certificate(
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
)
return private_key_path, cert_path
@pytest.fixture
def not_yet_valid_cert(tmp_path):
# Fixture for generating a certificate that is not yet valid
directory = tmp_path / "cert_not_yet_valid"
if directory.exists():
shutil.rmtree(directory)
directory.mkdir()
onion_address = "test.onion"
valid_from = datetime.now() + timedelta(days=30)
valid_to = valid_from + timedelta(days=30)
dns_names = ["test.onion", "www.test.onion"]
private_key_path, cert_path = generate_self_signed_tls_certificate(
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
)
return private_key_path, cert_path
@pytest.fixture
def letsencrypt_cert():
cert_path = "tests/data/letsencrypt-issued/fullchain.pem"
private_key_path = "tests/data/letsencrypt-issued/privkey.pem"
return private_key_path, cert_path
@pytest.fixture
def invalid_algorithm_key():
key_path = "tests/data/invalid-algorithm/dsa_private_key.pem"
return key_path
@pytest.fixture
def letsencrypt_valid_chain():
cert_path = "tests/data/letsencrypt-issued/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
@pytest.fixture
def invalid_algorithm_chain():
cert_path = "tests/data/invalid-algorithm/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
@pytest.fixture
def untrusted_root_ca_chain():
cert_path = "tests/data/untrusted-root-ca/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
@pytest.fixture
def circular_chain():
cert_path = "tests/data/no-end-entity-chain/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
def test_load_single_certificate(self_signed_rsa_cert):
"""Test loading a single certificate from PEM data."""
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
assert len(certificates) == 1
assert isinstance(certificates[0], Certificate)
def test_load_multiple_certificates(self_signed_rsa_cert):
"""Test loading multiple certificates from PEM data."""
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
pem_data = pem_cert + b"\n" + pem_cert
certificates = x509.load_certificates_from_pem(pem_data)
assert len(certificates) == 2
for loaded_cert in certificates:
assert isinstance(loaded_cert, Certificate)
def test_load_invalid_pem_data():
"""Test handling of invalid PEM data."""
invalid_pem = b"-----BEGIN CERTIFICATE-----\nInvalidCertificateData\n-----END CERTIFICATE-----"
with pytest.raises(TLSValidationError, match="TLS certificate parsing error"):
x509.load_certificates_from_pem(invalid_pem)
def test_validate_key_with_valid_RSA_key(self_signed_rsa_cert):
"""Test that a valid RSA private key is accepted."""
private_key_path, _ = self_signed_rsa_cert
with open(private_key_path, "r") as private_key_file:
private_key_pem = private_key_file.read()
assert x509.validate_key(private_key_pem) is True
def test_validate_key_with_valid_EC_key(letsencrypt_cert):
"""Test that a valid EC private key is accepted."""
private_key_path, _ = letsencrypt_cert
with open(private_key_path, "r") as private_key_file:
private_key_pem = private_key_file.read()
assert x509.validate_key(private_key_pem) is True
def test_validate_key_with_valid_DSA_key(invalid_algorithm_key):
"""Test that a valid DSA private key is not accepted."""
with open(invalid_algorithm_key, "r") as private_key_file:
private_key_pem = private_key_file.read()
assert x509.validate_key(private_key_pem) is False
@pytest.mark.parametrize(
"scenario, key_data, expected_result, expected_exception",
[
(
"invalid_key_format",
"-----BEGIN INVALID KEY-----\nInvalidData\n-----END INVALID KEY-----",
None,
x509.TLSInvalidPrivateKeyError,
),
("none_key", None, False, None),
("empty_key", "", False, None),
],
)
def test_validate_key_with_invalid_key(
scenario, key_data, expected_result, expected_exception
):
"""Test that various invalid private keys are not recognized."""
if expected_exception:
with pytest.raises(expected_exception):
x509.validate_key(key_data)
else:
assert x509.validate_key(key_data) is expected_result
@pytest.mark.parametrize(
"scenario, certificate_fixture, expected_result",
[
("expired_certificate", "expired_cert", True),
("valid_certificate", "self_signed_rsa_cert", False),
],
)
def test_validate_end_entity_expired(
scenario, certificate_fixture, expected_result, request
):
"""Test with both expired and valid certificates"""
_, cert_path = request.getfixturevalue(
certificate_fixture
) # Access cert_path from the fixture
with open(cert_path, "r") as cert_file:
tls_certificate_pem = cert_file.read().encode("utf-8")
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
assert x509.validate_end_entity_expired(end_entity_cert) is expected_result
@pytest.mark.parametrize(
"scenario, certificate_fixture, expected_result",
[
("not_yet_valid_certificate", "not_yet_valid_cert", True),
("valid_certificate", "self_signed_rsa_cert", False),
],
)
def test_validate_end_entity_not_yet_valid(
scenario, certificate_fixture, expected_result, request
):
# Test with a certificate that is not yet valid
_, cert_path = request.getfixturevalue(
certificate_fixture
) # Access cert_path from the fixture
with open(cert_path, "r") as cert_file:
tls_certificate_pem = cert_file.read().encode("utf-8")
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
assert x509.validate_end_entity_not_yet_valid(end_entity_cert) is expected_result
@pytest.mark.parametrize(
"scenario, cert_fixture, key_fixture, expected_result",
[
(
"matching_keys_rsa",
"self_signed_rsa_cert",
"self_signed_rsa_cert",
True,
), # Matching certificate and key
(
"matching_keys_ec",
"letsencrypt_cert",
"letsencrypt_cert",
True,
), # Matching certificate and key
(
"mismatched_keys",
"self_signed_rsa_cert",
"expired_cert",
False,
), # Mismatched certificate and key
("missing_key", "self_signed_rsa_cert", None, False), # Missing private key
("missing_cert", None, "self_signed_rsa_cert", False), # Missing certificate
],
)
def test_validate_key_matches_cert(
scenario, cert_fixture, key_fixture, expected_result, request
):
"""Test the validate_key_matches_cert function with different scenarios"""
cert_path = None
key_path = None
if cert_fixture:
_, cert_path = request.getfixturevalue(cert_fixture)
if key_fixture:
key_path, _ = request.getfixturevalue(key_fixture)
if cert_path is None or key_path is None:
assert x509.validate_key_matches_cert(key_path, cert_path) is expected_result
return
with open(cert_path, "r") as cert_file:
tls_certificate_pem = cert_file.read().encode("utf-8")
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
with open(key_path, "r") as key_file:
tls_private_key_pem = key_file.read()
assert (
x509.validate_key_matches_cert(tls_private_key_pem, end_entity_cert)
is expected_result
)
def test_extract_sans(letsencrypt_cert):
_, cert_path = letsencrypt_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certs = x509.load_certificates_from_pem(pem_cert)
sans = x509.extract_sans(certs[0])
assert isinstance(sans, list)
expected_sans = ["xahm6aech1mie5queyo8.censorship.guide"]
assert sans == expected_sans
@pytest.mark.parametrize(
"scenario, certificate_input, expected_result",
[
(
"invalid_certificate_format",
"-----BEGIN CERTIFICATE-----\nINVALID_CERTIFICATE\n-----END CERTIFICATE-----",
False,
),
("certificate_is_none", None, False),
("certificate_is_empty_string", "", False),
],
)
def test_extract_sans_invalid(scenario, certificate_input, expected_result):
sans = x509.extract_sans(certificate_input)
assert isinstance(sans, list)
def test_validate_certificate_chain_invalid_algorithm(invalid_algorithm_chain):
with pytest.raises(
TLSValidationError, match="Certificate using unsupported algorithm"
):
x509.validate_certificate_chain(invalid_algorithm_chain)
def test_validate_certificate_chain_invalid_root_ca(untrusted_root_ca_chain):
with pytest.raises(
TLSValidationError,
match="Certificate chain does not terminate at a trusted root CA.",
):
x509.validate_certificate_chain(untrusted_root_ca_chain)
def test_validate_certificate_chain_valid(letsencrypt_valid_chain):
assert x509.validate_certificate_chain(letsencrypt_valid_chain) is True
def test_build_chain_single_certificate(self_signed_rsa_cert):
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certs = x509.load_certificates_from_pem(pem_cert)
assert x509.build_certificate_chain(certs) == certs
def test_build_chain_valid_certificate_letsencrypt(letsencrypt_valid_chain):
assert (
x509.build_certificate_chain(letsencrypt_valid_chain) == letsencrypt_valid_chain
)
def test_build_chain_certificates_in_order(untrusted_root_ca_chain):
assert (
x509.build_certificate_chain(untrusted_root_ca_chain) == untrusted_root_ca_chain
)
def test_build_chain_certificates_in_random_order(untrusted_root_ca_chain):
for perm in permutations(untrusted_root_ca_chain):
assert x509.build_certificate_chain(perm) == untrusted_root_ca_chain
def test_build_chain_empty_list():
with pytest.raises(
TLSValidationError, match="Cannot identify the end-entity certificate."
):
x509.build_certificate_chain([]) == []
def test_build_chain_invalid_chain_no_end_entity(circular_chain):
with pytest.raises(
TLSValidationError, match="Cannot identify the end-entity certificate."
):
x509.build_certificate_chain(circular_chain)
def test_build_chain_certificates_missing_issuers(untrusted_root_ca_chain):
untrusted_root_ca_chain.pop(1)
with pytest.raises(
TLSValidationError, match="Certificates do not form a valid chain."
):
x509.build_certificate_chain(untrusted_root_ca_chain)
def test_build_chain_duplicate_certificates(untrusted_root_ca_chain):
duplicate_untrusted_root_ca_chain = untrusted_root_ca_chain * 2
assert (
x509.build_certificate_chain(duplicate_untrusted_root_ca_chain)
== untrusted_root_ca_chain
)
def test_build_chain_non_chain_certificates(
untrusted_root_ca_chain, letsencrypt_valid_chain
):
non_chain = untrusted_root_ca_chain + letsencrypt_valid_chain
with pytest.raises(
TLSValidationError, match="Certificates do not form a valid chain."
):
x509.build_certificate_chain(non_chain)
@pytest.mark.parametrize(
"scenario, cert_fixture, key_fixture, hostname, expected_result, expected_errors",
[
(
"empty_cert",
None,
None,
None,
None,
[
{
"key": "no_valid_certificates",
"message": "No valid certificates supplied.",
},
],
),
(
"trusted_cert_wrong_hostname",
"letsencrypt_cert",
"letsencrypt_cert",
"wrong.example.com",
None,
[
{
"key": "hostname_not_in_san",
"message": "wrong.example.com not found in SANs.",
},
{
"key": "hostname_not_in_san",
"message": "*.wrong.example.com not found in SANs.",
},
],
),
(
"expired_self_signed_cert",
"expired_cert",
"expired_cert",
None,
None,
[
{"key": "public_key_expired", "message": "TLS public key is expired."},
{
"key": "untrusted_root_ca",
"message": "Certificate chain does not terminate at a trusted root CA.",
},
],
),
(
"not_yet_valid_self_signed_cert",
"not_yet_valid_cert",
"not_yet_valid_cert",
None,
None,
[
{
"key": "public_key_future",
"message": "TLS public key is not yet valid.",
},
{
"key": "untrusted_root_ca",
"message": "Certificate chain does not terminate at a trusted root CA.",
},
],
),
(
"mismatched_key_cert",
"letsencrypt_cert",
"self_signed_rsa_cert",
None,
None,
[
{
"key": "key_mismatch",
"message": "Private key does not match certificate.",
},
],
),
],
)
def test_validate_tls_keys(
scenario,
cert_fixture,
key_fixture,
hostname,
expected_result,
expected_errors,
request,
):
# Get the certificate and key paths from the fixture
if cert_fixture and key_fixture:
_, cert_path = request.getfixturevalue(cert_fixture)
key_path, _ = request.getfixturevalue(key_fixture)
# Read the certificate and key files
with open(cert_path, "rb") as cert_file:
cert_pem = cert_file.read().decode("utf-8")
with open(key_path, "rb") as key_file:
key_pem = key_file.read().decode("utf-8")
else:
cert_pem = None
key_pem = None
# Call the validate_tls_keys function
skip_name_verification = False
if not hostname:
skip_name_verification = True
chain, san_list, errors = x509.validate_tls_keys(
tls_private_key_pem=key_pem,
tls_certificate_pem=cert_pem,
hostname=hostname,
skip_name_verification=skip_name_verification,
skip_chain_verification=False,
)
# Assert that the errors match the expected errors
assert errors == expected_errors
def test_validate_tls_keys_invalid_algorithn(
invalid_algorithm_key, self_signed_rsa_cert
):
with open(invalid_algorithm_key, "rb") as key_file:
key_pem = key_file.read().decode("utf-8")
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
cert_pem = cert_file.read().decode("utf-8")
chain, san_list, errors = x509.validate_tls_keys(
tls_private_key_pem=key_pem,
tls_certificate_pem=cert_pem,
hostname=None,
skip_name_verification=True,
skip_chain_verification=True,
)
expected_errors = [
{
"key": "private_key_not_rsa_or_ec",
"message": "Private key must be RSA or Elliptic-Curve.",
}
]
assert errors == expected_errors