fix: typing fixes since moving to Mapped types in models

This commit is contained in:
Iain Learmonth 2024-11-09 11:09:44 +00:00
parent d08388c339
commit 4693e994ba
5 changed files with 71 additions and 44 deletions

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import json
import tldextract
from datetime import datetime, timedelta
from typing import Optional, List, Union, Any, Dict
from sqlalchemy.orm import Mapped, mapped_column
from tldextract import extract
from werkzeug.datastructures import FileStorage
@ -10,6 +13,7 @@ from app.brm.brn import BRN
from app.brm.utils import thumbnail_uploaded_image, create_data_uri, normalize_color
from app.extensions import db
from app.models import AbstractConfiguration, AbstractResource, Deprecation
from app.models.base import Pool
from app.models.onions import Onion
country_origin = db.Table(
@ -22,12 +26,12 @@ country_origin = db.Table(
class Origin(AbstractConfiguration):
group_id = db.Column(db.Integer, db.ForeignKey("group.id"), nullable=False)
domain_name = db.Column(db.String(255), unique=True, nullable=False)
auto_rotation = db.Column(db.Boolean, nullable=False)
smart = db.Column(db.Boolean(), nullable=False)
assets = db.Column(db.Boolean(), nullable=False)
risk_level_override = db.Column(db.Integer(), nullable=True)
group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False)
domain_name = mapped_column(db.String(255), unique=True, nullable=False)
auto_rotation = mapped_column(db.Boolean, nullable=False)
smart = mapped_column(db.Boolean(), nullable=False)
assets = mapped_column(db.Boolean(), nullable=False)
risk_level_override = mapped_column(db.Integer(), nullable=True)
group = db.relationship("Group", back_populates="origins")
proxies = db.relationship("Proxy", back_populates="origin")
@ -71,10 +75,10 @@ class Origin(AbstractConfiguration):
def risk_level(self) -> Dict[str, int]:
if self.risk_level_override:
return {country.country_code: self.risk_level_override for country in self.countries}
frequency_factor = 0
recency_factor = 0
frequency_factor = 0.0
recency_factor = 0.0
recent_deprecations = (
db.session.query(Deprecation) # type: ignore[no-untyped-call]
db.session.query(Deprecation)
.join(Proxy,
Deprecation.resource_id == Proxy.id)
.join(Origin, Origin.id == Proxy.origin_id)
@ -106,8 +110,8 @@ class Country(AbstractConfiguration):
resource_id=self.country_code
)
country_code = db.Column(db.String(2), nullable=False)
risk_level_override = db.Column(db.Integer(), nullable=True)
country_code = mapped_column(db.String(2), nullable=False)
risk_level_override = mapped_column(db.Integer(), nullable=True)
origins = db.relationship("Origin", secondary=country_origin, back_populates='countries')
@ -115,10 +119,10 @@ class Country(AbstractConfiguration):
def risk_level(self) -> int:
if self.risk_level_override:
return int(self.risk_level_override // 2)
frequency_factor = 0
recency_factor = 0
frequency_factor = 0.0
recency_factor = 0.0
recent_deprecations = (
db.session.query(Deprecation) # type: ignore[no-untyped-call]
db.session.query(Deprecation)
.join(Proxy,
Deprecation.resource_id == Proxy.id)
.join(Origin, Origin.id == Proxy.origin_id)
@ -138,16 +142,16 @@ class Country(AbstractConfiguration):
class StaticOrigin(AbstractConfiguration):
group_id = db.Column(db.Integer, db.ForeignKey("group.id"), nullable=False)
storage_cloud_account_id = db.Column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False)
source_cloud_account_id = db.Column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False)
source_project = db.Column(db.String(255), nullable=False)
auto_rotate = db.Column(db.Boolean, nullable=False)
matrix_homeserver = db.Column(db.String(255), nullable=True)
keanu_convene_path = db.Column(db.String(255), nullable=True)
keanu_convene_config = db.Column(db.String(), nullable=True)
clean_insights_backend = db.Column(db.String(255), nullable=True)
origin_domain_name = db.Column(db.String(255), nullable=True)
group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False)
storage_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False)
source_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False)
source_project = mapped_column(db.String(255), nullable=False)
auto_rotate = mapped_column(db.Boolean, nullable=False)
matrix_homeserver = mapped_column(db.String(255), nullable=True)
keanu_convene_path = mapped_column(db.String(255), nullable=True)
keanu_convene_config = mapped_column(db.String(), nullable=True)
clean_insights_backend = mapped_column(db.String(255), nullable=True)
origin_domain_name = mapped_column(db.String(255), nullable=True)
@property
def brn(self) -> BRN:
@ -235,19 +239,21 @@ class StaticOrigin(AbstractConfiguration):
class Proxy(AbstractResource):
origin_id = db.Column(db.Integer, db.ForeignKey("origin.id"), nullable=False)
pool_id = db.Column(db.Integer, db.ForeignKey("pool.id"))
provider = db.Column(db.String(20), nullable=False)
psg = db.Column(db.Integer, nullable=True)
slug = db.Column(db.String(20), nullable=True)
terraform_updated = db.Column(db.DateTime(), nullable=True)
url = db.Column(db.String(255), nullable=True)
origin_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("origin.id"), nullable=False)
pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id"))
provider: Mapped[str] = mapped_column(db.String(20), nullable=False)
psg: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
slug: Mapped[Optional[str]] = mapped_column(db.String(20), nullable=True)
terraform_updated: Mapped[Optional[datetime]] = mapped_column(db.DateTime(), nullable=True)
url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
origin = db.relationship("Origin", back_populates="proxies")
pool = db.relationship("Pool", back_populates="proxies")
origin: Mapped[Origin] = db.relationship("Origin", back_populates="proxies")
pool: Mapped[Pool] = db.relationship("Pool", back_populates="proxies")
@property
def brn(self) -> BRN:
assert self.group_id is not None, "group_id should never be None" # nosec: B101
assert self.provider is not None, "provider should never be None" # nosec: B101
return BRN(
group_id=self.origin.group_id,
product="mirror",
@ -264,10 +270,10 @@ class Proxy(AbstractResource):
class SmartProxy(AbstractResource):
group_id = db.Column(db.Integer(), db.ForeignKey("group.id"), nullable=False)
instance_id = db.Column(db.String(100), nullable=True)
provider = db.Column(db.String(20), nullable=False)
region = db.Column(db.String(20), nullable=False)
group_id = mapped_column(db.Integer(), db.ForeignKey("group.id"), nullable=False)
instance_id = mapped_column(db.String(100), nullable=True)
provider = mapped_column(db.String(20), nullable=False)
region = mapped_column(db.String(20), nullable=False)
group = db.relationship("Group", back_populates="smart_proxies")

View file

@ -1,4 +1,5 @@
import json
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
@ -52,6 +53,9 @@ portal.register_blueprint(webhook, url_prefix="/webhook")
@portal.app_template_filter("bridge_expiry")
def calculate_bridge_expiry(b: Bridge) -> str:
if b.deprecated is None:
logging.warning("Bridge expiry requested by template for a bridge %s that was not expiring.", b.id)
return "Not expiring"
expiry = b.deprecated + timedelta(hours=b.conf.expiry_hours)
countdown = expiry - datetime.utcnow()
if countdown.days == 0:

View file

@ -1,36 +1,37 @@
import datetime
import os
import sys
from typing import Optional, Any, List, Tuple
from typing import Optional, Any, List, Sequence, Tuple
from sqlalchemy import select
from sqlalchemy import select, Row
from app import app
from app.extensions import db
from app.models import AbstractResource
from app.models.bridges import Bridge, BridgeConf
from app.models.cloud import CloudAccount, CloudProvider
from app.terraform.terraform import TerraformAutomation
BridgeResourceRow = List[Tuple[Bridge, BridgeConf, CloudAccount]]
BridgeResourceRow = Row[Tuple[AbstractResource, BridgeConf, CloudAccount]]
def active_bridges_by_provider(provider: CloudProvider) -> List[BridgeResourceRow]:
def active_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]:
stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where(
CloudAccount.provider == provider,
Bridge.destroyed.is_(None),
)
bridges: List[BridgeResourceRow] = db.session.execute(stmt).all()
bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all()
return bridges
def recently_destroyed_bridges_by_provider(provider: CloudProvider) -> List[BridgeResourceRow]:
def recently_destroyed_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]:
cutoff = datetime.datetime.utcnow() - datetime.timedelta(hours=72)
stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where(
CloudAccount.provider == provider,
Bridge.destroyed.is_not(None),
Bridge.destroyed >= cutoff,
)
bridges: List[BridgeResourceRow] = db.session.execute(stmt).all()
bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all()
return bridges

