From ae905c6d806de7f9cf38e72b919d4c9926805d9a Mon Sep 17 00:00:00 2001 From: irl Date: Sun, 10 Nov 2024 13:38:51 +0000 Subject: [PATCH] feat: initial api implementation --- app/__init__.py | 18 ++-- app/api/__init__.py | 192 ++++++++++++++++++++++++++++++++++++++++++ app/models/base.py | 9 ++ app/models/mirrors.py | 29 ++++++- 4 files changed, 239 insertions(+), 9 deletions(-) create mode 100644 app/api/__init__.py diff --git a/app/__init__.py b/app/__init__.py index 2b650fb..7ad99aa 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,19 +1,20 @@ import os +import sys from typing import Iterator +import yaml from flask import Flask, redirect, url_for, send_from_directory from flask.typing import ResponseReturnValue +from prometheus_client import make_wsgi_app, REGISTRY, Metric from prometheus_client.metrics_core import GaugeMetricFamily, CounterMetricFamily from prometheus_client.registry import Collector from sqlalchemy import text from werkzeug.middleware.dispatcher import DispatcherMiddleware -from prometheus_client import make_wsgi_app, REGISTRY, Metric -import yaml -import sys +from app.api import api +from app.extensions import bootstrap from app.extensions import db from app.extensions import migrate -from app.extensions import bootstrap from app.models.automation import Automation, AutomationState from app.portal import portal from app.portal.report import report @@ -26,10 +27,11 @@ app.wsgi_app = DispatcherMiddleware(app.wsgi_app, { # type: ignore[method-assig '/metrics': make_wsgi_app() }) -db.init_app(app) # type: ignore[no-untyped-call] +db.init_app(app) migrate.init_app(app, db, render_as_batch=True) bootstrap.init_app(app) +app.register_blueprint(api, url_prefix="/api") app.register_blueprint(portal, url_prefix="/portal") app.register_blueprint(tfstate, url_prefix="/tfstate") app.register_blueprint(report, url_prefix="/report") @@ -110,17 +112,19 @@ if not_migrating() and 'DISABLE_METRICS' not in os.environ: @app.route('/ui') -def redirect_ui(): +def redirect_ui() -> ResponseReturnValue: return redirect("/ui/") + @app.route('/ui/', defaults={'path': ''}) @app.route('/ui/') -def serve_ui(path): +def serve_ui(path: str) -> ResponseReturnValue: if path != "" and os.path.exists("app/static/ui/" + path): return send_from_directory('static/ui', path) else: return send_from_directory('static/ui', 'index.html') + @app.route('/') def index() -> ResponseReturnValue: # TODO: update to point at new UI when ready diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..7f6c3a4 --- /dev/null +++ b/app/api/__init__.py @@ -0,0 +1,192 @@ +import base64 +import binascii +import logging +import re + +from flask import Blueprint, request, jsonify, abort +from sqlalchemy import select + +from app.extensions import db +from app.models.base import Group +from app.models.mirrors import Origin, Proxy + +api = Blueprint('api', __name__) +logger = logging.getLogger(__name__) + +MAX_DOMAIN_NAME_LENGTH = 255 +DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$') +MAX_ALLOWED_ITEMS = 100 + + +@api.errorhandler(400) +def bad_request(error): + response = jsonify({'error': 'Bad Request', 'message': error.description}) + response.status_code = 400 + return response + + +@api.errorhandler(401) +def unauthorized(error): + response = jsonify({'error': 'Unauthorized', 'message': error.description}) + response.status_code = 401 + return response + + +@api.errorhandler(404) +def not_found(error): + response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'}) + response.status_code = 404 + return response + + +@api.errorhandler(500) +def internal_server_error(error): + response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'}) + response.status_code = 500 + return response + + +@api.teardown_app_request +def shutdown_session(exception=None): + db.session.remove() + + +def validate_max_items(max_items_str, max_allowed): + try: + max_items = int(max_items_str) + if max_items <= 0 or max_items > max_allowed: + raise ValueError() + return max_items + except ValueError: + abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.") + + +def validate_marker(marker_str): + try: + marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode() + marker_id = int(marker_decoded) + return marker_id + except (ValueError, binascii.Error): + abort(400, description="Marker must be a valid token.") + + +def list_resources( + model, + filters=None, + order_by=None, + serialize_func=None, + resource_name='ResourceList', + max_items_param='MaxItems', + marker_param='Marker', + max_allowed_items=100 +): + try: + marker = request.args.get(marker_param) + max_items = validate_max_items( + request.args.get(max_items_param, default='100'), max_allowed_items) + query = select(model) + + if filters: + query = query.where(*filters) + + if marker: + marker_id = validate_marker(marker) + query = query.where(model.id > marker_id) + + query = query.order_by(order_by or model.id) + query = query.limit(max_items + 1) # Need to know if there's more + + result = db.session.execute(query) + items = result.scalars().all() + items_list = [serialize_func(item) for item in items[:max_items]] + is_truncated = len(items) > max_items + + response = { + resource_name: { + marker_param: marker if marker else None, + max_items_param: str(max_items), + "Quantity": len(items_list), + "Items": items_list, + "IsTruncated": is_truncated, + } + } + + if is_truncated: + last_id = items[max_items - 1].id + next_marker = base64.urlsafe_b64encode(str(last_id).encode()).decode() + response[resource_name]["NextMarker"] = next_marker + + return jsonify(response) + except Exception: + logger.exception("An unexpected error occurred") + abort(500) + + +@api.route('/web/group', methods=['GET']) +def list_groups(): + return list_resources( + model=Group, + serialize_func=lambda group: group.to_dict(), + resource_name='OriginGroupList', + max_allowed_items=MAX_ALLOWED_ITEMS + ) + + +@api.route('/web/origin', methods=['GET']) +def list_origins(): + domain_name_filter = request.args.get('DomainName') + group_id_filter = request.args.get('GroupId') + + filters = [] + + if domain_name_filter: + if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: + abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") + if not DOMAIN_NAME_REGEX.match(domain_name_filter): + abort(400, description="DomainName contains invalid characters.") + filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%")) + + if group_id_filter: + try: + group_id_filter = int(group_id_filter) + filters.append(Origin.group_id == group_id_filter) + except ValueError: + abort(400, description="GroupId must be a valid integer.") + + return list_resources( + model=Origin, + filters=filters, + serialize_func=lambda origin: origin.to_dict(), + resource_name='OriginsList', + max_allowed_items=MAX_ALLOWED_ITEMS + ) + + +@api.route('/web/mirror', methods=['GET']) +def list_mirrors(): + status_filter = request.args.get('Status') + + filters = [] + + if status_filter: + if status_filter == "pending": + filters.append(Proxy.url.is_(None)) + filters.append(Proxy.deprecated.is_(None)) + filters.append(Proxy.destroyed.is_(None)) + if status_filter == "active": + filters.append(Proxy.url.is_not(None)) + filters.append(Proxy.deprecated.is_(None)) + filters.append(Proxy.destroyed.is_(None)) + if status_filter == "expiring": + filters.append(Proxy.deprecated.is_not(None)) + filters.append(Proxy.destroyed.is_(None)) + if status_filter == "destroyed": + filters.append(Proxy.destroyed.is_not(None)) + + return list_resources( + model=Proxy, + filters=filters, + serialize_func=lambda proxy: proxy.to_dict(), + resource_name='MirrorsList', + max_allowed_items=MAX_ALLOWED_ITEMS + ) diff --git a/app/models/base.py b/app/models/base.py index d230ac9..ea9ac10 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -33,6 +33,15 @@ class Group(AbstractConfiguration): resource_id=str(self.id) ) + def to_dict(self): + active_origins = [o for o in self.origins if o.destroyed is None] + return { + "Id": self.id, + "GroupName": self.group_name, + "Description": self.description, + "ActiveOriginCount": len(active_origins), + } + class Pool(AbstractConfiguration): pool_name = db.Column(db.String(80), unique=True, nullable=False) diff --git a/app/models/mirrors.py b/app/models/mirrors.py index 563de53..2bf1439 100644 --- a/app/models/mirrors.py +++ b/app/models/mirrors.py @@ -1,10 +1,10 @@ from __future__ import annotations import json -import tldextract from datetime import datetime, timedelta from typing import Optional, List, Union, Any, Dict +import tldextract from sqlalchemy.orm import Mapped, mapped_column from tldextract import extract from werkzeug.datastructures import FileStorage @@ -96,9 +96,19 @@ class Origin(AbstractConfiguration): frequency_factor += 1 risk_levels: Dict[str, int] = {} for country in self.countries: - risk_levels[country.country_code.upper()] = int(max(1, min(10, frequency_factor * recency_factor))) + country.risk_level + risk_levels[country.country_code.upper()] = int( + max(1, min(10, frequency_factor * recency_factor))) + country.risk_level return risk_levels + def to_dict(self): + return { + "Id": self.id, + "Description": self.description, + "DomainName": self.domain_name, + "RiskLevel": self.risk_level, + "RiskLevelOverride": self.risk_level_override, + } + class Country(AbstractConfiguration): @property @@ -268,6 +278,21 @@ class Proxy(AbstractResource): "origin_id", "provider", "psg", "slug", "terraform_updated", "url" ] + def to_dict(self): + status = "active" + if self.url is None: + status = "pending" + if self.deprecated is not None: + status = "expiring" + if self.destroyed is not None: + status = "destroyed" + return { + "Id": self.id, + "OriginDomain": self.origin.domain_name, + "MirrorDomain": self.url.replace("https://", "") if self.url else None, + "Status": status, + } + class SmartProxy(AbstractResource): group_id = mapped_column(db.Integer(), db.ForeignKey("group.id"), nullable=False)