majuna/tests/utils/test_x509.py

543 lines
18 KiB
Python
Raw Permalink Normal View History

import pytest
from cryptography.x509 import Certificate
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from datetime import datetime, timedelta
from app.util import x509
from app.util.x509 import TLSValidationError
from tests.api.test_onion import generate_self_signed_tls_certificate
from datetime import datetime, timedelta
import shutil
from itertools import permutations
@pytest.fixture
def self_signed_rsa_cert(tmp_path):
directory = (
tmp_path / "cert"
) # tmp_path is a pytest fixture implemented as a pathlib.Path object
if directory.exists():
shutil.rmtree(directory)
directory.mkdir()
onion_address = "test.onion"
valid_from = datetime.now()
valid_to = valid_from + timedelta(days=30)
dns_names = ["test.onion", "www.test.onion"]
# Generate certificate
private_key_path, cert_path = generate_self_signed_tls_certificate(
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
)
return private_key_path, cert_path
@pytest.fixture
def expired_cert(tmp_path):
# Fixture for generating a certificate that is expired
directory = tmp_path / "cert_expired"
if directory.exists():
shutil.rmtree(directory)
directory.mkdir()
onion_address = "test.onion"
valid_from = datetime.now() - timedelta(days=60)
valid_to = datetime.now() - timedelta(days=30)
dns_names = ["test.onion", "www.test.onion"]
private_key_path, cert_path = generate_self_signed_tls_certificate(
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
)
return private_key_path, cert_path
@pytest.fixture
def not_yet_valid_cert(tmp_path):
# Fixture for generating a certificate that is not yet valid
directory = tmp_path / "cert_not_yet_valid"
if directory.exists():
shutil.rmtree(directory)
directory.mkdir()
onion_address = "test.onion"
valid_from = datetime.now() + timedelta(days=30)
valid_to = valid_from + timedelta(days=30)
dns_names = ["test.onion", "www.test.onion"]
private_key_path, cert_path = generate_self_signed_tls_certificate(
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
)
return private_key_path, cert_path
@pytest.fixture
def letsencrypt_cert():
cert_path = "tests/data/letsencrypt-issued/fullchain.pem"
private_key_path = "tests/data/letsencrypt-issued/privkey.pem"
return private_key_path, cert_path
@pytest.fixture
def invalid_algorithm_key():
key_path = "tests/data/invalid-algorithm/dsa_private_key.pem"
return key_path
@pytest.fixture
def letsencrypt_valid_chain():
cert_path = "tests/data/letsencrypt-issued/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
@pytest.fixture
def invalid_algorithm_chain():
cert_path = "tests/data/invalid-algorithm/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
@pytest.fixture
def untrusted_root_ca_chain():
cert_path = "tests/data/untrusted-root-ca/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
@pytest.fixture
def circular_chain():
cert_path = "tests/data/no-end-entity-chain/fullchain.pem"
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
return certificates
def test_load_single_certificate(self_signed_rsa_cert):
"""Test loading a single certificate from PEM data."""
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certificates = x509.load_certificates_from_pem(pem_cert)
assert len(certificates) == 1
assert isinstance(certificates[0], Certificate)
def test_load_multiple_certificates(self_signed_rsa_cert):
"""Test loading multiple certificates from PEM data."""
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
pem_data = pem_cert + b"\n" + pem_cert
certificates = x509.load_certificates_from_pem(pem_data)
assert len(certificates) == 2
for loaded_cert in certificates:
assert isinstance(loaded_cert, Certificate)
def test_load_invalid_pem_data():
"""Test handling of invalid PEM data."""
invalid_pem = b"-----BEGIN CERTIFICATE-----\nInvalidCertificateData\n-----END CERTIFICATE-----"
with pytest.raises(TLSValidationError, match="TLS certificate parsing error"):
x509.load_certificates_from_pem(invalid_pem)
def test_validate_key_with_valid_RSA_key(self_signed_rsa_cert):
"""Test that a valid RSA private key is accepted."""
private_key_path, _ = self_signed_rsa_cert
with open(private_key_path, "r") as private_key_file:
private_key_pem = private_key_file.read()
assert x509.validate_key(private_key_pem) is True
def test_validate_key_with_valid_EC_key(letsencrypt_cert):
"""Test that a valid EC private key is accepted."""
private_key_path, _ = letsencrypt_cert
with open(private_key_path, "r") as private_key_file:
private_key_pem = private_key_file.read()
assert x509.validate_key(private_key_pem) is True
def test_validate_key_with_valid_DSA_key(invalid_algorithm_key):
"""Test that a valid DSA private key is not accepted."""
with open(invalid_algorithm_key, "r") as private_key_file:
private_key_pem = private_key_file.read()
assert x509.validate_key(private_key_pem) is False
@pytest.mark.parametrize(
"scenario, key_data, expected_result, expected_exception",
[
(
"invalid_key_format",
"-----BEGIN INVALID KEY-----\nInvalidData\n-----END INVALID KEY-----",
None,
x509.TLSInvalidPrivateKeyError,
),
("none_key", None, False, None),
("empty_key", "", False, None),
],
)
def test_validate_key_with_invalid_key(
scenario, key_data, expected_result, expected_exception
):
"""Test that various invalid private keys are not recognized."""
if expected_exception:
with pytest.raises(expected_exception):
x509.validate_key(key_data)
else:
assert x509.validate_key(key_data) is expected_result
@pytest.mark.parametrize(
"scenario, certificate_fixture, expected_result",
[
("expired_certificate", "expired_cert", True),
("valid_certificate", "self_signed_rsa_cert", False),
],
)
def test_validate_end_entity_expired(
scenario, certificate_fixture, expected_result, request
):
"""Test with both expired and valid certificates"""
_, cert_path = request.getfixturevalue(
certificate_fixture
) # Access cert_path from the fixture
with open(cert_path, "r") as cert_file:
tls_certificate_pem = cert_file.read().encode("utf-8")
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
assert x509.validate_end_entity_expired(end_entity_cert) is expected_result
@pytest.mark.parametrize(
"scenario, certificate_fixture, expected_result",
[
("not_yet_valid_certificate", "not_yet_valid_cert", True),
("valid_certificate", "self_signed_rsa_cert", False),
],
)
def test_validate_end_entity_not_yet_valid(
scenario, certificate_fixture, expected_result, request
):
# Test with a certificate that is not yet valid
_, cert_path = request.getfixturevalue(
certificate_fixture
) # Access cert_path from the fixture
with open(cert_path, "r") as cert_file:
tls_certificate_pem = cert_file.read().encode("utf-8")
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
assert x509.validate_end_entity_not_yet_valid(end_entity_cert) is expected_result
@pytest.mark.parametrize(
"scenario, cert_fixture, key_fixture, expected_result",
[
(
"matching_keys_rsa",
"self_signed_rsa_cert",
"self_signed_rsa_cert",
True,
), # Matching certificate and key
(
"matching_keys_ec",
"letsencrypt_cert",
"letsencrypt_cert",
True,
), # Matching certificate and key
(
"mismatched_keys",
"self_signed_rsa_cert",
"expired_cert",
False,
), # Mismatched certificate and key
("missing_key", "self_signed_rsa_cert", None, False), # Missing private key
("missing_cert", None, "self_signed_rsa_cert", False), # Missing certificate
],
)
def test_validate_key_matches_cert(
scenario, cert_fixture, key_fixture, expected_result, request
):
"""Test the validate_key_matches_cert function with different scenarios"""
cert_path = None
key_path = None
if cert_fixture:
_, cert_path = request.getfixturevalue(cert_fixture)
if key_fixture:
key_path, _ = request.getfixturevalue(key_fixture)
if cert_path is None or key_path is None:
assert x509.validate_key_matches_cert(key_path, cert_path) is expected_result
return
with open(cert_path, "r") as cert_file:
tls_certificate_pem = cert_file.read().encode("utf-8")
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
with open(key_path, "r") as key_file:
tls_private_key_pem = key_file.read()
assert (
x509.validate_key_matches_cert(tls_private_key_pem, end_entity_cert)
is expected_result
)
def test_extract_sans(letsencrypt_cert):
_, cert_path = letsencrypt_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certs = x509.load_certificates_from_pem(pem_cert)
sans = x509.extract_sans(certs[0])
assert isinstance(sans, list)
expected_sans = ["xahm6aech1mie5queyo8.censorship.guide"]
assert sans == expected_sans
@pytest.mark.parametrize(
"scenario, certificate_input, expected_result",
[
(
"invalid_certificate_format",
"-----BEGIN CERTIFICATE-----\nINVALID_CERTIFICATE\n-----END CERTIFICATE-----",
False,
),
("certificate_is_none", None, False),
("certificate_is_empty_string", "", False),
],
)
def test_extract_sans_invalid(scenario, certificate_input, expected_result):
sans = x509.extract_sans(certificate_input)
assert isinstance(sans, list)
def test_validate_certificate_chain_invalid_algorithm(invalid_algorithm_chain):
with pytest.raises(
TLSValidationError, match="Certificate using unsupported algorithm"
):
x509.validate_certificate_chain(invalid_algorithm_chain)
def test_validate_certificate_chain_invalid_root_ca(untrusted_root_ca_chain):
with pytest.raises(
TLSValidationError,
match="Certificate chain does not terminate at a trusted root CA.",
):
x509.validate_certificate_chain(untrusted_root_ca_chain)
def test_validate_certificate_chain_valid(letsencrypt_valid_chain):
assert x509.validate_certificate_chain(letsencrypt_valid_chain) is True
def test_build_chain_single_certificate(self_signed_rsa_cert):
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
pem_cert = cert_file.read()
certs = x509.load_certificates_from_pem(pem_cert)
assert x509.build_certificate_chain(certs) == certs
def test_build_chain_valid_certificate_letsencrypt(letsencrypt_valid_chain):
assert (
x509.build_certificate_chain(letsencrypt_valid_chain) == letsencrypt_valid_chain
)
def test_build_chain_certificates_in_order(untrusted_root_ca_chain):
assert (
x509.build_certificate_chain(untrusted_root_ca_chain) == untrusted_root_ca_chain
)
def test_build_chain_certificates_in_random_order(untrusted_root_ca_chain):
for perm in permutations(untrusted_root_ca_chain):
assert x509.build_certificate_chain(perm) == untrusted_root_ca_chain
def test_build_chain_empty_list():
with pytest.raises(
TLSValidationError, match="Cannot identify the end-entity certificate."
):
x509.build_certificate_chain([]) == []
def test_build_chain_invalid_chain_no_end_entity(circular_chain):
with pytest.raises(
TLSValidationError, match="Cannot identify the end-entity certificate."
):
x509.build_certificate_chain(circular_chain)
def test_build_chain_certificates_missing_issuers(untrusted_root_ca_chain):
untrusted_root_ca_chain.pop(1)
with pytest.raises(
TLSValidationError, match="Certificates do not form a valid chain."
):
x509.build_certificate_chain(untrusted_root_ca_chain)
def test_build_chain_duplicate_certificates(untrusted_root_ca_chain):
duplicate_untrusted_root_ca_chain = untrusted_root_ca_chain * 2
assert (
x509.build_certificate_chain(duplicate_untrusted_root_ca_chain)
== untrusted_root_ca_chain
)
def test_build_chain_non_chain_certificates(
untrusted_root_ca_chain, letsencrypt_valid_chain
):
non_chain = untrusted_root_ca_chain + letsencrypt_valid_chain
with pytest.raises(
TLSValidationError, match="Certificates do not form a valid chain."
):
x509.build_certificate_chain(non_chain)
@pytest.mark.parametrize(
"scenario, cert_fixture, key_fixture, hostname, expected_result, expected_errors",
[
(
"empty_cert",
None,
None,
None,
None,
[
{
"key": "no_valid_certificates",
"message": "No valid certificates supplied.",
},
],
),
(
"trusted_cert_wrong_hostname",
"letsencrypt_cert",
"letsencrypt_cert",
"wrong.example.com",
None,
[
{
"key": "hostname_not_in_san",
"message": "wrong.example.com not found in SANs.",
},
{
"key": "hostname_not_in_san",
"message": "*.wrong.example.com not found in SANs.",
},
],
),
(
"expired_self_signed_cert",
"expired_cert",
"expired_cert",
None,
None,
[
{"key": "public_key_expired", "message": "TLS public key is expired."},
{
"key": "untrusted_root_ca",
"message": "Certificate chain does not terminate at a trusted root CA.",
},
],
),
(
"not_yet_valid_self_signed_cert",
"not_yet_valid_cert",
"not_yet_valid_cert",
None,
None,
[
{
"key": "public_key_future",
"message": "TLS public key is not yet valid.",
},
{
"key": "untrusted_root_ca",
"message": "Certificate chain does not terminate at a trusted root CA.",
},
],
),
(
"mismatched_key_cert",
"letsencrypt_cert",
"self_signed_rsa_cert",
None,
None,
[
{
"key": "key_mismatch",
"message": "Private key does not match certificate.",
},
],
),
],
)
def test_validate_tls_keys(
scenario,
cert_fixture,
key_fixture,
hostname,
expected_result,
expected_errors,
request,
):
# Get the certificate and key paths from the fixture
if cert_fixture and key_fixture:
_, cert_path = request.getfixturevalue(cert_fixture)
key_path, _ = request.getfixturevalue(key_fixture)
# Read the certificate and key files
with open(cert_path, "rb") as cert_file:
cert_pem = cert_file.read().decode("utf-8")
with open(key_path, "rb") as key_file:
key_pem = key_file.read().decode("utf-8")
else:
cert_pem = None
key_pem = None
# Call the validate_tls_keys function
skip_name_verification = False
if not hostname:
skip_name_verification = True
chain, san_list, errors = x509.validate_tls_keys(
tls_private_key_pem=key_pem,
tls_certificate_pem=cert_pem,
hostname=hostname,
skip_name_verification=skip_name_verification,
skip_chain_verification=False,
)
# Assert that the errors match the expected errors
assert errors == expected_errors
def test_validate_tls_keys_invalid_algorithn(
invalid_algorithm_key, self_signed_rsa_cert
):
with open(invalid_algorithm_key, "rb") as key_file:
key_pem = key_file.read().decode("utf-8")
_, cert_path = self_signed_rsa_cert
with open(cert_path, "rb") as cert_file:
cert_pem = cert_file.read().decode("utf-8")
chain, san_list, errors = x509.validate_tls_keys(
tls_private_key_pem=key_pem,
tls_certificate_pem=cert_pem,
hostname=None,
skip_name_verification=True,
skip_chain_verification=True,
)
expected_errors = [
{
"key": "private_key_not_rsa_or_ec",
"message": "Private key must be RSA or Elliptic-Curve.",
}
]
assert errors == expected_errors