feat: remove pydantic from list generation
This commit is contained in:
parent
1e70ec8fa6
commit
d08388c339
8 changed files with 164 additions and 197 deletions
|
@ -1,87 +1,76 @@
|
|||
# pylint: disable=too-few-public-methods
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from typing import Dict, List, Optional, TypedDict
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
from tldextract import extract
|
||||
|
||||
from app.extensions import db
|
||||
from app.models.base import Group, Pool
|
||||
from app.models.mirrors import Proxy
|
||||
from app.models.mirrors import Proxy, Origin
|
||||
|
||||
|
||||
class MMMirror(BaseModel):
|
||||
origin_domain: str = Field(description="The full origin domain name")
|
||||
origin_domain_normalized: str = Field(description="The origin_domain with \"www.\" removed, if present")
|
||||
origin_domain_root: str = Field(description="The registered domain name of the origin, excluding subdomains")
|
||||
valid_from: str = Field(description="The date on which the mirror was added to the system")
|
||||
valid_to: Optional[str] = Field(description="The date on which the mirror was decommissioned")
|
||||
countries: Dict[str, int] = Field(description="A list mapping of risk levels to country")
|
||||
country: Optional[str] = Field(
|
||||
description="The country code of the country with the highest risk level where the origin is targeted")
|
||||
risk: int = Field(description="The risk score for the highest risk country")
|
||||
class MirrorMappingMirror(TypedDict):
|
||||
origin_domain: str
|
||||
origin_domain_normalized: str
|
||||
origin_domain_root: str
|
||||
valid_from: str
|
||||
valid_to: Optional[str]
|
||||
countries: Dict[str, int]
|
||||
country: Optional[str]
|
||||
risk: int
|
||||
|
||||
|
||||
class MirrorMapping(BaseModel):
|
||||
version: str = Field(
|
||||
description="Version number of the mirror mapping schema in use"
|
||||
)
|
||||
mappings: Dict[str, MMMirror] = Field(
|
||||
description="The domain name for the mirror"
|
||||
)
|
||||
s3_buckets: List[str] = Field(
|
||||
description="The names of all S3 buckets used for CloudFront logs"
|
||||
)
|
||||
|
||||
class Config:
|
||||
title = "Mirror Mapping Version 1.2"
|
||||
class MirrorMapping(TypedDict):
|
||||
version: str
|
||||
mappings: Dict[str, MirrorMappingMirror]
|
||||
s3_buckets: List[str]
|
||||
|
||||
|
||||
def mirror_mapping(_: Optional[Pool]) -> Dict[str, Union[str, Dict[str, str]]]:
|
||||
one_week_ago = datetime.utcnow() - timedelta(days=7)
|
||||
def mirror_mapping(_: Optional[Pool]) -> MirrorMapping:
|
||||
two_days_ago = datetime.utcnow() - timedelta(days=2)
|
||||
|
||||
proxies = (
|
||||
db.session.query(Proxy) # type: ignore[no-untyped-call]
|
||||
.filter(or_(Proxy.destroyed.is_(None), Proxy.destroyed > one_week_ago))
|
||||
db.session.query(Proxy)
|
||||
.options(selectinload(Proxy.origin).selectinload(Origin.countries))
|
||||
.filter(or_(Proxy.destroyed.is_(None), Proxy.destroyed > two_days_ago))
|
||||
.filter(Proxy.url.is_not(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
result = {}
|
||||
result: Dict[str, MirrorMappingMirror] = {}
|
||||
for proxy in proxies:
|
||||
if proxy.origin.countries: # Check if there are any associated countries
|
||||
risk_levels = proxy.origin.risk_level.items()
|
||||
highest_risk_country = max(risk_levels, key=lambda x: x[1])
|
||||
highest_risk_country_code = highest_risk_country[0]
|
||||
highest_risk_level = highest_risk_country[1]
|
||||
if proxy.url is None:
|
||||
logging.error("No URL for proxy %s", proxy)
|
||||
continue
|
||||
|
||||
countries = proxy.origin.risk_level
|
||||
if countries:
|
||||
highest_risk_country_code, highest_risk_level = max(countries.items(), key=lambda x: x[1])
|
||||
else:
|
||||
highest_risk_country_code = "ZZ"
|
||||
highest_risk_level = 0
|
||||
|
||||
result[proxy.url.lstrip("https://")] = MMMirror(
|
||||
origin_domain=proxy.origin.domain_name,
|
||||
origin_domain_normalized=proxy.origin.domain_name.replace("www.", ""),
|
||||
origin_domain_root=extract(proxy.origin.domain_name).registered_domain,
|
||||
valid_from=proxy.added.isoformat(),
|
||||
valid_to=proxy.destroyed.isoformat() if proxy.destroyed is not None else None,
|
||||
countries=proxy.origin.risk_level,
|
||||
country=highest_risk_country_code,
|
||||
risk=highest_risk_level
|
||||
)
|
||||
result[proxy.url.lstrip("https://")] = {
|
||||
"origin_domain": proxy.origin.domain_name,
|
||||
"origin_domain_normalized": proxy.origin.domain_name.replace("www.", ""),
|
||||
"origin_domain_root": extract(proxy.origin.domain_name).registered_domain,
|
||||
"valid_from": proxy.added.isoformat(),
|
||||
"valid_to": proxy.destroyed.isoformat() if proxy.destroyed else None,
|
||||
"countries": countries,
|
||||
"country": highest_risk_country_code,
|
||||
"risk": highest_risk_level
|
||||
}
|
||||
|
||||
return MirrorMapping(
|
||||
version="1.2",
|
||||
mappings=result,
|
||||
s3_buckets=[
|
||||
f"{current_app.config['GLOBAL_NAMESPACE']}-{g.group_name.lower()}-logs-cloudfront"
|
||||
for g in Group.query.filter(Group.destroyed.is_(None)).all()
|
||||
]
|
||||
).dict()
|
||||
groups = db.session.query(Group).options(selectinload(Group.pools))
|
||||
s3_buckets = [
|
||||
f"{current_app.config['GLOBAL_NAMESPACE']}-{g.group_name.lower()}-logs-cloudfront"
|
||||
for g in groups.filter(Group.destroyed.is_(None)).all()
|
||||
]
|
||||
|
||||
|
||||
if getattr(builtins, "__sphinx_build__", False):
|
||||
schema = MirrorMapping.schema_json()
|
||||
return {
|
||||
"version": "1.2",
|
||||
"mappings": result,
|
||||
"s3_buckets": s3_buckets
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue