feat: expand onion service api

This commit is contained in:
Iain Learmonth 2024-12-06 13:34:44 +00:00
parent c1b385ed99
commit e5976c4739
11 changed files with 646 additions and 348 deletions

View file

@ -1,7 +1,12 @@
import ssl
from datetime import datetime, timezone
from typing import Optional, Tuple, List, Dict, TYPE_CHECKING
from cryptography import x509
from cryptography.hazmat._oid import ExtensionOID
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
@ -63,3 +68,93 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
):
raise ValueError("Certificate chain does not terminate at a trusted root CA.")
return True
def validate_tls_keys(
tls_private_key_pem: Optional[str],
tls_certificate_pem: Optional[str],
skip_chain_verification: Optional[bool],
skip_name_verification: Optional[bool],
hostname: str
) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]:
errors = []
san_list = []
chain = None
skip_chain_verification = skip_chain_verification or False
skip_name_verification = skip_name_verification or False
try:
private_key = None
if tls_private_key_pem:
private_key = serialization.load_pem_private_key(
tls_private_key_pem.encode("utf-8"),
password=None,
backend=default_backend()
)
if not isinstance(private_key, rsa.RSAPrivateKey):
errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be RSA."})
if tls_certificate_pem:
certificates = list(load_certificates_from_pem(tls_certificate_pem.encode("utf-8")))
if not certificates:
errors.append({"Error": "tls_certificate_invalid", "Message": "No valid certificate found."})
else:
chain = build_certificate_chain(certificates)
end_entity_cert = chain[0]
if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."})
if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."})
if private_key:
public_key = end_entity_cert.public_key()
if TYPE_CHECKING:
assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101
assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101
assert end_entity_cert.signature_hash_algorithm is not None # nosec: B101
try:
test_message = b"test"
signature = private_key.sign(
test_message,
padding.PKCS1v15(),
end_entity_cert.signature_hash_algorithm,
)
public_key.verify(
signature,
test_message,
padding.PKCS1v15(),
end_entity_cert.signature_hash_algorithm,
)
except Exception:
errors.append(
{"Error": "tls_key_mismatch", "Message": "Private key does not match certificate."})
if not skip_chain_verification:
try:
validate_certificate_chain(chain)
except ValueError as e:
errors.append({"Error": "certificate_chain_invalid", "Message": str(e)})
if not skip_name_verification:
san_list = extract_sans(end_entity_cert)
for expected_hostname in [hostname, f"*.{hostname}"]:
if expected_hostname not in san_list:
errors.append(
{"Error": "hostname_not_in_san", "Message": f"{expected_hostname} not found in SANs."})
except Exception as e:
errors.append({"Error": "tls_validation_error", "Message": str(e)})
return chain, san_list, errors
def extract_sans(cert: x509.Certificate) -> List[str]:
try:
san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
return sans
except Exception:
return []