lint: reformat python code with black
This commit is contained in:
parent
331beb01b4
commit
a406a7974b
88 changed files with 2579 additions and 1608 deletions
|
@ -20,8 +20,9 @@ def onion_hostname(onion_public_key: bytes) -> str:
|
|||
return onion.lower()
|
||||
|
||||
|
||||
def decode_onion_keys(onion_private_key_base64: str, onion_public_key_base64: str) -> Tuple[
|
||||
Optional[bytes], Optional[bytes], List[Dict[str, str]]]:
|
||||
def decode_onion_keys(
|
||||
onion_private_key_base64: str, onion_public_key_base64: str
|
||||
) -> Tuple[Optional[bytes], Optional[bytes], List[Dict[str, str]]]:
|
||||
try:
|
||||
onion_private_key = base64.b64decode(onion_private_key_base64)
|
||||
onion_public_key = base64.b64decode(onion_public_key_base64)
|
||||
|
|
|
@ -22,14 +22,20 @@ def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]:
|
|||
return certificates
|
||||
|
||||
|
||||
def build_certificate_chain(certificates: list[x509.Certificate]) -> list[x509.Certificate]:
|
||||
def build_certificate_chain(
|
||||
certificates: list[x509.Certificate],
|
||||
) -> list[x509.Certificate]:
|
||||
if len(certificates) == 1:
|
||||
return certificates
|
||||
chain = []
|
||||
cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates}
|
||||
end_entity = next(
|
||||
(cert for cert in certificates if cert.subject.rfc4514_string() not in cert_map),
|
||||
None
|
||||
(
|
||||
cert
|
||||
for cert in certificates
|
||||
if cert.subject.rfc4514_string() not in cert_map
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not end_entity:
|
||||
raise ValueError("Cannot identify the end-entity certificate.")
|
||||
|
@ -51,7 +57,9 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
|
|||
for i in range(len(chain) - 1):
|
||||
next_public_key = chain[i + 1].public_key()
|
||||
if not (isinstance(next_public_key, RSAPublicKey)):
|
||||
raise ValueError(f"Certificate using unsupported algorithm: {type(next_public_key)}")
|
||||
raise ValueError(
|
||||
f"Certificate using unsupported algorithm: {type(next_public_key)}"
|
||||
)
|
||||
hash_algorithm = chain[i].signature_hash_algorithm
|
||||
if hash_algorithm is None:
|
||||
raise ValueError("Certificate missing hash algorithm")
|
||||
|
@ -59,23 +67,23 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
|
|||
chain[i].signature,
|
||||
chain[i].tbs_certificate_bytes,
|
||||
PKCS1v15(),
|
||||
hash_algorithm
|
||||
hash_algorithm,
|
||||
)
|
||||
|
||||
end_cert = chain[-1]
|
||||
if not any(
|
||||
end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates
|
||||
end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates
|
||||
):
|
||||
raise ValueError("Certificate chain does not terminate at a trusted root CA.")
|
||||
return True
|
||||
|
||||
|
||||
def validate_tls_keys(
|
||||
tls_private_key_pem: Optional[str],
|
||||
tls_certificate_pem: Optional[str],
|
||||
skip_chain_verification: Optional[bool],
|
||||
skip_name_verification: Optional[bool],
|
||||
hostname: str
|
||||
tls_private_key_pem: Optional[str],
|
||||
tls_certificate_pem: Optional[str],
|
||||
skip_chain_verification: Optional[bool],
|
||||
skip_name_verification: Optional[bool],
|
||||
hostname: str,
|
||||
) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]:
|
||||
errors = []
|
||||
san_list = []
|
||||
|
@ -90,31 +98,55 @@ def validate_tls_keys(
|
|||
private_key = serialization.load_pem_private_key(
|
||||
tls_private_key_pem.encode("utf-8"),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
backend=default_backend(),
|
||||
)
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be RSA."})
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_private_key_invalid",
|
||||
"Message": "Private key must be RSA.",
|
||||
}
|
||||
)
|
||||
|
||||
if tls_certificate_pem:
|
||||
certificates = list(load_certificates_from_pem(tls_certificate_pem.encode("utf-8")))
|
||||
certificates = list(
|
||||
load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))
|
||||
)
|
||||
if not certificates:
|
||||
errors.append({"Error": "tls_certificate_invalid", "Message": "No valid certificate found."})
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_certificate_invalid",
|
||||
"Message": "No valid certificate found.",
|
||||
}
|
||||
)
|
||||
else:
|
||||
chain = build_certificate_chain(certificates)
|
||||
end_entity_cert = chain[0]
|
||||
|
||||
if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc):
|
||||
errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."})
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_public_key_expired",
|
||||
"Message": "TLS public key is expired.",
|
||||
}
|
||||
)
|
||||
|
||||
if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc):
|
||||
errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."})
|
||||
errors.append(
|
||||
{
|
||||
"Error": "tls_public_key_future",
|
||||
"Message": "TLS public key is not yet valid.",
|
||||
}
|
||||
)
|
||||
|
||||
if private_key:
|
||||
public_key = end_entity_cert.public_key()
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101
|
||||
assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101
|
||||
assert end_entity_cert.signature_hash_algorithm is not None # nosec: B101
|
||||
assert (
|
||||
end_entity_cert.signature_hash_algorithm is not None
|
||||
) # nosec: B101
|
||||
try:
|
||||
test_message = b"test"
|
||||
signature = private_key.sign(
|
||||
|
@ -130,20 +162,30 @@ def validate_tls_keys(
|
|||
)
|
||||
except Exception:
|
||||
errors.append(
|
||||
{"Error": "tls_key_mismatch", "Message": "Private key does not match certificate."})
|
||||
{
|
||||
"Error": "tls_key_mismatch",
|
||||
"Message": "Private key does not match certificate.",
|
||||
}
|
||||
)
|
||||
|
||||
if not skip_chain_verification:
|
||||
try:
|
||||
validate_certificate_chain(chain)
|
||||
except ValueError as e:
|
||||
errors.append({"Error": "certificate_chain_invalid", "Message": str(e)})
|
||||
errors.append(
|
||||
{"Error": "certificate_chain_invalid", "Message": str(e)}
|
||||
)
|
||||
|
||||
if not skip_name_verification:
|
||||
san_list = extract_sans(end_entity_cert)
|
||||
for expected_hostname in [hostname, f"*.{hostname}"]:
|
||||
if expected_hostname not in san_list:
|
||||
errors.append(
|
||||
{"Error": "hostname_not_in_san", "Message": f"{expected_hostname} not found in SANs."})
|
||||
{
|
||||
"Error": "hostname_not_in_san",
|
||||
"Message": f"{expected_hostname} not found in SANs.",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
errors.append({"Error": "tls_validation_error", "Message": str(e)})
|
||||
|
@ -153,7 +195,9 @@ def validate_tls_keys(
|
|||
|
||||
def extract_sans(cert: x509.Certificate) -> List[str]:
|
||||
try:
|
||||
san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
|
||||
san_extension = cert.extensions.get_extension_for_oid(
|
||||
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
|
||||
)
|
||||
sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
|
||||
return sans
|
||||
except Exception:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue