majuna/app/api/__init__.py

192 lines
5.7 KiB
Python

import base64
import binascii
import logging
import re
from flask import Blueprint, request, jsonify, abort
from sqlalchemy import select
from app.extensions import db
from app.models.base import Group
from app.models.mirrors import Origin, Proxy
api = Blueprint('api', __name__)
logger = logging.getLogger(__name__)
MAX_DOMAIN_NAME_LENGTH = 255
DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$')
MAX_ALLOWED_ITEMS = 100
@api.errorhandler(400)
def bad_request(error):
response = jsonify({'error': 'Bad Request', 'message': error.description})
response.status_code = 400
return response
@api.errorhandler(401)
def unauthorized(error):
response = jsonify({'error': 'Unauthorized', 'message': error.description})
response.status_code = 401
return response
@api.errorhandler(404)
def not_found(error):
response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'})
response.status_code = 404
return response
@api.errorhandler(500)
def internal_server_error(error):
response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'})
response.status_code = 500
return response
@api.teardown_app_request
def shutdown_session(exception=None):
db.session.remove()
def validate_max_items(max_items_str, max_allowed):
try:
max_items = int(max_items_str)
if max_items <= 0 or max_items > max_allowed:
raise ValueError()
return max_items
except ValueError:
abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.")
def validate_marker(marker_str):
try:
marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode()
marker_id = int(marker_decoded)
return marker_id
except (ValueError, binascii.Error):
abort(400, description="Marker must be a valid token.")
def list_resources(
model,
filters=None,
order_by=None,
serialize_func=None,
resource_name='ResourceList',
max_items_param='MaxItems',
marker_param='Marker',
max_allowed_items=100
):
try:
marker = request.args.get(marker_param)
max_items = validate_max_items(
request.args.get(max_items_param, default='100'), max_allowed_items)
query = select(model)
if filters:
query = query.where(*filters)
if marker:
marker_id = validate_marker(marker)
query = query.where(model.id > marker_id)
query = query.order_by(order_by or model.id)
query = query.limit(max_items + 1) # Need to know if there's more
result = db.session.execute(query)
items = result.scalars().all()
items_list = [serialize_func(item) for item in items[:max_items]]
is_truncated = len(items) > max_items
response = {
resource_name: {
marker_param: marker if marker else None,
max_items_param: str(max_items),
"Quantity": len(items_list),
"Items": items_list,
"IsTruncated": is_truncated,
}
}
if is_truncated:
last_id = items[max_items - 1].id
next_marker = base64.urlsafe_b64encode(str(last_id).encode()).decode()
response[resource_name]["NextMarker"] = next_marker
return jsonify(response)
except Exception:
logger.exception("An unexpected error occurred")
abort(500)
@api.route('/web/group', methods=['GET'])
def list_groups():
return list_resources(
model=Group,
serialize_func=lambda group: group.to_dict(),
resource_name='OriginGroupList',
max_allowed_items=MAX_ALLOWED_ITEMS
)
@api.route('/web/origin', methods=['GET'])
def list_origins():
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
filters = []
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.")
if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.")
filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%"))
if group_id_filter:
try:
group_id_filter = int(group_id_filter)
filters.append(Origin.group_id == group_id_filter)
except ValueError:
abort(400, description="GroupId must be a valid integer.")
return list_resources(
model=Origin,
filters=filters,
serialize_func=lambda origin: origin.to_dict(),
resource_name='OriginsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)
@api.route('/web/mirror', methods=['GET'])
def list_mirrors():
status_filter = request.args.get('Status')
filters = []
if status_filter:
if status_filter == "pending":
filters.append(Proxy.url.is_(None))
filters.append(Proxy.deprecated.is_(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "active":
filters.append(Proxy.url.is_not(None))
filters.append(Proxy.deprecated.is_(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "expiring":
filters.append(Proxy.deprecated.is_not(None))
filters.append(Proxy.destroyed.is_(None))
if status_filter == "destroyed":
filters.append(Proxy.destroyed.is_not(None))
return list_resources(
model=Proxy,
filters=filters,
serialize_func=lambda proxy: proxy.to_dict(),
resource_name='MirrorsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)