refactor: moving more models to mapped_column

This commit is contained in:
Iain Learmonth 2024-11-10 15:13:29 +00:00
parent ea020d6edd
commit 75b2c1adf0
9 changed files with 272 additions and 94 deletions

View file

@ -2,9 +2,12 @@ import base64
import binascii
import logging
import re
from typing import Optional, List, Callable, Any, Type, Dict, Union
from flask import Blueprint, request, jsonify, abort
from sqlalchemy import select
from flask.typing import ResponseReturnValue
from sqlalchemy import select, BinaryExpression, ColumnElement
from werkzeug.exceptions import HTTPException
from app.extensions import db
from app.models.base import Group
@ -17,41 +20,38 @@ MAX_DOMAIN_NAME_LENGTH = 255
DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$')
MAX_ALLOWED_ITEMS = 100
ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]]
@api.errorhandler(400)
def bad_request(error):
def bad_request(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Bad Request', 'message': error.description})
response.status_code = 400
return response
@api.errorhandler(401)
def unauthorized(error):
def unauthorized(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Unauthorized', 'message': error.description})
response.status_code = 401
return response
@api.errorhandler(404)
def not_found(error):
def not_found(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'})
response.status_code = 404
return response
@api.errorhandler(500)
def internal_server_error(error):
def internal_server_error(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'})
response.status_code = 500
return response
@api.teardown_app_request
def shutdown_session(exception=None):
db.session.remove()
def validate_max_items(max_items_str, max_allowed):
def validate_max_items(max_items_str: str, max_allowed: int) -> int:
try:
max_items = int(max_items_str)
if max_items <= 0 or max_items > max_allowed:
@ -61,7 +61,7 @@ def validate_max_items(max_items_str, max_allowed):
abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.")
def validate_marker(marker_str):
def validate_marker(marker_str: str) -> int:
try:
marker_decoded = base64.urlsafe_b64decode(marker_str.encode()).decode()
marker_id = int(marker_decoded)
@ -71,15 +71,15 @@ def validate_marker(marker_str):
def list_resources(
model,
filters=None,
order_by=None,
serialize_func=None,
resource_name='ResourceList',
max_items_param='MaxItems',
marker_param='Marker',
max_allowed_items=100
):
model: Type[Any],
serialize_func: Callable[[Any], Dict[str, Any]],
filters: Optional[List[ListFilter]] = None,
order_by: Optional[ColumnElement[Any]] = None,
resource_name: str = 'ResourceList',
max_items_param: str = 'MaxItems',
marker_param: str = 'Marker',
max_allowed_items: int = 100
) -> ResponseReturnValue:
try:
marker = request.args.get(marker_param)
max_items = validate_max_items(
@ -123,7 +123,7 @@ def list_resources(
@api.route('/web/group', methods=['GET'])
def list_groups():
def list_groups() -> ResponseReturnValue:
return list_resources(
model=Group,
serialize_func=lambda group: group.to_dict(),
@ -133,11 +133,11 @@ def list_groups():
@api.route('/web/origin', methods=['GET'])
def list_origins():
def list_origins() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
filters = []
filters: List[ListFilter] = []
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
@ -148,22 +148,21 @@ def list_origins():
if group_id_filter:
try:
group_id_filter = int(group_id_filter)
filters.append(Origin.group_id == group_id_filter)
filters.append(Origin.group_id == int(group_id_filter))
except ValueError:
abort(400, description="GroupId must be a valid integer.")
return list_resources(
model=Origin,
filters=filters,
serialize_func=lambda origin: origin.to_dict(),
filters=filters,
resource_name='OriginsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)
@api.route('/web/mirror', methods=['GET'])
def list_mirrors():
def list_mirrors() -> ResponseReturnValue:
status_filter = request.args.get('Status')
filters = []
@ -185,8 +184,8 @@ def list_mirrors():
return list_resources(
model=Proxy,
filters=filters,
serialize_func=lambda proxy: proxy.to_dict(),
filters=filters,
resource_name='MirrorsList',
max_allowed_items=MAX_ALLOWED_ITEMS
)

View file

@ -13,10 +13,10 @@ class AbstractConfiguration(db.Model): # type: ignore
__abstract__ = True
id: Mapped[int] = mapped_column(db.Integer, primary_key=True)
description: Mapped[str] = mapped_column(db.String(255), nullable=False)
added: Mapped[datetime] = mapped_column(db.DateTime(), default=datetime.utcnow, nullable=False)
updated: Mapped[datetime] = mapped_column(db.DateTime(), default=datetime.utcnow, nullable=False)
destroyed: Mapped[datetime] = mapped_column(db.DateTime())
description: Mapped[str]
added: Mapped[datetime]
updated: Mapped[datetime]
destroyed: Mapped[Optional[datetime]]
@property
@abstractmethod

View file

@ -1,21 +1,36 @@
from datetime import datetime
from typing import List
from typing import List, TypedDict, Optional, TYPE_CHECKING
from sqlalchemy import and_
from sqlalchemy.orm import Mapped, aliased, mapped_column, relationship
from app.brm.brn import BRN
from app.extensions import db
from app.models import AbstractConfiguration
if TYPE_CHECKING:
from app.models.bridges import BridgeConf
from app.models.mirrors import Origin, Proxy, SmartProxy, StaticOrigin
from app.models.onions import Eotk, Onion
class GroupDict(TypedDict):
Id: int
GroupName: str
Description: str
ActiveOriginCount: int
class Group(AbstractConfiguration):
group_name = db.Column(db.String(80), unique=True, nullable=False)
eotk = db.Column(db.Boolean())
group_name: Mapped[str] = db.Column(db.String(80), unique=True, nullable=False)
eotk: Mapped[bool]
origins = db.relationship("Origin", back_populates="group")
statics = db.relationship("StaticOrigin", back_populates="group")
eotks = db.relationship("Eotk", back_populates="group")
onions = db.relationship("Onion", back_populates="group")
smart_proxies = db.relationship("SmartProxy", back_populates="group")
pools = db.relationship("Pool", secondary="pool_group", back_populates="groups")
origins: Mapped[List["Origin"]] = relationship("Origin", back_populates="group")
statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="group")
eotks: Mapped[List["Eotk"]] = relationship("Eotk", back_populates="group")
onions: Mapped[List["Onion"]] = relationship("Onion", back_populates="group")
smart_proxies: Mapped[List["SmartProxy"]] = relationship("SmartProxy", back_populates="group")
pools: Mapped[List["Pool"]] = relationship("Pool", secondary="pool_group", back_populates="groups")
@classmethod
def csv_header(cls) -> List[str]:
@ -33,25 +48,29 @@ class Group(AbstractConfiguration):
resource_id=str(self.id)
)
def to_dict(self):
active_origins = [o for o in self.origins if o.destroyed is None]
def to_dict(self) -> GroupDict:
active_origins_query = (
db.session.query(aliased(Origin))
.filter(and_(Origin.group_id == self.id, Origin.destroyed.is_(None)))
)
active_origins_count = active_origins_query.count()
return {
"Id": self.id,
"GroupName": self.group_name,
"Description": self.description,
"ActiveOriginCount": len(active_origins),
"ActiveOriginCount": active_origins_count,
}
class Pool(AbstractConfiguration):
pool_name = db.Column(db.String(80), unique=True, nullable=False)
api_key = db.Column(db.String(80), nullable=False)
redirector_domain = db.Column(db.String(128), nullable=True)
pool_name: Mapped[str] = mapped_column(db.String, unique=True)
api_key: Mapped[str]
redirector_domain: Mapped[Optional[str]]
bridgeconfs = db.relationship("BridgeConf", back_populates="pool")
proxies = db.relationship("Proxy", back_populates="pool")
lists = db.relationship("MirrorList", back_populates="pool")
groups = db.relationship("Group", secondary="pool_group", back_populates="pools")
bridgeconfs: Mapped[List["BridgeConf"]] = relationship("BridgeConf", back_populates="pool")
proxies: Mapped[List["Proxy"]] = relationship("Proxy", back_populates="pool")
lists: Mapped[List["MirrorList"]] = relationship("MirrorList", back_populates="pool")
groups: Mapped[List[Group]] = relationship("Group", secondary="pool_group", back_populates="pools")
@classmethod
def csv_header(cls) -> List[str]:

View file

@ -2,9 +2,12 @@ import enum
from datetime import datetime
from typing import List
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.brm.brn import BRN
from app.extensions import db
from app.models import AbstractConfiguration, AbstractResource
from app.models.base import Pool
class ProviderAllocation(enum.Enum):
@ -13,15 +16,15 @@ class ProviderAllocation(enum.Enum):
class BridgeConf(AbstractConfiguration):
pool_id = db.Column(db.Integer, db.ForeignKey("pool.id"), nullable=False)
method = db.Column(db.String(20), nullable=False)
target_number = db.Column(db.Integer())
max_number = db.Column(db.Integer())
expiry_hours = db.Column(db.Integer())
provider_allocation = db.Column(db.Enum(ProviderAllocation))
pool_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("pool.id"))
method: Mapped[str]
target_number: Mapped[int]
max_number: Mapped[int]
expiry_hours: Mapped[int]
provider_allocation: Mapped[ProviderAllocation]
pool = db.relationship("Pool", back_populates="bridgeconfs")
bridges = db.relationship("Bridge", back_populates="conf")
pool: Mapped[Pool] = relationship("Pool", back_populates="bridgeconfs")
bridges: Mapped[List["Bridge"]] = relationship("Bridge", back_populates="conf")
@property
def brn(self) -> BRN:

View file

@ -1,4 +1,7 @@
import enum
from typing import Any, Dict, List, TYPE_CHECKING
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.brm.brn import BRN
from app.extensions import db
@ -6,6 +9,10 @@ from app.models import AbstractConfiguration
from app.models.mirrors import StaticOrigin
if TYPE_CHECKING:
from app.models.bridges import Bridge
class CloudProvider(enum.Enum):
AWS = ("aws", "Amazon Web Services")
AZURE = ("azure", "Microsoft Azure")
@ -27,17 +34,17 @@ class CloudProvider(enum.Enum):
class CloudAccount(AbstractConfiguration):
provider = db.Column(db.Enum(CloudProvider))
credentials = db.Column(db.JSON())
enabled = db.Column(db.Boolean())
provider: Mapped[CloudProvider]
credentials: Mapped[Dict[str, Any]] = mapped_column(db.JSON())
enabled: Mapped[bool]
# CDN Quotas
max_distributions = db.Column(db.Integer())
max_sub_distributions = db.Column(db.Integer())
max_distributions: Mapped[int]
max_sub_distributions: Mapped[int]
# Compute Quotas
max_instances = db.Column(db.Integer())
max_instances: Mapped[int]
bridges = db.relationship("Bridge", back_populates="cloud_account")
statics = db.relationship("StaticOrigin", back_populates="storage_cloud_account", foreign_keys=[
bridges: Mapped[List["Bridge"]] = relationship("Bridge", back_populates="cloud_account")
statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="storage_cloud_account", foreign_keys=[
StaticOrigin.storage_cloud_account_id])
@property

View file

@ -2,10 +2,10 @@ from __future__ import annotations
import json
from datetime import datetime, timedelta
from typing import Optional, List, Union, Any, Dict
from typing import Optional, List, Union, Any, Dict, TypedDict, Literal
import tldextract
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, relationship
from tldextract import extract
from werkzeug.datastructures import FileStorage
@ -13,7 +13,7 @@ from app.brm.brn import BRN
from app.brm.utils import thumbnail_uploaded_image, create_data_uri, normalize_color
from app.extensions import db
from app.models import AbstractConfiguration, AbstractResource, Deprecation
from app.models.base import Pool
from app.models.base import Pool, Group
from app.models.onions import Onion
country_origin = db.Table(
@ -25,17 +25,25 @@ country_origin = db.Table(
)
class Origin(AbstractConfiguration):
group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False)
domain_name = mapped_column(db.String(255), unique=True, nullable=False)
auto_rotation = mapped_column(db.Boolean, nullable=False)
smart = mapped_column(db.Boolean(), nullable=False)
assets = mapped_column(db.Boolean(), nullable=False)
risk_level_override = mapped_column(db.Integer(), nullable=True)
class OriginDict(TypedDict):
Id: int
Description: str
DomainName: str
RiskLevel: Dict[str, int]
RiskLevelOverride: Optional[int]
group = db.relationship("Group", back_populates="origins")
proxies = db.relationship("Proxy", back_populates="origin")
countries = db.relationship("Country", secondary=country_origin, back_populates='origins')
class Origin(AbstractConfiguration):
group_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("group.id"))
domain_name: Mapped[str] = mapped_column(db.String(255), unique=True)
auto_rotation: Mapped[bool]
smart: Mapped[bool]
assets: Mapped[bool]
risk_level_override: Mapped[Optional[int]]
group: Mapped[Group] = relationship("Group", back_populates="origins")
proxies: Mapped[List[Proxy]] = relationship("Proxy", back_populates="origin")
countries: Mapped[List[Country]] = relationship("Country", secondary=country_origin, back_populates='origins')
@property
def brn(self) -> BRN:
@ -100,7 +108,7 @@ class Origin(AbstractConfiguration):
max(1, min(10, frequency_factor * recency_factor))) + country.risk_level
return risk_levels
def to_dict(self):
def to_dict(self) -> OriginDict:
return {
"Id": self.id,
"Description": self.description,
@ -250,6 +258,16 @@ class StaticOrigin(AbstractConfiguration):
self.updated = datetime.utcnow()
ResourceStatus = Union[Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"]]
class ProxyDict(TypedDict):
Id: int
OriginDomain: str
MirrorDomain: Optional[str]
Status: ResourceStatus
class Proxy(AbstractResource):
origin_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("origin.id"), nullable=False)
pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id"))
@ -259,8 +277,8 @@ class Proxy(AbstractResource):
terraform_updated: Mapped[Optional[datetime]] = mapped_column(db.DateTime(), nullable=True)
url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
origin: Mapped[Origin] = db.relationship("Origin", back_populates="proxies")
pool: Mapped[Pool] = db.relationship("Pool", back_populates="proxies")
origin: Mapped[Origin] = relationship("Origin", back_populates="proxies")
pool: Mapped[Pool] = relationship("Pool", back_populates="proxies")
@property
def brn(self) -> BRN:
@ -278,8 +296,8 @@ class Proxy(AbstractResource):
"origin_id", "provider", "psg", "slug", "terraform_updated", "url"
]
def to_dict(self):
status = "active"
def to_dict(self) -> ProxyDict:
status: ResourceStatus = "active"
if self.url is None:
status = "pending"
if self.deprecated is not None:

View file

@ -1,9 +1,13 @@
import base64
import hashlib
from typing import Optional
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.brm.brn import BRN
from app.extensions import db
from app.models import AbstractConfiguration, AbstractResource
from app.models.base import Group
class Onion(AbstractConfiguration):
@ -47,12 +51,12 @@ class Onion(AbstractConfiguration):
class Eotk(AbstractResource):
group_id = db.Column(db.Integer(), db.ForeignKey("group.id"), nullable=False)
instance_id = db.Column(db.String(100), nullable=True)
provider = db.Column(db.String(20), nullable=False)
region = db.Column(db.String(20), nullable=False)
group_id: Mapped[int] = mapped_column(db.Integer(), db.ForeignKey("group.id"))
instance_id: Mapped[Optional[str]]
provider: Mapped[str]
region: Mapped[str]
group = db.relationship("Group", back_populates="eotks")
group: Mapped[Group] = relationship("Group", back_populates="eotks")
@property
def brn(self) -> BRN:

View file

@ -1,7 +1,11 @@
from typing import Optional
from sqlalchemy.orm import Mapped, mapped_column
from app.extensions import db
class TerraformState(db.Model): # type: ignore
key = db.Column(db.String, primary_key=True)
state = db.Column(db.String)
lock = db.Column(db.String)
key: Mapped[str] = mapped_column(db.String, primary_key=True)
state: Mapped[str]
lock: Mapped[Optional[str]]

View file

@ -0,0 +1,124 @@
"""enforce not null restrctions
Revision ID: 13b1d64f134a
Revises: bbec86de37c4
Create Date: 2024-11-10 15:12:16.589705
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import sqlite
revision = '13b1d64f134a'
down_revision = 'bbec86de37c4'
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table('bridge_conf', schema=None) as batch_op:
batch_op.alter_column('target_number',
existing_type=sa.INTEGER(),
nullable=False)
batch_op.alter_column('max_number',
existing_type=sa.INTEGER(),
nullable=False)
batch_op.alter_column('expiry_hours',
existing_type=sa.INTEGER(),
nullable=False)
batch_op.alter_column('provider_allocation',
existing_type=sa.VARCHAR(length=6),
nullable=False)
with op.batch_alter_table('cloud_account', schema=None) as batch_op:
batch_op.alter_column('provider',
existing_type=sa.VARCHAR(length=10),
nullable=False)
batch_op.alter_column('credentials',
existing_type=sqlite.JSON(),
nullable=False)
batch_op.alter_column('enabled',
existing_type=sa.BOOLEAN(),
nullable=False)
batch_op.alter_column('max_distributions',
existing_type=sa.INTEGER(),
nullable=False)
batch_op.alter_column('max_sub_distributions',
existing_type=sa.INTEGER(),
nullable=False)
batch_op.alter_column('max_instances',
existing_type=sa.INTEGER(),
nullable=False)
with op.batch_alter_table('deprecation', schema=None) as batch_op:
batch_op.alter_column('resource_type',
existing_type=sa.VARCHAR(length=50),
nullable=False)
batch_op.alter_column('resource_id',
existing_type=sa.INTEGER(),
nullable=False)
with op.batch_alter_table('group', schema=None) as batch_op:
batch_op.alter_column('eotk',
existing_type=sa.BOOLEAN(),
nullable=False)
with op.batch_alter_table('terraform_state', schema=None) as batch_op:
batch_op.alter_column('state',
existing_type=sa.VARCHAR(),
nullable=False)
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('terraform_state', schema=None) as batch_op:
batch_op.alter_column('state',
existing_type=sa.VARCHAR(),
nullable=True)
with op.batch_alter_table('group', schema=None) as batch_op:
batch_op.alter_column('eotk',
existing_type=sa.BOOLEAN(),
nullable=True)
with op.batch_alter_table('deprecation', schema=None) as batch_op:
batch_op.alter_column('resource_id',
existing_type=sa.INTEGER(),
nullable=True)
batch_op.alter_column('resource_type',
existing_type=sa.VARCHAR(length=50),
nullable=True)
with op.batch_alter_table('cloud_account', schema=None) as batch_op:
batch_op.alter_column('max_instances',
existing_type=sa.INTEGER(),
nullable=True)
batch_op.alter_column('max_sub_distributions',
existing_type=sa.INTEGER(),
nullable=True)
batch_op.alter_column('max_distributions',
existing_type=sa.INTEGER(),
nullable=True)
batch_op.alter_column('enabled',
existing_type=sa.BOOLEAN(),
nullable=True)
batch_op.alter_column('credentials',
existing_type=sqlite.JSON(),
nullable=True)
batch_op.alter_column('provider',
existing_type=sa.VARCHAR(length=10),
nullable=True)
with op.batch_alter_table('bridge_conf', schema=None) as batch_op:
batch_op.alter_column('provider_allocation',
existing_type=sa.VARCHAR(length=6),
nullable=True)
batch_op.alter_column('expiry_hours',
existing_type=sa.INTEGER(),
nullable=True)
batch_op.alter_column('max_number',
existing_type=sa.INTEGER(),
nullable=True)
batch_op.alter_column('target_number',
existing_type=sa.INTEGER(),
nullable=True)