import ssl from datetime import datetime, timezone from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa, ec from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from cryptography.hazmat.primitives.asymmetric.ec import ( EllipticCurvePublicKey, EllipticCurvePrivateKey, ) from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric.utils import Prehashed from cryptography.hazmat.primitives import hashes class TLSValidationError(ValueError): key: str message: str def __str__(self) -> str: return self.message def as_dict(self) -> Dict[str, str]: return {"key": self.key, "message": self.message} class TLSCertificateParsingError(TLSValidationError): key = "could_not_parse_certificate" message = "TLS certificate parsing error" class TLSInvalidPrivateKeyError(TLSValidationError): key = "invalid_private_key" message = "Private key is invalid" class TLSNoEndEntityError(TLSValidationError): key = "could_not_identify_end_entity" message = "Cannot identify the end-entity certificate." class TLSChainNotValidError(TLSValidationError): key = "invalid_tls_chain" message = "Certificates do not form a valid chain." class TLSUnsupportedAlgorithmError(TLSValidationError): key = "invalid_algorithm" message = "Certificate using unsupported algorithm" class TLSMissingHashError(TLSValidationError): key = "missing_hash_algorithm" message = "Certificate missing hash algorithm." class TLSUntrustedRootCAError(TLSValidationError): key = "untrusted_root_ca" message = "Certificate chain does not terminate at a trusted root CA." def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]: certificates = [] try: for pem_block in pem_data.split(b"-----END CERTIFICATE-----"): pem_block = pem_block.strip() if pem_block: pem_block += b"-----END CERTIFICATE-----" certificate = x509.load_pem_x509_certificate( pem_block, default_backend() ) certificates.append(certificate) except ValueError: raise TLSCertificateParsingError() return certificates def build_certificate_chain( certificates: list[x509.Certificate], ) -> list[x509.Certificate]: if len(certificates) == 1: return certificates chain = [] cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates} end_entity = next( ( cert for cert in certificates if not any(cert.subject == other_cert.issuer for other_cert in certificates) ), None, ) if not end_entity: raise TLSNoEndEntityError() chain.append(end_entity) current_cert = end_entity # when there is 1 item left in cert_map, that will be the Root CA while len(cert_map) > 1: issuer_key = current_cert.issuer.rfc4514_string() if issuer_key not in cert_map: raise TLSChainNotValidError() next_cert = cert_map[issuer_key] chain.append(next_cert) cert_map.pop(current_cert.subject.rfc4514_string()) current_cert = next_cert return chain def validate_certificate_chain(chain: list[x509.Certificate]) -> bool: """Validate a certificate chain against the system's root CA store.""" context = ssl.create_default_context() store = context.get_ca_certs(binary_form=True) trusted_certificates = [x509.load_der_x509_certificate(cert) for cert in store] for i in range(len(chain) - 1): next_public_key = chain[i + 1].public_key() if not isinstance(next_public_key, RSAPublicKey) and not isinstance( next_public_key, EllipticCurvePublicKey ): raise TLSUnsupportedAlgorithmError() hash_algorithm = chain[i].signature_hash_algorithm if TYPE_CHECKING: if hash_algorithm is None: raise TLSMissingHashError() if isinstance(next_public_key, RSAPublicKey): next_public_key.verify( chain[i].signature, chain[i].tbs_certificate_bytes, PKCS1v15(), hash_algorithm, ) elif isinstance(next_public_key, EllipticCurvePublicKey): digest = hashes.Hash(hash_algorithm) digest.update(chain[i].tbs_certificate_bytes) digest_value = digest.finalize() next_public_key.verify( chain[i].signature, digest_value, ec.ECDSA(Prehashed(hash_algorithm)), ) end_cert = chain[-1] if not any( end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates ): raise TLSUntrustedRootCAError() return True def validate_key(tls_private_key_pem: Optional[str]) -> bool: if tls_private_key_pem: try: private_key = serialization.load_pem_private_key( tls_private_key_pem.encode("utf-8"), password=None, backend=default_backend(), ) return isinstance(private_key, rsa.RSAPrivateKey) or isinstance( private_key, ec.EllipticCurvePrivateKey ) except ValueError: raise TLSInvalidPrivateKeyError() return False def extract_sans(cert: x509.Certificate) -> List[str]: try: san_extension = cert.extensions.get_extension_for_oid( ExtensionOID.SUBJECT_ALTERNATIVE_NAME ) sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined] return sans except Exception: return [] def validate_end_entity_expired(certificate: x509.Certificate) -> bool: if certificate.not_valid_after_utc < datetime.now(timezone.utc): return True return False def validate_end_entity_not_yet_valid(certificate: x509.Certificate) -> bool: if certificate.not_valid_before_utc > datetime.now(timezone.utc): return True return False def validate_key_matches_cert( tls_private_key_pem: Optional[str], certificate: x509.Certificate ) -> bool: if not tls_private_key_pem or not certificate: return False private_key = serialization.load_pem_private_key( tls_private_key_pem.encode("utf-8"), password=None, backend=default_backend(), ) public_key = certificate.public_key() signature_hash_algorithm = certificate.signature_hash_algorithm if TYPE_CHECKING: assert isinstance(public_key, rsa.RSAPublicKey) or isinstance( public_key, EllipticCurvePublicKey ) # nosec: B101 assert isinstance(private_key, rsa.RSAPrivateKey) or isinstance( public_key, EllipticCurvePrivateKey ) # nosec: B101 assert signature_hash_algorithm is not None # nosec: B101 if not ( ( isinstance(private_key, rsa.RSAPrivateKey) and isinstance(public_key, rsa.RSAPublicKey) ) or ( isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(public_key, ec.EllipticCurvePublicKey) ) ): return False try: test_message = b"test" if isinstance(public_key, RSAPublicKey) and isinstance( private_key, rsa.RSAPrivateKey ): signature = private_key.sign( test_message, padding.PKCS1v15(), signature_hash_algorithm, ) public_key.verify( signature, test_message, padding.PKCS1v15(), signature_hash_algorithm, ) if isinstance(public_key, EllipticCurvePublicKey) and isinstance( private_key, ec.EllipticCurvePrivateKey ): signature = private_key.sign( test_message, ec.ECDSA(signature_hash_algorithm), ) public_key.verify( signature, test_message, ec.ECDSA(signature_hash_algorithm), ) return True except InvalidSignature: return False def validate_tls_keys( tls_private_key_pem: Optional[str], tls_certificate_pem: Optional[str], skip_chain_verification: Optional[bool], skip_name_verification: Optional[bool], hostname: str, ) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]: errors = [] san_list = [] chain = None skip_chain_verification = skip_chain_verification or False skip_name_verification = skip_name_verification or False certificates = [] try: if tls_certificate_pem: certificates = load_certificates_from_pem( tls_certificate_pem.encode("utf-8") ) except TLSValidationError as e: errors.append(e.as_dict()) if len(certificates) > 0: try: chain = build_certificate_chain(certificates) end_entity_cert = chain[0] # validate expiry if validate_end_entity_expired(end_entity_cert): errors.append( { "key": "public_key_expired", "message": "TLS public key is expired.", } ) # validate beginning if validate_end_entity_not_yet_valid(end_entity_cert): errors.append( { "key": "public_key_future", "message": "TLS public key is not yet valid.", } ) # chain verification if not skip_chain_verification: validate_certificate_chain(chain) # name verification if not skip_name_verification: san_list = extract_sans(end_entity_cert) for expected_hostname in [hostname, f"*.{hostname}"]: if expected_hostname not in san_list: errors.append( { "key": "hostname_not_in_san", "message": f"{expected_hostname} not found in SANs.", } ) # check if key is valid if validate_key(tls_private_key_pem): # check if key matches cert if not validate_key_matches_cert(tls_private_key_pem, end_entity_cert): errors.append( { "key": "key_mismatch", "message": "Private key does not match certificate.", } ) else: errors.append( { "key": "private_key_not_rsa_or_ec", "message": "Private key must be RSA or Elliptic-Curve.", } ) except TLSValidationError as e: errors.append(e.as_dict()) else: errors.append( { "key": "no_valid_certificates", "message": "No valid certificates supplied.", } ) return chain, san_list, errors