majuna/app/api/web.py

275 lines
11 KiB
Python

import base64
import logging
from datetime import datetime, timezone, timedelta
from typing import List, TypedDict, Optional
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey, RSAPrivateKey
from cryptography.x509.oid import ExtensionOID
from flask import Blueprint, request, jsonify, abort
from flask.typing import ResponseReturnValue
from sqlalchemy import exc
from tldextract import tldextract
from app.api.util import ListFilter, MAX_DOMAIN_NAME_LENGTH, DOMAIN_NAME_REGEX, list_resources, MAX_ALLOWED_ITEMS
from app.util.x509 import build_certificate_chain, validate_certificate_chain, load_certificates_from_pem
from app.extensions import db
from app.models.base import Group
from app.models.mirrors import Origin, Proxy
from app.models.onions import Onion
from app.util.onion import onion_hostname
api_web = Blueprint('web', __name__)
@api_web.route('/group', methods=['GET'])
def list_groups() -> ResponseReturnValue:
return list_resources(
Group,
lambda group: group.to_dict(),
resource_name='OriginGroupList',
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
)
@api_web.route('/origin', methods=['GET'])
def list_origins() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
filters: List[ListFilter] = []
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.")
if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.")
filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%"))
if group_id_filter:
try:
filters.append(Origin.group_id == int(group_id_filter))
except ValueError:
abort(400, description="GroupId must be a valid integer.")
return list_resources(
Origin,
lambda origin: origin.to_dict(),
filters=filters,
resource_name='OriginsList',
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
)
@api_web.route('/mirror', methods=['GET'])
def list_mirrors() -> ResponseReturnValue:
filters = []
twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24)
status_filter = request.args.get('Status')
if status_filter:
if status_filter == "pending":
filters.append(Proxy.url.is_(None))
filters.append(Proxy.deprecated.is_(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "active":
filters.append(Proxy.url.is_not(None))
filters.append(Proxy.deprecated.is_(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "expiring":
filters.append(Proxy.deprecated.is_not(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "destroyed":
filters.append(Proxy.destroyed > twenty_four_hours_ago)
else:
filters.append((Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago))
return list_resources(
Proxy,
lambda proxy: proxy.to_dict(),
filters=filters,
resource_name='MirrorsList',
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
)
@api_web.route('/onion', methods=['GET'])
def list_onions() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
filters: List[ListFilter] = [
(Onion.destroyed.is_(None))
]
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.")
if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.")
filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%"))
if group_id_filter:
try:
filters.append(Onion.group_id == int(group_id_filter))
except ValueError:
abort(400, description="GroupId must be a valid integer.")
return list_resources(
Onion,
lambda onion: onion.to_dict(),
filters=filters,
resource_name='OnionsList',
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
)
class CreateOnionRequest(TypedDict):
DomainName: str
Description: str
GroupId: int
OnionPrivateKey: str
OnionPublicKey: str
TlsPrivateKey: str
TlsCertificate: str
SkipChainValidation: Optional[bool]
SkipNameValidation: Optional[bool]
@api_web.route("/onion", methods=["POST"])
def create_onion() -> ResponseReturnValue:
data: Optional[CreateOnionRequest] = request.json
if not data:
abort(400)
errors = []
required_fields = ["DomainName", "Description", "OnionPrivateKey", "OnionPublicKey", "GroupId", "TlsPrivateKey", "TlsCertificate"]
for field in required_fields:
if field not in data or not data[field]: # type: ignore[literal-required]
errors.append({"Error": f"{field}_missing", "Message": f"Missing required field: {field}"})
if "GroupId" in data:
group = Group.query.get(data["GroupId"])
if not group:
errors.append({"Error": "group_id_not_found", "Message": "Specified group ID does not exist."})
try:
onion_private_key = base64.b64decode(data["OnionPrivateKey"])
onion_public_key = base64.b64decode(data["OnionPublicKey"])
except (KeyError, ValueError, TypeError):
errors.append({"Error": "onion_key_invalid", "Message": "Onion keys must be valid Base64-encoded data."})
if errors:
return jsonify({"Errors": errors}), 400
skip_chain_verification = data.get("SkipChainVerification", False)
skip_name_verification = data.get("SkipNameVerification", False)
try:
private_key = serialization.load_pem_private_key(
data["TlsPrivateKey"].encode("utf-8"),
password=None,
backend=default_backend()
)
certificates = list(load_certificates_from_pem(data["TlsCertificate"].encode("utf-8")))
if not certificates:
errors.append(
{"Error": "tls_public_key_invalid", "Message": "TLS public key must contain at least one certificate."})
return jsonify({"Errors": errors}), 400
chain = build_certificate_chain(certificates)
end_entity_cert = chain[0]
test_message = b"test"
if not isinstance(private_key, RSAPrivateKey):
errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be an RSA private key."})
return jsonify({"Errors": errors}), 400
hash_algorithm = end_entity_cert.signature_hash_algorithm
if hash_algorithm is None:
errors.append({"Error": "tls_public_key_invalid", "Message": "Public key using unsupported algorithm."})
return jsonify({"Errors": errors}), 400
signature = private_key.sign(
test_message,
PKCS1v15(),
hash_algorithm
)
end_entity_public_key = end_entity_cert.public_key()
if isinstance(end_entity_public_key, RSAPublicKey):
end_entity_public_key.verify(
signature,
test_message,
PKCS1v15(),
hash_algorithm
)
else:
errors.append({"Error": "tls_public_key_invalid", "Message": "Public key using unsupported algorithm."})
if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."})
if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."})
try:
san_extension = end_entity_cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
san_list = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
except x509.ExtensionNotFound:
san_list = []
errors.append({"Error": "san_not_found",
"Message": "No Subject Alternative Names (SANs) found in the TLS public key."})
if not skip_chain_verification:
try:
validate_certificate_chain(chain)
except ValueError as e:
errors.append({"Error": "certificate_chain_invalid", "Message": str(e)})
if not skip_name_verification:
if "DomainName" in data:
registered_domain = tldextract.extract(data["DomainName"]).registered_domain
if data["DomainName"] != registered_domain:
errors.append({"Error": "domain_name_not_registered_domain", "Message": "The domain name is invalid, or is not the top-level domain (i.e. a subdomain was entered)."})
if "OnionPublicKey" in data:
hostname = f"{onion_hostname(onion_public_key)}.onion"
for name in [hostname, f"*.{hostname}"]:
if name not in san_list:
errors.append({"Error": "certificate_san_missing", "Message": f"{name} missing from certificate SAN list."})
if errors:
return jsonify({"Errors": errors}), 400
except Exception as e:
errors.append({"Error": "tls_validation_error", "Message": f"TLS key/certificate validation failed: {str(e)}"})
return jsonify({"Errors": errors}), 400
onion = Onion(
domain_name=data["DomainName"].strip(),
description=data["Description"],
onion_private_key=onion_private_key,
onion_public_key=onion_public_key,
tls_private_key=data["TlsPrivateKey"].encode("utf-8"),
tls_public_key=data["TlsCertificate"].encode("utf-8"),
group_id=data["GroupId"],
added=datetime.now(timezone.utc),
updated=datetime.now(timezone.utc),
# cert_expiry_date=end_entity_cert.not_valid_after_utc, TODO: extend schema to accommodate these fields
# cert_sans=",".join(san_list),
)
try:
db.session.add(onion)
db.session.commit()
return jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), 201
except exc.SQLAlchemyError as e:
logging.exception(e)
errors.append({"Error": "database_error", "Message": "Failed to create onion service."})
return jsonify({"Errors": errors}), 500