View file

@ -127,6 +127,8 @@ class BridgeMetaAutomation(BaseAutomation):
).all()
logging.debug("Found %s deprecated bridges", len(deprecated_bridges))
for bridge in deprecated_bridges:
if bridge.deprecated is None:
continue # Possible due to SQLAlchemy lazy loading
cutoff = datetime.datetime.utcnow() - datetime.timedelta(hours=bridge.conf.expiry_hours)
if bridge.deprecated < cutoff:
logging.debug("Destroying expired bridge")

View file

@ -80,6 +80,9 @@ def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupC
proxies = all_active_proxies()
subgroup_count: SubgroupCount = OrderedDict()
for proxy in proxies:
if not proxy.psg:
logging.warning("Proxy %s has no psg", proxy.id)
continue
if proxy.provider not in subgroup_count:
subgroup_count[proxy.provider] = OrderedDict()
if proxy.origin.group_id not in subgroup_count[proxy.provider]:
@ -142,6 +145,7 @@ def auto_deprecate_proxies() -> None:
days=1, seconds=86400 * random.random()) # nosec: B311
if proxy.added < max_age_cutoff:
proxy.deprecate(reason="max_age_reached")
proxy.destroy()
def destroy_expired_proxies() -> None:
@ -298,6 +302,16 @@ class ProxyMetaAutomation(BaseAutomation):
Origin.destroyed.is_(None)
).all()
for origin in origins:
if origin.countries:
risk_levels = origin.risk_level.items()
highest_risk_country = max(risk_levels, key=lambda x: x[1])
highest_risk_level = highest_risk_country[1]
if highest_risk_level < 4:
for proxy in origin.proxies:
if proxy.destroyed is None and proxy.pool_id == -1:
logging.debug("Destroying hot spare proxy for origin %s (low risk)", origin)
proxy.destroy()
continue
if origin.destroyed is not None:
continue
proxies = Proxy.query.filter(