import pytest from cryptography.x509 import Certificate from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend from datetime import datetime, timedelta from app.util import x509 from app.util.x509 import TLSValidationError from tests.api.test_onion import generate_self_signed_tls_certificate from datetime import datetime, timedelta import shutil from itertools import permutations @pytest.fixture def self_signed_rsa_cert(tmp_path): directory = ( tmp_path / "cert" ) # tmp_path is a pytest fixture implemented as a pathlib.Path object if directory.exists(): shutil.rmtree(directory) directory.mkdir() onion_address = "test.onion" valid_from = datetime.now() valid_to = valid_from + timedelta(days=30) dns_names = ["test.onion", "www.test.onion"] # Generate certificate private_key_path, cert_path = generate_self_signed_tls_certificate( str(directory), onion_address, valid_from, valid_to, dns_names=dns_names ) return private_key_path, cert_path @pytest.fixture def expired_cert(tmp_path): # Fixture for generating a certificate that is expired directory = tmp_path / "cert_expired" if directory.exists(): shutil.rmtree(directory) directory.mkdir() onion_address = "test.onion" valid_from = datetime.now() - timedelta(days=60) valid_to = datetime.now() - timedelta(days=30) dns_names = ["test.onion", "www.test.onion"] private_key_path, cert_path = generate_self_signed_tls_certificate( str(directory), onion_address, valid_from, valid_to, dns_names=dns_names ) return private_key_path, cert_path @pytest.fixture def not_yet_valid_cert(tmp_path): # Fixture for generating a certificate that is not yet valid directory = tmp_path / "cert_not_yet_valid" if directory.exists(): shutil.rmtree(directory) directory.mkdir() onion_address = "test.onion" valid_from = datetime.now() + timedelta(days=30) valid_to = valid_from + timedelta(days=30) dns_names = ["test.onion", "www.test.onion"] private_key_path, cert_path = generate_self_signed_tls_certificate( str(directory), onion_address, valid_from, valid_to, dns_names=dns_names ) return private_key_path, cert_path @pytest.fixture def letsencrypt_cert(): cert_path = "tests/data/letsencrypt-issued/fullchain.pem" private_key_path = "tests/data/letsencrypt-issued/privkey.pem" return private_key_path, cert_path @pytest.fixture def invalid_algorithm_key(): key_path = "tests/data/invalid-algorithm/dsa_private_key.pem" return key_path @pytest.fixture def letsencrypt_valid_chain(): cert_path = "tests/data/letsencrypt-issued/fullchain.pem" with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() certificates = x509.load_certificates_from_pem(pem_cert) return certificates @pytest.fixture def invalid_algorithm_chain(): cert_path = "tests/data/invalid-algorithm/fullchain.pem" with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() certificates = x509.load_certificates_from_pem(pem_cert) return certificates @pytest.fixture def untrusted_root_ca_chain(): cert_path = "tests/data/untrusted-root-ca/fullchain.pem" with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() certificates = x509.load_certificates_from_pem(pem_cert) return certificates @pytest.fixture def circular_chain(): cert_path = "tests/data/no-end-entity-chain/fullchain.pem" with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() certificates = x509.load_certificates_from_pem(pem_cert) return certificates def test_load_single_certificate(self_signed_rsa_cert): """Test loading a single certificate from PEM data.""" _, cert_path = self_signed_rsa_cert with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() certificates = x509.load_certificates_from_pem(pem_cert) assert len(certificates) == 1 assert isinstance(certificates[0], Certificate) def test_load_multiple_certificates(self_signed_rsa_cert): """Test loading multiple certificates from PEM data.""" _, cert_path = self_signed_rsa_cert with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() pem_data = pem_cert + b"\n" + pem_cert certificates = x509.load_certificates_from_pem(pem_data) assert len(certificates) == 2 for loaded_cert in certificates: assert isinstance(loaded_cert, Certificate) def test_load_invalid_pem_data(): """Test handling of invalid PEM data.""" invalid_pem = b"-----BEGIN CERTIFICATE-----\nInvalidCertificateData\n-----END CERTIFICATE-----" with pytest.raises(TLSValidationError, match="TLS certificate parsing error"): x509.load_certificates_from_pem(invalid_pem) def test_validate_key_with_valid_RSA_key(self_signed_rsa_cert): """Test that a valid RSA private key is accepted.""" private_key_path, _ = self_signed_rsa_cert with open(private_key_path, "r") as private_key_file: private_key_pem = private_key_file.read() assert x509.validate_key(private_key_pem) is True def test_validate_key_with_valid_EC_key(letsencrypt_cert): """Test that a valid EC private key is accepted.""" private_key_path, _ = letsencrypt_cert with open(private_key_path, "r") as private_key_file: private_key_pem = private_key_file.read() assert x509.validate_key(private_key_pem) is True def test_validate_key_with_valid_DSA_key(invalid_algorithm_key): """Test that a valid DSA private key is not accepted.""" with open(invalid_algorithm_key, "r") as private_key_file: private_key_pem = private_key_file.read() assert x509.validate_key(private_key_pem) is False @pytest.mark.parametrize( "scenario, key_data, expected_result, expected_exception", [ ( "invalid_key_format", "-----BEGIN INVALID KEY-----\nInvalidData\n-----END INVALID KEY-----", None, x509.TLSInvalidPrivateKeyError, ), ("none_key", None, False, None), ("empty_key", "", False, None), ], ) def test_validate_key_with_invalid_key( scenario, key_data, expected_result, expected_exception ): """Test that various invalid private keys are not recognized.""" if expected_exception: with pytest.raises(expected_exception): x509.validate_key(key_data) else: assert x509.validate_key(key_data) is expected_result @pytest.mark.parametrize( "scenario, certificate_fixture, expected_result", [ ("expired_certificate", "expired_cert", True), ("valid_certificate", "self_signed_rsa_cert", False), ], ) def test_validate_end_entity_expired( scenario, certificate_fixture, expected_result, request ): """Test with both expired and valid certificates""" _, cert_path = request.getfixturevalue( certificate_fixture ) # Access cert_path from the fixture with open(cert_path, "r") as cert_file: tls_certificate_pem = cert_file.read().encode("utf-8") end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0] assert x509.validate_end_entity_expired(end_entity_cert) is expected_result @pytest.mark.parametrize( "scenario, certificate_fixture, expected_result", [ ("not_yet_valid_certificate", "not_yet_valid_cert", True), ("valid_certificate", "self_signed_rsa_cert", False), ], ) def test_validate_end_entity_not_yet_valid( scenario, certificate_fixture, expected_result, request ): # Test with a certificate that is not yet valid _, cert_path = request.getfixturevalue( certificate_fixture ) # Access cert_path from the fixture with open(cert_path, "r") as cert_file: tls_certificate_pem = cert_file.read().encode("utf-8") end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0] assert x509.validate_end_entity_not_yet_valid(end_entity_cert) is expected_result @pytest.mark.parametrize( "scenario, cert_fixture, key_fixture, expected_result", [ ( "matching_keys_rsa", "self_signed_rsa_cert", "self_signed_rsa_cert", True, ), # Matching certificate and key ( "matching_keys_ec", "letsencrypt_cert", "letsencrypt_cert", True, ), # Matching certificate and key ( "mismatched_keys", "self_signed_rsa_cert", "expired_cert", False, ), # Mismatched certificate and key ("missing_key", "self_signed_rsa_cert", None, False), # Missing private key ("missing_cert", None, "self_signed_rsa_cert", False), # Missing certificate ], ) def test_validate_key_matches_cert( scenario, cert_fixture, key_fixture, expected_result, request ): """Test the validate_key_matches_cert function with different scenarios""" cert_path = None key_path = None if cert_fixture: _, cert_path = request.getfixturevalue(cert_fixture) if key_fixture: key_path, _ = request.getfixturevalue(key_fixture) if cert_path is None or key_path is None: assert x509.validate_key_matches_cert(key_path, cert_path) is expected_result return with open(cert_path, "r") as cert_file: tls_certificate_pem = cert_file.read().encode("utf-8") end_entity_cert = x509.load_certificates_from_pem(tls_certificate_pem)[0] with open(key_path, "r") as key_file: tls_private_key_pem = key_file.read() assert ( x509.validate_key_matches_cert(tls_private_key_pem, end_entity_cert) is expected_result ) def test_extract_sans(letsencrypt_cert): _, cert_path = letsencrypt_cert with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() certs = x509.load_certificates_from_pem(pem_cert) sans = x509.extract_sans(certs[0]) assert isinstance(sans, list) expected_sans = ["xahm6aech1mie5queyo8.censorship.guide"] assert sans == expected_sans @pytest.mark.parametrize( "scenario, certificate_input, expected_result", [ ( "invalid_certificate_format", "-----BEGIN CERTIFICATE-----\nINVALID_CERTIFICATE\n-----END CERTIFICATE-----", False, ), ("certificate_is_none", None, False), ("certificate_is_empty_string", "", False), ], ) def test_extract_sans_invalid(scenario, certificate_input, expected_result): sans = x509.extract_sans(certificate_input) assert isinstance(sans, list) def test_validate_certificate_chain_invalid_algorithm(invalid_algorithm_chain): with pytest.raises( TLSValidationError, match="Certificate using unsupported algorithm" ): x509.validate_certificate_chain(invalid_algorithm_chain) def test_validate_certificate_chain_invalid_root_ca(untrusted_root_ca_chain): with pytest.raises( TLSValidationError, match="Certificate chain does not terminate at a trusted root CA.", ): x509.validate_certificate_chain(untrusted_root_ca_chain) def test_validate_certificate_chain_valid(letsencrypt_valid_chain): assert x509.validate_certificate_chain(letsencrypt_valid_chain) is True def test_build_chain_single_certificate(self_signed_rsa_cert): _, cert_path = self_signed_rsa_cert with open(cert_path, "rb") as cert_file: pem_cert = cert_file.read() certs = x509.load_certificates_from_pem(pem_cert) assert x509.build_certificate_chain(certs) == certs def test_build_chain_valid_certificate_letsencrypt(letsencrypt_valid_chain): assert ( x509.build_certificate_chain(letsencrypt_valid_chain) == letsencrypt_valid_chain ) def test_build_chain_certificates_in_order(untrusted_root_ca_chain): assert ( x509.build_certificate_chain(untrusted_root_ca_chain) == untrusted_root_ca_chain ) def test_build_chain_certificates_in_random_order(untrusted_root_ca_chain): for perm in permutations(untrusted_root_ca_chain): assert x509.build_certificate_chain(perm) == untrusted_root_ca_chain def test_build_chain_empty_list(): with pytest.raises( TLSValidationError, match="Cannot identify the end-entity certificate." ): x509.build_certificate_chain([]) == [] def test_build_chain_invalid_chain_no_end_entity(circular_chain): with pytest.raises( TLSValidationError, match="Cannot identify the end-entity certificate." ): x509.build_certificate_chain(circular_chain) def test_build_chain_certificates_missing_issuers(untrusted_root_ca_chain): untrusted_root_ca_chain.pop(1) with pytest.raises( TLSValidationError, match="Certificates do not form a valid chain." ): x509.build_certificate_chain(untrusted_root_ca_chain) def test_build_chain_duplicate_certificates(untrusted_root_ca_chain): duplicate_untrusted_root_ca_chain = untrusted_root_ca_chain * 2 assert ( x509.build_certificate_chain(duplicate_untrusted_root_ca_chain) == untrusted_root_ca_chain ) def test_build_chain_non_chain_certificates( untrusted_root_ca_chain, letsencrypt_valid_chain ): non_chain = untrusted_root_ca_chain + letsencrypt_valid_chain with pytest.raises( TLSValidationError, match="Certificates do not form a valid chain." ): x509.build_certificate_chain(non_chain) @pytest.mark.parametrize( "scenario, cert_fixture, key_fixture, hostname, expected_result, expected_errors", [ ( "empty_cert", None, None, None, None, [ { "key": "no_valid_certificates", "message": "No valid certificates supplied.", }, ], ), ( "trusted_cert_wrong_hostname", "letsencrypt_cert", "letsencrypt_cert", "wrong.example.com", None, [ { "key": "hostname_not_in_san", "message": "wrong.example.com not found in SANs.", }, { "key": "hostname_not_in_san", "message": "*.wrong.example.com not found in SANs.", }, ], ), ( "expired_self_signed_cert", "expired_cert", "expired_cert", None, None, [ {"key": "public_key_expired", "message": "TLS public key is expired."}, { "key": "untrusted_root_ca", "message": "Certificate chain does not terminate at a trusted root CA.", }, ], ), ( "not_yet_valid_self_signed_cert", "not_yet_valid_cert", "not_yet_valid_cert", None, None, [ { "key": "public_key_future", "message": "TLS public key is not yet valid.", }, { "key": "untrusted_root_ca", "message": "Certificate chain does not terminate at a trusted root CA.", }, ], ), ( "mismatched_key_cert", "letsencrypt_cert", "self_signed_rsa_cert", None, None, [ { "key": "key_mismatch", "message": "Private key does not match certificate.", }, ], ), ], ) def test_validate_tls_keys( scenario, cert_fixture, key_fixture, hostname, expected_result, expected_errors, request, ): # Get the certificate and key paths from the fixture if cert_fixture and key_fixture: _, cert_path = request.getfixturevalue(cert_fixture) key_path, _ = request.getfixturevalue(key_fixture) # Read the certificate and key files with open(cert_path, "rb") as cert_file: cert_pem = cert_file.read().decode("utf-8") with open(key_path, "rb") as key_file: key_pem = key_file.read().decode("utf-8") else: cert_pem = None key_pem = None # Call the validate_tls_keys function skip_name_verification = False if not hostname: skip_name_verification = True chain, san_list, errors = x509.validate_tls_keys( tls_private_key_pem=key_pem, tls_certificate_pem=cert_pem, hostname=hostname, skip_name_verification=skip_name_verification, skip_chain_verification=False, ) # Assert that the errors match the expected errors assert errors == expected_errors def test_validate_tls_keys_invalid_algorithn( invalid_algorithm_key, self_signed_rsa_cert ): with open(invalid_algorithm_key, "rb") as key_file: key_pem = key_file.read().decode("utf-8") _, cert_path = self_signed_rsa_cert with open(cert_path, "rb") as cert_file: cert_pem = cert_file.read().decode("utf-8") chain, san_list, errors = x509.validate_tls_keys( tls_private_key_pem=key_pem, tls_certificate_pem=cert_pem, hostname=None, skip_name_verification=True, skip_chain_verification=True, ) expected_errors = [ { "key": "private_key_not_rsa_or_ec", "message": "Private key must be RSA or Elliptic-Curve.", } ] assert errors == expected_errors