import base64 import binascii import logging import re from datetime import datetime, timedelta, timezone from typing import Optional, List, Callable, Any, Type, Dict, Union, Literal from flask import Blueprint, request, jsonify, abort from flask.typing import ResponseReturnValue from sqlalchemy import select, BinaryExpression, ColumnElement from werkzeug.exceptions import HTTPException from app.extensions import db from app.models.base import Group from app.models.mirrors import Origin, Proxy from app.models.onions import Onion api = Blueprint('api', __name__) logger = logging.getLogger(__name__) MAX_DOMAIN_NAME_LENGTH = 255 DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$') MAX_ALLOWED_ITEMS = 100 ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]] @api.errorhandler(400) def bad_request(error: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Bad Request', 'message': error.description}) response.status_code = 400 return response @api.errorhandler(401) def unauthorized(error: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Unauthorized', 'message': error.description}) response.status_code = 401 return response @api.errorhandler(404) def not_found(_: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'}) response.status_code = 404 return response @api.errorhandler(500) def internal_server_error(_: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'}) response.status_code = 500 return response def validate_max_items(max_items_str: str, max_allowed: int) -> int: try: max_items = int(max_items_str) if max_items <= 0 or max_items > max_allowed: raise ValueError() return max_items except ValueError: abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.") def validate_marker(marker_str: str) -> int: try: marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode() marker_id = int(marker_decoded) return marker_id except (ValueError, binascii.Error): abort(400, description="Marker must be a valid token.") TlpMarkings = Union[ Literal["default"], Literal["clear"], Literal["green"], Literal["amber"], Literal["amber+strict"], Literal["red"], ] def list_resources( # pylint: disable=too-many-arguments,too-many-locals model: Type[Any], serialize_func: Callable[[Any], Dict[str, Any]], *, filters: List[ListFilter] = None, order_by: Optional[ColumnElement[Any]] = None, resource_name: str = 'ResourceList', max_items_param: str = 'MaxItems', marker_param: str = 'Marker', max_allowed_items: int = 100, protective_marking: TlpMarkings = 'default', ) -> ResponseReturnValue: try: marker = request.args.get(marker_param) max_items = validate_max_items( request.args.get(max_items_param, default='100'), max_allowed_items) query = select(model) if filters: query = query.where(*filters) if marker: marker_id = validate_marker(marker) query = query.where(model.id > marker_id) query = query.order_by(order_by or model.id) query = query.limit(max_items + 1) # Need to know if there's more result = db.session.execute(query) items = result.scalars().all() items_list = [serialize_func(item) for item in items[:max_items]] is_truncated = len(items) > max_items response = { resource_name: { marker_param: marker if marker else None, max_items_param: str(max_items), "Quantity": len(items_list), "Items": items_list, "IsTruncated": is_truncated, "ProtectiveMarking": protective_marking, } } if is_truncated: last_id = items[max_items - 1].id next_marker = base64.urlsafe_b64encode(str(last_id).encode()).decode() response[resource_name]["NextMarker"] = next_marker return jsonify(response) except Exception: # pylint: disable=broad-exception-caught logger.exception("An unexpected error occurred") abort(500) @api.route('/web/group', methods=['GET']) def list_groups() -> ResponseReturnValue: return list_resources( Group, lambda group: group.to_dict(), resource_name='OriginGroupList', max_allowed_items=MAX_ALLOWED_ITEMS, protective_marking='amber', ) @api.route('/web/origin', methods=['GET']) def list_origins() -> ResponseReturnValue: domain_name_filter = request.args.get('DomainName') group_id_filter = request.args.get('GroupId') filters: List[ListFilter] = [] if domain_name_filter: if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") if not DOMAIN_NAME_REGEX.match(domain_name_filter): abort(400, description="DomainName contains invalid characters.") filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%")) if group_id_filter: try: filters.append(Origin.group_id == int(group_id_filter)) except ValueError: abort(400, description="GroupId must be a valid integer.") return list_resources( Origin, lambda origin: origin.to_dict(), filters=filters, resource_name='OriginsList', max_allowed_items=MAX_ALLOWED_ITEMS, protective_marking='amber', ) @api.route('/web/mirror', methods=['GET']) def list_mirrors() -> ResponseReturnValue: filters = [] twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) status_filter = request.args.get('Status') if status_filter: if status_filter == "pending": filters.append(Proxy.url.is_(None)) filters.append(Proxy.deprecated.is_(None)) filters.append(Proxy.destroyed.is_(None)) if status_filter == "active": filters.append(Proxy.url.is_not(None)) filters.append(Proxy.deprecated.is_(None)) filters.append(Proxy.destroyed.is_(None)) if status_filter == "expiring": filters.append(Proxy.deprecated.is_not(None)) filters.append(Proxy.destroyed.is_(None)) if status_filter == "destroyed": filters.append(Proxy.destroyed > twenty_four_hours_ago) else: filters.append((Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago)) return list_resources( Proxy, lambda proxy: proxy.to_dict(), filters=filters, resource_name='MirrorsList', max_allowed_items=MAX_ALLOWED_ITEMS, protective_marking='amber', ) @api.route('/web/onion', methods=['GET']) def list_onions() -> ResponseReturnValue: domain_name_filter = request.args.get('DomainName') group_id_filter = request.args.get('GroupId') filters: List[ListFilter] = [] if domain_name_filter: if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") if not DOMAIN_NAME_REGEX.match(domain_name_filter): abort(400, description="DomainName contains invalid characters.") filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%")) if group_id_filter: try: filters.append(Onion.group_id == int(group_id_filter)) except ValueError: abort(400, description="GroupId must be a valid integer.") return list_resources( Onion, lambda onion: onion.to_dict(), filters=filters, resource_name='OnionsList', max_allowed_items=MAX_ALLOWED_ITEMS, protective_marking='amber', )