majuna/app/api/util.py

101 lines
3.3 KiB
Python

import base64
import binascii
import logging
import re
from typing import Union, Any, Literal, Type, Callable, Dict, List, Optional
from flask import abort, request, jsonify
from flask.typing import ResponseReturnValue
from sqlalchemy import BinaryExpression, ColumnElement, select
from app.extensions import db
logger = logging.getLogger(__name__)
MAX_DOMAIN_NAME_LENGTH = 255
DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$')
MAX_ALLOWED_ITEMS = 100
ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]]
def validate_max_items(max_items_str: str, max_allowed: int) -> int:
try:
max_items = int(max_items_str)
if max_items <= 0 or max_items > max_allowed:
raise ValueError()
return max_items
except ValueError:
abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.")
def validate_marker(marker_str: str) -> int:
try:
marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode()
marker_id = int(marker_decoded)
return marker_id
except (ValueError, binascii.Error):
abort(400, description="Marker must be a valid token.")
TlpMarkings = Union[
Literal["default"],
Literal["clear"],
Literal["green"],
Literal["amber"],
Literal["amber+strict"],
Literal["red"],
]
def list_resources( # pylint: disable=too-many-arguments,too-many-locals
model: Type[Any],
serialize_func: Callable[[Any], Dict[str, Any]],
*,
filters: Optional[List[ListFilter]] = None,
order_by: Optional[ColumnElement[Any]] = None,
resource_name: str = 'ResourceList',
max_items_param: str = 'MaxItems',
marker_param: str = 'Marker',
max_allowed_items: int = 100,
protective_marking: TlpMarkings = 'default',
) -> ResponseReturnValue:
try:
marker = request.args.get(marker_param)
max_items = validate_max_items(
request.args.get(max_items_param, default='100'), max_allowed_items)
query = select(model)
if filters:
query = query.where(*filters)
if marker:
marker_id = validate_marker(marker)
query = query.where(model.id > marker_id)
query = query.order_by(order_by or model.id)
query = query.limit(max_items + 1) # Need to know if there's more
result = db.session.execute(query)
items = result.scalars().all()
items_list = [serialize_func(item) for item in items[:max_items]]
is_truncated = len(items) > max_items
response = {
resource_name: {
marker_param: marker if marker else None,
max_items_param: str(max_items),
"Quantity": len(items_list),
"Items": items_list,
"IsTruncated": is_truncated,
"ProtectiveMarking": protective_marking,
}
}
if is_truncated:
last_id = items[max_items - 1].id
next_marker = base64.urlsafe_b64encode(str(last_id).encode()).decode()
response[resource_name]["NextMarker"] = next_marker
return jsonify(response)
except Exception: # pylint: disable=broad-exception-caught
logger.exception("An unexpected error occurred")
abort(500)