diff --git a/app/models/mirrors.py b/app/models/mirrors.py index addae61..cdd6b3e 100644 --- a/app/models/mirrors.py +++ b/app/models/mirrors.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import json import tldextract from datetime import datetime, timedelta from typing import Optional, List, Union, Any, Dict +from sqlalchemy.orm import Mapped, mapped_column from tldextract import extract from werkzeug.datastructures import FileStorage @@ -10,6 +13,7 @@ from app.brm.brn import BRN from app.brm.utils import thumbnail_uploaded_image, create_data_uri, normalize_color from app.extensions import db from app.models import AbstractConfiguration, AbstractResource, Deprecation +from app.models.base import Pool from app.models.onions import Onion country_origin = db.Table( @@ -22,12 +26,12 @@ country_origin = db.Table( class Origin(AbstractConfiguration): - group_id = db.Column(db.Integer, db.ForeignKey("group.id"), nullable=False) - domain_name = db.Column(db.String(255), unique=True, nullable=False) - auto_rotation = db.Column(db.Boolean, nullable=False) - smart = db.Column(db.Boolean(), nullable=False) - assets = db.Column(db.Boolean(), nullable=False) - risk_level_override = db.Column(db.Integer(), nullable=True) + group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False) + domain_name = mapped_column(db.String(255), unique=True, nullable=False) + auto_rotation = mapped_column(db.Boolean, nullable=False) + smart = mapped_column(db.Boolean(), nullable=False) + assets = mapped_column(db.Boolean(), nullable=False) + risk_level_override = mapped_column(db.Integer(), nullable=True) group = db.relationship("Group", back_populates="origins") proxies = db.relationship("Proxy", back_populates="origin") @@ -71,10 +75,10 @@ class Origin(AbstractConfiguration): def risk_level(self) -> Dict[str, int]: if self.risk_level_override: return {country.country_code: self.risk_level_override for country in self.countries} - frequency_factor = 0 - recency_factor = 0 + frequency_factor = 0.0 + recency_factor = 0.0 recent_deprecations = ( - db.session.query(Deprecation) # type: ignore[no-untyped-call] + db.session.query(Deprecation) .join(Proxy, Deprecation.resource_id == Proxy.id) .join(Origin, Origin.id == Proxy.origin_id) @@ -106,8 +110,8 @@ class Country(AbstractConfiguration): resource_id=self.country_code ) - country_code = db.Column(db.String(2), nullable=False) - risk_level_override = db.Column(db.Integer(), nullable=True) + country_code = mapped_column(db.String(2), nullable=False) + risk_level_override = mapped_column(db.Integer(), nullable=True) origins = db.relationship("Origin", secondary=country_origin, back_populates='countries') @@ -115,10 +119,10 @@ class Country(AbstractConfiguration): def risk_level(self) -> int: if self.risk_level_override: return int(self.risk_level_override // 2) - frequency_factor = 0 - recency_factor = 0 + frequency_factor = 0.0 + recency_factor = 0.0 recent_deprecations = ( - db.session.query(Deprecation) # type: ignore[no-untyped-call] + db.session.query(Deprecation) .join(Proxy, Deprecation.resource_id == Proxy.id) .join(Origin, Origin.id == Proxy.origin_id) @@ -138,16 +142,16 @@ class Country(AbstractConfiguration): class StaticOrigin(AbstractConfiguration): - group_id = db.Column(db.Integer, db.ForeignKey("group.id"), nullable=False) - storage_cloud_account_id = db.Column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) - source_cloud_account_id = db.Column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) - source_project = db.Column(db.String(255), nullable=False) - auto_rotate = db.Column(db.Boolean, nullable=False) - matrix_homeserver = db.Column(db.String(255), nullable=True) - keanu_convene_path = db.Column(db.String(255), nullable=True) - keanu_convene_config = db.Column(db.String(), nullable=True) - clean_insights_backend = db.Column(db.String(255), nullable=True) - origin_domain_name = db.Column(db.String(255), nullable=True) + group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False) + storage_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) + source_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) + source_project = mapped_column(db.String(255), nullable=False) + auto_rotate = mapped_column(db.Boolean, nullable=False) + matrix_homeserver = mapped_column(db.String(255), nullable=True) + keanu_convene_path = mapped_column(db.String(255), nullable=True) + keanu_convene_config = mapped_column(db.String(), nullable=True) + clean_insights_backend = mapped_column(db.String(255), nullable=True) + origin_domain_name = mapped_column(db.String(255), nullable=True) @property def brn(self) -> BRN: @@ -235,19 +239,21 @@ class StaticOrigin(AbstractConfiguration): class Proxy(AbstractResource): - origin_id = db.Column(db.Integer, db.ForeignKey("origin.id"), nullable=False) - pool_id = db.Column(db.Integer, db.ForeignKey("pool.id")) - provider = db.Column(db.String(20), nullable=False) - psg = db.Column(db.Integer, nullable=True) - slug = db.Column(db.String(20), nullable=True) - terraform_updated = db.Column(db.DateTime(), nullable=True) - url = db.Column(db.String(255), nullable=True) + origin_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("origin.id"), nullable=False) + pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id")) + provider: Mapped[str] = mapped_column(db.String(20), nullable=False) + psg: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + slug: Mapped[Optional[str]] = mapped_column(db.String(20), nullable=True) + terraform_updated: Mapped[Optional[datetime]] = mapped_column(db.DateTime(), nullable=True) + url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - origin = db.relationship("Origin", back_populates="proxies") - pool = db.relationship("Pool", back_populates="proxies") + origin: Mapped[Origin] = db.relationship("Origin", back_populates="proxies") + pool: Mapped[Pool] = db.relationship("Pool", back_populates="proxies") @property def brn(self) -> BRN: + assert self.group_id is not None, "group_id should never be None" # nosec: B101 + assert self.provider is not None, "provider should never be None" # nosec: B101 return BRN( group_id=self.origin.group_id, product="mirror", @@ -264,10 +270,10 @@ class Proxy(AbstractResource): class SmartProxy(AbstractResource): - group_id = db.Column(db.Integer(), db.ForeignKey("group.id"), nullable=False) - instance_id = db.Column(db.String(100), nullable=True) - provider = db.Column(db.String(20), nullable=False) - region = db.Column(db.String(20), nullable=False) + group_id = mapped_column(db.Integer(), db.ForeignKey("group.id"), nullable=False) + instance_id = mapped_column(db.String(100), nullable=True) + provider = mapped_column(db.String(20), nullable=False) + region = mapped_column(db.String(20), nullable=False) group = db.relationship("Group", back_populates="smart_proxies") diff --git a/app/portal/__init__.py b/app/portal/__init__.py index eb1271f..868ae2f 100644 --- a/app/portal/__init__.py +++ b/app/portal/__init__.py @@ -1,4 +1,5 @@ import json +import logging from datetime import datetime, timedelta, timezone from typing import Optional @@ -52,6 +53,9 @@ portal.register_blueprint(webhook, url_prefix="/webhook") @portal.app_template_filter("bridge_expiry") def calculate_bridge_expiry(b: Bridge) -> str: + if b.deprecated is None: + logging.warning("Bridge expiry requested by template for a bridge %s that was not expiring.", b.id) + return "Not expiring" expiry = b.deprecated + timedelta(hours=b.conf.expiry_hours) countdown = expiry - datetime.utcnow() if countdown.days == 0: diff --git a/app/terraform/bridge/__init__.py b/app/terraform/bridge/__init__.py index 7848c97..e18cee7 100644 --- a/app/terraform/bridge/__init__.py +++ b/app/terraform/bridge/__init__.py @@ -1,36 +1,37 @@ import datetime import os import sys -from typing import Optional, Any, List, Tuple +from typing import Optional, Any, List, Sequence, Tuple -from sqlalchemy import select +from sqlalchemy import select, Row from app import app from app.extensions import db +from app.models import AbstractResource from app.models.bridges import Bridge, BridgeConf from app.models.cloud import CloudAccount, CloudProvider from app.terraform.terraform import TerraformAutomation -BridgeResourceRow = List[Tuple[Bridge, BridgeConf, CloudAccount]] +BridgeResourceRow = Row[Tuple[AbstractResource, BridgeConf, CloudAccount]] -def active_bridges_by_provider(provider: CloudProvider) -> List[BridgeResourceRow]: +def active_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]: stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where( CloudAccount.provider == provider, Bridge.destroyed.is_(None), ) - bridges: List[BridgeResourceRow] = db.session.execute(stmt).all() + bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all() return bridges -def recently_destroyed_bridges_by_provider(provider: CloudProvider) -> List[BridgeResourceRow]: +def recently_destroyed_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]: cutoff = datetime.datetime.utcnow() - datetime.timedelta(hours=72) stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where( CloudAccount.provider == provider, Bridge.destroyed.is_not(None), Bridge.destroyed >= cutoff, ) - bridges: List[BridgeResourceRow] = db.session.execute(stmt).all() + bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all() return bridges diff --git a/app/terraform/bridge/meta.py b/app/terraform/bridge/meta.py index d920dc1..f9391c4 100644 --- a/app/terraform/bridge/meta.py +++ b/app/terraform/bridge/meta.py @@ -127,6 +127,8 @@ class BridgeMetaAutomation(BaseAutomation): ).all() logging.debug("Found %s deprecated bridges", len(deprecated_bridges)) for bridge in deprecated_bridges: + if bridge.deprecated is None: + continue # Possible due to SQLAlchemy lazy loading cutoff = datetime.datetime.utcnow() - datetime.timedelta(hours=bridge.conf.expiry_hours) if bridge.deprecated < cutoff: logging.debug("Destroying expired bridge") diff --git a/app/terraform/proxy/meta.py b/app/terraform/proxy/meta.py index ec8fc8b..0ef6df6 100644 --- a/app/terraform/proxy/meta.py +++ b/app/terraform/proxy/meta.py @@ -80,6 +80,9 @@ def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupC proxies = all_active_proxies() subgroup_count: SubgroupCount = OrderedDict() for proxy in proxies: + if not proxy.psg: + logging.warning("Proxy %s has no psg", proxy.id) + continue if proxy.provider not in subgroup_count: subgroup_count[proxy.provider] = OrderedDict() if proxy.origin.group_id not in subgroup_count[proxy.provider]: @@ -142,6 +145,7 @@ def auto_deprecate_proxies() -> None: days=1, seconds=86400 * random.random()) # nosec: B311 if proxy.added < max_age_cutoff: proxy.deprecate(reason="max_age_reached") + proxy.destroy() def destroy_expired_proxies() -> None: @@ -298,6 +302,16 @@ class ProxyMetaAutomation(BaseAutomation): Origin.destroyed.is_(None) ).all() for origin in origins: + if origin.countries: + risk_levels = origin.risk_level.items() + highest_risk_country = max(risk_levels, key=lambda x: x[1]) + highest_risk_level = highest_risk_country[1] + if highest_risk_level < 4: + for proxy in origin.proxies: + if proxy.destroyed is None and proxy.pool_id == -1: + logging.debug("Destroying hot spare proxy for origin %s (low risk)", origin) + proxy.destroy() + continue if origin.destroyed is not None: continue proxies = Proxy.query.filter(