diff --git a/app/api/__init__.py b/app/api/__init__.py index 20ca112..8c579e9 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -1,28 +1,11 @@ -import base64 -import binascii -import logging -import re -from datetime import datetime, timedelta, timezone -from typing import Optional, List, Callable, Any, Type, Dict, Union, Literal - -from flask import Blueprint, request, jsonify, abort +from flask import Blueprint, jsonify from flask.typing import ResponseReturnValue -from sqlalchemy import select, BinaryExpression, ColumnElement from werkzeug.exceptions import HTTPException -from app.extensions import db -from app.models.base import Group -from app.models.mirrors import Origin, Proxy -from app.models.onions import Onion +from app.api.web import api_web api = Blueprint('api', __name__) -logger = logging.getLogger(__name__) - -MAX_DOMAIN_NAME_LENGTH = 255 -DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$') -MAX_ALLOWED_ITEMS = 100 - -ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]] +api.register_blueprint(api_web, url_prefix='/web') @api.errorhandler(400) @@ -51,191 +34,3 @@ def internal_server_error(_: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'}) response.status_code = 500 return response - - -def validate_max_items(max_items_str: str, max_allowed: int) -> int: - try: - max_items = int(max_items_str) - if max_items <= 0 or max_items > max_allowed: - raise ValueError() - return max_items - except ValueError: - abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.") - - -def validate_marker(marker_str: str) -> int: - try: - marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode() - marker_id = int(marker_decoded) - return marker_id - except (ValueError, binascii.Error): - abort(400, description="Marker must be a valid token.") - - -TlpMarkings = Union[ - Literal["default"], - Literal["clear"], - Literal["green"], - Literal["amber"], - Literal["amber+strict"], - Literal["red"], -] - - -def list_resources( # pylint: disable=too-many-arguments,too-many-locals - model: Type[Any], - serialize_func: Callable[[Any], Dict[str, Any]], - *, - filters: List[ListFilter] = None, - order_by: Optional[ColumnElement[Any]] = None, - resource_name: str = 'ResourceList', - max_items_param: str = 'MaxItems', - marker_param: str = 'Marker', - max_allowed_items: int = 100, - protective_marking: TlpMarkings = 'default', -) -> ResponseReturnValue: - try: - marker = request.args.get(marker_param) - max_items = validate_max_items( - request.args.get(max_items_param, default='100'), max_allowed_items) - query = select(model) - - if filters: - query = query.where(*filters) - - if marker: - marker_id = validate_marker(marker) - query = query.where(model.id > marker_id) - - query = query.order_by(order_by or model.id) - query = query.limit(max_items + 1) # Need to know if there's more - - result = db.session.execute(query) - items = result.scalars().all() - items_list = [serialize_func(item) for item in items[:max_items]] - is_truncated = len(items) > max_items - - response = { - resource_name: { - marker_param: marker if marker else None, - max_items_param: str(max_items), - "Quantity": len(items_list), - "Items": items_list, - "IsTruncated": is_truncated, - "ProtectiveMarking": protective_marking, - } - } - - if is_truncated: - last_id = items[max_items - 1].id - next_marker = base64.urlsafe_b64encode(str(last_id).encode()).decode() - response[resource_name]["NextMarker"] = next_marker - - return jsonify(response) - except Exception: # pylint: disable=broad-exception-caught - logger.exception("An unexpected error occurred") - abort(500) - - -@api.route('/web/group', methods=['GET']) -def list_groups() -> ResponseReturnValue: - return list_resources( - Group, - lambda group: group.to_dict(), - resource_name='OriginGroupList', - max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', - ) - - -@api.route('/web/origin', methods=['GET']) -def list_origins() -> ResponseReturnValue: - domain_name_filter = request.args.get('DomainName') - group_id_filter = request.args.get('GroupId') - - filters: List[ListFilter] = [] - - if domain_name_filter: - if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: - abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") - if not DOMAIN_NAME_REGEX.match(domain_name_filter): - abort(400, description="DomainName contains invalid characters.") - filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%")) - - if group_id_filter: - try: - filters.append(Origin.group_id == int(group_id_filter)) - except ValueError: - abort(400, description="GroupId must be a valid integer.") - - return list_resources( - Origin, - lambda origin: origin.to_dict(), - filters=filters, - resource_name='OriginsList', - max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', - ) - - -@api.route('/web/mirror', methods=['GET']) -def list_mirrors() -> ResponseReturnValue: - filters = [] - - twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) - status_filter = request.args.get('Status') - if status_filter: - if status_filter == "pending": - filters.append(Proxy.url.is_(None)) - filters.append(Proxy.deprecated.is_(None)) - filters.append(Proxy.destroyed.is_(None)) - if status_filter == "active": - filters.append(Proxy.url.is_not(None)) - filters.append(Proxy.deprecated.is_(None)) - filters.append(Proxy.destroyed.is_(None)) - if status_filter == "expiring": - filters.append(Proxy.deprecated.is_not(None)) - filters.append(Proxy.destroyed.is_(None)) - if status_filter == "destroyed": - filters.append(Proxy.destroyed > twenty_four_hours_ago) - else: - filters.append((Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago)) - - return list_resources( - Proxy, - lambda proxy: proxy.to_dict(), - filters=filters, - resource_name='MirrorsList', - max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', - ) - - -@api.route('/web/onion', methods=['GET']) -def list_onions() -> ResponseReturnValue: - domain_name_filter = request.args.get('DomainName') - group_id_filter = request.args.get('GroupId') - - filters: List[ListFilter] = [] - - if domain_name_filter: - if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: - abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") - if not DOMAIN_NAME_REGEX.match(domain_name_filter): - abort(400, description="DomainName contains invalid characters.") - filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%")) - - if group_id_filter: - try: - filters.append(Onion.group_id == int(group_id_filter)) - except ValueError: - abort(400, description="GroupId must be a valid integer.") - - return list_resources( - Onion, - lambda onion: onion.to_dict(), - filters=filters, - resource_name='OnionsList', - max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', - ) diff --git a/app/api/util.py b/app/api/util.py new file mode 100644 index 0000000..1069310 --- /dev/null +++ b/app/api/util.py @@ -0,0 +1,101 @@ +import base64 +import binascii +import logging +import re +from typing import Union, Any, Literal, Type, Callable, Dict, List, Optional + +from flask import abort, request, jsonify +from flask.typing import ResponseReturnValue +from sqlalchemy import BinaryExpression, ColumnElement, select + +from app.extensions import db + +logger = logging.getLogger(__name__) +MAX_DOMAIN_NAME_LENGTH = 255 +DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$') +MAX_ALLOWED_ITEMS = 100 +ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]] + + +def validate_max_items(max_items_str: str, max_allowed: int) -> int: + try: + max_items = int(max_items_str) + if max_items <= 0 or max_items > max_allowed: + raise ValueError() + return max_items + except ValueError: + abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.") + + +def validate_marker(marker_str: str) -> int: + try: + marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode() + marker_id = int(marker_decoded) + return marker_id + except (ValueError, binascii.Error): + abort(400, description="Marker must be a valid token.") + + +TlpMarkings = Union[ + Literal["default"], + Literal["clear"], + Literal["green"], + Literal["amber"], + Literal["amber+strict"], + Literal["red"], +] + + +def list_resources( # pylint: disable=too-many-arguments,too-many-locals + model: Type[Any], + serialize_func: Callable[[Any], Dict[str, Any]], + *, + filters: Optional[List[ListFilter]] = None, + order_by: Optional[ColumnElement[Any]] = None, + resource_name: str = 'ResourceList', + max_items_param: str = 'MaxItems', + marker_param: str = 'Marker', + max_allowed_items: int = 100, + protective_marking: TlpMarkings = 'default', +) -> ResponseReturnValue: + try: + marker = request.args.get(marker_param) + max_items = validate_max_items( + request.args.get(max_items_param, default='100'), max_allowed_items) + query = select(model) + + if filters: + query = query.where(*filters) + + if marker: + marker_id = validate_marker(marker) + query = query.where(model.id > marker_id) + + query = query.order_by(order_by or model.id) + query = query.limit(max_items + 1) # Need to know if there's more + + result = db.session.execute(query) + items = result.scalars().all() + items_list = [serialize_func(item) for item in items[:max_items]] + is_truncated = len(items) > max_items + + response = { + resource_name: { + marker_param: marker if marker else None, + max_items_param: str(max_items), + "Quantity": len(items_list), + "Items": items_list, + "IsTruncated": is_truncated, + "ProtectiveMarking": protective_marking, + } + } + + if is_truncated: + last_id = items[max_items - 1].id + next_marker = base64.urlsafe_b64encode(str(last_id).encode()).decode() + response[resource_name]["NextMarker"] = next_marker + + return jsonify(response) + except Exception: # pylint: disable=broad-exception-caught + logger.exception("An unexpected error occurred") + abort(500) diff --git a/app/api/web.py b/app/api/web.py new file mode 100644 index 0000000..74f55d6 --- /dev/null +++ b/app/api/web.py @@ -0,0 +1,275 @@ +import base64 +import logging +from datetime import datetime, timezone, timedelta +from typing import List, TypedDict, Optional + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey, RSAPrivateKey +from cryptography.x509.oid import ExtensionOID +from flask import Blueprint, request, jsonify, abort +from flask.typing import ResponseReturnValue +from sqlalchemy import exc +from tldextract import tldextract + +from app.api.util import ListFilter, MAX_DOMAIN_NAME_LENGTH, DOMAIN_NAME_REGEX, list_resources, MAX_ALLOWED_ITEMS +from app.util.x509 import build_certificate_chain, validate_certificate_chain, load_certificates_from_pem +from app.extensions import db +from app.models.base import Group +from app.models.mirrors import Origin, Proxy +from app.models.onions import Onion +from app.util.onion import onion_hostname + +api_web = Blueprint('web', __name__) + + +@api_web.route('/group', methods=['GET']) +def list_groups() -> ResponseReturnValue: + return list_resources( + Group, + lambda group: group.to_dict(), + resource_name='OriginGroupList', + max_allowed_items=MAX_ALLOWED_ITEMS, + protective_marking='amber', + ) + + +@api_web.route('/origin', methods=['GET']) +def list_origins() -> ResponseReturnValue: + domain_name_filter = request.args.get('DomainName') + group_id_filter = request.args.get('GroupId') + + filters: List[ListFilter] = [] + + if domain_name_filter: + if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: + abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") + if not DOMAIN_NAME_REGEX.match(domain_name_filter): + abort(400, description="DomainName contains invalid characters.") + filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%")) + + if group_id_filter: + try: + filters.append(Origin.group_id == int(group_id_filter)) + except ValueError: + abort(400, description="GroupId must be a valid integer.") + + return list_resources( + Origin, + lambda origin: origin.to_dict(), + filters=filters, + resource_name='OriginsList', + max_allowed_items=MAX_ALLOWED_ITEMS, + protective_marking='amber', + ) + + +@api_web.route('/mirror', methods=['GET']) +def list_mirrors() -> ResponseReturnValue: + filters = [] + + twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) + status_filter = request.args.get('Status') + if status_filter: + if status_filter == "pending": + filters.append(Proxy.url.is_(None)) + filters.append(Proxy.deprecated.is_(None)) + filters.append(Proxy.destroyed.is_(None)) + if status_filter == "active": + filters.append(Proxy.url.is_not(None)) + filters.append(Proxy.deprecated.is_(None)) + filters.append(Proxy.destroyed.is_(None)) + if status_filter == "expiring": + filters.append(Proxy.deprecated.is_not(None)) + filters.append(Proxy.destroyed.is_(None)) + if status_filter == "destroyed": + filters.append(Proxy.destroyed > twenty_four_hours_ago) + else: + filters.append((Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago)) + + return list_resources( + Proxy, + lambda proxy: proxy.to_dict(), + filters=filters, + resource_name='MirrorsList', + max_allowed_items=MAX_ALLOWED_ITEMS, + protective_marking='amber', + ) + + +@api_web.route('/onion', methods=['GET']) +def list_onions() -> ResponseReturnValue: + domain_name_filter = request.args.get('DomainName') + group_id_filter = request.args.get('GroupId') + + filters: List[ListFilter] = [ + (Onion.destroyed.is_(None)) + ] + + if domain_name_filter: + if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: + abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") + if not DOMAIN_NAME_REGEX.match(domain_name_filter): + abort(400, description="DomainName contains invalid characters.") + filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%")) + + if group_id_filter: + try: + filters.append(Onion.group_id == int(group_id_filter)) + except ValueError: + abort(400, description="GroupId must be a valid integer.") + + return list_resources( + Onion, + lambda onion: onion.to_dict(), + filters=filters, + resource_name='OnionsList', + max_allowed_items=MAX_ALLOWED_ITEMS, + protective_marking='amber', + ) + + +class CreateOnionRequest(TypedDict): + DomainName: str + Description: str + GroupId: int + OnionPrivateKey: str + OnionPublicKey: str + TlsPrivateKey: str + TlsCertificate: str + SkipChainValidation: Optional[bool] + SkipNameValidation: Optional[bool] + + +@api_web.route("/onion", methods=["POST"]) +def create_onion() -> ResponseReturnValue: + data: Optional[CreateOnionRequest] = request.json + if not data: + abort(400) + + errors = [] + + required_fields = ["DomainName", "Description", "OnionPrivateKey", "OnionPublicKey", "GroupId", "TlsPrivateKey", "TlsCertificate"] + for field in required_fields: + if field not in data or not data[field]: # type: ignore[literal-required] + errors.append({"Error": f"{field}_missing", "Message": f"Missing required field: {field}"}) + + if "GroupId" in data: + group = Group.query.get(data["GroupId"]) + if not group: + errors.append({"Error": "group_id_not_found", "Message": "Specified group ID does not exist."}) + + try: + onion_private_key = base64.b64decode(data["OnionPrivateKey"]) + onion_public_key = base64.b64decode(data["OnionPublicKey"]) + except (KeyError, ValueError, TypeError): + errors.append({"Error": "onion_key_invalid", "Message": "Onion keys must be valid Base64-encoded data."}) + + if errors: + return jsonify({"Errors": errors}), 400 + + skip_chain_verification = data.get("SkipChainVerification", False) + skip_name_verification = data.get("SkipNameVerification", False) + + try: + private_key = serialization.load_pem_private_key( + data["TlsPrivateKey"].encode("utf-8"), + password=None, + backend=default_backend() + ) + certificates = list(load_certificates_from_pem(data["TlsCertificate"].encode("utf-8"))) + if not certificates: + errors.append( + {"Error": "tls_public_key_invalid", "Message": "TLS public key must contain at least one certificate."}) + return jsonify({"Errors": errors}), 400 + + chain = build_certificate_chain(certificates) + end_entity_cert = chain[0] + + test_message = b"test" + if not isinstance(private_key, RSAPrivateKey): + errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be an RSA private key."}) + return jsonify({"Errors": errors}), 400 + hash_algorithm = end_entity_cert.signature_hash_algorithm + if hash_algorithm is None: + errors.append({"Error": "tls_public_key_invalid", "Message": "Public key using unsupported algorithm."}) + return jsonify({"Errors": errors}), 400 + signature = private_key.sign( + test_message, + PKCS1v15(), + hash_algorithm + ) + end_entity_public_key = end_entity_cert.public_key() + if isinstance(end_entity_public_key, RSAPublicKey): + end_entity_public_key.verify( + signature, + test_message, + PKCS1v15(), + hash_algorithm + ) + else: + errors.append({"Error": "tls_public_key_invalid", "Message": "Public key using unsupported algorithm."}) + + if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc): + errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."}) + + if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc): + errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."}) + + try: + san_extension = end_entity_cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + san_list = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined] + except x509.ExtensionNotFound: + san_list = [] + errors.append({"Error": "san_not_found", + "Message": "No Subject Alternative Names (SANs) found in the TLS public key."}) + + if not skip_chain_verification: + try: + validate_certificate_chain(chain) + except ValueError as e: + errors.append({"Error": "certificate_chain_invalid", "Message": str(e)}) + + if not skip_name_verification: + if "DomainName" in data: + registered_domain = tldextract.extract(data["DomainName"]).registered_domain + if data["DomainName"] != registered_domain: + errors.append({"Error": "domain_name_not_registered_domain", "Message": "The domain name is invalid, or is not the top-level domain (i.e. a subdomain was entered)."}) + + if "OnionPublicKey" in data: + hostname = f"{onion_hostname(onion_public_key)}.onion" + for name in [hostname, f"*.{hostname}"]: + if name not in san_list: + errors.append({"Error": "certificate_san_missing", "Message": f"{name} missing from certificate SAN list."}) + + if errors: + return jsonify({"Errors": errors}), 400 + + except Exception as e: + errors.append({"Error": "tls_validation_error", "Message": f"TLS key/certificate validation failed: {str(e)}"}) + return jsonify({"Errors": errors}), 400 + + onion = Onion( + domain_name=data["DomainName"].strip(), + description=data["Description"], + onion_private_key=onion_private_key, + onion_public_key=onion_public_key, + tls_private_key=data["TlsPrivateKey"].encode("utf-8"), + tls_public_key=data["TlsCertificate"].encode("utf-8"), + group_id=data["GroupId"], + added=datetime.now(timezone.utc), + updated=datetime.now(timezone.utc), + # cert_expiry_date=end_entity_cert.not_valid_after_utc, TODO: extend schema to accommodate these fields + # cert_sans=",".join(san_list), + ) + + try: + db.session.add(onion) + db.session.commit() + return jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), 201 + except exc.SQLAlchemyError as e: + logging.exception(e) + errors.append({"Error": "database_error", "Message": "Failed to create onion service."}) + return jsonify({"Errors": errors}), 500 diff --git a/app/models/onions.py b/app/models/onions.py index 9fa628e..8287fbb 100644 --- a/app/models/onions.py +++ b/app/models/onions.py @@ -1,5 +1,3 @@ -import base64 -import hashlib from typing import Optional, TypedDict from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -8,6 +6,7 @@ from app.brm.brn import BRN from app.extensions import db from app.models import AbstractConfiguration, AbstractResource from app.models.base import Group +from app.util.onion import onion_hostname class OnionDict(TypedDict): @@ -40,20 +39,7 @@ class Onion(AbstractConfiguration): @property def onion_name(self) -> str: - p = self.onion_public_key[32:] - - h = hashlib.sha3_256() - h.update(b".onion checksum") - h.update(p) - h.update(b"\x03") - checksum = h.digest() - - result = bytearray(p) - result.extend(checksum[0:2]) - result.append(0x03) - - onion = base64.b32encode(result).decode("utf-8").strip("=") - return onion.lower() + return onion_hostname(self.onion_public_key) def to_dict(self) -> OnionDict: return { diff --git a/app/portal/onion.py b/app/portal/onion.py index 02b15d3..f80055f 100644 --- a/app/portal/onion.py +++ b/app/portal/onion.py @@ -1,10 +1,9 @@ from datetime import datetime from typing import Optional -from flask import flash, redirect, url_for, render_template, Response, Blueprint +from flask import flash, redirect, render_template, Response, Blueprint from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm -from flask_wtf.file import FileRequired from sqlalchemy import exc from wtforms import StringField, SelectField, SubmitField from flask_wtf.file import FileField @@ -18,21 +17,6 @@ from app.portal.util import response_404, view_lifecycle bp = Blueprint("onion", __name__) -class NewOnionForm(FlaskForm): # type: ignore - domain_name = StringField('Domain Name', validators=[DataRequired()]) - description = StringField('Description', validators=[DataRequired()]) - onion_private_key = FileField('Onion Private Key', validators=[FileRequired()]) - onion_public_key = FileField('Onion Public Key', - description="The onion hostname will be automatically calculated from the public key.", - validators=[FileRequired()]) - tls_private_key = FileField('TLS Private Key (PEM format)', - description=("If no TLS key and certificate are provided, a self-signed certificate " - "will be generated.")) - tls_public_key = FileField('TLS Certificate (PEM format)') - group = SelectField('Group', validators=[DataRequired()]) - submit = SubmitField('Save Changes') - - class EditOnionForm(FlaskForm): # type: ignore domain_name = StringField('Domain Name', validators=[DataRequired()]) description = StringField('Description', validators=[DataRequired()]) @@ -47,36 +31,7 @@ class EditOnionForm(FlaskForm): # type: ignore @bp.route("/new", methods=['GET', 'POST']) @bp.route("/new/", methods=['GET', 'POST']) def onion_new(group_id: Optional[int] = None) -> ResponseReturnValue: - form = NewOnionForm() - form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] - if form.validate_on_submit(): - onion = Onion() - onion.group_id = form.group.data - onion.domain_name = form.domain_name.data - for at in [ - "onion_private_key", - "onion_public_key", - "tls_private_key", - "tls_public_key" - ]: - if form.__getattribute__(at).data is None: - flash(f"Failed to create new onion. {at} was not provided.", "danger") - return redirect(url_for("portal.onion.onion_list")) - onion.__setattr__(at, form.__getattribute__(at).data.read()) - onion.description = form.description.data - onion.created = datetime.utcnow() - onion.updated = datetime.utcnow() - try: - db.session.add(onion) - db.session.commit() - flash(f"Created new onion {onion.onion_name}.", "success") - return redirect(url_for("portal.onion.onion_edit", onion_id=onion.id)) - except exc.SQLAlchemyError: - flash("Failed to create new onion.", "danger") - return redirect(url_for("portal.onion.onion_list")) - if group_id: - form.group.data = group_id - return render_template("new.html.j2", section="onion", form=form) + return redirect("/ui/web/onions/new") @bp.route('/edit/', methods=['GET', 'POST']) @@ -116,18 +71,12 @@ def onion_edit(onion_id: int) -> ResponseReturnValue: @bp.route("/list") def onion_list() -> ResponseReturnValue: - onions = Onion.query.order_by(Onion.domain_name).all() - return render_template("list.html.j2", - section="onion", - title="Onion Services", - item="onion service", - new_link=url_for("portal.onion.onion_new"), - items=onions) + return redirect("/ui/web/onions") @bp.route("/destroy/", methods=['GET', 'POST']) -def onion_destroy(onion_id: int) -> ResponseReturnValue: - onion = Onion.query.filter(Onion.id == onion_id, Onion.destroyed.is_(None)).first() +def onion_destroy(onion_id: str) -> ResponseReturnValue: + onion = Onion.query.filter(Onion.id == int(onion_id), Onion.destroyed.is_(None)).first() if onion is None: return response_404("The requested onion service could not be found.") return view_lifecycle( diff --git a/app/util/__init__.py b/app/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/util/onion.py b/app/util/onion.py new file mode 100644 index 0000000..a7876d2 --- /dev/null +++ b/app/util/onion.py @@ -0,0 +1,19 @@ +import base64 +import hashlib + + +def onion_hostname(onion_public_key: bytes) -> str: + p = onion_public_key[32:] + + h = hashlib.sha3_256() + h.update(b".onion checksum") + h.update(p) + h.update(b"\x03") + checksum = h.digest() + + result = bytearray(p) + result.extend(checksum[0:2]) + result.append(0x03) + + onion = base64.b32encode(result).decode("utf-8").strip("=") + return onion.lower() diff --git a/app/util/x509.py b/app/util/x509.py new file mode 100644 index 0000000..6422464 --- /dev/null +++ b/app/util/x509.py @@ -0,0 +1,65 @@ +import ssl + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + + +def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]: + certificates = [] + for pem_block in pem_data.split(b"-----END CERTIFICATE-----"): + pem_block = pem_block.strip() + if pem_block: + pem_block += b"-----END CERTIFICATE-----" + certificate = x509.load_pem_x509_certificate(pem_block, default_backend()) + certificates.append(certificate) + return certificates + + +def build_certificate_chain(certificates: list[x509.Certificate]) -> list[x509.Certificate]: + if len(certificates) == 1: + return certificates + chain = [] + cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates} + end_entity = next( + (cert for cert in certificates if cert.subject.rfc4514_string() not in cert_map), + None + ) + if not end_entity: + raise ValueError("Cannot identify the end-entity certificate.") + chain.append(end_entity) + current_cert = end_entity + while current_cert.issuer.rfc4514_string() in cert_map: + next_cert = cert_map[current_cert.issuer.rfc4514_string()] + chain.append(next_cert) + current_cert = next_cert + return chain + + +def validate_certificate_chain(chain: list[x509.Certificate]) -> bool: + """Validate a certificate chain against the system's root CA store.""" + context = ssl.create_default_context() + store = context.get_ca_certs(binary_form=True) + trusted_certificates = [x509.load_der_x509_certificate(cert) for cert in store] + + for i in range(len(chain) - 1): + next_public_key = chain[i + 1].public_key() + if not (isinstance(next_public_key, RSAPublicKey)): + raise ValueError(f"Certificate using unsupported algorithm: {type(next_public_key)}") + hash_algorithm = chain[i].signature_hash_algorithm + if hash_algorithm is None: + raise ValueError("Certificate missing hash algorithm") + next_public_key.verify( + chain[i].signature, + chain[i].tbs_certificate_bytes, + PKCS1v15(), + hash_algorithm + ) + + end_cert = chain[-1] + if not any( + end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates + ): + raise ValueError("Certificate chain does not terminate at a trusted root CA.") + return True diff --git a/requirements.txt b/requirements.txt index e3e7394..473ae09 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ beautifulsoup4 bootstrap-flask boto3 bs4 +cryptography flask flask-migrate flask-wtf diff --git a/scripts/generate_test_onion_tls.py b/scripts/generate_test_onion_tls.py new file mode 100644 index 0000000..54608ef --- /dev/null +++ b/scripts/generate_test_onion_tls.py @@ -0,0 +1,160 @@ +import os +import shutil +import subprocess +import base64 +import json +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.hashes import SHA256 +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.backends import default_backend +from datetime import datetime, timedelta, timezone + + +def generate_onion_keys_with_mkp224o(folder_name: str, label: str): + """ + Generate Tor-compatible Onion service keys using mkp224o. + The keys are saved in the specified folder, and the Onion address is returned. + """ + os.makedirs(folder_name, exist_ok=True) + + # Call mkp224o to generate a single Onion service key + process = subprocess.run( + ["mkp224o", "-n", "1", "-d", folder_name, label], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + try: + process.check_returncode() + except subprocess.CalledProcessError: + print("STDOUT:", process.stdout.decode()) + print("STDERR:", process.stderr.decode()) + raise + + # Find the generated Onion address + for filename in os.listdir(folder_name): + if filename.endswith(".onion"): + onion_address = filename + onion_dir = os.path.join(folder_name, filename) + + # Move files to parent directory + for key_file in ["hs_ed25519_secret_key", "hs_ed25519_public_key", "hostname"]: + src = os.path.join(onion_dir, key_file) + dst = os.path.join(folder_name, key_file) + if os.path.exists(src): + shutil.move(src, dst) + + # Remove the now-empty directory + os.rmdir(onion_dir) + + return onion_address + + raise RuntimeError("Failed to generate Onion keys using mkp224o") + + +def generate_self_signed_tls_certificate(folder_name: str, onion_address: str, valid_from: datetime, valid_to: datetime, dns_names=None): + """ + Generate a self-signed TLS certificate for the Onion address and save it in the specified folder. + """ + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend(), + ) + subject = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test State"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test City"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Org"), + x509.NameAttribute(NameOID.COMMON_NAME, onion_address), + ] + ) + if dns_names is None: + dns_names = [onion_address, f"*.{onion_address}"] + + san_extension = x509.SubjectAlternativeName([x509.DNSName(name) for name in dns_names]) + + certificate = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(valid_from) + .not_valid_after(valid_to) + .add_extension(san_extension, critical=False) + .sign(private_key, SHA256(), default_backend()) + ) + + private_key_path = os.path.join(folder_name, "tls_private_key.pem") + certificate_path = os.path.join(folder_name, "tls_certificate.pem") + + with open(private_key_path, "wb") as f: + f.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + with open(certificate_path, "wb") as f: + f.write(certificate.public_bytes(serialization.Encoding.PEM)) + + return private_key_path, certificate_path + + +def generate_rest_payload(parent_folder: str, folder_name: str, onion_address: str): + """ + Generate REST payload for a specific Onion service and append it to a shared .rest file. + """ + rest_file_path = os.path.join(parent_folder, "new_onion.rest") + + with open(os.path.join(folder_name, "hs_ed25519_secret_key"), "rb") as f: + onion_private_key = base64.b64encode(f.read()).decode("utf-8") + with open(os.path.join(folder_name, "hs_ed25519_public_key"), "rb") as f: + onion_public_key = base64.b64encode(f.read()).decode("utf-8") + with open(os.path.join(folder_name, "tls_private_key.pem"), "r") as f: + tls_private_key = f.read() + with open(os.path.join(folder_name, "tls_certificate.pem"), "r") as f: + tls_public_key = f.read() + + payload = { + "DomainName": onion_address, + "Description": f"Generated Onion Service for {folder_name}", + "OnionPrivateKey": onion_private_key, + "OnionPublicKey": onion_public_key, + "TlsPrivateKey": tls_private_key, + "TlsCertificate": tls_public_key, + "SkipChainVerification": True, + "GroupId": 1, + } + + with open(rest_file_path, "a") as f: + f.write(f"### Create Onion Service ({folder_name})\n") + f.write("POST http://localhost:5000/api/web/onion\n") + f.write("Content-Type: application/json\n\n") + json.dump(payload, f, indent=4) + f.write("\n\n") + + +if __name__ == "__main__": + parent_folder = "." + scenarios = [ + ("self_signed_onion_service", datetime.now(timezone.utc), datetime.now(timezone.utc) + timedelta(days=365), None), + ("expired_onion_service", datetime.now(timezone.utc) - timedelta(days=730), datetime.now(timezone.utc) - timedelta(days=365), None), + ("future_onion_service", datetime.now(timezone.utc) + timedelta(days=365), datetime.now(timezone.utc) + timedelta(days=730), None), + ("wrong_name_onion_service", datetime.now(timezone.utc), datetime.now(timezone.utc) + timedelta(days=365), ["wrong-name.example.com"]), + ] + + if os.path.exists("new_onion.rest"): + os.remove("new_onion.rest") + + for folder_name, valid_from, valid_to, dns_names in scenarios: + print(f"Generating {folder_name}...") + onion_address = generate_onion_keys_with_mkp224o(folder_name, "test") + generate_self_signed_tls_certificate(folder_name, onion_address, valid_from, valid_to, dns_names) + generate_rest_payload(parent_folder, folder_name, onion_address) + + print("All Onion services and REST requests generated successfully.")