feat: break up validate_tls_keys and add unit tests
I've split the existing code in several new functions: - load_certificates_from_pem (takes pem data as bytes) - build_certificate_chain (takes a list of Certificates) - validate_certificate_chain (takes a list of Certificates) - validate_key (takes pem data as a string) - validate_key_matches_cert (now takes a pem key string and a Certificate) - extract_sans (now takes a Certificate) - validate_end_entity_expired (now takes a Certificate) - validate_end_entity_not_yet_valid (now takes a Certificate) When a relevant exception arises, these functions raise a type of TLSValidationError, these are appended to the list of errors when validating a cert.
This commit is contained in:
parent
5275a2a882
commit
d5fa521fa1
10 changed files with 1091 additions and 120 deletions
378
app/util/x509.py
378
app/util/x509.py
|
@ -6,19 +6,77 @@ from cryptography import x509
|
|||
from cryptography.hazmat._oid import ExtensionOID
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa, ec
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
EllipticCurvePublicKey,
|
||||
EllipticCurvePrivateKey,
|
||||
)
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
|
||||
class TLSValidationError(ValueError):
|
||||
key: str
|
||||
message: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.message
|
||||
|
||||
def as_dict(self) -> Dict[str, str]:
|
||||
return {"key": self.key, "message": self.message}
|
||||
|
||||
|
||||
class TLSCertificateParsingError(TLSValidationError):
|
||||
key = "could_not_parse_certificate"
|
||||
message = "TLS certificate parsing error"
|
||||
|
||||
|
||||
class TLSInvalidPrivateKeyError(TLSValidationError):
|
||||
key = "invalid_private_key"
|
||||
message = "Private key is invalid"
|
||||
|
||||
|
||||
class TLSNoEndEntityError(TLSValidationError):
|
||||
key = "could_not_identify_end_entity"
|
||||
message = "Cannot identify the end-entity certificate."
|
||||
|
||||
|
||||
class TLSChainNotValidError(TLSValidationError):
|
||||
key = "invalid_tls_chain"
|
||||
message = "Certificates do not form a valid chain."
|
||||
|
||||
|
||||
class TLSUnsupportedAlgorithmError(TLSValidationError):
|
||||
key = "invalid_algorithm"
|
||||
message = "Certificate using unsupported algorithm"
|
||||
|
||||
|
||||
class TLSMissingHashError(TLSValidationError):
|
||||
key = "missing_hash_algorithm"
|
||||
message = "Certificate missing hash algorithm."
|
||||
|
||||
|
||||
class TLSUntrustedRootCAError(TLSValidationError):
|
||||
key = "untrusted_root_ca"
|
||||
message = "Certificate chain does not terminate at a trusted root CA."
|
||||
|
||||
|
||||
def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]:
|
||||
certificates = []
|
||||
for pem_block in pem_data.split(b"-----END CERTIFICATE-----"):
|
||||
pem_block = pem_block.strip()
|
||||
if pem_block:
|
||||
pem_block += b"-----END CERTIFICATE-----"
|
||||
certificate = x509.load_pem_x509_certificate(pem_block, default_backend())
|
||||
certificates.append(certificate)
|
||||
try:
|
||||
for pem_block in pem_data.split(b"-----END CERTIFICATE-----"):
|
||||
pem_block = pem_block.strip()
|
||||
if pem_block:
|
||||
pem_block += b"-----END CERTIFICATE-----"
|
||||
certificate = x509.load_pem_x509_certificate(
|
||||
pem_block, default_backend()
|
||||
)
|
||||
certificates.append(certificate)
|
||||
except ValueError:
|
||||
raise TLSCertificateParsingError()
|
||||
return certificates
|
||||
|
||||
|
||||
|
@ -33,17 +91,22 @@ def build_certificate_chain(
|
|||
(
|
||||
cert
|
||||
for cert in certificates
|
||||
if cert.subject.rfc4514_string() not in cert_map
|
||||
if not any(cert.subject == other_cert.issuer for other_cert in certificates)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not end_entity:
|
||||
raise ValueError("Cannot identify the end-entity certificate.")
|
||||
raise TLSNoEndEntityError()
|
||||
chain.append(end_entity)
|
||||
current_cert = end_entity
|
||||
while current_cert.issuer.rfc4514_string() in cert_map:
|
||||
next_cert = cert_map[current_cert.issuer.rfc4514_string()]
|
||||
# when there is 1 item left in cert_map, that will be the Root CA
|
||||
while len(cert_map) > 1:
|
||||
issuer_key = current_cert.issuer.rfc4514_string()
|
||||
if issuer_key not in cert_map:
|
||||
raise TLSChainNotValidError()
|
||||
next_cert = cert_map[issuer_key]
|
||||
chain.append(next_cert)
|
||||
cert_map.pop(current_cert.subject.rfc4514_string())
|
||||
current_cert = next_cert
|
||||
return chain
|
||||
|
||||
|
@ -56,28 +119,142 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
|
|||
|
||||
for i in range(len(chain) - 1):
|
||||
next_public_key = chain[i + 1].public_key()
|
||||
if not (isinstance(next_public_key, RSAPublicKey)):
|
||||
raise ValueError(
|
||||
f"Certificate using unsupported algorithm: {type(next_public_key)}"
|
||||
)
|
||||
if not isinstance(next_public_key, RSAPublicKey) and not isinstance(
|
||||
next_public_key, EllipticCurvePublicKey
|
||||
):
|
||||
raise TLSUnsupportedAlgorithmError()
|
||||
hash_algorithm = chain[i].signature_hash_algorithm
|
||||
if hash_algorithm is None:
|
||||
raise ValueError("Certificate missing hash algorithm")
|
||||
next_public_key.verify(
|
||||
chain[i].signature,
|
||||
chain[i].tbs_certificate_bytes,
|
||||
PKCS1v15(),
|
||||
hash_algorithm,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
if hash_algorithm is None:
|
||||
raise TLSMissingHashError()
|
||||
if isinstance(next_public_key, RSAPublicKey):
|
||||
next_public_key.verify(
|
||||
chain[i].signature,
|
||||
chain[i].tbs_certificate_bytes,
|
||||
PKCS1v15(),
|
||||
hash_algorithm,
|
||||
)
|
||||
elif isinstance(next_public_key, EllipticCurvePublicKey):
|
||||
digest = hashes.Hash(hash_algorithm)
|
||||
digest.update(chain[i].tbs_certificate_bytes)
|
||||
digest_value = digest.finalize()
|
||||
next_public_key.verify(
|
||||
chain[i].signature,
|
||||
digest_value,
|
||||
ec.ECDSA(Prehashed(hash_algorithm)),
|
||||
)
|
||||
|
||||
end_cert = chain[-1]
|
||||
if not any(
|
||||
end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates
|
||||
):
|
||||
raise ValueError("Certificate chain does not terminate at a trusted root CA.")
|
||||
raise TLSUntrustedRootCAError()
|
||||
return True
|
||||
|
||||
|
||||
def validate_key(tls_private_key_pem: Optional[str]) -> bool:
|
||||
if tls_private_key_pem:
|
||||
try:
|
||||
private_key = serialization.load_pem_private_key(
|
||||
tls_private_key_pem.encode("utf-8"),
|
||||
password=None,
|
||||
backend=default_backend(),
|
||||
)
|
||||
return isinstance(private_key, rsa.RSAPrivateKey) or isinstance(
|
||||
private_key, ec.EllipticCurvePrivateKey
|
||||
)
|
||||
except ValueError:
|
||||
raise TLSInvalidPrivateKeyError()
|
||||
return False
|
||||
|
||||
|
||||
def extract_sans(cert: x509.Certificate) -> List[str]:
|
||||
try:
|
||||
san_extension = cert.extensions.get_extension_for_oid(
|
||||
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
|
||||
)
|
||||
sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
|
||||
return sans
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def validate_end_entity_expired(certificate: x509.Certificate) -> bool:
|
||||
if certificate.not_valid_after_utc < datetime.now(timezone.utc):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def validate_end_entity_not_yet_valid(certificate: x509.Certificate) -> bool:
|
||||
if certificate.not_valid_before_utc > datetime.now(timezone.utc):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def validate_key_matches_cert(
|
||||
tls_private_key_pem: Optional[str], certificate: x509.Certificate
|
||||
) -> bool:
|
||||
if not tls_private_key_pem or not certificate:
|
||||
return False
|
||||
private_key = serialization.load_pem_private_key(
|
||||
tls_private_key_pem.encode("utf-8"),
|
||||
password=None,
|
||||
backend=default_backend(),
|
||||
)
|
||||
public_key = certificate.public_key()
|
||||
signature_hash_algorithm = certificate.signature_hash_algorithm
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(public_key, rsa.RSAPublicKey) or isinstance(
|
||||
public_key, EllipticCurvePublicKey
|
||||
) # nosec: B101
|
||||
assert isinstance(private_key, rsa.RSAPrivateKey) or isinstance(
|
||||
public_key, EllipticCurvePrivateKey
|
||||
) # nosec: B101
|
||||
assert signature_hash_algorithm is not None # nosec: B101
|
||||
if not (
|
||||
(
|
||||
isinstance(private_key, rsa.RSAPrivateKey)
|
||||
and isinstance(public_key, rsa.RSAPublicKey)
|
||||
)
|
||||
or (
|
||||
isinstance(private_key, ec.EllipticCurvePrivateKey)
|
||||
and isinstance(public_key, ec.EllipticCurvePublicKey)
|
||||
)
|
||||
):
|
||||
return False
|
||||
try:
|
||||
test_message = b"test"
|
||||
if isinstance(public_key, RSAPublicKey) and isinstance(
|
||||
private_key, rsa.RSAPrivateKey
|
||||
):
|
||||
signature = private_key.sign(
|
||||
test_message,
|
||||
padding.PKCS1v15(),
|
||||
signature_hash_algorithm,
|
||||
)
|
||||
public_key.verify(
|
||||
signature,
|
||||
test_message,
|
||||
padding.PKCS1v15(),
|
||||
signature_hash_algorithm,
|
||||
)
|
||||
if isinstance(public_key, EllipticCurvePublicKey) and isinstance(
|
||||
private_key, ec.EllipticCurvePrivateKey
|
||||
):
|
||||
signature = private_key.sign(
|
||||
test_message,
|
||||
ec.ECDSA(signature_hash_algorithm),
|
||||
)
|
||||
public_key.verify(
|
||||
signature,
|
||||
test_message,
|
||||
ec.ECDSA(signature_hash_algorithm),
|
||||
)
|
||||
return True
|
||||
except InvalidSignature:
|
||||
return False
|
||||
|
||||
|
||||
def validate_tls_keys(
|
||||
tls_private_key_pem: Optional[str],
|
||||
tls_certificate_pem: Optional[str],
|
||||
|
@ -91,114 +268,75 @@ def validate_tls_keys(
|
|||
|
||||
skip_chain_verification = skip_chain_verification or False
|
||||
skip_name_verification = skip_name_verification or False
|
||||
|
||||
certificates = []
|
||||
try:
|
||||
private_key = None
|
||||
if tls_private_key_pem:
|
||||
private_key = serialization.load_pem_private_key(
|
||||
tls_private_key_pem.encode("utf-8"),
|
||||
password=None,
|
||||
backend=default_backend(),
|
||||
)
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_private_key_invalid",
|
||||
"Message": "Private key must be RSA.",
|
||||
}
|
||||
)
|
||||
|
||||
if tls_certificate_pem:
|
||||
certificates = list(
|
||||
load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))
|
||||
certificates = load_certificates_from_pem(
|
||||
tls_certificate_pem.encode("utf-8")
|
||||
)
|
||||
if not certificates:
|
||||
except TLSValidationError as e:
|
||||
errors.append(e.as_dict())
|
||||
if len(certificates) > 0:
|
||||
try:
|
||||
chain = build_certificate_chain(certificates)
|
||||
end_entity_cert = chain[0]
|
||||
# validate expiry
|
||||
if validate_end_entity_expired(end_entity_cert):
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_certificate_invalid",
|
||||
"Message": "No valid certificate found.",
|
||||
"key": "public_key_expired",
|
||||
"message": "TLS public key is expired.",
|
||||
}
|
||||
)
|
||||
else:
|
||||
chain = build_certificate_chain(certificates)
|
||||
end_entity_cert = chain[0]
|
||||
|
||||
if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc):
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_public_key_expired",
|
||||
"Message": "TLS public key is expired.",
|
||||
}
|
||||
)
|
||||
|
||||
if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc):
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_public_key_future",
|
||||
"Message": "TLS public key is not yet valid.",
|
||||
}
|
||||
)
|
||||
|
||||
if private_key:
|
||||
public_key = end_entity_cert.public_key()
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101
|
||||
assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101
|
||||
assert (
|
||||
end_entity_cert.signature_hash_algorithm is not None
|
||||
) # nosec: B101
|
||||
try:
|
||||
test_message = b"test"
|
||||
signature = private_key.sign(
|
||||
test_message,
|
||||
padding.PKCS1v15(),
|
||||
end_entity_cert.signature_hash_algorithm,
|
||||
)
|
||||
public_key.verify(
|
||||
signature,
|
||||
test_message,
|
||||
padding.PKCS1v15(),
|
||||
end_entity_cert.signature_hash_algorithm,
|
||||
)
|
||||
except Exception:
|
||||
# validate beginning
|
||||
if validate_end_entity_not_yet_valid(end_entity_cert):
|
||||
errors.append(
|
||||
{
|
||||
"key": "public_key_future",
|
||||
"message": "TLS public key is not yet valid.",
|
||||
}
|
||||
)
|
||||
# chain verification
|
||||
if not skip_chain_verification:
|
||||
validate_certificate_chain(chain)
|
||||
# name verification
|
||||
if not skip_name_verification:
|
||||
san_list = extract_sans(end_entity_cert)
|
||||
for expected_hostname in [hostname, f"*.{hostname}"]:
|
||||
if expected_hostname not in san_list:
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_key_mismatch",
|
||||
"Message": "Private key does not match certificate.",
|
||||
"key": "hostname_not_in_san",
|
||||
"message": f"{expected_hostname} not found in SANs.",
|
||||
}
|
||||
)
|
||||
# check if key is valid
|
||||
if validate_key(tls_private_key_pem):
|
||||
# check if key matches cert
|
||||
if not validate_key_matches_cert(tls_private_key_pem, end_entity_cert):
|
||||
errors.append(
|
||||
{
|
||||
"key": "key_mismatch",
|
||||
"message": "Private key does not match certificate.",
|
||||
}
|
||||
)
|
||||
else:
|
||||
errors.append(
|
||||
{
|
||||
"key": "private_key_not_rsa_or_ec",
|
||||
"message": "Private key must be RSA or Elliptic-Curve.",
|
||||
}
|
||||
)
|
||||
|
||||
if not skip_chain_verification:
|
||||
try:
|
||||
validate_certificate_chain(chain)
|
||||
except ValueError as e:
|
||||
errors.append(
|
||||
{"Error": "certificate_chain_invalid", "Message": str(e)}
|
||||
)
|
||||
except TLSValidationError as e:
|
||||
errors.append(e.as_dict())
|
||||
|
||||
if not skip_name_verification:
|
||||
san_list = extract_sans(end_entity_cert)
|
||||
for expected_hostname in [hostname, f"*.{hostname}"]:
|
||||
if expected_hostname not in san_list:
|
||||
errors.append(
|
||||
{
|
||||
"Error": "hostname_not_in_san",
|
||||
"Message": f"{expected_hostname} not found in SANs.",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
errors.append({"Error": "tls_validation_error", "Message": str(e)})
|
||||
else:
|
||||
errors.append(
|
||||
{
|
||||
"key": "no_valid_certificates",
|
||||
"message": "No valid certificates supplied.",
|
||||
}
|
||||
)
|
||||
|
||||
return chain, san_list, errors
|
||||
|
||||
|
||||
def extract_sans(cert: x509.Certificate) -> List[str]:
|
||||
try:
|
||||
san_extension = cert.extensions.get_extension_for_oid(
|
||||
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
|
||||
)
|
||||
sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
|
||||
return sans
|
||||
except Exception:
|
||||
return []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue