feat: initial api implementation

This commit is contained in:
Iain Learmonth 2024-11-10 13:38:51 +00:00
parent a482d5bba8
commit ae905c6d80
4 changed files with 239 additions and 9 deletions

View file

@ -1,19 +1,20 @@
import os import os
import sys
from typing import Iterator from typing import Iterator
import yaml
from flask import Flask, redirect, url_for, send_from_directory from flask import Flask, redirect, url_for, send_from_directory
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from prometheus_client import make_wsgi_app, REGISTRY, Metric
from prometheus_client.metrics_core import GaugeMetricFamily, CounterMetricFamily from prometheus_client.metrics_core import GaugeMetricFamily, CounterMetricFamily
from prometheus_client.registry import Collector from prometheus_client.registry import Collector
from sqlalchemy import text from sqlalchemy import text
from werkzeug.middleware.dispatcher import DispatcherMiddleware from werkzeug.middleware.dispatcher import DispatcherMiddleware
from prometheus_client import make_wsgi_app, REGISTRY, Metric
import yaml
import sys
from app.api import api
from app.extensions import bootstrap
from app.extensions import db from app.extensions import db
from app.extensions import migrate from app.extensions import migrate
from app.extensions import bootstrap
from app.models.automation import Automation, AutomationState from app.models.automation import Automation, AutomationState
from app.portal import portal from app.portal import portal
from app.portal.report import report from app.portal.report import report
@ -26,10 +27,11 @@ app.wsgi_app = DispatcherMiddleware(app.wsgi_app, { # type: ignore[method-assig
'/metrics': make_wsgi_app() '/metrics': make_wsgi_app()
}) })
db.init_app(app) # type: ignore[no-untyped-call] db.init_app(app)
migrate.init_app(app, db, render_as_batch=True) migrate.init_app(app, db, render_as_batch=True)
bootstrap.init_app(app) bootstrap.init_app(app)
app.register_blueprint(api, url_prefix="/api")
app.register_blueprint(portal, url_prefix="/portal") app.register_blueprint(portal, url_prefix="/portal")
app.register_blueprint(tfstate, url_prefix="/tfstate") app.register_blueprint(tfstate, url_prefix="/tfstate")
app.register_blueprint(report, url_prefix="/report") app.register_blueprint(report, url_prefix="/report")
@ -110,17 +112,19 @@ if not_migrating() and 'DISABLE_METRICS' not in os.environ:
@app.route('/ui') @app.route('/ui')
def redirect_ui(): def redirect_ui() -> ResponseReturnValue:
return redirect("/ui/") return redirect("/ui/")
@app.route('/ui/', defaults={'path': ''}) @app.route('/ui/', defaults={'path': ''})
@app.route('/ui/<path:path>') @app.route('/ui/<path:path>')
def serve_ui(path): def serve_ui(path: str) -> ResponseReturnValue:
if path != "" and os.path.exists("app/static/ui/" + path): if path != "" and os.path.exists("app/static/ui/" + path):
return send_from_directory('static/ui', path) return send_from_directory('static/ui', path)
else: else:
return send_from_directory('static/ui', 'index.html') return send_from_directory('static/ui', 'index.html')
@app.route('/') @app.route('/')
def index() -> ResponseReturnValue: def index() -> ResponseReturnValue:
# TODO: update to point at new UI when ready # TODO: update to point at new UI when ready

192
app/api/__init__.py Normal file
View file

@ -0,0 +1,192 @@
import base64
import binascii
import logging
import re
from flask import Blueprint, request, jsonify, abort
from sqlalchemy import select
from app.extensions import db
from app.models.base import Group
from app.models.mirrors import Origin, Proxy
api = Blueprint('api', __name__)
logger = logging.getLogger(__name__)
MAX_DOMAIN_NAME_LENGTH = 255
DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$')
MAX_ALLOWED_ITEMS = 100
@api.errorhandler(400)
def bad_request(error):
response = jsonify({'error': 'Bad Request', 'message': error.description})
response.status_code = 400
return response
@api.errorhandler(401)
def unauthorized(error):
response = jsonify({'error': 'Unauthorized', 'message': error.description})
response.status_code = 401
return response
@api.errorhandler(404)
def not_found(error):
response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'})
response.status_code = 404
return response
@api.errorhandler(500)
def internal_server_error(error):
response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'})
response.status_code = 500
return response
@api.teardown_app_request
def shutdown_session(exception=None):
db.session.remove()
def validate_max_items(max_items_str, max_allowed):
try:
max_items = int(max_items_str)
if max_items <= 0 or max_items > max_allowed:
raise ValueError()
return max_items
except ValueError:
abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.")
def validate_marker(marker_str):
try:
marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode()
marker_id = int(marker_decoded)
return marker_id
except (ValueError, binascii.Error):
abort(400, description="Marker must be a valid token.")
def list_resources(
model,
filters=None,
order_by=None,
serialize_func=None,
resource_name='ResourceList',
max_items_param='MaxItems',
marker_param='Marker',
max_allowed_items=100
):
try:
marker = request.args.get(marker_param)
max_items = validate_max_items(
request.args.get(max_items_param, default='100'), max_allowed_items)
query = select(model)
if filters:
query = query.where(*filters)
if marker:
marker_id = validate_marker(marker)
query = query.where(model.id > marker_id)
query = query.order_by(order_by or model.id)
query = query.limit(max_items + 1) # Need to know if there's more
result = db.session.execute(query)
items = result.scalars().all()
items_list = [serialize_func(item) for item in items[:max_items]]
is_truncated = len(items) > max_items
response = {
resource_name: {
marker_param: marker if marker else None,
max_items_param: str(max_items),
"Quantity": len(items_list),
"Items": items_list,
"IsTruncated": is_truncated,
}
}
if is_truncated:
last_id = items[max_items - 1].id
next_marker = base64.urlsafe_b64encode(str(last_id).encode()).decode()
response[resource_name]["NextMarker"] = next_marker
return jsonify(response)
except Exception:
logger.exception("An unexpected error occurred")
abort(500)
@api.route('/web/group', methods=['GET'])
def list_groups():
return list_resources(
model=Group,
serialize_func=lambda group: group.to_dict(),
resource_name='OriginGroupList',
max_allowed_items=MAX_ALLOWED_ITEMS
)
@api.route('/web/origin', methods=['GET'])
def list_origins():
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
filters = []
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.")
if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.")
filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%"))
if group_id_filter:
try:
group_id_filter = int(group_id_filter)
filters.append(Origin.group_id == group_id_filter)
except ValueError:
abort(400, description="GroupId must be a valid integer.")
return list_resources(
model=Origin,
filters=filters,
serialize_func=lambda origin: origin.to_dict(),
resource_name='OriginsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)
@api.route('/web/mirror', methods=['GET'])
def list_mirrors():
status_filter = request.args.get('Status')
filters = []
if status_filter:
if status_filter == "pending":
filters.append(Proxy.url.is_(None))
filters.append(Proxy.deprecated.is_(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "active":
filters.append(Proxy.url.is_not(None))
filters.append(Proxy.deprecated.is_(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "expiring":
filters.append(Proxy.deprecated.is_not(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "destroyed":
filters.append(Proxy.destroyed.is_not(None))
return list_resources(
model=Proxy,
filters=filters,
serialize_func=lambda proxy: proxy.to_dict(),
resource_name='MirrorsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)

View file

@ -33,6 +33,15 @@ class Group(AbstractConfiguration):
resource_id=str(self.id) resource_id=str(self.id)
) )
def to_dict(self):
active_origins = [o for o in self.origins if o.destroyed is None]
return {
"Id": self.id,
"GroupName": self.group_name,
"Description": self.description,
"ActiveOriginCount": len(active_origins),
}
class Pool(AbstractConfiguration): class Pool(AbstractConfiguration):
pool_name = db.Column(db.String(80), unique=True, nullable=False) pool_name = db.Column(db.String(80), unique=True, nullable=False)

View file

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import json import json
import tldextract
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, List, Union, Any, Dict from typing import Optional, List, Union, Any, Dict
import tldextract
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from tldextract import extract from tldextract import extract
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
@ -96,9 +96,19 @@ class Origin(AbstractConfiguration):
frequency_factor += 1 frequency_factor += 1
risk_levels: Dict[str, int] = {} risk_levels: Dict[str, int] = {}
for country in self.countries: for country in self.countries:
risk_levels[country.country_code.upper()] = int(max(1, min(10, frequency_factor * recency_factor))) + country.risk_level risk_levels[country.country_code.upper()] = int(
max(1, min(10, frequency_factor * recency_factor))) + country.risk_level
return risk_levels return risk_levels
def to_dict(self):
return {
"Id": self.id,
"Description": self.description,
"DomainName": self.domain_name,
"RiskLevel": self.risk_level,
"RiskLevelOverride": self.risk_level_override,
}
class Country(AbstractConfiguration): class Country(AbstractConfiguration):
@property @property
@ -268,6 +278,21 @@ class Proxy(AbstractResource):
"origin_id", "provider", "psg", "slug", "terraform_updated", "url" "origin_id", "provider", "psg", "slug", "terraform_updated", "url"
] ]
def to_dict(self):
status = "active"
if self.url is None:
status = "pending"
if self.deprecated is not None:
status = "expiring"
if self.destroyed is not None:
status = "destroyed"
return {
"Id": self.id,
"OriginDomain": self.origin.domain_name,
"MirrorDomain": self.url.replace("https://", "") if self.url else None,
"Status": status,
}
class SmartProxy(AbstractResource): class SmartProxy(AbstractResource):
group_id = mapped_column(db.Integer(), db.ForeignKey("group.id"), nullable=False) group_id = mapped_column(db.Integer(), db.ForeignKey("group.id"), nullable=False)