diff --git a/app/api/__init__.py b/app/api/__init__.py index 7f6c3a4..020f8c3 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -2,9 +2,12 @@ import base64 import binascii import logging import re +from typing import Optional, List, Callable, Any, Type, Dict, Union from flask import Blueprint, request, jsonify, abort -from sqlalchemy import select +from flask.typing import ResponseReturnValue +from sqlalchemy import select, BinaryExpression, ColumnElement +from werkzeug.exceptions import HTTPException from app.extensions import db from app.models.base import Group @@ -17,41 +20,38 @@ MAX_DOMAIN_NAME_LENGTH = 255 DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$') MAX_ALLOWED_ITEMS = 100 +ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]] + @api.errorhandler(400) -def bad_request(error): +def bad_request(error: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Bad Request', 'message': error.description}) response.status_code = 400 return response @api.errorhandler(401) -def unauthorized(error): +def unauthorized(error: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Unauthorized', 'message': error.description}) response.status_code = 401 return response @api.errorhandler(404) -def not_found(error): +def not_found(_: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'}) response.status_code = 404 return response @api.errorhandler(500) -def internal_server_error(error): +def internal_server_error(_: HTTPException) -> ResponseReturnValue: response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'}) response.status_code = 500 return response -@api.teardown_app_request -def shutdown_session(exception=None): - db.session.remove() - - -def validate_max_items(max_items_str, max_allowed): +def validate_max_items(max_items_str: str, max_allowed: int) -> int: try: max_items = int(max_items_str) if max_items <= 0 or max_items > max_allowed: @@ -61,7 +61,7 @@ def validate_max_items(max_items_str, max_allowed): abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.") -def validate_marker(marker_str): +def validate_marker(marker_str: str) -> int: try: marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode() marker_id = int(marker_decoded) @@ -71,15 +71,15 @@ def validate_marker(marker_str): def list_resources( - model, - filters=None, - order_by=None, - serialize_func=None, - resource_name='ResourceList', - max_items_param='MaxItems', - marker_param='Marker', - max_allowed_items=100 -): + model: Type[Any], + serialize_func: Callable[[Any], Dict[str, Any]], + filters: Optional[List[ListFilter]] = None, + order_by: Optional[ColumnElement[Any]] = None, + resource_name: str = 'ResourceList', + max_items_param: str = 'MaxItems', + marker_param: str = 'Marker', + max_allowed_items: int = 100 +) -> ResponseReturnValue: try: marker = request.args.get(marker_param) max_items = validate_max_items( @@ -123,7 +123,7 @@ def list_resources( @api.route('/web/group', methods=['GET']) -def list_groups(): +def list_groups() -> ResponseReturnValue: return list_resources( model=Group, serialize_func=lambda group: group.to_dict(), @@ -133,11 +133,11 @@ def list_groups(): @api.route('/web/origin', methods=['GET']) -def list_origins(): +def list_origins() -> ResponseReturnValue: domain_name_filter = request.args.get('DomainName') group_id_filter = request.args.get('GroupId') - filters = [] + filters: List[ListFilter] = [] if domain_name_filter: if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: @@ -148,22 +148,21 @@ def list_origins(): if group_id_filter: try: - group_id_filter = int(group_id_filter) - filters.append(Origin.group_id == group_id_filter) + filters.append(Origin.group_id == int(group_id_filter)) except ValueError: abort(400, description="GroupId must be a valid integer.") return list_resources( model=Origin, - filters=filters, serialize_func=lambda origin: origin.to_dict(), + filters=filters, resource_name='OriginsList', max_allowed_items=MAX_ALLOWED_ITEMS ) @api.route('/web/mirror', methods=['GET']) -def list_mirrors(): +def list_mirrors() -> ResponseReturnValue: status_filter = request.args.get('Status') filters = [] @@ -185,8 +184,8 @@ def list_mirrors(): return list_resources( model=Proxy, - filters=filters, serialize_func=lambda proxy: proxy.to_dict(), + filters=filters, resource_name='MirrorsList', max_allowed_items=MAX_ALLOWED_ITEMS ) diff --git a/app/models/__init__.py b/app/models/__init__.py index 7b4c6c4..899ac45 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -13,10 +13,10 @@ class AbstractConfiguration(db.Model): # type: ignore __abstract__ = True id: Mapped[int] = mapped_column(db.Integer, primary_key=True) - description: Mapped[str] = mapped_column(db.String(255), nullable=False) - added: Mapped[datetime] = mapped_column(db.DateTime(), default=datetime.utcnow, nullable=False) - updated: Mapped[datetime] = mapped_column(db.DateTime(), default=datetime.utcnow, nullable=False) - destroyed: Mapped[datetime] = mapped_column(db.DateTime()) + description: Mapped[str] + added: Mapped[datetime] + updated: Mapped[datetime] + destroyed: Mapped[Optional[datetime]] @property @abstractmethod diff --git a/app/models/base.py b/app/models/base.py index ea9ac10..aefc087 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -1,21 +1,36 @@ from datetime import datetime -from typing import List +from typing import List, TypedDict, Optional, TYPE_CHECKING + +from sqlalchemy import and_ +from sqlalchemy.orm import Mapped, aliased, mapped_column, relationship from app.brm.brn import BRN from app.extensions import db from app.models import AbstractConfiguration +if TYPE_CHECKING: + from app.models.bridges import BridgeConf + from app.models.mirrors import Origin, Proxy, SmartProxy, StaticOrigin + from app.models.onions import Eotk, Onion + + +class GroupDict(TypedDict): + Id: int + GroupName: str + Description: str + ActiveOriginCount: int + class Group(AbstractConfiguration): - group_name = db.Column(db.String(80), unique=True, nullable=False) - eotk = db.Column(db.Boolean()) + group_name: Mapped[str] = db.Column(db.String(80), unique=True, nullable=False) + eotk: Mapped[bool] - origins = db.relationship("Origin", back_populates="group") - statics = db.relationship("StaticOrigin", back_populates="group") - eotks = db.relationship("Eotk", back_populates="group") - onions = db.relationship("Onion", back_populates="group") - smart_proxies = db.relationship("SmartProxy", back_populates="group") - pools = db.relationship("Pool", secondary="pool_group", back_populates="groups") + origins: Mapped[List["Origin"]] = relationship("Origin", back_populates="group") + statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="group") + eotks: Mapped[List["Eotk"]] = relationship("Eotk", back_populates="group") + onions: Mapped[List["Onion"]] = relationship("Onion", back_populates="group") + smart_proxies: Mapped[List["SmartProxy"]] = relationship("SmartProxy", back_populates="group") + pools: Mapped[List["Pool"]] = relationship("Pool", secondary="pool_group", back_populates="groups") @classmethod def csv_header(cls) -> List[str]: @@ -33,25 +48,29 @@ class Group(AbstractConfiguration): resource_id=str(self.id) ) - def to_dict(self): - active_origins = [o for o in self.origins if o.destroyed is None] + def to_dict(self) -> GroupDict: + active_origins_query = ( + db.session.query(aliased(Origin)) + .filter(and_(Origin.group_id == self.id, Origin.destroyed.is_(None))) + ) + active_origins_count = active_origins_query.count() return { "Id": self.id, "GroupName": self.group_name, "Description": self.description, - "ActiveOriginCount": len(active_origins), + "ActiveOriginCount": active_origins_count, } class Pool(AbstractConfiguration): - pool_name = db.Column(db.String(80), unique=True, nullable=False) - api_key = db.Column(db.String(80), nullable=False) - redirector_domain = db.Column(db.String(128), nullable=True) + pool_name: Mapped[str] = mapped_column(db.String, unique=True) + api_key: Mapped[str] + redirector_domain: Mapped[Optional[str]] - bridgeconfs = db.relationship("BridgeConf", back_populates="pool") - proxies = db.relationship("Proxy", back_populates="pool") - lists = db.relationship("MirrorList", back_populates="pool") - groups = db.relationship("Group", secondary="pool_group", back_populates="pools") + bridgeconfs: Mapped[List["BridgeConf"]] = relationship("BridgeConf", back_populates="pool") + proxies: Mapped[List["Proxy"]] = relationship("Proxy", back_populates="pool") + lists: Mapped[List["MirrorList"]] = relationship("MirrorList", back_populates="pool") + groups: Mapped[List[Group]] = relationship("Group", secondary="pool_group", back_populates="pools") @classmethod def csv_header(cls) -> List[str]: diff --git a/app/models/bridges.py b/app/models/bridges.py index bc774c8..87420c6 100644 --- a/app/models/bridges.py +++ b/app/models/bridges.py @@ -2,9 +2,12 @@ import enum from datetime import datetime from typing import List +from sqlalchemy.orm import Mapped, mapped_column, relationship + from app.brm.brn import BRN from app.extensions import db from app.models import AbstractConfiguration, AbstractResource +from app.models.base import Pool class ProviderAllocation(enum.Enum): @@ -13,15 +16,15 @@ class ProviderAllocation(enum.Enum): class BridgeConf(AbstractConfiguration): - pool_id = db.Column(db.Integer, db.ForeignKey("pool.id"), nullable=False) - method = db.Column(db.String(20), nullable=False) - target_number = db.Column(db.Integer()) - max_number = db.Column(db.Integer()) - expiry_hours = db.Column(db.Integer()) - provider_allocation = db.Column(db.Enum(ProviderAllocation)) + pool_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("pool.id")) + method: Mapped[str] + target_number: Mapped[int] + max_number: Mapped[int] + expiry_hours: Mapped[int] + provider_allocation: Mapped[ProviderAllocation] - pool = db.relationship("Pool", back_populates="bridgeconfs") - bridges = db.relationship("Bridge", back_populates="conf") + pool: Mapped[Pool] = relationship("Pool", back_populates="bridgeconfs") + bridges: Mapped[List["Bridge"]] = relationship("Bridge", back_populates="conf") @property def brn(self) -> BRN: diff --git a/app/models/cloud.py b/app/models/cloud.py index fbbc802..f3aec7d 100644 --- a/app/models/cloud.py +++ b/app/models/cloud.py @@ -1,4 +1,7 @@ import enum +from typing import Any, Dict, List, TYPE_CHECKING + +from sqlalchemy.orm import Mapped, mapped_column, relationship from app.brm.brn import BRN from app.extensions import db @@ -6,6 +9,10 @@ from app.models import AbstractConfiguration from app.models.mirrors import StaticOrigin +if TYPE_CHECKING: + from app.models.bridges import Bridge + + class CloudProvider(enum.Enum): AWS = ("aws", "Amazon Web Services") AZURE = ("azure", "Microsoft Azure") @@ -27,17 +34,17 @@ class CloudProvider(enum.Enum): class CloudAccount(AbstractConfiguration): - provider = db.Column(db.Enum(CloudProvider)) - credentials = db.Column(db.JSON()) - enabled = db.Column(db.Boolean()) + provider: Mapped[CloudProvider] + credentials: Mapped[Dict[str, Any]] = mapped_column(db.JSON()) + enabled: Mapped[bool] # CDN Quotas - max_distributions = db.Column(db.Integer()) - max_sub_distributions = db.Column(db.Integer()) + max_distributions: Mapped[int] + max_sub_distributions: Mapped[int] # Compute Quotas - max_instances = db.Column(db.Integer()) + max_instances: Mapped[int] - bridges = db.relationship("Bridge", back_populates="cloud_account") - statics = db.relationship("StaticOrigin", back_populates="storage_cloud_account", foreign_keys=[ + bridges: Mapped[List["Bridge"]] = relationship("Bridge", back_populates="cloud_account") + statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="storage_cloud_account", foreign_keys=[ StaticOrigin.storage_cloud_account_id]) @property diff --git a/app/models/mirrors.py b/app/models/mirrors.py index 2bf1439..67b65fb 100644 --- a/app/models/mirrors.py +++ b/app/models/mirrors.py @@ -2,10 +2,10 @@ from __future__ import annotations import json from datetime import datetime, timedelta -from typing import Optional, List, Union, Any, Dict +from typing import Optional, List, Union, Any, Dict, TypedDict, Literal import tldextract -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from tldextract import extract from werkzeug.datastructures import FileStorage @@ -13,7 +13,7 @@ from app.brm.brn import BRN from app.brm.utils import thumbnail_uploaded_image, create_data_uri, normalize_color from app.extensions import db from app.models import AbstractConfiguration, AbstractResource, Deprecation -from app.models.base import Pool +from app.models.base import Pool, Group from app.models.onions import Onion country_origin = db.Table( @@ -25,17 +25,25 @@ country_origin = db.Table( ) -class Origin(AbstractConfiguration): - group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False) - domain_name = mapped_column(db.String(255), unique=True, nullable=False) - auto_rotation = mapped_column(db.Boolean, nullable=False) - smart = mapped_column(db.Boolean(), nullable=False) - assets = mapped_column(db.Boolean(), nullable=False) - risk_level_override = mapped_column(db.Integer(), nullable=True) +class OriginDict(TypedDict): + Id: int + Description: str + DomainName: str + RiskLevel: Dict[str, int] + RiskLevelOverride: Optional[int] - group = db.relationship("Group", back_populates="origins") - proxies = db.relationship("Proxy", back_populates="origin") - countries = db.relationship("Country", secondary=country_origin, back_populates='origins') + +class Origin(AbstractConfiguration): + group_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("group.id")) + domain_name: Mapped[str] = mapped_column(db.String(255), unique=True) + auto_rotation: Mapped[bool] + smart: Mapped[bool] + assets: Mapped[bool] + risk_level_override: Mapped[Optional[int]] + + group: Mapped[Group] = relationship("Group", back_populates="origins") + proxies: Mapped[List[Proxy]] = relationship("Proxy", back_populates="origin") + countries: Mapped[List[Country]] = relationship("Country", secondary=country_origin, back_populates='origins') @property def brn(self) -> BRN: @@ -100,7 +108,7 @@ class Origin(AbstractConfiguration): max(1, min(10, frequency_factor * recency_factor))) + country.risk_level return risk_levels - def to_dict(self): + def to_dict(self) -> OriginDict: return { "Id": self.id, "Description": self.description, @@ -250,6 +258,16 @@ class StaticOrigin(AbstractConfiguration): self.updated = datetime.utcnow() +ResourceStatus = Union[Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"]] + + +class ProxyDict(TypedDict): + Id: int + OriginDomain: str + MirrorDomain: Optional[str] + Status: ResourceStatus + + class Proxy(AbstractResource): origin_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("origin.id"), nullable=False) pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id")) @@ -259,8 +277,8 @@ class Proxy(AbstractResource): terraform_updated: Mapped[Optional[datetime]] = mapped_column(db.DateTime(), nullable=True) url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - origin: Mapped[Origin] = db.relationship("Origin", back_populates="proxies") - pool: Mapped[Pool] = db.relationship("Pool", back_populates="proxies") + origin: Mapped[Origin] = relationship("Origin", back_populates="proxies") + pool: Mapped[Pool] = relationship("Pool", back_populates="proxies") @property def brn(self) -> BRN: @@ -278,8 +296,8 @@ class Proxy(AbstractResource): "origin_id", "provider", "psg", "slug", "terraform_updated", "url" ] - def to_dict(self): - status = "active" + def to_dict(self) -> ProxyDict: + status: ResourceStatus = "active" if self.url is None: status = "pending" if self.deprecated is not None: diff --git a/app/models/onions.py b/app/models/onions.py index 06929f3..6eda839 100644 --- a/app/models/onions.py +++ b/app/models/onions.py @@ -1,9 +1,13 @@ import base64 import hashlib +from typing import Optional + +from sqlalchemy.orm import Mapped, mapped_column, relationship from app.brm.brn import BRN from app.extensions import db from app.models import AbstractConfiguration, AbstractResource +from app.models.base import Group class Onion(AbstractConfiguration): @@ -47,12 +51,12 @@ class Onion(AbstractConfiguration): class Eotk(AbstractResource): - group_id = db.Column(db.Integer(), db.ForeignKey("group.id"), nullable=False) - instance_id = db.Column(db.String(100), nullable=True) - provider = db.Column(db.String(20), nullable=False) - region = db.Column(db.String(20), nullable=False) + group_id: Mapped[int] = mapped_column(db.Integer(), db.ForeignKey("group.id")) + instance_id: Mapped[Optional[str]] + provider: Mapped[str] + region: Mapped[str] - group = db.relationship("Group", back_populates="eotks") + group: Mapped[Group] = relationship("Group", back_populates="eotks") @property def brn(self) -> BRN: diff --git a/app/models/tfstate.py b/app/models/tfstate.py index 0529435..9bea3b9 100644 --- a/app/models/tfstate.py +++ b/app/models/tfstate.py @@ -1,7 +1,11 @@ +from typing import Optional + +from sqlalchemy.orm import Mapped, mapped_column + from app.extensions import db class TerraformState(db.Model): # type: ignore - key = db.Column(db.String, primary_key=True) - state = db.Column(db.String) - lock = db.Column(db.String) + key: Mapped[str] = mapped_column(db.String, primary_key=True) + state: Mapped[str] + lock: Mapped[Optional[str]] diff --git a/migrations/versions/13b1d64f134a_enforce_not_null_restrctions.py b/migrations/versions/13b1d64f134a_enforce_not_null_restrctions.py new file mode 100644 index 0000000..d6e04a3 --- /dev/null +++ b/migrations/versions/13b1d64f134a_enforce_not_null_restrctions.py @@ -0,0 +1,124 @@ +"""enforce not null restrctions + +Revision ID: 13b1d64f134a +Revises: bbec86de37c4 +Create Date: 2024-11-10 15:12:16.589705 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import sqlite + +revision = '13b1d64f134a' +down_revision = 'bbec86de37c4' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('bridge_conf', schema=None) as batch_op: + batch_op.alter_column('target_number', + existing_type=sa.INTEGER(), + nullable=False) + batch_op.alter_column('max_number', + existing_type=sa.INTEGER(), + nullable=False) + batch_op.alter_column('expiry_hours', + existing_type=sa.INTEGER(), + nullable=False) + batch_op.alter_column('provider_allocation', + existing_type=sa.VARCHAR(length=6), + nullable=False) + + with op.batch_alter_table('cloud_account', schema=None) as batch_op: + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=10), + nullable=False) + batch_op.alter_column('credentials', + existing_type=sqlite.JSON(), + nullable=False) + batch_op.alter_column('enabled', + existing_type=sa.BOOLEAN(), + nullable=False) + batch_op.alter_column('max_distributions', + existing_type=sa.INTEGER(), + nullable=False) + batch_op.alter_column('max_sub_distributions', + existing_type=sa.INTEGER(), + nullable=False) + batch_op.alter_column('max_instances', + existing_type=sa.INTEGER(), + nullable=False) + + with op.batch_alter_table('deprecation', schema=None) as batch_op: + batch_op.alter_column('resource_type', + existing_type=sa.VARCHAR(length=50), + nullable=False) + batch_op.alter_column('resource_id', + existing_type=sa.INTEGER(), + nullable=False) + + with op.batch_alter_table('group', schema=None) as batch_op: + batch_op.alter_column('eotk', + existing_type=sa.BOOLEAN(), + nullable=False) + + with op.batch_alter_table('terraform_state', schema=None) as batch_op: + batch_op.alter_column('state', + existing_type=sa.VARCHAR(), + nullable=False) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('terraform_state', schema=None) as batch_op: + batch_op.alter_column('state', + existing_type=sa.VARCHAR(), + nullable=True) + + with op.batch_alter_table('group', schema=None) as batch_op: + batch_op.alter_column('eotk', + existing_type=sa.BOOLEAN(), + nullable=True) + + with op.batch_alter_table('deprecation', schema=None) as batch_op: + batch_op.alter_column('resource_id', + existing_type=sa.INTEGER(), + nullable=True) + batch_op.alter_column('resource_type', + existing_type=sa.VARCHAR(length=50), + nullable=True) + + with op.batch_alter_table('cloud_account', schema=None) as batch_op: + batch_op.alter_column('max_instances', + existing_type=sa.INTEGER(), + nullable=True) + batch_op.alter_column('max_sub_distributions', + existing_type=sa.INTEGER(), + nullable=True) + batch_op.alter_column('max_distributions', + existing_type=sa.INTEGER(), + nullable=True) + batch_op.alter_column('enabled', + existing_type=sa.BOOLEAN(), + nullable=True) + batch_op.alter_column('credentials', + existing_type=sqlite.JSON(), + nullable=True) + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=10), + nullable=True) + + with op.batch_alter_table('bridge_conf', schema=None) as batch_op: + batch_op.alter_column('provider_allocation', + existing_type=sa.VARCHAR(length=6), + nullable=True) + batch_op.alter_column('expiry_hours', + existing_type=sa.INTEGER(), + nullable=True) + batch_op.alter_column('max_number', + existing_type=sa.INTEGER(), + nullable=True) + batch_op.alter_column('target_number', + existing_type=sa.INTEGER(), + nullable=True)