From e5976c473980557c9109883e022cdb7b76076222 Mon Sep 17 00:00:00 2001 From: irl Date: Fri, 6 Dec 2024 13:34:44 +0000 Subject: [PATCH] feat: expand onion service api --- app/api/__init__.py | 2 + app/api/onion.py | 206 +++++++++++++++++++++++++++ app/api/util.py | 20 +++ app/api/web.py | 196 +------------------------- app/models/onions.py | 33 +++-- app/portal/onion.py | 53 +------ app/util/onion.py | 11 ++ app/util/x509.py | 95 +++++++++++++ scripts/generate_test_onion_tls.py | 99 +------------ tests/api/conftest.py | 64 +++++++++ tests/api/test_onion.py | 215 +++++++++++++++++++++++++++++ 11 files changed, 646 insertions(+), 348 deletions(-) create mode 100644 app/api/onion.py create mode 100644 tests/api/conftest.py create mode 100644 tests/api/test_onion.py diff --git a/app/api/__init__.py b/app/api/__init__.py index 8c579e9..8ae0071 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -2,9 +2,11 @@ from flask import Blueprint, jsonify from flask.typing import ResponseReturnValue from werkzeug.exceptions import HTTPException +from app.api.onion import api_onion from app.api.web import api_web api = Blueprint('api', __name__) +api.register_blueprint(api_onion, url_prefix='/onion') api.register_blueprint(api_web, url_prefix='/web') diff --git a/app/api/onion.py b/app/api/onion.py new file mode 100644 index 0000000..0092821 --- /dev/null +++ b/app/api/onion.py @@ -0,0 +1,206 @@ +import sys +from datetime import datetime, timezone +from typing import List, TypedDict, NotRequired, Optional + +from cryptography import x509 +from flask import request, abort, jsonify, Blueprint +from flask.typing import ResponseReturnValue +from sqlalchemy import exc + +from app.extensions import db +from app.api.util import ListFilter, MAX_DOMAIN_NAME_LENGTH, DOMAIN_NAME_REGEX, list_resources, MAX_ALLOWED_ITEMS, \ + validate_description, get_single_resource +from app.models.base import Group +from app.models.onions import Onion +from app.util.onion import onion_hostname, decode_onion_keys +from app.util.x509 import validate_tls_keys + +api_onion = Blueprint('api_onion', __name__) + + +@api_onion.route('/onion', methods=['GET']) +def list_onions() -> ResponseReturnValue: + domain_name_filter = request.args.get('DomainName') + group_id_filter = request.args.get('GroupId') + + filters: List[ListFilter] = [ + (Onion.destroyed.is_(None)) + ] + + if domain_name_filter: + if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: + abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") + if not DOMAIN_NAME_REGEX.match(domain_name_filter): + abort(400, description="DomainName contains invalid characters.") + filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%")) + + if group_id_filter: + try: + filters.append(Onion.group_id == int(group_id_filter)) + except ValueError: + abort(400, description="GroupId must be a valid integer.") + + return list_resources( + Onion, + lambda onion: onion.to_dict(), + filters=filters, + resource_name='OnionsList', + max_allowed_items=MAX_ALLOWED_ITEMS, + protective_marking='amber', + ) + + +class CreateOnionRequest(TypedDict): + DomainName: str + Description: str + GroupId: int + OnionPrivateKey: str + OnionPublicKey: str + TlsPrivateKey: str + TlsCertificate: str + SkipChainVerification: NotRequired[bool] + SkipNameVerification: NotRequired[bool] + + +@api_onion.route("/onion", methods=["POST"]) +def create_onion() -> ResponseReturnValue: + data: Optional[CreateOnionRequest] = request.json + if not data: + abort(400) + + errors = [] + for field in ["DomainName", "Description", "OnionPrivateKey", "OnionPublicKey", "GroupId", "TlsPrivateKey", + "TlsCertificate"]: + if not data.get(field): + errors.append({"Error": f"{field}_missing", "Message": f"Missing required field: {field}"}) + + onion_private_key, onion_public_key, onion_errors = decode_onion_keys(data["OnionPrivateKey"], + data["OnionPublicKey"]) + if onion_errors: + errors.extend(onion_errors) + + if onion_public_key is None: + return jsonify({"Errors": errors}), 400 + + if onion_private_key: + existing_onion = db.session.query(Onion).where( + Onion.onion_private_key == onion_private_key, + Onion.destroyed.is_(None), + ).first() + if existing_onion: + errors.append( + {"Error": "duplicate_onion_key", "Message": "An onion service with this private key already exists."}) + + if "GroupId" in data: + group = Group.query.get(data["GroupId"]) + if not group: + errors.append({"Error": "group_id_not_found", "Message": "Invalid group ID."}) + + chain, san_list, tls_errors = validate_tls_keys( + data["TlsPrivateKey"], data["TlsCertificate"], data.get("SkipChainVerification"), + data.get("SkipNameVerification"), + f"{onion_hostname(onion_public_key)}.onion" + ) + + if tls_errors: + errors.extend(tls_errors) + + if errors: + return jsonify({"Errors": errors}), 400 + + cert_expiry_date = chain[0].not_valid_after if chain else None + + onion = Onion( + domain_name=data["DomainName"], + description=data["Description"], + onion_private_key=onion_private_key, + onion_public_key=onion_public_key, + tls_private_key=data["TlsPrivateKey"].encode("utf-8"), + tls_public_key=data["TlsCertificate"].encode("utf-8"), + group_id=data["GroupId"], + added=datetime.now(timezone.utc), + updated=datetime.now(timezone.utc), + cert_expiry=cert_expiry_date, + cert_sans=",".join(san_list) + ) + + try: + db.session.add(onion) + db.session.commit() + return jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), 201 + except exc.SQLAlchemyError as e: + return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500 + + +class UpdateOnionRequest(TypedDict): + Description: NotRequired[str] + TlsPrivateKey: NotRequired[str] + TlsCertificate: NotRequired[str] + SkipChainVerification: NotRequired[bool] + SkipNameVerification: NotRequired[bool] + + +@api_onion.route("/onion/", methods=["PUT"]) +def update_onion(onion_id: int) -> ResponseReturnValue: + data: Optional[UpdateOnionRequest] = request.json + if not data: + abort(400) + + errors = [] + + onion = Onion.query.get(onion_id) + if not onion: + return jsonify( + {"Errors": [{"Error": "onion_not_found", "Message": f"No Onion service found with ID {onion_id}"}]}), 404 + + if "Description" in data: + description = data["Description"] + print(f"Description {description}", file=sys.stderr) + if validate_description(description): + onion.description = description + else: + errors.append({"Error": "description_error", "Message": "Description field is invalid"}) + + tls_private_key_pem: Optional[str] = None + tls_certificate_pem: Optional[str] = None + chain: Optional[List[x509.Certificate]] = None + san_list: Optional[List[str]] = None + + if "TlsCertificate" in data: + tls_certificate_pem = data.get("TlsCertificate") + if "TlsPrivateKey" in data: + tls_private_key_pem = data.get("TlsPrivateKey") + else: + tls_private_key_pem = onion.tls_private_key.decode("utf-8") + + chain, san_list, tls_errors = validate_tls_keys( + tls_private_key_pem, tls_certificate_pem, data.get("SkipChainVerification", False), + data.get("SkipNameVerification", False), + f"{onion_hostname(onion.onion_public_key)}.onion", + ) + if tls_errors: + errors.extend(tls_errors) + + if errors: + return jsonify({"Errors": errors}), 400 + + if tls_private_key_pem: + onion.tls_private_key = tls_private_key_pem.encode("utf-8") + + if tls_certificate_pem and san_list: + onion.tls_public_key = tls_certificate_pem.encode("utf-8") + onion.cert_expiry_date = chain[0].not_valid_after_utc if chain else None + onion.cert_sans = ",".join(san_list) + + onion.updated = datetime.now(timezone.utc) + + try: + db.session.commit() + return jsonify({"Message": "Onion service updated successfully."}), 200 + except exc.SQLAlchemyError as e: + return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500 + + +@api_onion.route("/onion/", methods=["GET"]) +def get_onion(onion_id: int) -> ResponseReturnValue: + return get_single_resource(Onion, onion_id, "Onion") diff --git a/app/api/util.py b/app/api/util.py index 1069310..2d659d9 100644 --- a/app/api/util.py +++ b/app/api/util.py @@ -99,3 +99,23 @@ def list_resources( # pylint: disable=too-many-arguments,too-many-locals except Exception: # pylint: disable=broad-exception-caught logger.exception("An unexpected error occurred") abort(500) + + +def get_single_resource(model: Type[Any], id_: int, resource_name: str) -> ResponseReturnValue: + try: + resource = db.session.get(model, id_) + if not resource: + return jsonify({ + "Error": "resource_not_found", + "Message": f"No {resource_name} found with ID {id_}" + }), 404 + return jsonify({resource_name: resource.to_dict()}), 200 + except Exception: # pylint: disable=broad-exception-caught + logger.exception("An unexpected error occurred while retrieving the onion") + abort(500) + + +def validate_description(description: Optional[str]) -> bool: + if description is None: + return False + return True diff --git a/app/api/web.py b/app/api/web.py index 74f55d6..adc48a1 100644 --- a/app/api/web.py +++ b/app/api/web.py @@ -1,26 +1,12 @@ -import base64 -import logging -from datetime import datetime, timezone, timedelta -from typing import List, TypedDict, Optional +from datetime import datetime, timedelta, timezone +from typing import List -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey, RSAPrivateKey -from cryptography.x509.oid import ExtensionOID -from flask import Blueprint, request, jsonify, abort +from flask import Blueprint, request, abort from flask.typing import ResponseReturnValue -from sqlalchemy import exc -from tldextract import tldextract from app.api.util import ListFilter, MAX_DOMAIN_NAME_LENGTH, DOMAIN_NAME_REGEX, list_resources, MAX_ALLOWED_ITEMS -from app.util.x509 import build_certificate_chain, validate_certificate_chain, load_certificates_from_pem -from app.extensions import db from app.models.base import Group from app.models.mirrors import Origin, Proxy -from app.models.onions import Onion -from app.util.onion import onion_hostname api_web = Blueprint('web', __name__) @@ -97,179 +83,3 @@ def list_mirrors() -> ResponseReturnValue: max_allowed_items=MAX_ALLOWED_ITEMS, protective_marking='amber', ) - - -@api_web.route('/onion', methods=['GET']) -def list_onions() -> ResponseReturnValue: - domain_name_filter = request.args.get('DomainName') - group_id_filter = request.args.get('GroupId') - - filters: List[ListFilter] = [ - (Onion.destroyed.is_(None)) - ] - - if domain_name_filter: - if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: - abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") - if not DOMAIN_NAME_REGEX.match(domain_name_filter): - abort(400, description="DomainName contains invalid characters.") - filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%")) - - if group_id_filter: - try: - filters.append(Onion.group_id == int(group_id_filter)) - except ValueError: - abort(400, description="GroupId must be a valid integer.") - - return list_resources( - Onion, - lambda onion: onion.to_dict(), - filters=filters, - resource_name='OnionsList', - max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', - ) - - -class CreateOnionRequest(TypedDict): - DomainName: str - Description: str - GroupId: int - OnionPrivateKey: str - OnionPublicKey: str - TlsPrivateKey: str - TlsCertificate: str - SkipChainValidation: Optional[bool] - SkipNameValidation: Optional[bool] - - -@api_web.route("/onion", methods=["POST"]) -def create_onion() -> ResponseReturnValue: - data: Optional[CreateOnionRequest] = request.json - if not data: - abort(400) - - errors = [] - - required_fields = ["DomainName", "Description", "OnionPrivateKey", "OnionPublicKey", "GroupId", "TlsPrivateKey", "TlsCertificate"] - for field in required_fields: - if field not in data or not data[field]: # type: ignore[literal-required] - errors.append({"Error": f"{field}_missing", "Message": f"Missing required field: {field}"}) - - if "GroupId" in data: - group = Group.query.get(data["GroupId"]) - if not group: - errors.append({"Error": "group_id_not_found", "Message": "Specified group ID does not exist."}) - - try: - onion_private_key = base64.b64decode(data["OnionPrivateKey"]) - onion_public_key = base64.b64decode(data["OnionPublicKey"]) - except (KeyError, ValueError, TypeError): - errors.append({"Error": "onion_key_invalid", "Message": "Onion keys must be valid Base64-encoded data."}) - - if errors: - return jsonify({"Errors": errors}), 400 - - skip_chain_verification = data.get("SkipChainVerification", False) - skip_name_verification = data.get("SkipNameVerification", False) - - try: - private_key = serialization.load_pem_private_key( - data["TlsPrivateKey"].encode("utf-8"), - password=None, - backend=default_backend() - ) - certificates = list(load_certificates_from_pem(data["TlsCertificate"].encode("utf-8"))) - if not certificates: - errors.append( - {"Error": "tls_public_key_invalid", "Message": "TLS public key must contain at least one certificate."}) - return jsonify({"Errors": errors}), 400 - - chain = build_certificate_chain(certificates) - end_entity_cert = chain[0] - - test_message = b"test" - if not isinstance(private_key, RSAPrivateKey): - errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be an RSA private key."}) - return jsonify({"Errors": errors}), 400 - hash_algorithm = end_entity_cert.signature_hash_algorithm - if hash_algorithm is None: - errors.append({"Error": "tls_public_key_invalid", "Message": "Public key using unsupported algorithm."}) - return jsonify({"Errors": errors}), 400 - signature = private_key.sign( - test_message, - PKCS1v15(), - hash_algorithm - ) - end_entity_public_key = end_entity_cert.public_key() - if isinstance(end_entity_public_key, RSAPublicKey): - end_entity_public_key.verify( - signature, - test_message, - PKCS1v15(), - hash_algorithm - ) - else: - errors.append({"Error": "tls_public_key_invalid", "Message": "Public key using unsupported algorithm."}) - - if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc): - errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."}) - - if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc): - errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."}) - - try: - san_extension = end_entity_cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) - san_list = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined] - except x509.ExtensionNotFound: - san_list = [] - errors.append({"Error": "san_not_found", - "Message": "No Subject Alternative Names (SANs) found in the TLS public key."}) - - if not skip_chain_verification: - try: - validate_certificate_chain(chain) - except ValueError as e: - errors.append({"Error": "certificate_chain_invalid", "Message": str(e)}) - - if not skip_name_verification: - if "DomainName" in data: - registered_domain = tldextract.extract(data["DomainName"]).registered_domain - if data["DomainName"] != registered_domain: - errors.append({"Error": "domain_name_not_registered_domain", "Message": "The domain name is invalid, or is not the top-level domain (i.e. a subdomain was entered)."}) - - if "OnionPublicKey" in data: - hostname = f"{onion_hostname(onion_public_key)}.onion" - for name in [hostname, f"*.{hostname}"]: - if name not in san_list: - errors.append({"Error": "certificate_san_missing", "Message": f"{name} missing from certificate SAN list."}) - - if errors: - return jsonify({"Errors": errors}), 400 - - except Exception as e: - errors.append({"Error": "tls_validation_error", "Message": f"TLS key/certificate validation failed: {str(e)}"}) - return jsonify({"Errors": errors}), 400 - - onion = Onion( - domain_name=data["DomainName"].strip(), - description=data["Description"], - onion_private_key=onion_private_key, - onion_public_key=onion_public_key, - tls_private_key=data["TlsPrivateKey"].encode("utf-8"), - tls_public_key=data["TlsCertificate"].encode("utf-8"), - group_id=data["GroupId"], - added=datetime.now(timezone.utc), - updated=datetime.now(timezone.utc), - # cert_expiry_date=end_entity_cert.not_valid_after_utc, TODO: extend schema to accommodate these fields - # cert_sans=",".join(san_list), - ) - - try: - db.session.add(onion) - db.session.commit() - return jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), 201 - except exc.SQLAlchemyError as e: - logging.exception(e) - errors.append({"Error": "database_error", "Message": "Failed to create onion service."}) - return jsonify({"Errors": errors}), 500 diff --git a/app/models/onions.py b/app/models/onions.py index 8287fbb..897f6f2 100644 --- a/app/models/onions.py +++ b/app/models/onions.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Optional, TypedDict from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -10,9 +11,16 @@ from app.util.onion import onion_hostname class OnionDict(TypedDict): - Id: int + Added: str + CertExpiry: str + CertSans: str + Description: str DomainName: str + GroupId: int + GroupName: str + Id: int OnionName: str + Updated: str class Onion(AbstractConfiguration): @@ -26,14 +34,14 @@ class Onion(AbstractConfiguration): resource_id=self.onion_name ) - group_id = db.Column(db.Integer(), db.ForeignKey("group.id"), nullable=False) - domain_name = db.Column(db.String(255), nullable=False) - - onion_public_key = db.Column(db.LargeBinary, nullable=False) - onion_private_key = db.Column(db.LargeBinary, nullable=False) - - tls_public_key = db.Column(db.LargeBinary, nullable=False) - tls_private_key = db.Column(db.LargeBinary, nullable=False) + group_id: Mapped[int] = mapped_column(db.ForeignKey("group.id")) + domain_name: Mapped[str] + cert_expiry: Mapped[datetime] = mapped_column(db.DateTime(timezone=True)) + cert_sans: Mapped[str] + onion_public_key: Mapped[bytes] + onion_private_key: Mapped[bytes] + tls_public_key: Mapped[bytes] + tls_private_key: Mapped[bytes] group = db.relationship("Group", back_populates="onions") @@ -43,9 +51,16 @@ class Onion(AbstractConfiguration): def to_dict(self) -> OnionDict: return { + "Added": self.added.isoformat(), "Id": self.id, + "CertExpiry": self.cert_expiry.isoformat(), + "CertSans": self.cert_sans, + "Description": self.description, "DomainName": self.domain_name, + "GroupId": self.group_id, + "GroupName": self.group.group_name, "OnionName": self.onion_name, + "Updated": self.updated.isoformat(), } diff --git a/app/portal/onion.py b/app/portal/onion.py index f80055f..e2b4271 100644 --- a/app/portal/onion.py +++ b/app/portal/onion.py @@ -1,33 +1,14 @@ -from datetime import datetime from typing import Optional -from flask import flash, redirect, render_template, Response, Blueprint +from flask import redirect, Blueprint from flask.typing import ResponseReturnValue -from flask_wtf import FlaskForm -from sqlalchemy import exc -from wtforms import StringField, SelectField, SubmitField -from flask_wtf.file import FileField -from wtforms.validators import DataRequired -from app.extensions import db -from app.models.base import Group from app.models.onions import Onion from app.portal.util import response_404, view_lifecycle bp = Blueprint("onion", __name__) -class EditOnionForm(FlaskForm): # type: ignore - domain_name = StringField('Domain Name', validators=[DataRequired()]) - description = StringField('Description', validators=[DataRequired()]) - tls_private_key = FileField('TLS Private Key (PEM format)', - description="If no file is submitted, the TLS key will remain unchanged.") - tls_public_key = FileField('TLS Certificate (PEM format)', - description="If no file is submitted, the TLS certificate will remain unchanged.") - group = SelectField('Group', validators=[DataRequired()]) - submit = SubmitField('Save Changes') - - @bp.route("/new", methods=['GET', 'POST']) @bp.route("/new/", methods=['GET', 'POST']) def onion_new(group_id: Optional[int] = None) -> ResponseReturnValue: @@ -36,37 +17,7 @@ def onion_new(group_id: Optional[int] = None) -> ResponseReturnValue: @bp.route('/edit/', methods=['GET', 'POST']) def onion_edit(onion_id: int) -> ResponseReturnValue: - onion: Optional[Onion] = Onion.query.filter(Onion.id == onion_id).first() - if onion is None: - return Response(render_template("error.html.j2", - section="onion", - header="404 Onion Not Found", - message="The requested onion service could not be found."), - status=404) - form = EditOnionForm(group=onion.group_id, - domain_name=onion.domain_name, - description=onion.description) - form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] - if form.validate_on_submit(): - onion.group_id = form.group.data - onion.description = form.description.data - onion.domain_name = form.domain_name.data - for at in [ - "tls_private_key", - "tls_public_key" - ]: - if getattr(form, at).data is not None: - # Don't clear the key if no key is uploaded - setattr(onion, at, getattr(form, at).data.read()) - onion.updated = datetime.utcnow() - try: - db.session.commit() - flash("Saved changes to group.", "success") - except exc.SQLAlchemyError: - flash("An error occurred saving the changes to the group.", "danger") - return render_template("onion.html.j2", - section="onion", - onion=onion, form=form) + return redirect("/ui/web/onions/edit/{}".format(onion_id)) @bp.route("/list") diff --git a/app/util/onion.py b/app/util/onion.py index a7876d2..527972d 100644 --- a/app/util/onion.py +++ b/app/util/onion.py @@ -1,5 +1,6 @@ import base64 import hashlib +from typing import Tuple, Optional, List, Dict def onion_hostname(onion_public_key: bytes) -> str: @@ -17,3 +18,13 @@ def onion_hostname(onion_public_key: bytes) -> str: onion = base64.b32encode(result).decode("utf-8").strip("=") return onion.lower() + + +def decode_onion_keys(onion_private_key_base64: str, onion_public_key_base64: str) -> Tuple[ + Optional[bytes], Optional[bytes], List[Dict[str, str]]]: + try: + onion_private_key = base64.b64decode(onion_private_key_base64) + onion_public_key = base64.b64decode(onion_public_key_base64) + return onion_private_key, onion_public_key, [] + except ValueError as e: + return None, None, [{"Error": "invalid_onion_key", "Message": str(e)}] diff --git a/app/util/x509.py b/app/util/x509.py index 6422464..0813a50 100644 --- a/app/util/x509.py +++ b/app/util/x509.py @@ -1,7 +1,12 @@ import ssl +from datetime import datetime, timezone +from typing import Optional, Tuple, List, Dict, TYPE_CHECKING from cryptography import x509 +from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey @@ -63,3 +68,93 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool: ): raise ValueError("Certificate chain does not terminate at a trusted root CA.") return True + + +def validate_tls_keys( + tls_private_key_pem: Optional[str], + tls_certificate_pem: Optional[str], + skip_chain_verification: Optional[bool], + skip_name_verification: Optional[bool], + hostname: str +) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]: + errors = [] + san_list = [] + chain = None + + skip_chain_verification = skip_chain_verification or False + skip_name_verification = skip_name_verification or False + + try: + private_key = None + if tls_private_key_pem: + private_key = serialization.load_pem_private_key( + tls_private_key_pem.encode("utf-8"), + password=None, + backend=default_backend() + ) + if not isinstance(private_key, rsa.RSAPrivateKey): + errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be RSA."}) + + if tls_certificate_pem: + certificates = list(load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))) + if not certificates: + errors.append({"Error": "tls_certificate_invalid", "Message": "No valid certificate found."}) + else: + chain = build_certificate_chain(certificates) + end_entity_cert = chain[0] + + if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc): + errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."}) + + if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc): + errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."}) + + if private_key: + public_key = end_entity_cert.public_key() + if TYPE_CHECKING: + assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101 + assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101 + assert end_entity_cert.signature_hash_algorithm is not None # nosec: B101 + try: + test_message = b"test" + signature = private_key.sign( + test_message, + padding.PKCS1v15(), + end_entity_cert.signature_hash_algorithm, + ) + public_key.verify( + signature, + test_message, + padding.PKCS1v15(), + end_entity_cert.signature_hash_algorithm, + ) + except Exception: + errors.append( + {"Error": "tls_key_mismatch", "Message": "Private key does not match certificate."}) + + if not skip_chain_verification: + try: + validate_certificate_chain(chain) + except ValueError as e: + errors.append({"Error": "certificate_chain_invalid", "Message": str(e)}) + + if not skip_name_verification: + san_list = extract_sans(end_entity_cert) + for expected_hostname in [hostname, f"*.{hostname}"]: + if expected_hostname not in san_list: + errors.append( + {"Error": "hostname_not_in_san", "Message": f"{expected_hostname} not found in SANs."}) + + except Exception as e: + errors.append({"Error": "tls_validation_error", "Message": str(e)}) + + return chain, san_list, errors + + +def extract_sans(cert: x509.Certificate) -> List[str]: + try: + san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined] + return sans + except Exception: + return [] diff --git a/scripts/generate_test_onion_tls.py b/scripts/generate_test_onion_tls.py index 54608ef..ea965b0 100644 --- a/scripts/generate_test_onion_tls.py +++ b/scripts/generate_test_onion_tls.py @@ -11,101 +11,10 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend from datetime import datetime, timedelta, timezone - -def generate_onion_keys_with_mkp224o(folder_name: str, label: str): - """ - Generate Tor-compatible Onion service keys using mkp224o. - The keys are saved in the specified folder, and the Onion address is returned. - """ - os.makedirs(folder_name, exist_ok=True) - - # Call mkp224o to generate a single Onion service key - process = subprocess.run( - ["mkp224o", "-n", "1", "-d", folder_name, label], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE - ) - try: - process.check_returncode() - except subprocess.CalledProcessError: - print("STDOUT:", process.stdout.decode()) - print("STDERR:", process.stderr.decode()) - raise - - # Find the generated Onion address - for filename in os.listdir(folder_name): - if filename.endswith(".onion"): - onion_address = filename - onion_dir = os.path.join(folder_name, filename) - - # Move files to parent directory - for key_file in ["hs_ed25519_secret_key", "hs_ed25519_public_key", "hostname"]: - src = os.path.join(onion_dir, key_file) - dst = os.path.join(folder_name, key_file) - if os.path.exists(src): - shutil.move(src, dst) - - # Remove the now-empty directory - os.rmdir(onion_dir) - - return onion_address - - raise RuntimeError("Failed to generate Onion keys using mkp224o") +from tests.api.test_onion import generate_onion_keys_with_mkp224o, generate_self_signed_tls_certificate -def generate_self_signed_tls_certificate(folder_name: str, onion_address: str, valid_from: datetime, valid_to: datetime, dns_names=None): - """ - Generate a self-signed TLS certificate for the Onion address and save it in the specified folder. - """ - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend(), - ) - subject = x509.Name( - [ - x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), - x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test State"), - x509.NameAttribute(NameOID.LOCALITY_NAME, "Test City"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Org"), - x509.NameAttribute(NameOID.COMMON_NAME, onion_address), - ] - ) - if dns_names is None: - dns_names = [onion_address, f"*.{onion_address}"] - - san_extension = x509.SubjectAlternativeName([x509.DNSName(name) for name in dns_names]) - - certificate = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(subject) - .public_key(private_key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(valid_from) - .not_valid_after(valid_to) - .add_extension(san_extension, critical=False) - .sign(private_key, SHA256(), default_backend()) - ) - - private_key_path = os.path.join(folder_name, "tls_private_key.pem") - certificate_path = os.path.join(folder_name, "tls_certificate.pem") - - with open(private_key_path, "wb") as f: - f.write( - private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - ) - with open(certificate_path, "wb") as f: - f.write(certificate.public_bytes(serialization.Encoding.PEM)) - - return private_key_path, certificate_path - - -def generate_rest_payload(parent_folder: str, folder_name: str, onion_address: str): +def generate_create_rest_payload(parent_folder: str, folder_name: str): """ Generate REST payload for a specific Onion service and append it to a shared .rest file. """ @@ -121,7 +30,7 @@ def generate_rest_payload(parent_folder: str, folder_name: str, onion_address: s tls_public_key = f.read() payload = { - "DomainName": onion_address, + "DomainName": "example.com", "Description": f"Generated Onion Service for {folder_name}", "OnionPrivateKey": onion_private_key, "OnionPublicKey": onion_public_key, @@ -155,6 +64,6 @@ if __name__ == "__main__": print(f"Generating {folder_name}...") onion_address = generate_onion_keys_with_mkp224o(folder_name, "test") generate_self_signed_tls_certificate(folder_name, onion_address, valid_from, valid_to, dns_names) - generate_rest_payload(parent_folder, folder_name, onion_address) + generate_create_rest_payload(parent_folder, folder_name) print("All Onion services and REST requests generated successfully.") diff --git a/tests/api/conftest.py b/tests/api/conftest.py new file mode 100644 index 0000000..2a9d58f --- /dev/null +++ b/tests/api/conftest.py @@ -0,0 +1,64 @@ +import os +import shutil +import tempfile +from datetime import datetime, timezone +from multiprocessing import Process +from time import sleep + +import pytest +from sqlalchemy.exc import IntegrityError + +from app import app, db +from app.models.base import Group + + +@pytest.fixture(scope="session", autouse=True) +def test_server(test_database): + process = Process(target=run_app) + process.start() + + sleep(2) + + yield + + process.terminate() + process.join() + + +@pytest.fixture(scope="session", autouse=True) +def test_database(): + temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + db_path = temp_db.name + + # app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{db_path}" + app.config["TESTING"] = True + + with app.app_context(): + db.create_all() + group = Group(group_name="test-group", + description="Test group", + eotk=True, added=datetime.now(timezone.utc), + updated=datetime.now(timezone.utc)) + try: + db.session.add(group) + db.session.commit() + except IntegrityError: + db.session.rollback() + + yield db_path + + with app.app_context(): + db.drop_all() + + os.unlink(db_path) + + +def run_app(): + app.run(host="localhost", port=5001, debug=False, use_reloader=False) + + +@pytest.fixture(scope="function") +def temporary_test_directory(): + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) # Cleanup after the test diff --git a/tests/api/test_onion.py b/tests/api/test_onion.py new file mode 100644 index 0000000..933c419 --- /dev/null +++ b/tests/api/test_onion.py @@ -0,0 +1,215 @@ +import base64 +import json +import os +import shutil +import subprocess +from datetime import datetime, timedelta, timezone + +import pytest +import requests +from cryptography import x509 +from cryptography.hazmat._oid import NameOID +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.hashes import SHA256 + +from tests.api.conftest import temporary_test_directory + +mkp224o_available = shutil.which("mkp224o") is not None + + +def generate_self_signed_tls_certificate(folder_name: str, onion_address: str, valid_from: datetime, valid_to: datetime, + dns_names=None): + """ + Generate a self-signed TLS certificate for the Onion address and save it in the specified folder. + """ + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend(), + ) + subject = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test State"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test City"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Org"), + x509.NameAttribute(NameOID.COMMON_NAME, onion_address), + ] + ) + if dns_names is None: + dns_names = [onion_address, f"*.{onion_address}"] + + san_extension = x509.SubjectAlternativeName([x509.DNSName(name) for name in dns_names]) + + certificate = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(valid_from) + .not_valid_after(valid_to) + .add_extension(san_extension, critical=False) + .sign(private_key, SHA256(), default_backend()) + ) + + private_key_path = os.path.join(folder_name, "tls_private_key.pem") + certificate_path = os.path.join(folder_name, "tls_certificate.pem") + + with open(private_key_path, "wb") as f: + f.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + with open(certificate_path, "wb") as f: + f.write(certificate.public_bytes(serialization.Encoding.PEM)) + + return private_key_path, certificate_path + + +def generate_create_request_payload(folder_name: str): + with open(os.path.join(folder_name, "hs_ed25519_secret_key"), "rb") as f: + onion_private_key = base64.b64encode(f.read()).decode("utf-8") + with open(os.path.join(folder_name, "hs_ed25519_public_key"), "rb") as f: + onion_public_key = base64.b64encode(f.read()).decode("utf-8") + with open(os.path.join(folder_name, "tls_private_key.pem"), "r") as f: + tls_private_key = f.read() + with open(os.path.join(folder_name, "tls_certificate.pem"), "r") as f: + tls_public_key = f.read() + + payload = { + "DomainName": "example.com", + "Description": f"Generated Onion Service for {folder_name}", + "OnionPrivateKey": onion_private_key, + "OnionPublicKey": onion_public_key, + "TlsPrivateKey": tls_private_key, + "TlsCertificate": tls_public_key, + "SkipChainVerification": True, + "GroupId": 1, + } + + return payload + + +def generate_onion_keys_with_mkp224o(folder_name: str, label: str): + """ + Generate Tor-compatible Onion service keys using mkp224o. + The keys are saved in the specified folder, and the Onion address is returned. + """ + os.makedirs(folder_name, exist_ok=True) + + # Call mkp224o to generate a single Onion service key + process = subprocess.run( + ["mkp224o", "-n", "1", "-d", folder_name, label], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + try: + process.check_returncode() + except subprocess.CalledProcessError: + print("STDOUT:", process.stdout.decode()) + print("STDERR:", process.stderr.decode()) + raise + + # Find the generated Onion address + for filename in os.listdir(folder_name): + if filename.endswith(".onion"): + onion_address = filename + onion_dir = os.path.join(folder_name, filename) + + # Move files to parent directory + for key_file in ["hs_ed25519_secret_key", "hs_ed25519_public_key", "hostname"]: + src = os.path.join(onion_dir, key_file) + dst = os.path.join(folder_name, key_file) + if os.path.exists(src): + shutil.move(src, dst) + + # Remove the now-empty directory + os.rmdir(onion_dir) + + return onion_address + + raise RuntimeError("Failed to generate Onion keys using mkp224o") + +@pytest.mark.skipif(not mkp224o_available, reason="mkp224o is not available in PATH") +@pytest.mark.parametrize("scenario", [ + ("self_signed_onion_service", datetime.now(timezone.utc), datetime.now(timezone.utc) + timedelta(days=365), None, + 201, []), + ("expired_onion_service", datetime.now(timezone.utc) - timedelta(days=730), + datetime.now(timezone.utc) - timedelta(days=365), None, 400, ["tls_public_key_expired"]), + ("future_onion_service", datetime.now(timezone.utc) + timedelta(days=365), + datetime.now(timezone.utc) + timedelta(days=730), None, 400, ["tls_public_key_future"]), + ("wrong_name_onion_service", datetime.now(timezone.utc), datetime.now(timezone.utc) + timedelta(days=365), + ["wrong-name.example.com"], 400, ["hostname_not_in_san"]), +]) +def test_onion_service_creation(temporary_test_directory, scenario): + folder_name, valid_from, valid_to, dns_names, expected_status, expected_errors = scenario + + folder_path = os.path.join(temporary_test_directory, folder_name) + os.makedirs(folder_path, exist_ok=True) + + onion_address = generate_onion_keys_with_mkp224o(folder_path, "test") + generate_self_signed_tls_certificate(folder_path, onion_address, valid_from, valid_to, dns_names) + + payload = generate_create_request_payload(folder_path) + + response = requests.post( + "http://localhost:5001/api/onion/onion", + headers={"Content-Type": "application/json"}, + data=json.dumps(payload), + ) + + assert response.status_code == expected_status, f"Unexpected response: {response.text}" + + response_data = response.json() + if expected_errors: + assert "Errors" in response_data + assert set(expected_errors) == set([e["Error"] for e in response_data["Errors"]]) + else: + assert "Errors" not in response_data + + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + +@pytest.mark.skipif(not mkp224o_available, reason="mkp224o is not available in PATH") +@pytest.mark.parametrize("new_description, expected_status", [ + ("Updated description", 200), + (None, 400), +]) +def test_update_onion_description(temporary_test_directory, new_description, expected_status): + update_payload = {"Description": new_description} + response = requests.put("http://localhost:5001/api/onion/onion/1", headers={"Content-Type": "application/json"}, + data=json.dumps(update_payload)) + + assert response.status_code == expected_status + +@pytest.mark.skipif(not mkp224o_available, reason="mkp224o is not available in PATH") +def test_update_tls_certificate(temporary_test_directory): + folder_path = os.path.join(temporary_test_directory, "update_certificate") + os.makedirs(folder_path) + onion_address = "test.onion" + private_key_path, certificate_path = generate_self_signed_tls_certificate( + folder_path, onion_address, datetime.now(timezone.utc), datetime.now(timezone.utc) + timedelta(days=365) + ) + + with open(private_key_path, "r") as f: + new_private_key = f.read() + with open(certificate_path, "r") as f: + new_certificate = f.read() + + update_payload = { + "TlsPrivateKey": new_private_key, + "TlsCertificate": new_certificate, + "SkipChainVerification": True, + "SkipNameVerification": True, + } + + response = requests.put("http://localhost:5001/api/onion/onion/1", headers={"Content-Type": "application/json"}, + data=json.dumps(update_payload)) + + assert response.status_code == 200