refactor: moving more models to mapped_column

This commit is contained in:
Iain Learmonth 2024-11-10 15:13:29 +00:00
parent ea020d6edd
commit 75b2c1adf0
9 changed files with 272 additions and 94 deletions

View file

@ -2,9 +2,12 @@ import base64
import binascii
import logging
import re
from typing import Optional, List, Callable, Any, Type, Dict, Union
from flask import Blueprint, request, jsonify, abort
from sqlalchemy import select
from flask.typing import ResponseReturnValue
from sqlalchemy import select, BinaryExpression, ColumnElement
from werkzeug.exceptions import HTTPException
from app.extensions import db
from app.models.base import Group
@ -17,41 +20,38 @@ MAX_DOMAIN_NAME_LENGTH = 255
DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$')
MAX_ALLOWED_ITEMS = 100
ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]]
@api.errorhandler(400)
def bad_request(error):
def bad_request(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Bad Request', 'message': error.description})
response.status_code = 400
return response
@api.errorhandler(401)
def unauthorized(error):
def unauthorized(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Unauthorized', 'message': error.description})
response.status_code = 401
return response
@api.errorhandler(404)
def not_found(error):
def not_found(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'})
response.status_code = 404
return response
@api.errorhandler(500)
def internal_server_error(error):
def internal_server_error(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'})
response.status_code = 500
return response
@api.teardown_app_request
def shutdown_session(exception=None):
db.session.remove()
def validate_max_items(max_items_str, max_allowed):
def validate_max_items(max_items_str: str, max_allowed: int) -> int:
try:
max_items = int(max_items_str)
if max_items <= 0 or max_items > max_allowed:
@ -61,7 +61,7 @@ def validate_max_items(max_items_str, max_allowed):
abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.")
def validate_marker(marker_str):
def validate_marker(marker_str: str) -> int:
try:
marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode()
marker_id = int(marker_decoded)
@ -71,15 +71,15 @@ def validate_marker(marker_str):
def list_resources(
model,
filters=None,
order_by=None,
serialize_func=None,
resource_name='ResourceList',
max_items_param='MaxItems',
marker_param='Marker',
max_allowed_items=100
):
model: Type[Any],
serialize_func: Callable[[Any], Dict[str, Any]],
filters: Optional[List[ListFilter]] = None,
order_by: Optional[ColumnElement[Any]] = None,
resource_name: str = 'ResourceList',
max_items_param: str = 'MaxItems',
marker_param: str = 'Marker',
max_allowed_items: int = 100
) -> ResponseReturnValue:
try:
marker = request.args.get(marker_param)
max_items = validate_max_items(
@ -123,7 +123,7 @@ def list_resources(
@api.route('/web/group', methods=['GET'])
def list_groups():
def list_groups() -> ResponseReturnValue:
return list_resources(
model=Group,
serialize_func=lambda group: group.to_dict(),
@ -133,11 +133,11 @@ def list_groups():
@api.route('/web/origin', methods=['GET'])
def list_origins():
def list_origins() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
filters = []
filters: List[ListFilter] = []
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
@ -148,22 +148,21 @@ def list_origins():
if group_id_filter:
try:
group_id_filter = int(group_id_filter)
filters.append(Origin.group_id == group_id_filter)
filters.append(Origin.group_id == int(group_id_filter))
except ValueError:
abort(400, description="GroupId must be a valid integer.")
return list_resources(
model=Origin,
filters=filters,
serialize_func=lambda origin: origin.to_dict(),
filters=filters,
resource_name='OriginsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)
@api.route('/web/mirror', methods=['GET'])
def list_mirrors():
def list_mirrors() -> ResponseReturnValue:
status_filter = request.args.get('Status')
filters = []
@ -185,8 +184,8 @@ def list_mirrors():
return list_resources(
model=Proxy,
filters=filters,
serialize_func=lambda proxy: proxy.to_dict(),
filters=filters,
resource_name='MirrorsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)