majuna/app/util/x509.py

343 lines
12 KiB
Python
Raw Normal View History

import ssl
2024-12-06 13:34:44 +00:00
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from cryptography import x509
2024-12-06 13:34:44 +00:00
from cryptography.hazmat._oid import ExtensionOID
from cryptography.hazmat.backends import default_backend
2024-12-06 13:34:44 +00:00
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa, ec
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePublicKey,
EllipticCurvePrivateKey,
)
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
from cryptography.hazmat.primitives import hashes
class TLSValidationError(ValueError):
key: str
message: str
def __str__(self) -> str:
return self.message
def as_dict(self) -> Dict[str, str]:
return {"key": self.key, "message": self.message}
class TLSCertificateParsingError(TLSValidationError):
key = "could_not_parse_certificate"
message = "TLS certificate parsing error"
class TLSInvalidPrivateKeyError(TLSValidationError):
key = "invalid_private_key"
message = "Private key is invalid"
class TLSNoEndEntityError(TLSValidationError):
key = "could_not_identify_end_entity"
message = "Cannot identify the end-entity certificate."
class TLSChainNotValidError(TLSValidationError):
key = "invalid_tls_chain"
message = "Certificates do not form a valid chain."
class TLSUnsupportedAlgorithmError(TLSValidationError):
key = "invalid_algorithm"
message = "Certificate using unsupported algorithm"
class TLSMissingHashError(TLSValidationError):
key = "missing_hash_algorithm"
message = "Certificate missing hash algorithm."
class TLSUntrustedRootCAError(TLSValidationError):
key = "untrusted_root_ca"
message = "Certificate chain does not terminate at a trusted root CA."
def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]:
certificates = []
try:
for pem_block in pem_data.split(b"-----END CERTIFICATE-----"):
pem_block = pem_block.strip()
if pem_block:
pem_block += b"-----END CERTIFICATE-----"
certificate = x509.load_pem_x509_certificate(
pem_block, default_backend()
)
certificates.append(certificate)
except ValueError:
raise TLSCertificateParsingError()
return certificates
2024-12-06 18:15:47 +00:00
def build_certificate_chain(
certificates: list[x509.Certificate],
) -> list[x509.Certificate]:
if len(certificates) == 1:
return certificates
chain = []
cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates}
end_entity = next(
2024-12-06 18:15:47 +00:00
(
cert
for cert in certificates
if not any(cert.subject == other_cert.issuer for other_cert in certificates)
2024-12-06 18:15:47 +00:00
),
None,
)
if not end_entity:
raise TLSNoEndEntityError()
chain.append(end_entity)
current_cert = end_entity
# when there is 1 item left in cert_map, that will be the Root CA
while len(cert_map) > 1:
issuer_key = current_cert.issuer.rfc4514_string()
if issuer_key not in cert_map:
raise TLSChainNotValidError()
next_cert = cert_map[issuer_key]
chain.append(next_cert)
cert_map.pop(current_cert.subject.rfc4514_string())
current_cert = next_cert
return chain
def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
"""Validate a certificate chain against the system's root CA store."""
context = ssl.create_default_context()
store = context.get_ca_certs(binary_form=True)
trusted_certificates = [x509.load_der_x509_certificate(cert) for cert in store]
for i in range(len(chain) - 1):
next_public_key = chain[i + 1].public_key()
if not isinstance(next_public_key, RSAPublicKey) and not isinstance(
next_public_key, EllipticCurvePublicKey
):
raise TLSUnsupportedAlgorithmError()
hash_algorithm = chain[i].signature_hash_algorithm
if TYPE_CHECKING:
if hash_algorithm is None:
raise TLSMissingHashError()
if isinstance(next_public_key, RSAPublicKey):
next_public_key.verify(
chain[i].signature,
chain[i].tbs_certificate_bytes,
PKCS1v15(),
hash_algorithm,
)
elif isinstance(next_public_key, EllipticCurvePublicKey):
digest = hashes.Hash(hash_algorithm)
digest.update(chain[i].tbs_certificate_bytes)
digest_value = digest.finalize()
next_public_key.verify(
chain[i].signature,
digest_value,
ec.ECDSA(Prehashed(hash_algorithm)),
)
end_cert = chain[-1]
if not any(
2024-12-06 18:15:47 +00:00
end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates
):
raise TLSUntrustedRootCAError()
return True
2024-12-06 13:34:44 +00:00
def validate_key(tls_private_key_pem: Optional[str]) -> bool:
if tls_private_key_pem:
try:
private_key = serialization.load_pem_private_key(
tls_private_key_pem.encode("utf-8"),
password=None,
backend=default_backend(),
)
return isinstance(private_key, rsa.RSAPrivateKey) or isinstance(
private_key, ec.EllipticCurvePrivateKey
)
except ValueError:
raise TLSInvalidPrivateKeyError()
return False
def extract_sans(cert: x509.Certificate) -> List[str]:
try:
san_extension = cert.extensions.get_extension_for_oid(
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
)
sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
return sans
except Exception:
return []
def validate_end_entity_expired(certificate: x509.Certificate) -> bool:
if certificate.not_valid_after_utc < datetime.now(timezone.utc):
return True
return False
def validate_end_entity_not_yet_valid(certificate: x509.Certificate) -> bool:
if certificate.not_valid_before_utc > datetime.now(timezone.utc):
return True
return False
def validate_key_matches_cert(
tls_private_key_pem: Optional[str], certificate: x509.Certificate
) -> bool:
if not tls_private_key_pem or not certificate:
return False
private_key = serialization.load_pem_private_key(
tls_private_key_pem.encode("utf-8"),
password=None,
backend=default_backend(),
)
public_key = certificate.public_key()
signature_hash_algorithm = certificate.signature_hash_algorithm
if TYPE_CHECKING:
assert isinstance(public_key, rsa.RSAPublicKey) or isinstance(
public_key, EllipticCurvePublicKey
) # nosec: B101
assert isinstance(private_key, rsa.RSAPrivateKey) or isinstance(
public_key, EllipticCurvePrivateKey
) # nosec: B101
assert signature_hash_algorithm is not None # nosec: B101
if not (
(
isinstance(private_key, rsa.RSAPrivateKey)
and isinstance(public_key, rsa.RSAPublicKey)
)
or (
isinstance(private_key, ec.EllipticCurvePrivateKey)
and isinstance(public_key, ec.EllipticCurvePublicKey)
)
):
return False
try:
test_message = b"test"
if isinstance(public_key, RSAPublicKey) and isinstance(
private_key, rsa.RSAPrivateKey
):
signature = private_key.sign(
test_message,
padding.PKCS1v15(),
signature_hash_algorithm,
)
public_key.verify(
signature,
test_message,
padding.PKCS1v15(),
signature_hash_algorithm,
)
if isinstance(public_key, EllipticCurvePublicKey) and isinstance(
private_key, ec.EllipticCurvePrivateKey
):
signature = private_key.sign(
test_message,
ec.ECDSA(signature_hash_algorithm),
)
public_key.verify(
signature,
test_message,
ec.ECDSA(signature_hash_algorithm),
)
return True
except InvalidSignature:
return False
2024-12-06 13:34:44 +00:00
def validate_tls_keys(
2024-12-06 18:15:47 +00:00
tls_private_key_pem: Optional[str],
tls_certificate_pem: Optional[str],
skip_chain_verification: Optional[bool],
skip_name_verification: Optional[bool],
hostname: str,
2024-12-06 13:34:44 +00:00
) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]:
errors = []
san_list = []
chain = None
skip_chain_verification = skip_chain_verification or False
skip_name_verification = skip_name_verification or False
certificates = []
2024-12-06 13:34:44 +00:00
try:
if tls_certificate_pem:
certificates = load_certificates_from_pem(
tls_certificate_pem.encode("utf-8")
2024-12-06 13:34:44 +00:00
)
except TLSValidationError as e:
errors.append(e.as_dict())
if len(certificates) > 0:
try:
chain = build_certificate_chain(certificates)
end_entity_cert = chain[0]
# validate expiry
if validate_end_entity_expired(end_entity_cert):
2024-12-06 18:15:47 +00:00
errors.append(
{
"key": "public_key_expired",
"message": "TLS public key is expired.",
2024-12-06 18:15:47 +00:00
}
)
# validate beginning
if validate_end_entity_not_yet_valid(end_entity_cert):
2024-12-06 18:15:47 +00:00
errors.append(
{
"key": "public_key_future",
"message": "TLS public key is not yet valid.",
2024-12-06 18:15:47 +00:00
}
)
# chain verification
if not skip_chain_verification:
validate_certificate_chain(chain)
# name verification
if not skip_name_verification:
san_list = extract_sans(end_entity_cert)
for expected_hostname in [hostname, f"*.{hostname}"]:
if expected_hostname not in san_list:
2024-12-06 13:34:44 +00:00
errors.append(
2024-12-06 18:15:47 +00:00
{
"key": "hostname_not_in_san",
"message": f"{expected_hostname} not found in SANs.",
2024-12-06 18:15:47 +00:00
}
)
# check if key is valid
if validate_key(tls_private_key_pem):
# check if key matches cert
if not validate_key_matches_cert(tls_private_key_pem, end_entity_cert):
errors.append(
{
"key": "key_mismatch",
"message": "Private key does not match certificate.",
}
)
else:
errors.append(
{
"key": "private_key_not_rsa_or_ec",
"message": "Private key must be RSA or Elliptic-Curve.",
}
)
2024-12-06 13:34:44 +00:00
except TLSValidationError as e:
errors.append(e.as_dict())
2024-12-06 13:34:44 +00:00
else:
errors.append(
{
"key": "no_valid_certificates",
"message": "No valid certificates supplied.",
}
)
2024-12-06 13:34:44 +00:00
return chain, san_list, errors