543 lines
18 KiB
Python
543 lines
18 KiB
Python
![]() |
import pytest
|
||
|
from cryptography.x509 import Certificate
|
||
|
from cryptography.x509.oid import NameOID
|
||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||
|
from cryptography.hazmat.primitives import serialization
|
||
|
from cryptography.hazmat.backends import default_backend
|
||
|
from datetime import datetime, timedelta
|
||
|
|
||
|
from app.util import x509
|
||
|
from app.util.x509 import TLSValidationError
|
||
|
from tests.api.test_onion import generate_self_signed_tls_certificate
|
||
|
from datetime import datetime, timedelta
|
||
|
import shutil
|
||
|
from itertools import permutations
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def self_signed_rsa_cert(tmp_path):
|
||
|
directory = (
|
||
|
tmp_path / "cert"
|
||
|
) # tmp_path is a pytest fixture implemented as a pathlib.Path object
|
||
|
if directory.exists():
|
||
|
shutil.rmtree(directory)
|
||
|
directory.mkdir()
|
||
|
onion_address = "test.onion"
|
||
|
valid_from = datetime.now()
|
||
|
valid_to = valid_from + timedelta(days=30)
|
||
|
dns_names = ["test.onion", "www.test.onion"]
|
||
|
|
||
|
# Generate certificate
|
||
|
private_key_path, cert_path = generate_self_signed_tls_certificate(
|
||
|
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
|
||
|
)
|
||
|
return private_key_path, cert_path
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def expired_cert(tmp_path):
|
||
|
# Fixture for generating a certificate that is expired
|
||
|
directory = tmp_path / "cert_expired"
|
||
|
if directory.exists():
|
||
|
shutil.rmtree(directory)
|
||
|
directory.mkdir()
|
||
|
onion_address = "test.onion"
|
||
|
valid_from = datetime.now() - timedelta(days=60)
|
||
|
valid_to = datetime.now() - timedelta(days=30)
|
||
|
dns_names = ["test.onion", "www.test.onion"]
|
||
|
|
||
|
private_key_path, cert_path = generate_self_signed_tls_certificate(
|
||
|
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
|
||
|
)
|
||
|
return private_key_path, cert_path
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def not_yet_valid_cert(tmp_path):
|
||
|
# Fixture for generating a certificate that is not yet valid
|
||
|
directory = tmp_path / "cert_not_yet_valid"
|
||
|
if directory.exists():
|
||
|
shutil.rmtree(directory)
|
||
|
directory.mkdir()
|
||
|
onion_address = "test.onion"
|
||
|
valid_from = datetime.now() + timedelta(days=30)
|
||
|
valid_to = valid_from + timedelta(days=30)
|
||
|
dns_names = ["test.onion", "www.test.onion"]
|
||
|
private_key_path, cert_path = generate_self_signed_tls_certificate(
|
||
|
str(directory), onion_address, valid_from, valid_to, dns_names=dns_names
|
||
|
)
|
||
|
return private_key_path, cert_path
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def letsencrypt_cert():
|
||
|
cert_path = "tests/data/letsencrypt-issued/fullchain.pem"
|
||
|
private_key_path = "tests/data/letsencrypt-issued/privkey.pem"
|
||
|
return private_key_path, cert_path
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def invalid_algorithm_key():
|
||
|
key_path = "tests/data/invalid-algorithm/dsa_private_key.pem"
|
||
|
return key_path
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def letsencrypt_valid_chain():
|
||
|
cert_path = "tests/data/letsencrypt-issued/fullchain.pem"
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
certificates = x509.load_certificates_from_pem(pem_cert)
|
||
|
return certificates
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def invalid_algorithm_chain():
|
||
|
cert_path = "tests/data/invalid-algorithm/fullchain.pem"
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
certificates = x509.load_certificates_from_pem(pem_cert)
|
||
|
return certificates
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def untrusted_root_ca_chain():
|
||
|
cert_path = "tests/data/untrusted-root-ca/fullchain.pem"
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
certificates = x509.load_certificates_from_pem(pem_cert)
|
||
|
return certificates
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def circular_chain():
|
||
|
cert_path = "tests/data/no-end-entity-chain/fullchain.pem"
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
certificates = x509.load_certificates_from_pem(pem_cert)
|
||
|
return certificates
|
||
|
|
||
|
|
||
|
def test_load_single_certificate(self_signed_rsa_cert):
|
||
|
"""Test loading a single certificate from PEM data."""
|
||
|
_, cert_path = self_signed_rsa_cert
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
|
||
|
certificates = x509.load_certificates_from_pem(pem_cert)
|
||
|
assert len(certificates) == 1
|
||
|
assert isinstance(certificates[0], Certificate)
|
||
|
|
||
|
|
||
|
def test_load_multiple_certificates(self_signed_rsa_cert):
|
||
|
"""Test loading multiple certificates from PEM data."""
|
||
|
_, cert_path = self_signed_rsa_cert
|
||
|
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
pem_data = pem_cert + b"\n" + pem_cert
|
||
|
certificates = x509.load_certificates_from_pem(pem_data)
|
||
|
assert len(certificates) == 2
|
||
|
for loaded_cert in certificates:
|
||
|
assert isinstance(loaded_cert, Certificate)
|
||
|
|
||
|
|
||
|
def test_load_invalid_pem_data():
|
||
|
"""Test handling of invalid PEM data."""
|
||
|
invalid_pem = b"-----BEGIN CERTIFICATE-----\nInvalidCertificateData\n-----END CERTIFICATE-----"
|
||
|
with pytest.raises(TLSValidationError, match="TLS certificate parsing error"):
|
||
|
x509.load_certificates_from_pem(invalid_pem)
|
||
|
|
||
|
|
||
|
def test_validate_key_with_valid_RSA_key(self_signed_rsa_cert):
|
||
|
"""Test that a valid RSA private key is accepted."""
|
||
|
private_key_path, _ = self_signed_rsa_cert
|
||
|
with open(private_key_path, "r") as private_key_file:
|
||
|
private_key_pem = private_key_file.read()
|
||
|
assert x509.validate_key(private_key_pem) is True
|
||
|
|
||
|
|
||
|
def test_validate_key_with_valid_EC_key(letsencrypt_cert):
|
||
|
"""Test that a valid EC private key is accepted."""
|
||
|
private_key_path, _ = letsencrypt_cert
|
||
|
with open(private_key_path, "r") as private_key_file:
|
||
|
private_key_pem = private_key_file.read()
|
||
|
assert x509.validate_key(private_key_pem) is True
|
||
|
|
||
|
|
||
|
def test_validate_key_with_valid_DSA_key(invalid_algorithm_key):
|
||
|
"""Test that a valid DSA private key is not accepted."""
|
||
|
with open(invalid_algorithm_key, "r") as private_key_file:
|
||
|
private_key_pem = private_key_file.read()
|
||
|
assert x509.validate_key(private_key_pem) is False
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"scenario, key_data, expected_result, expected_exception",
|
||
|
[
|
||
|
(
|
||
|
"invalid_key_format",
|
||
|
"-----BEGIN INVALID KEY-----\nInvalidData\n-----END INVALID KEY-----",
|
||
|
None,
|
||
|
x509.TLSInvalidPrivateKeyError,
|
||
|
),
|
||
|
("none_key", None, False, None),
|
||
|
("empty_key", "", False, None),
|
||
|
],
|
||
|
)
|
||
|
def test_validate_key_with_invalid_key(
|
||
|
scenario, key_data, expected_result, expected_exception
|
||
|
):
|
||
|
"""Test that various invalid private keys are not recognized."""
|
||
|
if expected_exception:
|
||
|
with pytest.raises(expected_exception):
|
||
|
x509.validate_key(key_data)
|
||
|
else:
|
||
|
assert x509.validate_key(key_data) is expected_result
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"scenario, certificate_fixture, expected_result",
|
||
|
[
|
||
|
("expired_certificate", "expired_cert", True),
|
||
|
("valid_certificate", "self_signed_rsa_cert", False),
|
||
|
],
|
||
|
)
|
||
|
def test_validate_end_entity_expired(
|
||
|
scenario, certificate_fixture, expected_result, request
|
||
|
):
|
||
|
"""Test with both expired and valid certificates"""
|
||
|
_, cert_path = request.getfixturevalue(
|
||
|
certificate_fixture
|
||
|
) # Access cert_path from the fixture
|
||
|
with open(cert_path, "r") as cert_file:
|
||
|
tls_certificate_pem = cert_file.read().encode("utf-8")
|
||
|
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
|
||
|
assert x509.validate_end_entity_expired(end_entity_cert) is expected_result
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"scenario, certificate_fixture, expected_result",
|
||
|
[
|
||
|
("not_yet_valid_certificate", "not_yet_valid_cert", True),
|
||
|
("valid_certificate", "self_signed_rsa_cert", False),
|
||
|
],
|
||
|
)
|
||
|
def test_validate_end_entity_not_yet_valid(
|
||
|
scenario, certificate_fixture, expected_result, request
|
||
|
):
|
||
|
# Test with a certificate that is not yet valid
|
||
|
_, cert_path = request.getfixturevalue(
|
||
|
certificate_fixture
|
||
|
) # Access cert_path from the fixture
|
||
|
with open(cert_path, "r") as cert_file:
|
||
|
tls_certificate_pem = cert_file.read().encode("utf-8")
|
||
|
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
|
||
|
assert x509.validate_end_entity_not_yet_valid(end_entity_cert) is expected_result
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"scenario, cert_fixture, key_fixture, expected_result",
|
||
|
[
|
||
|
(
|
||
|
"matching_keys_rsa",
|
||
|
"self_signed_rsa_cert",
|
||
|
"self_signed_rsa_cert",
|
||
|
True,
|
||
|
), # Matching certificate and key
|
||
|
(
|
||
|
"matching_keys_ec",
|
||
|
"letsencrypt_cert",
|
||
|
"letsencrypt_cert",
|
||
|
True,
|
||
|
), # Matching certificate and key
|
||
|
(
|
||
|
"mismatched_keys",
|
||
|
"self_signed_rsa_cert",
|
||
|
"expired_cert",
|
||
|
False,
|
||
|
), # Mismatched certificate and key
|
||
|
("missing_key", "self_signed_rsa_cert", None, False), # Missing private key
|
||
|
("missing_cert", None, "self_signed_rsa_cert", False), # Missing certificate
|
||
|
],
|
||
|
)
|
||
|
def test_validate_key_matches_cert(
|
||
|
scenario, cert_fixture, key_fixture, expected_result, request
|
||
|
):
|
||
|
"""Test the validate_key_matches_cert function with different scenarios"""
|
||
|
cert_path = None
|
||
|
key_path = None
|
||
|
|
||
|
if cert_fixture:
|
||
|
_, cert_path = request.getfixturevalue(cert_fixture)
|
||
|
if key_fixture:
|
||
|
key_path, _ = request.getfixturevalue(key_fixture)
|
||
|
|
||
|
if cert_path is None or key_path is None:
|
||
|
assert x509.validate_key_matches_cert(key_path, cert_path) is expected_result
|
||
|
return
|
||
|
with open(cert_path, "r") as cert_file:
|
||
|
tls_certificate_pem = cert_file.read().encode("utf-8")
|
||
|
end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0]
|
||
|
with open(key_path, "r") as key_file:
|
||
|
tls_private_key_pem = key_file.read()
|
||
|
assert (
|
||
|
x509.validate_key_matches_cert(tls_private_key_pem, end_entity_cert)
|
||
|
is expected_result
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_extract_sans(letsencrypt_cert):
|
||
|
_, cert_path = letsencrypt_cert
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
certs = x509.load_certificates_from_pem(pem_cert)
|
||
|
sans = x509.extract_sans(certs[0])
|
||
|
assert isinstance(sans, list)
|
||
|
expected_sans = ["xahm6aech1mie5queyo8.censorship.guide"]
|
||
|
assert sans == expected_sans
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"scenario, certificate_input, expected_result",
|
||
|
[
|
||
|
(
|
||
|
"invalid_certificate_format",
|
||
|
"-----BEGIN CERTIFICATE-----\nINVALID_CERTIFICATE\n-----END CERTIFICATE-----",
|
||
|
False,
|
||
|
),
|
||
|
("certificate_is_none", None, False),
|
||
|
("certificate_is_empty_string", "", False),
|
||
|
],
|
||
|
)
|
||
|
def test_extract_sans_invalid(scenario, certificate_input, expected_result):
|
||
|
sans = x509.extract_sans(certificate_input)
|
||
|
assert isinstance(sans, list)
|
||
|
|
||
|
|
||
|
def test_validate_certificate_chain_invalid_algorithm(invalid_algorithm_chain):
|
||
|
with pytest.raises(
|
||
|
TLSValidationError, match="Certificate using unsupported algorithm"
|
||
|
):
|
||
|
x509.validate_certificate_chain(invalid_algorithm_chain)
|
||
|
|
||
|
|
||
|
def test_validate_certificate_chain_invalid_root_ca(untrusted_root_ca_chain):
|
||
|
with pytest.raises(
|
||
|
TLSValidationError,
|
||
|
match="Certificate chain does not terminate at a trusted root CA.",
|
||
|
):
|
||
|
x509.validate_certificate_chain(untrusted_root_ca_chain)
|
||
|
|
||
|
|
||
|
def test_validate_certificate_chain_valid(letsencrypt_valid_chain):
|
||
|
assert x509.validate_certificate_chain(letsencrypt_valid_chain) is True
|
||
|
|
||
|
|
||
|
def test_build_chain_single_certificate(self_signed_rsa_cert):
|
||
|
_, cert_path = self_signed_rsa_cert
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
pem_cert = cert_file.read()
|
||
|
certs = x509.load_certificates_from_pem(pem_cert)
|
||
|
assert x509.build_certificate_chain(certs) == certs
|
||
|
|
||
|
|
||
|
def test_build_chain_valid_certificate_letsencrypt(letsencrypt_valid_chain):
|
||
|
assert (
|
||
|
x509.build_certificate_chain(letsencrypt_valid_chain) == letsencrypt_valid_chain
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_build_chain_certificates_in_order(untrusted_root_ca_chain):
|
||
|
assert (
|
||
|
x509.build_certificate_chain(untrusted_root_ca_chain) == untrusted_root_ca_chain
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_build_chain_certificates_in_random_order(untrusted_root_ca_chain):
|
||
|
for perm in permutations(untrusted_root_ca_chain):
|
||
|
assert x509.build_certificate_chain(perm) == untrusted_root_ca_chain
|
||
|
|
||
|
|
||
|
def test_build_chain_empty_list():
|
||
|
with pytest.raises(
|
||
|
TLSValidationError, match="Cannot identify the end-entity certificate."
|
||
|
):
|
||
|
x509.build_certificate_chain([]) == []
|
||
|
|
||
|
|
||
|
def test_build_chain_invalid_chain_no_end_entity(circular_chain):
|
||
|
with pytest.raises(
|
||
|
TLSValidationError, match="Cannot identify the end-entity certificate."
|
||
|
):
|
||
|
x509.build_certificate_chain(circular_chain)
|
||
|
|
||
|
|
||
|
def test_build_chain_certificates_missing_issuers(untrusted_root_ca_chain):
|
||
|
untrusted_root_ca_chain.pop(1)
|
||
|
with pytest.raises(
|
||
|
TLSValidationError, match="Certificates do not form a valid chain."
|
||
|
):
|
||
|
x509.build_certificate_chain(untrusted_root_ca_chain)
|
||
|
|
||
|
|
||
|
def test_build_chain_duplicate_certificates(untrusted_root_ca_chain):
|
||
|
duplicate_untrusted_root_ca_chain = untrusted_root_ca_chain * 2
|
||
|
assert (
|
||
|
x509.build_certificate_chain(duplicate_untrusted_root_ca_chain)
|
||
|
== untrusted_root_ca_chain
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_build_chain_non_chain_certificates(
|
||
|
untrusted_root_ca_chain, letsencrypt_valid_chain
|
||
|
):
|
||
|
non_chain = untrusted_root_ca_chain + letsencrypt_valid_chain
|
||
|
with pytest.raises(
|
||
|
TLSValidationError, match="Certificates do not form a valid chain."
|
||
|
):
|
||
|
x509.build_certificate_chain(non_chain)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"scenario, cert_fixture, key_fixture, hostname, expected_result, expected_errors",
|
||
|
[
|
||
|
(
|
||
|
"empty_cert",
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
[
|
||
|
{
|
||
|
"key": "no_valid_certificates",
|
||
|
"message": "No valid certificates supplied.",
|
||
|
},
|
||
|
],
|
||
|
),
|
||
|
(
|
||
|
"trusted_cert_wrong_hostname",
|
||
|
"letsencrypt_cert",
|
||
|
"letsencrypt_cert",
|
||
|
"wrong.example.com",
|
||
|
None,
|
||
|
[
|
||
|
{
|
||
|
"key": "hostname_not_in_san",
|
||
|
"message": "wrong.example.com not found in SANs.",
|
||
|
},
|
||
|
{
|
||
|
"key": "hostname_not_in_san",
|
||
|
"message": "*.wrong.example.com not found in SANs.",
|
||
|
},
|
||
|
],
|
||
|
),
|
||
|
(
|
||
|
"expired_self_signed_cert",
|
||
|
"expired_cert",
|
||
|
"expired_cert",
|
||
|
None,
|
||
|
None,
|
||
|
[
|
||
|
{"key": "public_key_expired", "message": "TLS public key is expired."},
|
||
|
{
|
||
|
"key": "untrusted_root_ca",
|
||
|
"message": "Certificate chain does not terminate at a trusted root CA.",
|
||
|
},
|
||
|
],
|
||
|
),
|
||
|
(
|
||
|
"not_yet_valid_self_signed_cert",
|
||
|
"not_yet_valid_cert",
|
||
|
"not_yet_valid_cert",
|
||
|
None,
|
||
|
None,
|
||
|
[
|
||
|
{
|
||
|
"key": "public_key_future",
|
||
|
"message": "TLS public key is not yet valid.",
|
||
|
},
|
||
|
{
|
||
|
"key": "untrusted_root_ca",
|
||
|
"message": "Certificate chain does not terminate at a trusted root CA.",
|
||
|
},
|
||
|
],
|
||
|
),
|
||
|
(
|
||
|
"mismatched_key_cert",
|
||
|
"letsencrypt_cert",
|
||
|
"self_signed_rsa_cert",
|
||
|
None,
|
||
|
None,
|
||
|
[
|
||
|
{
|
||
|
"key": "key_mismatch",
|
||
|
"message": "Private key does not match certificate.",
|
||
|
},
|
||
|
],
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_validate_tls_keys(
|
||
|
scenario,
|
||
|
cert_fixture,
|
||
|
key_fixture,
|
||
|
hostname,
|
||
|
expected_result,
|
||
|
expected_errors,
|
||
|
request,
|
||
|
):
|
||
|
# Get the certificate and key paths from the fixture
|
||
|
if cert_fixture and key_fixture:
|
||
|
_, cert_path = request.getfixturevalue(cert_fixture)
|
||
|
key_path, _ = request.getfixturevalue(key_fixture)
|
||
|
|
||
|
# Read the certificate and key files
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
cert_pem = cert_file.read().decode("utf-8")
|
||
|
with open(key_path, "rb") as key_file:
|
||
|
key_pem = key_file.read().decode("utf-8")
|
||
|
else:
|
||
|
cert_pem = None
|
||
|
key_pem = None
|
||
|
|
||
|
# Call the validate_tls_keys function
|
||
|
skip_name_verification = False
|
||
|
if not hostname:
|
||
|
skip_name_verification = True
|
||
|
chain, san_list, errors = x509.validate_tls_keys(
|
||
|
tls_private_key_pem=key_pem,
|
||
|
tls_certificate_pem=cert_pem,
|
||
|
hostname=hostname,
|
||
|
skip_name_verification=skip_name_verification,
|
||
|
skip_chain_verification=False,
|
||
|
)
|
||
|
|
||
|
# Assert that the errors match the expected errors
|
||
|
assert errors == expected_errors
|
||
|
|
||
|
|
||
|
def test_validate_tls_keys_invalid_algorithn(
|
||
|
invalid_algorithm_key, self_signed_rsa_cert
|
||
|
):
|
||
|
with open(invalid_algorithm_key, "rb") as key_file:
|
||
|
key_pem = key_file.read().decode("utf-8")
|
||
|
_, cert_path = self_signed_rsa_cert
|
||
|
with open(cert_path, "rb") as cert_file:
|
||
|
cert_pem = cert_file.read().decode("utf-8")
|
||
|
|
||
|
chain, san_list, errors = x509.validate_tls_keys(
|
||
|
tls_private_key_pem=key_pem,
|
||
|
tls_certificate_pem=cert_pem,
|
||
|
hostname=None,
|
||
|
skip_name_verification=True,
|
||
|
skip_chain_verification=True,
|
||
|
)
|
||
|
expected_errors = [
|
||
|
{
|
||
|
"key": "private_key_not_rsa_or_ec",
|
||
|
"message": "Private key must be RSA or Elliptic-Curve.",
|
||
|
}
|
||
|
]
|
||
|
assert errors == expected_errors
|