import ssl from datetime import datetime, timezone from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]: certificates = [] for pem_block in pem_data.split(b"-----END CERTIFICATE-----"): pem_block = pem_block.strip() if pem_block: pem_block += b"-----END CERTIFICATE-----" certificate = x509.load_pem_x509_certificate(pem_block, default_backend()) certificates.append(certificate) return certificates def build_certificate_chain(certificates: list[x509.Certificate]) -> list[x509.Certificate]: if len(certificates) == 1: return certificates chain = [] cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates} end_entity = next( (cert for cert in certificates if cert.subject.rfc4514_string() not in cert_map), None ) if not end_entity: raise ValueError("Cannot identify the end-entity certificate.") chain.append(end_entity) current_cert = end_entity while current_cert.issuer.rfc4514_string() in cert_map: next_cert = cert_map[current_cert.issuer.rfc4514_string()] chain.append(next_cert) current_cert = next_cert return chain def validate_certificate_chain(chain: list[x509.Certificate]) -> bool: """Validate a certificate chain against the system's root CA store.""" context = ssl.create_default_context() store = context.get_ca_certs(binary_form=True) trusted_certificates = [x509.load_der_x509_certificate(cert) for cert in store] for i in range(len(chain) - 1): next_public_key = chain[i + 1].public_key() if not (isinstance(next_public_key, RSAPublicKey)): raise ValueError(f"Certificate using unsupported algorithm: {type(next_public_key)}") hash_algorithm = chain[i].signature_hash_algorithm if hash_algorithm is None: raise ValueError("Certificate missing hash algorithm") next_public_key.verify( chain[i].signature, chain[i].tbs_certificate_bytes, PKCS1v15(), hash_algorithm ) end_cert = chain[-1] if not any( end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates ): raise ValueError("Certificate chain does not terminate at a trusted root CA.") return True def validate_tls_keys( tls_private_key_pem: Optional[str], tls_certificate_pem: Optional[str], skip_chain_verification: Optional[bool], skip_name_verification: Optional[bool], hostname: str ) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]: errors = [] san_list = [] chain = None skip_chain_verification = skip_chain_verification or False skip_name_verification = skip_name_verification or False try: private_key = None if tls_private_key_pem: private_key = serialization.load_pem_private_key( tls_private_key_pem.encode("utf-8"), password=None, backend=default_backend() ) if not isinstance(private_key, rsa.RSAPrivateKey): errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be RSA."}) if tls_certificate_pem: certificates = list(load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))) if not certificates: errors.append({"Error": "tls_certificate_invalid", "Message": "No valid certificate found."}) else: chain = build_certificate_chain(certificates) end_entity_cert = chain[0] if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc): errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."}) if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc): errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."}) if private_key: public_key = end_entity_cert.public_key() if TYPE_CHECKING: assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101 assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101 assert end_entity_cert.signature_hash_algorithm is not None # nosec: B101 try: test_message = b"test" signature = private_key.sign( test_message, padding.PKCS1v15(), end_entity_cert.signature_hash_algorithm, ) public_key.verify( signature, test_message, padding.PKCS1v15(), end_entity_cert.signature_hash_algorithm, ) except Exception: errors.append( {"Error": "tls_key_mismatch", "Message": "Private key does not match certificate."}) if not skip_chain_verification: try: validate_certificate_chain(chain) except ValueError as e: errors.append({"Error": "certificate_chain_invalid", "Message": str(e)}) if not skip_name_verification: san_list = extract_sans(end_entity_cert) for expected_hostname in [hostname, f"*.{hostname}"]: if expected_hostname not in san_list: errors.append( {"Error": "hostname_not_in_san", "Message": f"{expected_hostname} not found in SANs."}) except Exception as e: errors.append({"Error": "tls_validation_error", "Message": str(e)}) return chain, san_list, errors def extract_sans(cert: x509.Certificate) -> List[str]: try: san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined] return sans except Exception: return []