From a406a7974b45acc77c69fb9e36d5140cc30a41bb Mon Sep 17 00:00:00 2001 From: irl Date: Fri, 6 Dec 2024 18:15:47 +0000 Subject: [PATCH] lint: reformat python code with black --- app/__init__.py | 150 ++++++++---- app/alarms.py | 13 +- app/api/__init__.py | 18 +- app/api/onion.py | 128 +++++++--- app/api/util.py | 47 ++-- app/api/web.py | 44 ++-- app/brm/brn.py | 28 ++- app/brm/static.py | 34 +-- app/brm/utils.py | 16 +- app/cli/__init__.py | 4 +- app/cli/__main__.py | 7 +- app/cli/automate.py | 99 +++++--- app/cli/db.py | 14 +- app/cli/list.py | 5 +- app/extensions.py | 4 +- app/lists/bc2.py | 52 ++-- app/lists/bridgelines.py | 11 +- app/lists/mirror_mapping.py | 12 +- app/lists/redirector.py | 17 +- app/models/__init__.py | 81 ++++--- app/models/activity.py | 47 ++-- app/models/alarms.py | 24 +- app/models/automation.py | 4 +- app/models/base.py | 53 ++-- app/models/bridges.py | 23 +- app/models/cloud.py | 11 +- app/models/mirrors.py | 160 +++++++++---- app/models/onions.py | 4 +- app/portal/__init__.py | 93 ++++++-- app/portal/automation.py | 87 ++++--- app/portal/bridge.py | 85 ++++--- app/portal/bridgeconf.py | 196 +++++++++------ app/portal/cloud.py | 252 ++++++++++++-------- app/portal/country.py | 57 +++-- app/portal/eotk.py | 29 ++- app/portal/forms.py | 6 +- app/portal/group.py | 48 ++-- app/portal/list.py | 121 ++++++---- app/portal/onion.py | 14 +- app/portal/origin.py | 218 ++++++++++------- app/portal/pool.py | 170 ++++++++----- app/portal/proxy.py | 67 +++--- app/portal/report.py | 91 ++++--- app/portal/smart_proxy.py | 18 +- app/portal/static.py | 232 +++++++++++------- app/portal/storage.py | 41 ++-- app/portal/util.py | 36 +-- app/portal/webhook.py | 77 +++--- app/terraform/__init__.py | 14 +- app/terraform/alarms/eotk_aws.py | 35 +-- app/terraform/alarms/proxy_azure_cdn.py | 23 +- app/terraform/alarms/proxy_cloudfront.py | 33 +-- app/terraform/alarms/proxy_http_status.py | 26 +- app/terraform/alarms/smart_aws.py | 35 +-- app/terraform/block/block_blocky.py | 31 ++- app/terraform/block/block_scriptzteam.py | 3 +- app/terraform/block/bridge.py | 40 +++- app/terraform/block/bridge_bridgelines.py | 4 +- app/terraform/block/bridge_github.py | 14 +- app/terraform/block/bridge_gitlab.py | 15 +- app/terraform/block/bridge_reachability.py | 6 +- app/terraform/block/bridge_roskomsvoboda.py | 4 +- app/terraform/block_external.py | 23 +- app/terraform/block_mirror.py | 25 +- app/terraform/block_ooni.py | 44 ++-- app/terraform/block_roskomsvoboda.py | 62 +++-- app/terraform/bridge/__init__.py | 55 +++-- app/terraform/bridge/gandi.py | 5 +- app/terraform/bridge/hcloud.py | 5 +- app/terraform/bridge/meta.py | 77 ++++-- app/terraform/bridge/ovh.py | 5 +- app/terraform/eotk/aws.py | 62 ++--- app/terraform/list/__init__.py | 32 ++- app/terraform/list/github.py | 4 +- app/terraform/list/gitlab.py | 4 +- app/terraform/list/s3.py | 5 +- app/terraform/proxy/__init__.py | 45 ++-- app/terraform/proxy/azure_cdn.py | 5 +- app/terraform/proxy/cloudfront.py | 33 ++- app/terraform/proxy/fastly.py | 13 +- app/terraform/proxy/meta.py | 101 +++++--- app/terraform/static/aws.py | 35 ++- app/terraform/static/meta.py | 14 +- app/terraform/terraform.py | 91 ++++--- app/tfstate.py | 16 +- app/util/onion.py | 5 +- app/util/x509.py | 88 +++++-- setup.cfg | 2 +- 88 files changed, 2579 insertions(+), 1608 deletions(-) diff --git a/app/__init__.py b/app/__init__.py index dea99a2..596503d 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -6,8 +6,7 @@ import yaml from flask import Flask, redirect, send_from_directory, url_for from flask.typing import ResponseReturnValue from prometheus_client import CollectorRegistry, Metric, make_wsgi_app -from prometheus_client.metrics_core import (CounterMetricFamily, - GaugeMetricFamily) +from prometheus_client.metrics_core import CounterMetricFamily, GaugeMetricFamily from prometheus_client.registry import REGISTRY, Collector from prometheus_flask_exporter import PrometheusMetrics from sqlalchemy import text @@ -28,9 +27,9 @@ app.config.from_file("../config.yaml", load=yaml.safe_load) registry = CollectorRegistry() metrics = PrometheusMetrics(app, registry=registry) -app.wsgi_app = DispatcherMiddleware(app.wsgi_app, { # type: ignore[method-assign] - '/metrics': make_wsgi_app(registry) -}) +app.wsgi_app = DispatcherMiddleware( # type: ignore[method-assign] + app.wsgi_app, {"/metrics": make_wsgi_app(registry)} +) # register default collectors to our new registry collectors = list(REGISTRY._collector_to_names.keys()) @@ -54,12 +53,16 @@ def not_migrating() -> bool: class DefinedProxiesCollector(Collector): def collect(self) -> Iterator[Metric]: with app.app_context(): - ok = GaugeMetricFamily("database_collector", - "Status of a database collector (0: bad, 1: good)", - labels=["collector"]) + ok = GaugeMetricFamily( + "database_collector", + "Status of a database collector (0: bad, 1: good)", + labels=["collector"], + ) try: with db.engine.connect() as conn: - result = conn.execute(text(""" + result = conn.execute( + text( + """ SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name, COUNT(proxy.id) FROM proxy, origin, pool, "group" WHERE proxy.origin_id = origin.id @@ -67,13 +70,24 @@ class DefinedProxiesCollector(Collector): AND proxy.pool_id = pool.id AND proxy.destroyed IS NULL GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name; - """)) - c = GaugeMetricFamily("defined_proxies", "Number of proxies currently defined for deployment", - labels=['group_id', 'group_name', 'provider', 'pool_id', - 'pool_name']) + """ + ) + ) + c = GaugeMetricFamily( + "defined_proxies", + "Number of proxies currently defined for deployment", + labels=[ + "group_id", + "group_name", + "provider", + "pool_id", + "pool_name", + ], + ) for row in result: - c.add_metric([str(row[0]), row[1], row[2], str(row[3]), row[4]], - row[5]) + c.add_metric( + [str(row[0]), row[1], row[2], str(row[3]), row[4]], row[5] + ) yield c ok.add_metric(["defined_proxies"], 1) except SQLAlchemyError: @@ -84,12 +98,16 @@ class DefinedProxiesCollector(Collector): class BlockedProxiesCollector(Collector): def collect(self) -> Iterator[Metric]: with app.app_context(): - ok = GaugeMetricFamily("database_collector", - "Status of a database collector (0: bad, 1: good)", - labels=["collector"]) + ok = GaugeMetricFamily( + "database_collector", + "Status of a database collector (0: bad, 1: good)", + labels=["collector"], + ) try: with db.engine.connect() as conn: - result = conn.execute(text(""" + result = conn.execute( + text( + """ SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name, proxy.deprecation_reason, COUNT(proxy.id) FROM proxy, origin, pool, "group" WHERE proxy.origin_id = origin.id @@ -98,14 +116,26 @@ class BlockedProxiesCollector(Collector): AND proxy.deprecated IS NOT NULL GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name, proxy.deprecation_reason; - """)) - c = CounterMetricFamily("deprecated_proxies", - "Number of proxies deprecated", - labels=['group_id', 'group_name', 'provider', 'pool_id', 'pool_name', - 'deprecation_reason']) + """ + ) + ) + c = CounterMetricFamily( + "deprecated_proxies", + "Number of proxies deprecated", + labels=[ + "group_id", + "group_name", + "provider", + "pool_id", + "pool_name", + "deprecation_reason", + ], + ) for row in result: - c.add_metric([str(row[0]), row[1], row[2], str(row[3]), row[4], row[5]], - row[6]) + c.add_metric( + [str(row[0]), row[1], row[2], str(row[3]), row[4], row[5]], + row[6], + ) yield c ok.add_metric(["deprecated_proxies"], 0) except SQLAlchemyError: @@ -116,24 +146,36 @@ class BlockedProxiesCollector(Collector): class AutomationCollector(Collector): def collect(self) -> Iterator[Metric]: with app.app_context(): - ok = GaugeMetricFamily("database_collector", - "Status of a database collector (0: bad, 1: good)", - labels=["collector"]) + ok = GaugeMetricFamily( + "database_collector", + "Status of a database collector (0: bad, 1: good)", + labels=["collector"], + ) try: - state = GaugeMetricFamily("automation_state", "The automation state (0: idle, 1: running, 2: error)", - labels=['automation_name']) - enabled = GaugeMetricFamily("automation_enabled", - "Whether an automation is enabled (0: disabled, 1: enabled)", - labels=['automation_name']) - next_run = GaugeMetricFamily("automation_next_run", "The timestamp of the next run of the automation", - labels=['automation_name']) - last_run_start = GaugeMetricFamily("automation_last_run_start", - "The timestamp of the last run of the automation ", - labels=['automation_name']) + state = GaugeMetricFamily( + "automation_state", + "The automation state (0: idle, 1: running, 2: error)", + labels=["automation_name"], + ) + enabled = GaugeMetricFamily( + "automation_enabled", + "Whether an automation is enabled (0: disabled, 1: enabled)", + labels=["automation_name"], + ) + next_run = GaugeMetricFamily( + "automation_next_run", + "The timestamp of the next run of the automation", + labels=["automation_name"], + ) + last_run_start = GaugeMetricFamily( + "automation_last_run_start", + "The timestamp of the last run of the automation ", + labels=["automation_name"], + ) automations = Automation.query.all() for automation in automations: - if automation.short_name in app.config['HIDDEN_AUTOMATIONS']: + if automation.short_name in app.config["HIDDEN_AUTOMATIONS"]: continue if automation.state == AutomationState.IDLE: state.add_metric([automation.short_name], 0) @@ -141,13 +183,19 @@ class AutomationCollector(Collector): state.add_metric([automation.short_name], 1) else: state.add_metric([automation.short_name], 2) - enabled.add_metric([automation.short_name], 1 if automation.enabled else 0) + enabled.add_metric( + [automation.short_name], 1 if automation.enabled else 0 + ) if automation.next_run: - next_run.add_metric([automation.short_name], automation.next_run.timestamp()) + next_run.add_metric( + [automation.short_name], automation.next_run.timestamp() + ) else: next_run.add_metric([automation.short_name], 0) if automation.last_run: - last_run_start.add_metric([automation.short_name], automation.last_run.timestamp()) + last_run_start.add_metric( + [automation.short_name], automation.last_run.timestamp() + ) else: last_run_start.add_metric([automation.short_name], 0) yield state @@ -161,31 +209,31 @@ class AutomationCollector(Collector): # register all custom collectors to registry -if not_migrating() and 'DISABLE_METRICS' not in os.environ: +if not_migrating() and "DISABLE_METRICS" not in os.environ: registry.register(DefinedProxiesCollector()) registry.register(BlockedProxiesCollector()) registry.register(AutomationCollector()) -@app.route('/ui') +@app.route("/ui") def redirect_ui() -> ResponseReturnValue: return redirect("/ui/") -@app.route('/ui/', defaults={'path': ''}) -@app.route('/ui/') +@app.route("/ui/", defaults={"path": ""}) +@app.route("/ui/") def serve_ui(path: str) -> ResponseReturnValue: if path != "" and os.path.exists("app/static/ui/" + path): - return send_from_directory('static/ui', path) + return send_from_directory("static/ui", path) else: - return send_from_directory('static/ui', 'index.html') + return send_from_directory("static/ui", "index.html") -@app.route('/') +@app.route("/") def index() -> ResponseReturnValue: # TODO: update to point at new UI when ready return redirect(url_for("portal.portal_home")) -if __name__ == '__main__': +if __name__ == "__main__": app.run() diff --git a/app/alarms.py b/app/alarms.py index 0be0b2c..ee2b507 100644 --- a/app/alarms.py +++ b/app/alarms.py @@ -7,18 +7,15 @@ from app.models.alarms import Alarm def alarms_for(target: BRN) -> List[Alarm]: - return list(Alarm.query.filter( - Alarm.target == str(target) - ).all()) + return list(Alarm.query.filter(Alarm.target == str(target)).all()) -def _get_alarm(target: BRN, - aspect: str, - create_if_missing: bool = True) -> Optional[Alarm]: +def _get_alarm( + target: BRN, aspect: str, create_if_missing: bool = True +) -> Optional[Alarm]: target_str = str(target) alarm: Optional[Alarm] = Alarm.query.filter( - Alarm.aspect == aspect, - Alarm.target == target_str + Alarm.aspect == aspect, Alarm.target == target_str ).first() if create_if_missing and alarm is None: alarm = Alarm() diff --git a/app/api/__init__.py b/app/api/__init__.py index 8ae0071..a61a02f 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -5,34 +5,38 @@ from werkzeug.exceptions import HTTPException from app.api.onion import api_onion from app.api.web import api_web -api = Blueprint('api', __name__) -api.register_blueprint(api_onion, url_prefix='/onion') -api.register_blueprint(api_web, url_prefix='/web') +api = Blueprint("api", __name__) +api.register_blueprint(api_onion, url_prefix="/onion") +api.register_blueprint(api_web, url_prefix="/web") @api.errorhandler(400) def bad_request(error: HTTPException) -> ResponseReturnValue: - response = jsonify({'error': 'Bad Request', 'message': error.description}) + response = jsonify({"error": "Bad Request", "message": error.description}) response.status_code = 400 return response @api.errorhandler(401) def unauthorized(error: HTTPException) -> ResponseReturnValue: - response = jsonify({'error': 'Unauthorized', 'message': error.description}) + response = jsonify({"error": "Unauthorized", "message": error.description}) response.status_code = 401 return response @api.errorhandler(404) def not_found(_: HTTPException) -> ResponseReturnValue: - response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'}) + response = jsonify( + {"error": "Not found", "message": "Resource could not be found."} + ) response.status_code = 404 return response @api.errorhandler(500) def internal_server_error(_: HTTPException) -> ResponseReturnValue: - response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'}) + response = jsonify( + {"error": "Internal Server Error", "message": "An unexpected error occurred."} + ) response.status_code = 500 return response diff --git a/app/api/onion.py b/app/api/onion.py index c284161..8949b0f 100644 --- a/app/api/onion.py +++ b/app/api/onion.py @@ -7,31 +7,37 @@ from flask import Blueprint, abort, jsonify, request from flask.typing import ResponseReturnValue from sqlalchemy import exc -from app.api.util import (DOMAIN_NAME_REGEX, MAX_ALLOWED_ITEMS, - MAX_DOMAIN_NAME_LENGTH, ListFilter, - get_single_resource, list_resources, - validate_description) +from app.api.util import ( + DOMAIN_NAME_REGEX, + MAX_ALLOWED_ITEMS, + MAX_DOMAIN_NAME_LENGTH, + ListFilter, + get_single_resource, + list_resources, + validate_description, +) from app.extensions import db from app.models.base import Group from app.models.onions import Onion from app.util.onion import decode_onion_keys, onion_hostname from app.util.x509 import validate_tls_keys -api_onion = Blueprint('api_onion', __name__) +api_onion = Blueprint("api_onion", __name__) -@api_onion.route('/onion', methods=['GET']) +@api_onion.route("/onion", methods=["GET"]) def list_onions() -> ResponseReturnValue: - domain_name_filter = request.args.get('DomainName') - group_id_filter = request.args.get('GroupId') + domain_name_filter = request.args.get("DomainName") + group_id_filter = request.args.get("GroupId") - filters: List[ListFilter] = [ - (Onion.destroyed.is_(None)) - ] + filters: List[ListFilter] = [(Onion.destroyed.is_(None))] if domain_name_filter: if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: - abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") + abort( + 400, + description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.", + ) if not DOMAIN_NAME_REGEX.match(domain_name_filter): abort(400, description="DomainName contains invalid characters.") filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%")) @@ -46,9 +52,9 @@ def list_onions() -> ResponseReturnValue: Onion, lambda onion: onion.to_dict(), filters=filters, - resource_name='OnionsList', + resource_name="OnionsList", max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', + protective_marking="amber", ) @@ -71,13 +77,26 @@ def create_onion() -> ResponseReturnValue: abort(400) errors = [] - for field in ["DomainName", "Description", "OnionPrivateKey", "OnionPublicKey", "GroupId", "TlsPrivateKey", - "TlsCertificate"]: + for field in [ + "DomainName", + "Description", + "OnionPrivateKey", + "OnionPublicKey", + "GroupId", + "TlsPrivateKey", + "TlsCertificate", + ]: if not data.get(field): - errors.append({"Error": f"{field}_missing", "Message": f"Missing required field: {field}"}) + errors.append( + { + "Error": f"{field}_missing", + "Message": f"Missing required field: {field}", + } + ) - onion_private_key, onion_public_key, onion_errors = decode_onion_keys(data["OnionPrivateKey"], - data["OnionPublicKey"]) + onion_private_key, onion_public_key, onion_errors = decode_onion_keys( + data["OnionPrivateKey"], data["OnionPublicKey"] + ) if onion_errors: errors.extend(onion_errors) @@ -85,23 +104,35 @@ def create_onion() -> ResponseReturnValue: return jsonify({"Errors": errors}), 400 if onion_private_key: - existing_onion = db.session.query(Onion).where( - Onion.onion_private_key == onion_private_key, - Onion.destroyed.is_(None), - ).first() + existing_onion = ( + db.session.query(Onion) + .where( + Onion.onion_private_key == onion_private_key, + Onion.destroyed.is_(None), + ) + .first() + ) if existing_onion: errors.append( - {"Error": "duplicate_onion_key", "Message": "An onion service with this private key already exists."}) + { + "Error": "duplicate_onion_key", + "Message": "An onion service with this private key already exists.", + } + ) if "GroupId" in data: group = Group.query.get(data["GroupId"]) if not group: - errors.append({"Error": "group_id_not_found", "Message": "Invalid group ID."}) + errors.append( + {"Error": "group_id_not_found", "Message": "Invalid group ID."} + ) chain, san_list, tls_errors = validate_tls_keys( - data["TlsPrivateKey"], data["TlsCertificate"], data.get("SkipChainVerification"), + data["TlsPrivateKey"], + data["TlsCertificate"], + data.get("SkipChainVerification"), data.get("SkipNameVerification"), - f"{onion_hostname(onion_public_key)}.onion" + f"{onion_hostname(onion_public_key)}.onion", ) if tls_errors: @@ -123,15 +154,21 @@ def create_onion() -> ResponseReturnValue: added=datetime.now(timezone.utc), updated=datetime.now(timezone.utc), cert_expiry=cert_expiry_date, - cert_sans=",".join(san_list) + cert_sans=",".join(san_list), ) try: db.session.add(onion) db.session.commit() - return jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), 201 + return ( + jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), + 201, + ) except exc.SQLAlchemyError as e: - return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500 + return ( + jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), + 500, + ) class UpdateOnionRequest(TypedDict): @@ -152,8 +189,19 @@ def update_onion(onion_id: int) -> ResponseReturnValue: onion = Onion.query.get(onion_id) if not onion: - return jsonify( - {"Errors": [{"Error": "onion_not_found", "Message": f"No Onion service found with ID {onion_id}"}]}), 404 + return ( + jsonify( + { + "Errors": [ + { + "Error": "onion_not_found", + "Message": f"No Onion service found with ID {onion_id}", + } + ] + } + ), + 404, + ) if "Description" in data: description = data["Description"] @@ -161,7 +209,12 @@ def update_onion(onion_id: int) -> ResponseReturnValue: if validate_description(description): onion.description = description else: - errors.append({"Error": "description_error", "Message": "Description field is invalid"}) + errors.append( + { + "Error": "description_error", + "Message": "Description field is invalid", + } + ) tls_private_key_pem: Optional[str] = None tls_certificate_pem: Optional[str] = None @@ -176,7 +229,9 @@ def update_onion(onion_id: int) -> ResponseReturnValue: tls_private_key_pem = onion.tls_private_key.decode("utf-8") chain, san_list, tls_errors = validate_tls_keys( - tls_private_key_pem, tls_certificate_pem, data.get("SkipChainVerification", False), + tls_private_key_pem, + tls_certificate_pem, + data.get("SkipChainVerification", False), data.get("SkipNameVerification", False), f"{onion_hostname(onion.onion_public_key)}.onion", ) @@ -200,7 +255,10 @@ def update_onion(onion_id: int) -> ResponseReturnValue: db.session.commit() return jsonify({"Message": "Onion service updated successfully."}), 200 except exc.SQLAlchemyError as e: - return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500 + return ( + jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), + 500, + ) @api_onion.route("/onion/", methods=["GET"]) diff --git a/app/api/util.py b/app/api/util.py index fd43a1b..15067a4 100644 --- a/app/api/util.py +++ b/app/api/util.py @@ -12,7 +12,7 @@ from app.extensions import db logger = logging.getLogger(__name__) MAX_DOMAIN_NAME_LENGTH = 255 -DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$') +DOMAIN_NAME_REGEX = re.compile(r"^[a-zA-Z0-9.\-]*$") MAX_ALLOWED_ITEMS = 100 ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]] @@ -24,7 +24,10 @@ def validate_max_items(max_items_str: str, max_allowed: int) -> int: raise ValueError() return max_items except ValueError: - abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.") + abort( + 400, + description=f"MaxItems must be a positive integer not exceeding {max_allowed}.", + ) def validate_marker(marker_str: str) -> int: @@ -47,21 +50,22 @@ TlpMarkings = Union[ def list_resources( # pylint: disable=too-many-arguments,too-many-locals - model: Type[Any], - serialize_func: Callable[[Any], Dict[str, Any]], - *, - filters: Optional[List[ListFilter]] = None, - order_by: Optional[ColumnElement[Any]] = None, - resource_name: str = 'ResourceList', - max_items_param: str = 'MaxItems', - marker_param: str = 'Marker', - max_allowed_items: int = 100, - protective_marking: TlpMarkings = 'default', + model: Type[Any], + serialize_func: Callable[[Any], Dict[str, Any]], + *, + filters: Optional[List[ListFilter]] = None, + order_by: Optional[ColumnElement[Any]] = None, + resource_name: str = "ResourceList", + max_items_param: str = "MaxItems", + marker_param: str = "Marker", + max_allowed_items: int = 100, + protective_marking: TlpMarkings = "default", ) -> ResponseReturnValue: try: marker = request.args.get(marker_param) max_items = validate_max_items( - request.args.get(max_items_param, default='100'), max_allowed_items) + request.args.get(max_items_param, default="100"), max_allowed_items + ) query = select(model) if filters: @@ -101,14 +105,21 @@ def list_resources( # pylint: disable=too-many-arguments,too-many-locals abort(500) -def get_single_resource(model: Type[Any], id_: int, resource_name: str) -> ResponseReturnValue: +def get_single_resource( + model: Type[Any], id_: int, resource_name: str +) -> ResponseReturnValue: try: resource = db.session.get(model, id_) if not resource: - return jsonify({ - "Error": "resource_not_found", - "Message": f"No {resource_name} found with ID {id_}" - }), 404 + return ( + jsonify( + { + "Error": "resource_not_found", + "Message": f"No {resource_name} found with ID {id_}", + } + ), + 404, + ) return jsonify({resource_name: resource.to_dict()}), 200 except Exception: # pylint: disable=broad-exception-caught logger.exception("An unexpected error occurred while retrieving the onion") diff --git a/app/api/web.py b/app/api/web.py index 09ecfaa..09c7191 100644 --- a/app/api/web.py +++ b/app/api/web.py @@ -4,35 +4,43 @@ from typing import List from flask import Blueprint, abort, request from flask.typing import ResponseReturnValue -from app.api.util import (DOMAIN_NAME_REGEX, MAX_ALLOWED_ITEMS, - MAX_DOMAIN_NAME_LENGTH, ListFilter, list_resources) +from app.api.util import ( + DOMAIN_NAME_REGEX, + MAX_ALLOWED_ITEMS, + MAX_DOMAIN_NAME_LENGTH, + ListFilter, + list_resources, +) from app.models.base import Group from app.models.mirrors import Origin, Proxy -api_web = Blueprint('web', __name__) +api_web = Blueprint("web", __name__) -@api_web.route('/group', methods=['GET']) +@api_web.route("/group", methods=["GET"]) def list_groups() -> ResponseReturnValue: return list_resources( Group, lambda group: group.to_dict(), - resource_name='OriginGroupList', + resource_name="OriginGroupList", max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', + protective_marking="amber", ) -@api_web.route('/origin', methods=['GET']) +@api_web.route("/origin", methods=["GET"]) def list_origins() -> ResponseReturnValue: - domain_name_filter = request.args.get('DomainName') - group_id_filter = request.args.get('GroupId') + domain_name_filter = request.args.get("DomainName") + group_id_filter = request.args.get("GroupId") filters: List[ListFilter] = [] if domain_name_filter: if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: - abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") + abort( + 400, + description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.", + ) if not DOMAIN_NAME_REGEX.match(domain_name_filter): abort(400, description="DomainName contains invalid characters.") filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%")) @@ -47,18 +55,18 @@ def list_origins() -> ResponseReturnValue: Origin, lambda origin: origin.to_dict(), filters=filters, - resource_name='OriginsList', + resource_name="OriginsList", max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', + protective_marking="amber", ) -@api_web.route('/mirror', methods=['GET']) +@api_web.route("/mirror", methods=["GET"]) def list_mirrors() -> ResponseReturnValue: filters = [] twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) - status_filter = request.args.get('Status') + status_filter = request.args.get("Status") if status_filter: if status_filter == "pending": filters.append(Proxy.url.is_(None)) @@ -74,13 +82,15 @@ def list_mirrors() -> ResponseReturnValue: if status_filter == "destroyed": filters.append(Proxy.destroyed > twenty_four_hours_ago) else: - filters.append((Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago)) + filters.append( + (Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago) + ) return list_resources( Proxy, lambda proxy: proxy.to_dict(), filters=filters, - resource_name='MirrorsList', + resource_name="MirrorsList", max_allowed_items=MAX_ALLOWED_ITEMS, - protective_marking='amber', + protective_marking="amber", ) diff --git a/app/brm/brn.py b/app/brm/brn.py index f9a167f..6b4604a 100644 --- a/app/brm/brn.py +++ b/app/brm/brn.py @@ -29,31 +29,37 @@ class BRN: def from_str(cls, string: str) -> BRN: parts = string.split(":") if len(parts) != 6 or parts[0].lower() != "brn" or not is_integer(parts[2]): - raise TypeError(f"Expected a valid BRN but got {repr(string)} (invalid parts).") + raise TypeError( + f"Expected a valid BRN but got {repr(string)} (invalid parts)." + ) resource_parts = parts[5].split("/") if len(resource_parts) != 2: - raise TypeError(f"Expected a valid BRN but got {repr(string)} (invalid resource parts).") + raise TypeError( + f"Expected a valid BRN but got {repr(string)} (invalid resource parts)." + ) return cls( global_namespace=parts[1], group_id=int(parts[2]), product=parts[3], provider=parts[4], resource_type=resource_parts[0], - resource_id=resource_parts[1] + resource_id=resource_parts[1], ) def __eq__(self, other: Any) -> bool: return str(self) == str(other) def __str__(self) -> str: - return ":".join([ - "brn", - self.global_namespace, - str(self.group_id), - self.product, - self.provider, - f"{self.resource_type}/{self.resource_id}" - ]) + return ":".join( + [ + "brn", + self.global_namespace, + str(self.group_id), + self.product, + self.provider, + f"{self.resource_type}/{self.resource_id}", + ] + ) def __repr__(self) -> str: return f"" diff --git a/app/brm/static.py b/app/brm/static.py index c622ec6..3d333fe 100644 --- a/app/brm/static.py +++ b/app/brm/static.py @@ -9,18 +9,18 @@ from app.models.mirrors import StaticOrigin def create_static_origin( - description: str, - group_id: int, - storage_cloud_account_id: int, - source_cloud_account_id: int, - source_project: str, - auto_rotate: bool, - matrix_homeserver: Optional[str], - keanu_convene_path: Optional[str], - keanu_convene_logo: Optional[FileStorage], - keanu_convene_color: Optional[str], - clean_insights_backend: Optional[Union[str, bool]], - db_session_commit: bool = False, + description: str, + group_id: int, + storage_cloud_account_id: int, + source_cloud_account_id: int, + source_project: str, + auto_rotate: bool, + matrix_homeserver: Optional[str], + keanu_convene_path: Optional[str], + keanu_convene_logo: Optional[FileStorage], + keanu_convene_color: Optional[str], + clean_insights_backend: Optional[Union[str, bool]], + db_session_commit: bool = False, ) -> StaticOrigin: """ Create a new static origin. @@ -47,14 +47,18 @@ def create_static_origin( else: raise ValueError("group_id must be an int") if isinstance(storage_cloud_account_id, int): - cloud_account = CloudAccount.query.filter(CloudAccount.id == storage_cloud_account_id).first() + cloud_account = CloudAccount.query.filter( + CloudAccount.id == storage_cloud_account_id + ).first() if cloud_account is None: raise ValueError("storage_cloud_account_id must match an existing provider") static_origin.storage_cloud_account_id = storage_cloud_account_id else: raise ValueError("storage_cloud_account_id must be an int") if isinstance(source_cloud_account_id, int): - cloud_account = CloudAccount.query.filter(CloudAccount.id == source_cloud_account_id).first() + cloud_account = CloudAccount.query.filter( + CloudAccount.id == source_cloud_account_id + ).first() if cloud_account is None: raise ValueError("source_cloud_account_id must match an existing provider") static_origin.source_cloud_account_id = source_cloud_account_id @@ -69,7 +73,7 @@ def create_static_origin( keanu_convene_logo, keanu_convene_color, clean_insights_backend, - False + False, ) if db_session_commit: db.session.add(static_origin) diff --git a/app/brm/utils.py b/app/brm/utils.py index c87e3a1..037e624 100644 --- a/app/brm/utils.py +++ b/app/brm/utils.py @@ -26,7 +26,9 @@ def is_integer(contender: Any) -> bool: return float(contender).is_integer() -def thumbnail_uploaded_image(file: FileStorage, max_size: Tuple[int, int] = (256, 256)) -> bytes: +def thumbnail_uploaded_image( + file: FileStorage, max_size: Tuple[int, int] = (256, 256) +) -> bytes: """ Process an uploaded image file into a resized image of a specific size. @@ -39,7 +41,9 @@ def thumbnail_uploaded_image(file: FileStorage, max_size: Tuple[int, int] = (256 img = Image.open(file) img.thumbnail(max_size) byte_arr = BytesIO() - img.save(byte_arr, format='PNG' if file.filename.lower().endswith('.png') else 'JPEG') + img.save( + byte_arr, format="PNG" if file.filename.lower().endswith(".png") else "JPEG" + ) return byte_arr.getvalue() @@ -52,9 +56,11 @@ def create_data_uri(bytes_data: bytes, file_extension: str) -> str: :return: A data URI representing the image. """ # base64 encode - encoded = base64.b64encode(bytes_data).decode('ascii') + encoded = base64.b64encode(bytes_data).decode("ascii") # create data URI - data_uri = "data:image/{};base64,{}".format('jpeg' if file_extension == 'jpg' else file_extension, encoded) + data_uri = "data:image/{};base64,{}".format( + "jpeg" if file_extension == "jpg" else file_extension, encoded + ) return data_uri @@ -80,7 +86,7 @@ def normalize_color(color: str) -> str: return webcolors.name_to_hex(color) # type: ignore[no-any-return] except ValueError: pass - if color.startswith('#'): + if color.startswith("#"): color = color[1:].lower() if len(color) in [3, 6]: try: diff --git a/app/cli/__init__.py b/app/cli/__init__.py index 8db1e74..7119ea7 100644 --- a/app/cli/__init__.py +++ b/app/cli/__init__.py @@ -3,7 +3,9 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - _SubparserType = argparse._SubParsersAction[argparse.ArgumentParser] # pylint: disable=protected-access + _SubparserType = argparse._SubParsersAction[ + argparse.ArgumentParser + ] # pylint: disable=protected-access else: _SubparserType = Any diff --git a/app/cli/__main__.py b/app/cli/__main__.py index ff38806..9a1f733 100644 --- a/app/cli/__main__.py +++ b/app/cli/__main__.py @@ -13,7 +13,9 @@ def parse_args(argv: List[str]) -> None: if basename(argv[0]) == "__main__.py": argv[0] = "bypass" parser = argparse.ArgumentParser() - parser.add_argument("-v", "--verbose", help="increase logging verbosity", action="store_true") + parser.add_argument( + "-v", "--verbose", help="increase logging verbosity", action="store_true" + ) subparsers = parser.add_subparsers(title="command", help="command to run") AutomateCliHandler.add_subparser_to(subparsers) DbCliHandler.add_subparser_to(subparsers) @@ -28,7 +30,6 @@ def parse_args(argv: List[str]) -> None: if __name__ == "__main__": VERBOSE = "-v" in sys.argv or "--verbose" in sys.argv - logging.basicConfig( - level=logging.DEBUG if VERBOSE else logging.INFO) + logging.basicConfig(level=logging.DEBUG if VERBOSE else logging.INFO) logging.debug("Arguments: %s", sys.argv) parse_args(sys.argv) diff --git a/app/cli/automate.py b/app/cli/automate.py index a4a198e..a0146e6 100644 --- a/app/cli/automate.py +++ b/app/cli/automate.py @@ -14,18 +14,14 @@ from app.models.automation import Automation, AutomationLogs, AutomationState from app.terraform import BaseAutomation from app.terraform.alarms.eotk_aws import AlarmEotkAwsAutomation from app.terraform.alarms.proxy_azure_cdn import AlarmProxyAzureCdnAutomation -from app.terraform.alarms.proxy_cloudfront import \ - AlarmProxyCloudfrontAutomation -from app.terraform.alarms.proxy_http_status import \ - AlarmProxyHTTPStatusAutomation +from app.terraform.alarms.proxy_cloudfront import AlarmProxyCloudfrontAutomation +from app.terraform.alarms.proxy_http_status import AlarmProxyHTTPStatusAutomation from app.terraform.alarms.smart_aws import AlarmSmartAwsAutomation from app.terraform.block.block_blocky import BlockBlockyAutomation -from app.terraform.block.block_scriptzteam import \ - BlockBridgeScriptzteamAutomation +from app.terraform.block.block_scriptzteam import BlockBridgeScriptzteamAutomation from app.terraform.block.bridge_github import BlockBridgeGitHubAutomation from app.terraform.block.bridge_gitlab import BlockBridgeGitlabAutomation -from app.terraform.block.bridge_roskomsvoboda import \ - BlockBridgeRoskomsvobodaAutomation +from app.terraform.block.bridge_roskomsvoboda import BlockBridgeRoskomsvobodaAutomation from app.terraform.block_external import BlockExternalAutomation from app.terraform.block_ooni import BlockOONIAutomation from app.terraform.block_roskomsvoboda import BlockRoskomsvobodaAutomation @@ -58,12 +54,10 @@ jobs = { BlockExternalAutomation, BlockOONIAutomation, BlockRoskomsvobodaAutomation, - # Create new resources BridgeMetaAutomation, StaticMetaAutomation, ProxyMetaAutomation, - # Terraform BridgeAWSAutomation, BridgeGandiAutomation, @@ -74,14 +68,12 @@ jobs = { ProxyAzureCdnAutomation, ProxyCloudfrontAutomation, ProxyFastlyAutomation, - # Import alarms AlarmEotkAwsAutomation, AlarmProxyAzureCdnAutomation, AlarmProxyCloudfrontAutomation, AlarmProxyHTTPStatusAutomation, AlarmSmartAwsAutomation, - # Update lists ListGithubAutomation, ListGitlabAutomation, @@ -103,9 +95,12 @@ def run_all(**kwargs: bool) -> None: run_job(job, **kwargs) -def run_job(job_cls: Type[BaseAutomation], *, - force: bool = False, ignore_schedule: bool = False) -> None: - automation = Automation.query.filter(Automation.short_name == job_cls.short_name).first() +def run_job( + job_cls: Type[BaseAutomation], *, force: bool = False, ignore_schedule: bool = False +) -> None: + automation = Automation.query.filter( + Automation.short_name == job_cls.short_name + ).first() if automation is None: automation = Automation() automation.short_name = job_cls.short_name @@ -121,18 +116,24 @@ def run_job(job_cls: Type[BaseAutomation], *, logging.warning("Not running an already running automation") return if not ignore_schedule and not force: - if automation.next_run is not None and automation.next_run > datetime.now(tz=timezone.utc): + if automation.next_run is not None and automation.next_run > datetime.now( + tz=timezone.utc + ): logging.warning("Not time to run this job yet") return if not automation.enabled and not force: - logging.warning("job %s is disabled and --force not specified", job_cls.short_name) + logging.warning( + "job %s is disabled and --force not specified", job_cls.short_name + ) return automation.state = AutomationState.RUNNING db.session.commit() try: - if 'TERRAFORM_DIRECTORY' in app.config: - working_dir = os.path.join(app.config['TERRAFORM_DIRECTORY'], - job_cls.short_name or job_cls.__class__.__name__.lower()) + if "TERRAFORM_DIRECTORY" in app.config: + working_dir = os.path.join( + app.config["TERRAFORM_DIRECTORY"], + job_cls.short_name or job_cls.__class__.__name__.lower(), + ) else: working_dir = tempfile.mkdtemp() job: BaseAutomation = job_cls(working_dir) @@ -150,8 +151,9 @@ def run_job(job_cls: Type[BaseAutomation], *, if job is not None and success: automation.state = AutomationState.IDLE automation.next_run = datetime.now(tz=timezone.utc) + timedelta( - minutes=getattr(job, "frequency", 7)) - if 'TERRAFORM_DIRECTORY' not in app.config and working_dir is not None: + minutes=getattr(job, "frequency", 7) + ) + if "TERRAFORM_DIRECTORY" not in app.config and working_dir is not None: # We used a temporary working directory shutil.rmtree(working_dir) else: @@ -165,7 +167,7 @@ def run_job(job_cls: Type[BaseAutomation], *, "list_gitlab", "block_blocky", "block_external", - "block_ooni" + "block_ooni", ] if job.short_name not in safe_jobs: automation.enabled = False @@ -179,10 +181,12 @@ def run_job(job_cls: Type[BaseAutomation], *, db.session.commit() activity = Activity( activity_type="automation", - text=(f"[{automation.short_name}] 🚨 Automation failure: It was not possible to handle this failure safely " - "and so the automation task has been automatically disabled. It may be possible to simply re-enable " - "the task, but repeated failures will usually require deeper investigation. See logs for full " - "details.") + text=( + f"[{automation.short_name}] 🚨 Automation failure: It was not possible to handle this failure safely " + "and so the automation task has been automatically disabled. It may be possible to simply re-enable " + "the task, but repeated failures will usually require deeper investigation. See logs for full " + "details." + ), ) db.session.add(activity) activity.notify() # Notify before commit because the failure occurred even if we can't commit. @@ -194,20 +198,43 @@ class AutomateCliHandler(BaseCliHandler): @classmethod def add_subparser_to(cls, subparsers: _SubparserType) -> None: parser = subparsers.add_parser("automate", help="automation operations") - parser.add_argument("-a", "--all", dest="all", help="run all automation jobs", action="store_true") - parser.add_argument("-j", "--job", dest="job", choices=sorted(jobs.keys()), - help="run a specific automation job") - parser.add_argument("--force", help="run job even if disabled and it's not time yet", action="store_true") - parser.add_argument("--ignore-schedule", help="run job even if it's not time yet", action="store_true") + parser.add_argument( + "-a", + "--all", + dest="all", + help="run all automation jobs", + action="store_true", + ) + parser.add_argument( + "-j", + "--job", + dest="job", + choices=sorted(jobs.keys()), + help="run a specific automation job", + ) + parser.add_argument( + "--force", + help="run job even if disabled and it's not time yet", + action="store_true", + ) + parser.add_argument( + "--ignore-schedule", + help="run job even if it's not time yet", + action="store_true", + ) parser.set_defaults(cls=cls) def run(self) -> None: with app.app_context(): if self.args.job: - run_job(jobs[self.args.job], - force=self.args.force, - ignore_schedule=self.args.ignore_schedule) + run_job( + jobs[self.args.job], + force=self.args.force, + ignore_schedule=self.args.ignore_schedule, + ) elif self.args.all: - run_all(force=self.args.force, ignore_schedule=self.args.ignore_schedule) + run_all( + force=self.args.force, ignore_schedule=self.args.ignore_schedule + ) else: logging.error("No action requested") diff --git a/app/cli/db.py b/app/cli/db.py index c4f518f..480ebf0 100644 --- a/app/cli/db.py +++ b/app/cli/db.py @@ -40,7 +40,7 @@ models: List[Model] = [ Eotk, MirrorList, TerraformState, - Webhook + Webhook, ] @@ -53,7 +53,7 @@ class ExportEncoder(json.JSONEncoder): if isinstance(o, AutomationState): return o.name if isinstance(o, bytes): - return base64.encodebytes(o).decode('utf-8') + return base64.encodebytes(o).decode("utf-8") if isinstance(o, (datetime.datetime, datetime.date, datetime.time)): return o.isoformat() return super().default(o) @@ -82,7 +82,7 @@ def db_export() -> None: decoder: Dict[str, Callable[[Any], Any]] = { "AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x), "AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x), - "bytes": lambda x: base64.decodebytes(x.encode('utf-8')), + "bytes": lambda x: base64.decodebytes(x.encode("utf-8")), "datetime": datetime.datetime.fromisoformat, "int": int, "str": lambda x: x, @@ -110,8 +110,12 @@ class DbCliHandler(BaseCliHandler): @classmethod def add_subparser_to(cls, subparsers: _SubparserType) -> None: parser = subparsers.add_parser("db", help="database operations") - parser.add_argument("--export", help="export data to JSON format", action="store_true") - parser.add_argument("--import", help="import data from JSON format", action="store_true") + parser.add_argument( + "--export", help="export data to JSON format", action="store_true" + ) + parser.add_argument( + "--import", help="import data from JSON format", action="store_true" + ) parser.set_defaults(cls=cls) def run(self) -> None: diff --git a/app/cli/list.py b/app/cli/list.py index fad43dc..11e8b73 100644 --- a/app/cli/list.py +++ b/app/cli/list.py @@ -17,8 +17,9 @@ class ListCliHandler(BaseCliHandler): @classmethod def add_subparser_to(cls, subparsers: _SubparserType) -> None: parser = subparsers.add_parser("list", help="list operations") - parser.add_argument("--dump", choices=sorted(lists.keys()), - help="dump a list in JSON format") + parser.add_argument( + "--dump", choices=sorted(lists.keys()), help="dump a list in JSON format" + ) parser.set_defaults(cls=cls) def run(self) -> None: diff --git a/app/extensions.py b/app/extensions.py index 257b535..f2d8865 100644 --- a/app/extensions.py +++ b/app/extensions.py @@ -4,11 +4,11 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy import MetaData convention = { - "ix": 'ix_%(column_0_label)s', + "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s" + "pk": "pk_%(table_name)s", } metadata = MetaData(naming_convention=convention) diff --git a/app/lists/bc2.py b/app/lists/bc2.py index c82580e..8942433 100644 --- a/app/lists/bc2.py +++ b/app/lists/bc2.py @@ -26,13 +26,15 @@ def onion_alternative(origin: Origin) -> List[BC2Alternative]: url: Optional[str] = origin.onion() if url is None: return [] - return [{ - "proto": "tor", - "type": "eotk", - "created_at": str(origin.added), - "updated_at": str(origin.updated), - "url": url - }] + return [ + { + "proto": "tor", + "type": "eotk", + "created_at": str(origin.added), + "updated_at": str(origin.updated), + "url": url, + } + ] def proxy_alternative(proxy: Proxy) -> Optional[BC2Alternative]: @@ -43,43 +45,51 @@ def proxy_alternative(proxy: Proxy) -> Optional[BC2Alternative]: "type": "mirror", "created_at": proxy.added.isoformat(), "updated_at": proxy.updated.isoformat(), - "url": proxy.url + "url": proxy.url, } def main_domain(origin: Origin) -> str: description: str = origin.description if description.startswith("proxy:"): - return description[len("proxy:"):].replace("www.", "") + return description[len("proxy:") :].replace("www.", "") domain_name: str = origin.domain_name return domain_name.replace("www.", "") def active_proxies(origin: Origin, pool: Pool) -> List[Proxy]: return [ - proxy for proxy in origin.proxies - if proxy.url is not None and not proxy.deprecated and not proxy.destroyed and proxy.pool_id == pool.id + proxy + for proxy in origin.proxies + if proxy.url is not None + and not proxy.deprecated + and not proxy.destroyed + and proxy.pool_id == pool.id ] def mirror_sites(pool: Pool) -> BypassCensorship2: - origins = Origin.query.filter(Origin.destroyed.is_(None)).order_by(Origin.domain_name).all() + origins = ( + Origin.query.filter(Origin.destroyed.is_(None)) + .order_by(Origin.domain_name) + .all() + ) sites: List[BC2Site] = [] for origin in origins: # Gather alternatives, filtering out None values from proxy_alternative alternatives = onion_alternative(origin) + [ - alt for proxy in active_proxies(origin, pool) + alt + for proxy in active_proxies(origin, pool) if (alt := proxy_alternative(proxy)) is not None ] # Add the site dictionary to the list - sites.append({ - "main_domain": main_domain(origin), - "available_alternatives": list(alternatives) - }) + sites.append( + { + "main_domain": main_domain(origin), + "available_alternatives": list(alternatives), + } + ) - return { - "version": "2.0", - "sites": sites - } + return {"version": "2.0", "sites": sites} diff --git a/app/lists/bridgelines.py b/app/lists/bridgelines.py index 98de134..baadc7f 100644 --- a/app/lists/bridgelines.py +++ b/app/lists/bridgelines.py @@ -11,12 +11,14 @@ class BridgelinesDict(TypedDict): bridgelines: List[str] -def bridgelines(pool: Pool, *, distribution_method: Optional[str] = None) -> BridgelinesDict: +def bridgelines( + pool: Pool, *, distribution_method: Optional[str] = None +) -> BridgelinesDict: # Fetch bridges with selectinload for related data query = Bridge.query.options(selectinload(Bridge.conf)).filter( Bridge.destroyed.is_(None), Bridge.deprecated.is_(None), - Bridge.bridgeline.is_not(None) + Bridge.bridgeline.is_not(None), ) if distribution_method is not None: @@ -26,7 +28,4 @@ def bridgelines(pool: Pool, *, distribution_method: Optional[str] = None) -> Bri bridgelines = [b.bridgeline for b in query.all() if b.conf.pool_id == pool.id] # Return dictionary directly, inlining the previous `to_dict` functionality - return { - "version": "1.0", - "bridgelines": bridgelines - } + return {"version": "1.0", "bridgelines": bridgelines} diff --git a/app/lists/mirror_mapping.py b/app/lists/mirror_mapping.py index f191d8a..faf0f01 100644 --- a/app/lists/mirror_mapping.py +++ b/app/lists/mirror_mapping.py @@ -48,7 +48,9 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping: countries = proxy.origin.risk_level if countries: - highest_risk_country_code, highest_risk_level = max(countries.items(), key=lambda x: x[1]) + highest_risk_country_code, highest_risk_level = max( + countries.items(), key=lambda x: x[1] + ) else: highest_risk_country_code = "ZZ" highest_risk_level = 0 @@ -61,7 +63,7 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping: "valid_to": proxy.destroyed.isoformat() if proxy.destroyed else None, "countries": countries, "country": highest_risk_country_code, - "risk": highest_risk_level + "risk": highest_risk_level, } groups = db.session.query(Group).options(selectinload(Group.pools)) @@ -70,8 +72,4 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping: for g in groups.filter(Group.destroyed.is_(None)).all() ] - return { - "version": "1.2", - "mappings": result, - "s3_buckets": s3_buckets - } + return {"version": "1.2", "mappings": result, "s3_buckets": s3_buckets} diff --git a/app/lists/redirector.py b/app/lists/redirector.py index 263ec3e..93fa949 100644 --- a/app/lists/redirector.py +++ b/app/lists/redirector.py @@ -26,15 +26,17 @@ def redirector_pool_origins(pool: Pool) -> Dict[str, str]: Proxy.deprecated.is_(None), Proxy.destroyed.is_(None), Proxy.url.is_not(None), - Proxy.pool_id == pool.id + Proxy.pool_id == pool.id, ) } def redirector_data(_: Optional[Pool]) -> RedirectorData: - active_pools = Pool.query.options( - selectinload(Pool.proxies) - ).filter(Pool.destroyed.is_(None)).all() + active_pools = ( + Pool.query.options(selectinload(Pool.proxies)) + .filter(Pool.destroyed.is_(None)) + .all() + ) pools: List[RedirectorPool] = [ { @@ -42,12 +44,9 @@ def redirector_data(_: Optional[Pool]) -> RedirectorData: "description": pool.description, "api_key": pool.api_key, "redirector_domain": pool.redirector_domain, - "origins": redirector_pool_origins(pool) + "origins": redirector_pool_origins(pool), } for pool in active_pools ] - return { - "version": "1.0", - "pools": pools - } + return {"version": "1.0", "pools": pools} diff --git a/app/models/__init__.py b/app/models/__init__.py index 73207ad..fabbe13 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -17,7 +17,9 @@ class AbstractConfiguration(db.Model): # type: ignore description: Mapped[str] added: Mapped[datetime] = mapped_column(AwareDateTime()) updated: Mapped[datetime] = mapped_column(AwareDateTime()) - destroyed: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) + destroyed: Mapped[Optional[datetime]] = mapped_column( + AwareDateTime(), nullable=True + ) @property @abstractmethod @@ -30,14 +32,10 @@ class AbstractConfiguration(db.Model): # type: ignore @classmethod def csv_header(cls) -> List[str]: - return [ - "id", "description", "added", "updated", "destroyed" - ] + return ["id", "description", "added", "updated", "destroyed"] def csv_row(self) -> List[Any]: - return [ - getattr(self, x) for x in self.csv_header() - ] + return [getattr(self, x) for x in self.csv_header()] class Deprecation(db.Model): # type: ignore[name-defined,misc] @@ -51,7 +49,8 @@ class Deprecation(db.Model): # type: ignore[name-defined,misc] @property def resource(self) -> "AbstractResource": from app.models.mirrors import Proxy # pylint: disable=R0401 - model = {'Proxy': Proxy}[self.resource_type] + + model = {"Proxy": Proxy}[self.resource_type] return model.query.get(self.resource_id) # type: ignore[no-any-return] @@ -61,29 +60,38 @@ class AbstractResource(db.Model): # type: ignore id: Mapped[int] = mapped_column(db.Integer, primary_key=True) added: Mapped[datetime] = mapped_column(AwareDateTime()) updated: Mapped[datetime] = mapped_column(AwareDateTime()) - deprecated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) + deprecated: Mapped[Optional[datetime]] = mapped_column( + AwareDateTime(), nullable=True + ) deprecation_reason: Mapped[Optional[str]] - destroyed: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) + destroyed: Mapped[Optional[datetime]] = mapped_column( + AwareDateTime(), nullable=True + ) - def __init__(self, *, - id: Optional[int] = None, - added: Optional[datetime] = None, - updated: Optional[datetime] = None, - deprecated: Optional[datetime] = None, - deprecation_reason: Optional[str] = None, - destroyed: Optional[datetime] = None, - **kwargs: Any) -> None: + def __init__( + self, + *, + id: Optional[int] = None, + added: Optional[datetime] = None, + updated: Optional[datetime] = None, + deprecated: Optional[datetime] = None, + deprecation_reason: Optional[str] = None, + destroyed: Optional[datetime] = None, + **kwargs: Any + ) -> None: if added is None: added = datetime.now(tz=timezone.utc) if updated is None: updated = datetime.now(tz=timezone.utc) - super().__init__(id=id, - added=added, - updated=updated, - deprecated=deprecated, - deprecation_reason=deprecation_reason, - destroyed=destroyed, - **kwargs) + super().__init__( + id=id, + added=added, + updated=updated, + deprecated=deprecated, + deprecation_reason=deprecation_reason, + destroyed=destroyed, + **kwargs + ) @property @abstractmethod @@ -110,19 +118,21 @@ class AbstractResource(db.Model): # type: ignore resource_type=type(self).__name__, resource_id=self.id, reason=reason, - meta=meta + meta=meta, ) db.session.add(new_deprecation) return True - logging.info("Not deprecating %s (reason=%s) because it's already deprecated with that reason.", - self.brn, reason) + logging.info( + "Not deprecating %s (reason=%s) because it's already deprecated with that reason.", + self.brn, + reason, + ) return False @property def deprecations(self) -> List[Deprecation]: return Deprecation.query.filter_by( # type: ignore[no-any-return] - resource_type='Proxy', - resource_id=self.id + resource_type="Proxy", resource_id=self.id ).all() def destroy(self) -> None: @@ -139,10 +149,13 @@ class AbstractResource(db.Model): # type: ignore @classmethod def csv_header(cls) -> List[str]: return [ - "id", "added", "updated", "deprecated", "deprecation_reason", "destroyed" + "id", + "added", + "updated", + "deprecated", + "deprecation_reason", + "destroyed", ] def csv_row(self) -> List[Union[datetime, bool, int, str]]: - return [ - getattr(self, x) for x in self.csv_header() - ] + return [getattr(self, x) for x in self.csv_header()] diff --git a/app/models/activity.py b/app/models/activity.py index 0f3cb1f..71c4a5e 100644 --- a/app/models/activity.py +++ b/app/models/activity.py @@ -17,31 +17,40 @@ class Activity(db.Model): # type: ignore text: Mapped[str] added: Mapped[datetime] = mapped_column(AwareDateTime()) - def __init__(self, *, - id: Optional[int] = None, - group_id: Optional[int] = None, - activity_type: str, - text: str, - added: Optional[datetime] = None, - **kwargs: Any) -> None: - if not isinstance(activity_type, str) or len(activity_type) > 20 or activity_type == "": - raise TypeError("expected string for activity type between 1 and 20 characters") + def __init__( + self, + *, + id: Optional[int] = None, + group_id: Optional[int] = None, + activity_type: str, + text: str, + added: Optional[datetime] = None, + **kwargs: Any + ) -> None: + if ( + not isinstance(activity_type, str) + or len(activity_type) > 20 + or activity_type == "" + ): + raise TypeError( + "expected string for activity type between 1 and 20 characters" + ) if not isinstance(text, str): raise TypeError("expected string for text") if added is None: added = datetime.now(tz=timezone.utc) - super().__init__(id=id, - group_id=group_id, - activity_type=activity_type, - text=text, - added=added, - **kwargs) + super().__init__( + id=id, + group_id=group_id, + activity_type=activity_type, + text=text, + added=added, + **kwargs + ) def notify(self) -> int: count = 0 - hooks = Webhook.query.filter( - Webhook.destroyed.is_(None) - ) + hooks = Webhook.query.filter(Webhook.destroyed.is_(None)) for hook in hooks: hook.send(self.text) count += 1 @@ -59,7 +68,7 @@ class Webhook(AbstractConfiguration): product="notify", provider=self.format, resource_type="conf", - resource_id=str(self.id) + resource_id=str(self.id), ) def send(self, text: str) -> None: diff --git a/app/models/alarms.py b/app/models/alarms.py index 2f89f40..d483d41 100644 --- a/app/models/alarms.py +++ b/app/models/alarms.py @@ -37,7 +37,15 @@ class Alarm(db.Model): # type: ignore @classmethod def csv_header(cls) -> List[str]: - return ["id", "target", "alarm_type", "alarm_state", "state_changed", "last_updated", "text"] + return [ + "id", + "target", + "alarm_type", + "alarm_state", + "state_changed", + "last_updated", + "text", + ] def csv_row(self) -> List[Any]: return [getattr(self, x) for x in self.csv_header()] @@ -45,11 +53,15 @@ class Alarm(db.Model): # type: ignore def update_state(self, state: AlarmState, text: str) -> None: if self.alarm_state != state or self.state_changed is None: self.state_changed = datetime.now(tz=timezone.utc) - activity = Activity(activity_type="alarm_state", - text=f"[{self.aspect}] {state.emoji} Alarm state changed from " - f"{self.alarm_state.name} to {state.name} on {self.target}: {text}.") - if (self.alarm_state.name in ["WARNING", "CRITICAL"] - or state.name in ["WARNING", "CRITICAL"]): + activity = Activity( + activity_type="alarm_state", + text=f"[{self.aspect}] {state.emoji} Alarm state changed from " + f"{self.alarm_state.name} to {state.name} on {self.target}: {text}.", + ) + if self.alarm_state.name in ["WARNING", "CRITICAL"] or state.name in [ + "WARNING", + "CRITICAL", + ]: # Notifications are only sent on recovery from warning/critical state or on entry # to warning/critical states. This should reduce alert fatigue. activity.notify() diff --git a/app/models/automation.py b/app/models/automation.py index c050118..d33745c 100644 --- a/app/models/automation.py +++ b/app/models/automation.py @@ -33,7 +33,7 @@ class Automation(AbstractConfiguration): product="core", provider="", resource_type="automation", - resource_id=self.short_name + resource_id=self.short_name, ) def kick(self) -> None: @@ -55,5 +55,5 @@ class AutomationLogs(AbstractResource): product="core", provider="", resource_type="automationlog", - resource_id=str(self.id) + resource_id=str(self.id), ) diff --git a/app/models/base.py b/app/models/base.py index 9ecb35f..5f492d0 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -26,17 +26,21 @@ class Group(AbstractConfiguration): eotk: Mapped[bool] origins: Mapped[List["Origin"]] = relationship("Origin", back_populates="group") - statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="group") + statics: Mapped[List["StaticOrigin"]] = relationship( + "StaticOrigin", back_populates="group" + ) eotks: Mapped[List["Eotk"]] = relationship("Eotk", back_populates="group") onions: Mapped[List["Onion"]] = relationship("Onion", back_populates="group") - smart_proxies: Mapped[List["SmartProxy"]] = relationship("SmartProxy", back_populates="group") - pools: Mapped[List["Pool"]] = relationship("Pool", secondary="pool_group", back_populates="groups") + smart_proxies: Mapped[List["SmartProxy"]] = relationship( + "SmartProxy", back_populates="group" + ) + pools: Mapped[List["Pool"]] = relationship( + "Pool", secondary="pool_group", back_populates="groups" + ) @classmethod def csv_header(cls) -> List[str]: - return super().csv_header() + [ - "group_name", "eotk" - ] + return super().csv_header() + ["group_name", "eotk"] @property def brn(self) -> BRN: @@ -45,16 +49,15 @@ class Group(AbstractConfiguration): product="group", provider="", resource_type="group", - resource_id=str(self.id) + resource_id=str(self.id), ) def to_dict(self) -> GroupDict: if not TYPE_CHECKING: from app.models.mirrors import Origin # to prevent circular import - active_origins_query = ( - db.session.query(aliased(Origin)) - .filter(and_(Origin.group_id == self.id, Origin.destroyed.is_(None))) + active_origins_query = db.session.query(aliased(Origin)).filter( + and_(Origin.group_id == self.id, Origin.destroyed.is_(None)) ) active_origins_count = active_origins_query.count() return { @@ -70,16 +73,20 @@ class Pool(AbstractConfiguration): api_key: Mapped[str] redirector_domain: Mapped[Optional[str]] - bridgeconfs: Mapped[List["BridgeConf"]] = relationship("BridgeConf", back_populates="pool") + bridgeconfs: Mapped[List["BridgeConf"]] = relationship( + "BridgeConf", back_populates="pool" + ) proxies: Mapped[List["Proxy"]] = relationship("Proxy", back_populates="pool") - lists: Mapped[List["MirrorList"]] = relationship("MirrorList", back_populates="pool") - groups: Mapped[List[Group]] = relationship("Group", secondary="pool_group", back_populates="pools") + lists: Mapped[List["MirrorList"]] = relationship( + "MirrorList", back_populates="pool" + ) + groups: Mapped[List[Group]] = relationship( + "Group", secondary="pool_group", back_populates="pools" + ) @classmethod def csv_header(cls) -> List[str]: - return super().csv_header() + [ - "pool_name" - ] + return super().csv_header() + ["pool_name"] @property def brn(self) -> BRN: @@ -88,7 +95,7 @@ class Pool(AbstractConfiguration): product="pool", provider="", resource_type="pool", - resource_id=str(self.pool_name) + resource_id=str(self.pool_name), ) @@ -121,14 +128,14 @@ class MirrorList(AbstractConfiguration): "bc3": "Bypass Censorship v3", "bca": "Bypass Censorship Analytics", "bridgelines": "Tor Bridge Lines", - "rdr": "Redirector Data" + "rdr": "Redirector Data", } encodings_supported = { "json": "JSON (Plain)", "jsno": "JSON (Obfuscated)", "js": "JavaScript (Plain)", - "jso": "JavaScript (Obfuscated)" + "jso": "JavaScript (Obfuscated)", } def destroy(self) -> None: @@ -149,7 +156,11 @@ class MirrorList(AbstractConfiguration): @classmethod def csv_header(cls) -> List[str]: return super().csv_header() + [ - "provider", "format", "container", "branch", "filename" + "provider", + "format", + "container", + "branch", + "filename", ] @property @@ -159,5 +170,5 @@ class MirrorList(AbstractConfiguration): product="list", provider=self.provider, resource_type="list", - resource_id=str(self.id) + resource_id=str(self.id), ) diff --git a/app/models/bridges.py b/app/models/bridges.py index b2e538d..59883cd 100644 --- a/app/models/bridges.py +++ b/app/models/bridges.py @@ -34,7 +34,7 @@ class BridgeConf(AbstractConfiguration): product="bridge", provider="", resource_type="bridgeconf", - resource_id=str(self.id) + resource_id=str(self.id), ) def destroy(self) -> None: @@ -48,14 +48,22 @@ class BridgeConf(AbstractConfiguration): @classmethod def csv_header(cls) -> List[str]: return super().csv_header() + [ - "pool_id", "provider", "method", "description", "target_number", "max_number", "expiry_hours" + "pool_id", + "provider", + "method", + "description", + "target_number", + "max_number", + "expiry_hours", ] class Bridge(AbstractResource): conf_id: Mapped[int] = mapped_column(db.ForeignKey("bridge_conf.id")) cloud_account_id: Mapped[int] = mapped_column(db.ForeignKey("cloud_account.id")) - terraform_updated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) + terraform_updated: Mapped[Optional[datetime]] = mapped_column( + AwareDateTime(), nullable=True + ) nickname: Mapped[Optional[str]] fingerprint: Mapped[Optional[str]] hashed_fingerprint: Mapped[Optional[str]] @@ -71,11 +79,16 @@ class Bridge(AbstractResource): product="bridge", provider=self.cloud_account.provider.key, resource_type="bridge", - resource_id=str(self.id) + resource_id=str(self.id), ) @classmethod def csv_header(cls) -> List[str]: return super().csv_header() + [ - "conf_id", "terraform_updated", "nickname", "fingerprint", "hashed_fingerprint", "bridgeline" + "conf_id", + "terraform_updated", + "nickname", + "fingerprint", + "hashed_fingerprint", + "bridgeline", ] diff --git a/app/models/cloud.py b/app/models/cloud.py index d10101e..3d92696 100644 --- a/app/models/cloud.py +++ b/app/models/cloud.py @@ -42,9 +42,14 @@ class CloudAccount(AbstractConfiguration): # Compute Quotas max_instances: Mapped[int] - bridges: Mapped[List["Bridge"]] = relationship("Bridge", back_populates="cloud_account") - statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="storage_cloud_account", foreign_keys=[ - StaticOrigin.storage_cloud_account_id]) + bridges: Mapped[List["Bridge"]] = relationship( + "Bridge", back_populates="cloud_account" + ) + statics: Mapped[List["StaticOrigin"]] = relationship( + "StaticOrigin", + back_populates="storage_cloud_account", + foreign_keys=[StaticOrigin.storage_cloud_account_id], + ) @property def brn(self) -> BRN: diff --git a/app/models/mirrors.py b/app/models/mirrors.py index 8e0d64c..f3875ec 100644 --- a/app/models/mirrors.py +++ b/app/models/mirrors.py @@ -10,8 +10,7 @@ from tldextract import extract from werkzeug.datastructures import FileStorage from app.brm.brn import BRN -from app.brm.utils import (create_data_uri, normalize_color, - thumbnail_uploaded_image) +from app.brm.utils import create_data_uri, normalize_color, thumbnail_uploaded_image from app.extensions import db from app.models import AbstractConfiguration, AbstractResource, Deprecation from app.models.base import Group, Pool @@ -19,10 +18,10 @@ from app.models.onions import Onion from app.models.types import AwareDateTime country_origin = db.Table( - 'country_origin', + "country_origin", db.metadata, - db.Column('country_id', db.ForeignKey('country.id'), primary_key=True), - db.Column('origin_id', db.ForeignKey('origin.id'), primary_key=True), + db.Column("country_id", db.ForeignKey("country.id"), primary_key=True), + db.Column("origin_id", db.ForeignKey("origin.id"), primary_key=True), extend_existing=True, ) @@ -45,7 +44,9 @@ class Origin(AbstractConfiguration): group: Mapped[Group] = relationship("Group", back_populates="origins") proxies: Mapped[List[Proxy]] = relationship("Proxy", back_populates="origin") - countries: Mapped[List[Country]] = relationship("Country", secondary=country_origin, back_populates='origins') + countries: Mapped[List[Country]] = relationship( + "Country", secondary=country_origin, back_populates="origins" + ) @property def brn(self) -> BRN: @@ -54,13 +55,18 @@ class Origin(AbstractConfiguration): product="mirror", provider="conf", resource_type="origin", - resource_id=self.domain_name + resource_id=self.domain_name, ) @classmethod def csv_header(cls) -> List[str]: return super().csv_header() + [ - "group_id", "domain_name", "auto_rotation", "smart", "assets", "country" + "group_id", + "domain_name", + "auto_rotation", + "smart", + "assets", + "country", ] def destroy(self) -> None: @@ -84,30 +90,41 @@ class Origin(AbstractConfiguration): @property def risk_level(self) -> Dict[str, int]: if self.risk_level_override: - return {country.country_code: self.risk_level_override for country in self.countries} + return { + country.country_code: self.risk_level_override + for country in self.countries + } frequency_factor = 0.0 recency_factor = 0.0 recent_deprecations = ( db.session.query(Deprecation) - .join(Proxy, - Deprecation.resource_id == Proxy.id) + .join(Proxy, Deprecation.resource_id == Proxy.id) .join(Origin, Origin.id == Proxy.origin_id) .filter( Origin.id == self.id, - Deprecation.resource_type == 'Proxy', - Deprecation.deprecated_at >= datetime.now(tz=timezone.utc) - timedelta(hours=168), - Deprecation.reason != "destroyed" + Deprecation.resource_type == "Proxy", + Deprecation.deprecated_at + >= datetime.now(tz=timezone.utc) - timedelta(hours=168), + Deprecation.reason != "destroyed", ) .distinct(Proxy.id) .all() ) for deprecation in recent_deprecations: - recency_factor += 1 / max((datetime.now(tz=timezone.utc) - deprecation.deprecated_at).total_seconds() // 3600, 1) + recency_factor += 1 / max( + ( + datetime.now(tz=timezone.utc) - deprecation.deprecated_at + ).total_seconds() + // 3600, + 1, + ) frequency_factor += 1 risk_levels: Dict[str, int] = {} for country in self.countries: - risk_levels[country.country_code.upper()] = int( - max(1, min(10, frequency_factor * recency_factor))) + country.risk_level + risk_levels[country.country_code.upper()] = ( + int(max(1, min(10, frequency_factor * recency_factor))) + + country.risk_level + ) return risk_levels def to_dict(self) -> OriginDict: @@ -128,13 +145,15 @@ class Country(AbstractConfiguration): product="country", provider="iso3166-1", resource_type="alpha2", - resource_id=self.country_code + resource_id=self.country_code, ) country_code: Mapped[str] risk_level_override: Mapped[Optional[int]] - origins = db.relationship("Origin", secondary=country_origin, back_populates='countries') + origins = db.relationship( + "Origin", secondary=country_origin, back_populates="countries" + ) @property def risk_level(self) -> int: @@ -144,29 +163,39 @@ class Country(AbstractConfiguration): recency_factor = 0.0 recent_deprecations = ( db.session.query(Deprecation) - .join(Proxy, - Deprecation.resource_id == Proxy.id) + .join(Proxy, Deprecation.resource_id == Proxy.id) .join(Origin, Origin.id == Proxy.origin_id) .join(Origin.countries) .filter( Country.id == self.id, - Deprecation.resource_type == 'Proxy', - Deprecation.deprecated_at >= datetime.now(tz=timezone.utc) - timedelta(hours=168), - Deprecation.reason != "destroyed" + Deprecation.resource_type == "Proxy", + Deprecation.deprecated_at + >= datetime.now(tz=timezone.utc) - timedelta(hours=168), + Deprecation.reason != "destroyed", ) .distinct(Proxy.id) .all() ) for deprecation in recent_deprecations: - recency_factor += 1 / max((datetime.now(tz=timezone.utc) - deprecation.deprecated_at).total_seconds() // 3600, 1) + recency_factor += 1 / max( + ( + datetime.now(tz=timezone.utc) - deprecation.deprecated_at + ).total_seconds() + // 3600, + 1, + ) frequency_factor += 1 return int(max(1, min(10, frequency_factor * recency_factor))) class StaticOrigin(AbstractConfiguration): group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False) - storage_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) - source_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) + storage_cloud_account_id = mapped_column( + db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False + ) + source_cloud_account_id = mapped_column( + db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False + ) source_project = mapped_column(db.String(255), nullable=False) auto_rotate = mapped_column(db.Boolean, nullable=False) matrix_homeserver = mapped_column(db.String(255), nullable=True) @@ -182,30 +211,34 @@ class StaticOrigin(AbstractConfiguration): product="mirror", provider="aws", resource_type="static", - resource_id=self.domain_name + resource_id=self.domain_name, ) group = db.relationship("Group", back_populates="statics") - storage_cloud_account = db.relationship("CloudAccount", back_populates="statics", - foreign_keys=[storage_cloud_account_id]) - source_cloud_account = db.relationship("CloudAccount", back_populates="statics", - foreign_keys=[source_cloud_account_id]) + storage_cloud_account = db.relationship( + "CloudAccount", + back_populates="statics", + foreign_keys=[storage_cloud_account_id], + ) + source_cloud_account = db.relationship( + "CloudAccount", back_populates="statics", foreign_keys=[source_cloud_account_id] + ) def destroy(self) -> None: # TODO: The StaticMetaAutomation will clean up for now, but it should probably happen here for consistency super().destroy() def update( - self, - source_project: str, - description: str, - auto_rotate: bool, - matrix_homeserver: Optional[str], - keanu_convene_path: Optional[str], - keanu_convene_logo: Optional[FileStorage], - keanu_convene_color: Optional[str], - clean_insights_backend: Optional[Union[str, bool]], - db_session_commit: bool, + self, + source_project: str, + description: str, + auto_rotate: bool, + matrix_homeserver: Optional[str], + keanu_convene_path: Optional[str], + keanu_convene_logo: Optional[FileStorage], + keanu_convene_color: Optional[str], + clean_insights_backend: Optional[Union[str, bool]], + db_session_commit: bool, ) -> None: if isinstance(source_project, str): self.source_project = source_project @@ -235,19 +268,29 @@ class StaticOrigin(AbstractConfiguration): elif isinstance(keanu_convene_logo, FileStorage): if keanu_convene_logo.filename: # if False, no file was uploaded keanu_convene_config["logo"] = create_data_uri( - thumbnail_uploaded_image(keanu_convene_logo), keanu_convene_logo.filename) + thumbnail_uploaded_image(keanu_convene_logo), + keanu_convene_logo.filename, + ) else: raise ValueError("keanu_convene_logo must be a FileStorage") try: if isinstance(keanu_convene_color, str): - keanu_convene_config["color"] = normalize_color(keanu_convene_color) # can raise ValueError + keanu_convene_config["color"] = normalize_color( + keanu_convene_color + ) # can raise ValueError else: raise ValueError() # re-raised below with message except ValueError: - raise ValueError("keanu_convene_path must be a str containing an HTML color (CSS name or hex)") - self.keanu_convene_config = json.dumps(keanu_convene_config, separators=(',', ':')) + raise ValueError( + "keanu_convene_path must be a str containing an HTML color (CSS name or hex)" + ) + self.keanu_convene_config = json.dumps( + keanu_convene_config, separators=(",", ":") + ) del keanu_convene_config # done with this temporary variable - if clean_insights_backend is None or (isinstance(clean_insights_backend, bool) and not clean_insights_backend): + if clean_insights_backend is None or ( + isinstance(clean_insights_backend, bool) and not clean_insights_backend + ): self.clean_insights_backend = None elif isinstance(clean_insights_backend, bool) and clean_insights_backend: self.clean_insights_backend = "metrics.cleaninsights.org" @@ -260,7 +303,9 @@ class StaticOrigin(AbstractConfiguration): self.updated = datetime.now(tz=timezone.utc) -ResourceStatus = Union[Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"]] +ResourceStatus = Union[ + Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"] +] class ProxyDict(TypedDict): @@ -271,12 +316,16 @@ class ProxyDict(TypedDict): class Proxy(AbstractResource): - origin_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("origin.id"), nullable=False) + origin_id: Mapped[int] = mapped_column( + db.Integer, db.ForeignKey("origin.id"), nullable=False + ) pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id")) provider: Mapped[str] = mapped_column(db.String(20), nullable=False) psg: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) slug: Mapped[Optional[str]] = mapped_column(db.String(20), nullable=True) - terraform_updated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) + terraform_updated: Mapped[Optional[datetime]] = mapped_column( + AwareDateTime(), nullable=True + ) url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) origin: Mapped[Origin] = relationship("Origin", back_populates="proxies") @@ -289,13 +338,18 @@ class Proxy(AbstractResource): product="mirror", provider=self.provider, resource_type="proxy", - resource_id=str(self.id) + resource_id=str(self.id), ) @classmethod def csv_header(cls) -> List[str]: return super().csv_header() + [ - "origin_id", "provider", "psg", "slug", "terraform_updated", "url" + "origin_id", + "provider", + "psg", + "slug", + "terraform_updated", + "url", ] def to_dict(self) -> ProxyDict: @@ -329,5 +383,5 @@ class SmartProxy(AbstractResource): product="mirror", provider=self.provider, resource_type="smart_proxy", - resource_id=str(1) + resource_id=str(1), ) diff --git a/app/models/onions.py b/app/models/onions.py index 7efd601..c3c7e71 100644 --- a/app/models/onions.py +++ b/app/models/onions.py @@ -32,7 +32,7 @@ class Onion(AbstractConfiguration): product="eotk", provider="*", resource_type="onion", - resource_id=self.onion_name + resource_id=self.onion_name, ) group_id: Mapped[int] = mapped_column(db.ForeignKey("group.id")) @@ -80,5 +80,5 @@ class Eotk(AbstractResource): provider=self.provider, product="eotk", resource_type="instance", - resource_id=self.region + resource_id=self.region, ) diff --git a/app/portal/__init__.py b/app/portal/__init__.py index d56a088..d9299b0 100644 --- a/app/portal/__init__.py +++ b/app/portal/__init__.py @@ -32,7 +32,9 @@ from app.portal.static import bp as static from app.portal.storage import bp as storage from app.portal.webhook import bp as webhook -portal = Blueprint("portal", __name__, template_folder="templates", static_folder="static") +portal = Blueprint( + "portal", __name__, template_folder="templates", static_folder="static" +) portal.register_blueprint(automation, url_prefix="/automation") portal.register_blueprint(bridgeconf, url_prefix="/bridgeconf") portal.register_blueprint(bridge, url_prefix="/bridge") @@ -54,7 +56,10 @@ portal.register_blueprint(webhook, url_prefix="/webhook") @portal.app_template_filter("bridge_expiry") def calculate_bridge_expiry(b: Bridge) -> str: if b.deprecated is None: - logging.warning("Bridge expiry requested by template for a bridge %s that was not expiring.", b.id) + logging.warning( + "Bridge expiry requested by template for a bridge %s that was not expiring.", + b.id, + ) return "Not expiring" expiry = b.deprecated + timedelta(hours=b.conf.expiry_hours) countdown = expiry - datetime.now(tz=timezone.utc) @@ -85,27 +90,27 @@ def describe_brn(s: str) -> ResponseReturnValue: if parts[3] == "mirror": if parts[5].startswith("origin/"): origin = Origin.query.filter( - Origin.domain_name == parts[5][len("origin/"):] + Origin.domain_name == parts[5][len("origin/") :] ).first() if not origin: return s return f"Origin: {origin.domain_name} ({origin.group.group_name})" if parts[5].startswith("proxy/"): proxy = Proxy.query.filter( - Proxy.id == int(parts[5][len("proxy/"):]) + Proxy.id == int(parts[5][len("proxy/") :]) ).first() if not proxy: return s return Markup( - f"Proxy: {proxy.url}
({proxy.origin.group.group_name}: {proxy.origin.domain_name})") + f"Proxy: {proxy.url}
({proxy.origin.group.group_name}: {proxy.origin.domain_name})" + ) if parts[5].startswith("quota/"): if parts[4] == "cloudfront": return f"Quota: CloudFront {parts[5][len('quota/'):]}" if parts[3] == "eotk": if parts[5].startswith("instance/"): eotk = Eotk.query.filter( - Eotk.group_id == parts[2], - Eotk.region == parts[5][len("instance/"):] + Eotk.group_id == parts[2], Eotk.region == parts[5][len("instance/") :] ).first() if not eotk: return s @@ -138,9 +143,16 @@ def portal_home() -> ResponseReturnValue: proxies = Proxy.query.filter(Proxy.destroyed.is_(None)).all() last24 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=1))).all()) last72 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=3))).all()) - lastweek = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=7))).all()) + lastweek = len( + Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=7))).all() + ) alarms = { - s: len(Alarm.query.filter(Alarm.alarm_state == s.upper(), Alarm.last_updated > (now - timedelta(days=1))).all()) + s: len( + Alarm.query.filter( + Alarm.alarm_state == s.upper(), + Alarm.last_updated > (now - timedelta(days=1)), + ).all() + ) for s in ["critical", "warning", "ok", "unknown"] } bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all() @@ -148,13 +160,36 @@ def portal_home() -> ResponseReturnValue: d: len(Bridge.query.filter(Bridge.deprecated > (now - timedelta(days=d))).all()) for d in [1, 3, 7] } - activity = Activity.query.filter(Activity.added > (now - timedelta(days=2))).order_by(desc(Activity.added)).all() - onionified = len([o for o in Origin.query.filter(Origin.destroyed.is_(None)).all() if o.onion() is not None]) + activity = ( + Activity.query.filter(Activity.added > (now - timedelta(days=2))) + .order_by(desc(Activity.added)) + .all() + ) + onionified = len( + [ + o + for o in Origin.query.filter(Origin.destroyed.is_(None)).all() + if o.onion() is not None + ] + ) ooni_blocked = total_origins_blocked() total_origins = len(Origin.query.filter(Origin.destroyed.is_(None)).all()) - return render_template("home.html.j2", section="home", groups=groups, last24=last24, last72=last72, - lastweek=lastweek, proxies=proxies, **alarms, activity=activity, total_origins=total_origins, - onionified=onionified, br_last=br_last, ooni_blocked=ooni_blocked, bridges=bridges) + return render_template( + "home.html.j2", + section="home", + groups=groups, + last24=last24, + last72=last72, + lastweek=lastweek, + proxies=proxies, + **alarms, + activity=activity, + total_origins=total_origins, + onionified=onionified, + br_last=br_last, + ooni_blocked=ooni_blocked, + bridges=bridges, + ) @portal.route("/search") @@ -163,19 +198,27 @@ def search() -> ResponseReturnValue: if query is None: return redirect(url_for("portal.portal_home")) proxies = Proxy.query.filter( - or_(func.lower(Proxy.url).contains(query.lower())), Proxy.destroyed.is_(None)).all() + or_(func.lower(Proxy.url).contains(query.lower())), Proxy.destroyed.is_(None) + ).all() origins = Origin.query.filter( - or_(func.lower(Origin.description).contains(query.lower()), - func.lower(Origin.domain_name).contains(query.lower()))).all() - return render_template("search.html.j2", section="home", proxies=proxies, origins=origins) + or_( + func.lower(Origin.description).contains(query.lower()), + func.lower(Origin.domain_name).contains(query.lower()), + ) + ).all() + return render_template( + "search.html.j2", section="home", proxies=proxies, origins=origins + ) -@portal.route('/alarms') +@portal.route("/alarms") def view_alarms() -> ResponseReturnValue: one_day_ago = datetime.now(timezone.utc) - timedelta(days=1) - alarms = Alarm.query.filter(Alarm.last_updated >= one_day_ago).order_by( - desc(Alarm.alarm_state), desc(Alarm.state_changed)).all() - return render_template("list.html.j2", - section="alarm", - title="Alarms", - items=alarms) + alarms = ( + Alarm.query.filter(Alarm.last_updated >= one_day_ago) + .order_by(desc(Alarm.alarm_state), desc(Alarm.state_changed)) + .all() + ) + return render_template( + "list.html.j2", section="alarm", title="Alarms", items=alarms + ) diff --git a/app/portal/automation.py b/app/portal/automation.py index b1e878d..a0c51b3 100644 --- a/app/portal/automation.py +++ b/app/portal/automation.py @@ -17,40 +17,52 @@ bp = Blueprint("automation", __name__) _SECTION_TEMPLATE_VARS = { "section": "automation", - "help_url": "https://bypass.censorship.guide/user/automation.html" + "help_url": "https://bypass.censorship.guide/user/automation.html", } class EditAutomationForm(FlaskForm): # type: ignore - enabled = BooleanField('Enabled') - submit = SubmitField('Save Changes') + enabled = BooleanField("Enabled") + submit = SubmitField("Save Changes") @bp.route("/list") def automation_list() -> ResponseReturnValue: - automations = list(filter( - lambda a: a.short_name not in current_app.config.get('HIDDEN_AUTOMATIONS', []), - Automation.query.filter( - Automation.destroyed.is_(None)).order_by(Automation.description).all() - )) + automations = list( + filter( + lambda a: a.short_name + not in current_app.config.get("HIDDEN_AUTOMATIONS", []), + Automation.query.filter(Automation.destroyed.is_(None)) + .order_by(Automation.description) + .all(), + ) + ) states = {tfs.key: tfs for tfs in TerraformState.query.all()} - return render_template("list.html.j2", - title="Automation Jobs", - item="automation", - items=automations, - states=states, - **_SECTION_TEMPLATE_VARS) + return render_template( + "list.html.j2", + title="Automation Jobs", + item="automation", + items=automations, + states=states, + **_SECTION_TEMPLATE_VARS + ) -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def automation_edit(automation_id: int) -> ResponseReturnValue: - automation: Optional[Automation] = Automation.query.filter(Automation.id == automation_id).first() + automation: Optional[Automation] = Automation.query.filter( + Automation.id == automation_id + ).first() if automation is None: - return Response(render_template("error.html.j2", - header="404 Automation Job Not Found", - message="The requested automation job could not be found.", - **_SECTION_TEMPLATE_VARS), - status=404) + return Response( + render_template( + "error.html.j2", + header="404 Automation Job Not Found", + message="The requested automation job could not be found.", + **_SECTION_TEMPLATE_VARS + ), + status=404, + ) form = EditAutomationForm(enabled=automation.enabled) if form.validate_on_submit(): automation.enabled = form.enabled.data @@ -59,21 +71,30 @@ def automation_edit(automation_id: int) -> ResponseReturnValue: db.session.commit() flash("Saved changes to bridge configuration.", "success") except exc.SQLAlchemyError: - flash("An error occurred saving the changes to the bridge configuration.", "danger") - logs = AutomationLogs.query.filter(AutomationLogs.automation_id == automation.id).order_by( - desc(AutomationLogs.added)).limit(5).all() - return render_template("automation.html.j2", - automation=automation, - logs=logs, - form=form, - **_SECTION_TEMPLATE_VARS) + flash( + "An error occurred saving the changes to the bridge configuration.", + "danger", + ) + logs = ( + AutomationLogs.query.filter(AutomationLogs.automation_id == automation.id) + .order_by(desc(AutomationLogs.added)) + .limit(5) + .all() + ) + return render_template( + "automation.html.j2", + automation=automation, + logs=logs, + form=form, + **_SECTION_TEMPLATE_VARS + ) -@bp.route("/kick/", methods=['GET', 'POST']) +@bp.route("/kick/", methods=["GET", "POST"]) def automation_kick(automation_id: int) -> ResponseReturnValue: automation = Automation.query.filter( - Automation.id == automation_id, - Automation.destroyed.is_(None)).first() + Automation.id == automation_id, Automation.destroyed.is_(None) + ).first() if automation is None: return response_404("The requested bridge configuration could not be found.") return view_lifecycle( @@ -83,5 +104,5 @@ def automation_kick(automation_id: int) -> ResponseReturnValue: success_view="portal.automation.automation_list", success_message="This automation job will next run within 1 minute.", resource=automation, - action="kick" + action="kick", ) diff --git a/app/portal/bridge.py b/app/portal/bridge.py index 5ac7013..3b4c9be 100644 --- a/app/portal/bridge.py +++ b/app/portal/bridge.py @@ -1,7 +1,6 @@ from typing import Optional -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from app.extensions import db @@ -12,57 +11,79 @@ bp = Blueprint("bridge", __name__) _SECTION_TEMPLATE_VARS = { "section": "bridge", - "help_url": "https://bypass.censorship.guide/user/bridges.html" + "help_url": "https://bypass.censorship.guide/user/bridges.html", } @bp.route("/list") def bridge_list() -> ResponseReturnValue: bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all() - return render_template("list.html.j2", - title="Tor Bridges", - item="bridge", - items=bridges, - **_SECTION_TEMPLATE_VARS) + return render_template( + "list.html.j2", + title="Tor Bridges", + item="bridge", + items=bridges, + **_SECTION_TEMPLATE_VARS, + ) -@bp.route("/block/", methods=['GET', 'POST']) +@bp.route("/block/", methods=["GET", "POST"]) def bridge_blocked(bridge_id: int) -> ResponseReturnValue: - bridge: Optional[Bridge] = Bridge.query.filter(Bridge.id == bridge_id, Bridge.destroyed.is_(None)).first() + bridge: Optional[Bridge] = Bridge.query.filter( + Bridge.id == bridge_id, Bridge.destroyed.is_(None) + ).first() if bridge is None: - return Response(render_template("error.html.j2", - header="404 Proxy Not Found", - message="The requested bridge could not be found.", - **_SECTION_TEMPLATE_VARS)) + return Response( + render_template( + "error.html.j2", + header="404 Proxy Not Found", + message="The requested bridge could not be found.", + **_SECTION_TEMPLATE_VARS, + ) + ) form = LifecycleForm() if form.validate_on_submit(): bridge.deprecate(reason="manual") db.session.commit() flash("Bridge will be shortly replaced.", "success") - return redirect(url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)) - return render_template("lifecycle.html.j2", - header=f"Mark bridge {bridge.hashed_fingerprint} as blocked?", - message=bridge.hashed_fingerprint, - form=form, - **_SECTION_TEMPLATE_VARS) + return redirect( + url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id) + ) + return render_template( + "lifecycle.html.j2", + header=f"Mark bridge {bridge.hashed_fingerprint} as blocked?", + message=bridge.hashed_fingerprint, + form=form, + **_SECTION_TEMPLATE_VARS, + ) -@bp.route("/expire/", methods=['GET', 'POST']) +@bp.route("/expire/", methods=["GET", "POST"]) def bridge_expire(bridge_id: int) -> ResponseReturnValue: - bridge: Optional[Bridge] = Bridge.query.filter(Bridge.id == bridge_id, Bridge.destroyed.is_(None)).first() + bridge: Optional[Bridge] = Bridge.query.filter( + Bridge.id == bridge_id, Bridge.destroyed.is_(None) + ).first() if bridge is None: - return Response(render_template("error.html.j2", - header="404 Proxy Not Found", - message="The requested bridge could not be found.", - **_SECTION_TEMPLATE_VARS)) + return Response( + render_template( + "error.html.j2", + header="404 Proxy Not Found", + message="The requested bridge could not be found.", + **_SECTION_TEMPLATE_VARS, + ) + ) form = LifecycleForm() if form.validate_on_submit(): bridge.destroy() db.session.commit() flash("Bridge will be shortly destroyed.", "success") - return redirect(url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)) - return render_template("lifecycle.html.j2", - header=f"Destroy bridge {bridge.hashed_fingerprint}?", - message=bridge.hashed_fingerprint, - form=form, - **_SECTION_TEMPLATE_VARS) + return redirect( + url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id) + ) + return render_template( + "lifecycle.html.j2", + header=f"Destroy bridge {bridge.hashed_fingerprint}?", + message=bridge.hashed_fingerprint, + form=form, + **_SECTION_TEMPLATE_VARS, + ) diff --git a/app/portal/bridgeconf.py b/app/portal/bridgeconf.py index 360f185..40baf17 100644 --- a/app/portal/bridgeconf.py +++ b/app/portal/bridgeconf.py @@ -1,8 +1,7 @@ from datetime import datetime, timezone from typing import List, Optional -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm from sqlalchemy import exc @@ -19,77 +18,109 @@ bp = Blueprint("bridgeconf", __name__) _SECTION_TEMPLATE_VARS = { "section": "bridgeconf", - "help_url": "https://bypass.censorship.guide/user/bridges.html" + "help_url": "https://bypass.censorship.guide/user/bridges.html", } class NewBridgeConfForm(FlaskForm): # type: ignore - method = SelectField('Distribution Method', validators=[DataRequired()]) - description = StringField('Description') - pool = SelectField('Pool', validators=[DataRequired()]) - target_number = IntegerField('Target Number', - description="The number of active bridges to deploy (excluding deprecated bridges).", - validators=[NumberRange(1, message="One or more bridges must be created.")]) - max_number = IntegerField('Maximum Number', - description="The maximum number of bridges to deploy (including deprecated bridges).", - validators=[ - NumberRange(1, message="Must be at least 1, ideally greater than target number.")]) - expiry_hours = IntegerField('Expiry Timer (hours)', - description=("The number of hours to wait after a bridge is deprecated before its " - "destruction.")) - provider_allocation = SelectField('Provider Allocation Method', - description="How to allocate new bridges to providers.", - choices=[ - ("COST", "Use cheapest provider first"), - ("RANDOM", "Use providers randomly"), - ]) - submit = SubmitField('Save Changes') + method = SelectField("Distribution Method", validators=[DataRequired()]) + description = StringField("Description") + pool = SelectField("Pool", validators=[DataRequired()]) + target_number = IntegerField( + "Target Number", + description="The number of active bridges to deploy (excluding deprecated bridges).", + validators=[NumberRange(1, message="One or more bridges must be created.")], + ) + max_number = IntegerField( + "Maximum Number", + description="The maximum number of bridges to deploy (including deprecated bridges).", + validators=[ + NumberRange( + 1, message="Must be at least 1, ideally greater than target number." + ) + ], + ) + expiry_hours = IntegerField( + "Expiry Timer (hours)", + description=( + "The number of hours to wait after a bridge is deprecated before its " + "destruction." + ), + ) + provider_allocation = SelectField( + "Provider Allocation Method", + description="How to allocate new bridges to providers.", + choices=[ + ("COST", "Use cheapest provider first"), + ("RANDOM", "Use providers randomly"), + ], + ) + submit = SubmitField("Save Changes") class EditBridgeConfForm(FlaskForm): # type: ignore - description = StringField('Description') - target_number = IntegerField('Target Number', - description="The number of active bridges to deploy (excluding deprecated bridges).", - validators=[NumberRange(1, message="One or more bridges must be created.")]) - max_number = IntegerField('Maximum Number', - description="The maximum number of bridges to deploy (including deprecated bridges).", - validators=[ - NumberRange(1, message="Must be at least 1, ideally greater than target number.")]) - expiry_hours = IntegerField('Expiry Timer (hours)', - description=("The number of hours to wait after a bridge is deprecated before its " - "destruction.")) - provider_allocation = SelectField('Provider Allocation Method', - description="How to allocate new bridges to providers.", - choices=[ - ("COST", "Use cheapest provider first"), - ("RANDOM", "Use providers randomly"), - ]) - submit = SubmitField('Save Changes') + description = StringField("Description") + target_number = IntegerField( + "Target Number", + description="The number of active bridges to deploy (excluding deprecated bridges).", + validators=[NumberRange(1, message="One or more bridges must be created.")], + ) + max_number = IntegerField( + "Maximum Number", + description="The maximum number of bridges to deploy (including deprecated bridges).", + validators=[ + NumberRange( + 1, message="Must be at least 1, ideally greater than target number." + ) + ], + ) + expiry_hours = IntegerField( + "Expiry Timer (hours)", + description=( + "The number of hours to wait after a bridge is deprecated before its " + "destruction." + ), + ) + provider_allocation = SelectField( + "Provider Allocation Method", + description="How to allocate new bridges to providers.", + choices=[ + ("COST", "Use cheapest provider first"), + ("RANDOM", "Use providers randomly"), + ], + ) + submit = SubmitField("Save Changes") @bp.route("/list") def bridgeconf_list() -> ResponseReturnValue: - bridgeconfs: List[BridgeConf] = BridgeConf.query.filter(BridgeConf.destroyed.is_(None)).all() - return render_template("list.html.j2", - title="Tor Bridge Configurations", - item="bridge configuration", - items=bridgeconfs, - new_link=url_for("portal.bridgeconf.bridgeconf_new"), - **_SECTION_TEMPLATE_VARS) + bridgeconfs: List[BridgeConf] = BridgeConf.query.filter( + BridgeConf.destroyed.is_(None) + ).all() + return render_template( + "list.html.j2", + title="Tor Bridge Configurations", + item="bridge configuration", + items=bridgeconfs, + new_link=url_for("portal.bridgeconf.bridgeconf_new"), + **_SECTION_TEMPLATE_VARS, + ) -@bp.route("/new", methods=['GET', 'POST']) -@bp.route("/new/", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) +@bp.route("/new/", methods=["GET", "POST"]) def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue: form = NewBridgeConfForm() - form.pool.choices = [(x.id, x.pool_name) for x in Pool.query.filter(Pool.destroyed.is_(None)).all()] + form.pool.choices = [ + (x.id, x.pool_name) for x in Pool.query.filter(Pool.destroyed.is_(None)).all() + ] form.method.choices = [ ("any", "Any (BridgeDB)"), ("email", "E-Mail (BridgeDB)"), ("moat", "Moat (BridgeDB)"), ("settings", "Settings (BridgeDB)"), ("https", "HTTPS (BridgeDB)"), - ("none", "None (Private)") + ("none", "None (Private)"), ] if form.validate_on_submit(): bridgeconf = BridgeConf() @@ -99,7 +130,9 @@ def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue: bridgeconf.target_number = form.target_number.data bridgeconf.max_number = form.max_number.data bridgeconf.expiry_hours = form.expiry_hours.data - bridgeconf.provider_allocation = ProviderAllocation[form.provider_allocation.data] + bridgeconf.provider_allocation = ProviderAllocation[ + form.provider_allocation.data + ] bridgeconf.added = datetime.now(tz=timezone.utc) bridgeconf.updated = datetime.now(tz=timezone.utc) try: @@ -112,47 +145,56 @@ def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue: return redirect(url_for("portal.bridgeconf.bridgeconf_list")) if group_id: form.group.data = group_id - return render_template("new.html.j2", - form=form, - **_SECTION_TEMPLATE_VARS) + return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS) -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def bridgeconf_edit(bridgeconf_id: int) -> ResponseReturnValue: bridgeconf = BridgeConf.query.filter(BridgeConf.id == bridgeconf_id).first() if bridgeconf is None: - return Response(render_template("error.html.j2", - header="404 Bridge Configuration Not Found", - message="The requested bridge configuration could not be found.", - **_SECTION_TEMPLATE_VARS), - status=404) - form = EditBridgeConfForm(description=bridgeconf.description, - target_number=bridgeconf.target_number, - max_number=bridgeconf.max_number, - expiry_hours=bridgeconf.expiry_hours, - provider_allocation=bridgeconf.provider_allocation.name, - ) + return Response( + render_template( + "error.html.j2", + header="404 Bridge Configuration Not Found", + message="The requested bridge configuration could not be found.", + **_SECTION_TEMPLATE_VARS, + ), + status=404, + ) + form = EditBridgeConfForm( + description=bridgeconf.description, + target_number=bridgeconf.target_number, + max_number=bridgeconf.max_number, + expiry_hours=bridgeconf.expiry_hours, + provider_allocation=bridgeconf.provider_allocation.name, + ) if form.validate_on_submit(): bridgeconf.description = form.description.data bridgeconf.target_number = form.target_number.data bridgeconf.max_number = form.max_number.data bridgeconf.expiry_hours = form.expiry_hours.data - bridgeconf.provider_allocation = ProviderAllocation[form.provider_allocation.data] + bridgeconf.provider_allocation = ProviderAllocation[ + form.provider_allocation.data + ] bridgeconf.updated = datetime.now(tz=timezone.utc) try: db.session.commit() flash("Saved changes to bridge configuration.", "success") except exc.SQLAlchemyError: - flash("An error occurred saving the changes to the bridge configuration.", "danger") - return render_template("bridgeconf.html.j2", - bridgeconf=bridgeconf, - form=form, - **_SECTION_TEMPLATE_VARS) + flash( + "An error occurred saving the changes to the bridge configuration.", + "danger", + ) + return render_template( + "bridgeconf.html.j2", bridgeconf=bridgeconf, form=form, **_SECTION_TEMPLATE_VARS + ) -@bp.route("/destroy/", methods=['GET', 'POST']) +@bp.route("/destroy/", methods=["GET", "POST"]) def bridgeconf_destroy(bridgeconf_id: int) -> ResponseReturnValue: - bridgeconf = BridgeConf.query.filter(BridgeConf.id == bridgeconf_id, BridgeConf.destroyed.is_(None)).first() + bridgeconf = BridgeConf.query.filter( + BridgeConf.id == bridgeconf_id, BridgeConf.destroyed.is_(None) + ).first() if bridgeconf is None: return response_404("The requested bridge configuration could not be found.") return view_lifecycle( @@ -162,5 +204,5 @@ def bridgeconf_destroy(bridgeconf_id: int) -> ResponseReturnValue: success_message="All bridges from the destroyed configuration will shortly be destroyed at their providers.", section="bridgeconf", resource=bridgeconf, - action="destroy" + action="destroy", ) diff --git a/app/portal/cloud.py b/app/portal/cloud.py index 06d9767..83a5f64 100644 --- a/app/portal/cloud.py +++ b/app/portal/cloud.py @@ -3,8 +3,15 @@ from typing import Dict, List, Optional, Type, Union from flask import Blueprint, redirect, render_template, url_for from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm -from wtforms import (BooleanField, Form, FormField, IntegerField, SelectField, - StringField, SubmitField) +from wtforms import ( + BooleanField, + Form, + FormField, + IntegerField, + SelectField, + StringField, + SubmitField, +) from wtforms.validators import InputRequired from app.extensions import db @@ -14,54 +21,72 @@ bp = Blueprint("cloud", __name__) _SECTION_TEMPLATE_VARS = { "section": "cloud", - "help_url": "https://bypass.censorship.guide/user/cloud.html" + "help_url": "https://bypass.censorship.guide/user/cloud.html", } class NewCloudAccountForm(FlaskForm): # type: ignore - provider = SelectField('Cloud Provider', validators=[InputRequired()]) - submit = SubmitField('Next') + provider = SelectField("Cloud Provider", validators=[InputRequired()]) + submit = SubmitField("Next") class AWSAccountForm(FlaskForm): # type: ignore - provider = StringField('Platform', render_kw={"disabled": ""}) - description = StringField('Description', validators=[InputRequired()]) - aws_access_key = StringField('AWS Access Key', validators=[InputRequired()]) - aws_secret_key = StringField('AWS Secret Key', validators=[InputRequired()]) - aws_region = StringField('AWS Region', default='us-east-2', validators=[InputRequired()]) - max_distributions = IntegerField('Cloudfront Distributions Quota', default=200, - description="This is the quota for number of distributions per account.", - validators=[InputRequired()]) - max_instances = IntegerField('EC2 Instance Quota', default=2, - description="This can be impacted by a number of quotas including instance limits " - "and IP address limits.", - validators=[InputRequired()]) - enabled = BooleanField('Enable this account', default=True, - description="New resources will not be deployed to disabled accounts, however existing " - "resources will persist until destroyed at the end of their lifecycle.") - submit = SubmitField('Save Changes') + provider = StringField("Platform", render_kw={"disabled": ""}) + description = StringField("Description", validators=[InputRequired()]) + aws_access_key = StringField("AWS Access Key", validators=[InputRequired()]) + aws_secret_key = StringField("AWS Secret Key", validators=[InputRequired()]) + aws_region = StringField( + "AWS Region", default="us-east-2", validators=[InputRequired()] + ) + max_distributions = IntegerField( + "Cloudfront Distributions Quota", + default=200, + description="This is the quota for number of distributions per account.", + validators=[InputRequired()], + ) + max_instances = IntegerField( + "EC2 Instance Quota", + default=2, + description="This can be impacted by a number of quotas including instance limits " + "and IP address limits.", + validators=[InputRequired()], + ) + enabled = BooleanField( + "Enable this account", + default=True, + description="New resources will not be deployed to disabled accounts, however existing " + "resources will persist until destroyed at the end of their lifecycle.", + ) + submit = SubmitField("Save Changes") class HcloudAccountForm(FlaskForm): # type: ignore - provider = StringField('Platform', render_kw={"disabled": ""}) - description = StringField('Description', validators=[InputRequired()]) - hcloud_token = StringField('Hetzner Cloud Token', validators=[InputRequired()]) - max_instances = IntegerField('Server Limit', default=10, - validators=[InputRequired()]) - enabled = BooleanField('Enable this account', default=True, - description="New resources will not be deployed to disabled accounts, however existing " - "resources will persist until destroyed at the end of their lifecycle.") - submit = SubmitField('Save Changes') + provider = StringField("Platform", render_kw={"disabled": ""}) + description = StringField("Description", validators=[InputRequired()]) + hcloud_token = StringField("Hetzner Cloud Token", validators=[InputRequired()]) + max_instances = IntegerField( + "Server Limit", default=10, validators=[InputRequired()] + ) + enabled = BooleanField( + "Enable this account", + default=True, + description="New resources will not be deployed to disabled accounts, however existing " + "resources will persist until destroyed at the end of their lifecycle.", + ) + submit = SubmitField("Save Changes") class GitlabAccountForm(FlaskForm): # type: ignore - provider = StringField('Platform', render_kw={"disabled": ""}) - description = StringField('Description', validators=[InputRequired()]) - gitlab_token = StringField('GitLab Access Token', validators=[InputRequired()]) - enabled = BooleanField('Enable this account', default=True, - description="New resources will not be deployed to disabled accounts, however existing " - "resources will persist until destroyed at the end of their lifecycle.") - submit = SubmitField('Save Changes') + provider = StringField("Platform", render_kw={"disabled": ""}) + description = StringField("Description", validators=[InputRequired()]) + gitlab_token = StringField("GitLab Access Token", validators=[InputRequired()]) + enabled = BooleanField( + "Enable this account", + default=True, + description="New resources will not be deployed to disabled accounts, however existing " + "resources will persist until destroyed at the end of their lifecycle.", + ) + submit = SubmitField("Save Changes") class OvhHorizonForm(Form): # type: ignore[misc] @@ -77,16 +102,20 @@ class OvhApiForm(Form): # type: ignore[misc] class OvhAccountForm(FlaskForm): # type: ignore - provider = StringField('Platform', render_kw={"disabled": ""}) - description = StringField('Description', validators=[InputRequired()]) - horizon = FormField(OvhHorizonForm, 'OpenStack Horizon API') - ovh_api = FormField(OvhApiForm, 'OVH API') - max_instances = IntegerField('Server Limit', default=10, - validators=[InputRequired()]) - enabled = BooleanField('Enable this account', default=True, - description="New resources will not be deployed to disabled accounts, however existing " - "resources will persist until destroyed at the end of their lifecycle.") - submit = SubmitField('Save Changes') + provider = StringField("Platform", render_kw={"disabled": ""}) + description = StringField("Description", validators=[InputRequired()]) + horizon = FormField(OvhHorizonForm, "OpenStack Horizon API") + ovh_api = FormField(OvhApiForm, "OVH API") + max_instances = IntegerField( + "Server Limit", default=10, validators=[InputRequired()] + ) + enabled = BooleanField( + "Enable this account", + default=True, + description="New resources will not be deployed to disabled accounts, however existing " + "resources will persist until destroyed at the end of their lifecycle.", + ) + submit = SubmitField("Save Changes") class GandiHorizonForm(Form): # type: ignore[misc] @@ -96,18 +125,24 @@ class GandiHorizonForm(Form): # type: ignore[misc] class GandiAccountForm(FlaskForm): # type: ignore - provider = StringField('Platform', render_kw={"disabled": ""}) - description = StringField('Description', validators=[InputRequired()]) - horizon = FormField(GandiHorizonForm, 'OpenStack Horizon API') - max_instances = IntegerField('Server Limit', default=10, - validators=[InputRequired()]) - enabled = BooleanField('Enable this account', default=True, - description="New resources will not be deployed to disabled accounts, however existing " - "resources will persist until destroyed at the end of their lifecycle.") - submit = SubmitField('Save Changes') + provider = StringField("Platform", render_kw={"disabled": ""}) + description = StringField("Description", validators=[InputRequired()]) + horizon = FormField(GandiHorizonForm, "OpenStack Horizon API") + max_instances = IntegerField( + "Server Limit", default=10, validators=[InputRequired()] + ) + enabled = BooleanField( + "Enable this account", + default=True, + description="New resources will not be deployed to disabled accounts, however existing " + "resources will persist until destroyed at the end of their lifecycle.", + ) + submit = SubmitField("Save Changes") -CloudAccountForm = Union[AWSAccountForm, HcloudAccountForm, GandiAccountForm, OvhAccountForm] +CloudAccountForm = Union[ + AWSAccountForm, HcloudAccountForm, GandiAccountForm, OvhAccountForm +] provider_forms: Dict[str, Type[CloudAccountForm]] = { CloudProvider.AWS.name: AWSAccountForm, @@ -118,7 +153,9 @@ provider_forms: Dict[str, Type[CloudAccountForm]] = { } -def cloud_account_save(account: Optional[CloudAccount], provider: CloudProvider, form: CloudAccountForm) -> None: +def cloud_account_save( + account: Optional[CloudAccount], provider: CloudProvider, form: CloudAccountForm +) -> None: if not account: account = CloudAccount() account.provider = provider @@ -162,7 +199,9 @@ def cloud_account_save(account: Optional[CloudAccount], provider: CloudProvider, "ovh_openstack_password": form.horizon.data["ovh_openstack_password"], "ovh_openstack_tenant_id": form.horizon.data["ovh_openstack_tenant_id"], "ovh_cloud_application_key": form.ovh_api.data["ovh_cloud_application_key"], - "ovh_cloud_application_secret": form.ovh_api.data["ovh_cloud_application_secret"], + "ovh_cloud_application_secret": form.ovh_api.data[ + "ovh_cloud_application_secret" + ], "ovh_cloud_consumer_key": form.ovh_api.data["ovh_cloud_consumer_key"], } account.max_distributions = 0 @@ -182,53 +221,82 @@ def cloud_account_populate(form: CloudAccountForm, account: CloudAccount) -> Non form.aws_region.data = account.credentials["aws_region"] form.max_distributions.data = account.max_distributions form.max_instances.data = account.max_instances - elif account.provider == CloudProvider.HCLOUD and isinstance(form, HcloudAccountForm): + elif account.provider == CloudProvider.HCLOUD and isinstance( + form, HcloudAccountForm + ): form.hcloud_token.data = account.credentials["hcloud_token"] form.max_instances.data = account.max_instances elif account.provider == CloudProvider.GANDI and isinstance(form, GandiAccountForm): - form.horizon.form.gandi_openstack_user.data = account.credentials["gandi_openstack_user"] - form.horizon.form.gandi_openstack_password.data = account.credentials["gandi_openstack_password"] - form.horizon.form.gandi_openstack_tenant_id.data = account.credentials["gandi_openstack_tenant_id"] + form.horizon.form.gandi_openstack_user.data = account.credentials[ + "gandi_openstack_user" + ] + form.horizon.form.gandi_openstack_password.data = account.credentials[ + "gandi_openstack_password" + ] + form.horizon.form.gandi_openstack_tenant_id.data = account.credentials[ + "gandi_openstack_tenant_id" + ] form.max_instances.data = account.max_instances - elif account.provider == CloudProvider.GITLAB and isinstance(form, GitlabAccountForm): + elif account.provider == CloudProvider.GITLAB and isinstance( + form, GitlabAccountForm + ): form.gitlab_token.data = account.credentials["gitlab_token"] elif account.provider == CloudProvider.OVH and isinstance(form, OvhAccountForm): - form.horizon.form.ovh_openstack_user.data = account.credentials["ovh_openstack_user"] - form.horizon.form.ovh_openstack_password.data = account.credentials["ovh_openstack_password"] - form.horizon.form.ovh_openstack_tenant_id.data = account.credentials["ovh_openstack_tenant_id"] - form.ovh_api.form.ovh_cloud_application_key.data = account.credentials["ovh_cloud_application_key"] - form.ovh_api.form.ovh_cloud_application_secret.data = account.credentials["ovh_cloud_application_secret"] - form.ovh_api.form.ovh_cloud_consumer_key.data = account.credentials["ovh_cloud_consumer_key"] + form.horizon.form.ovh_openstack_user.data = account.credentials[ + "ovh_openstack_user" + ] + form.horizon.form.ovh_openstack_password.data = account.credentials[ + "ovh_openstack_password" + ] + form.horizon.form.ovh_openstack_tenant_id.data = account.credentials[ + "ovh_openstack_tenant_id" + ] + form.ovh_api.form.ovh_cloud_application_key.data = account.credentials[ + "ovh_cloud_application_key" + ] + form.ovh_api.form.ovh_cloud_application_secret.data = account.credentials[ + "ovh_cloud_application_secret" + ] + form.ovh_api.form.ovh_cloud_consumer_key.data = account.credentials[ + "ovh_cloud_consumer_key" + ] form.max_instances.data = account.max_instances else: - raise RuntimeError(f"Unknown provider {account.provider} or form data {type(form)} did not match provider.") + raise RuntimeError( + f"Unknown provider {account.provider} or form data {type(form)} did not match provider." + ) @bp.route("/list") def cloud_account_list() -> ResponseReturnValue: - accounts: List[CloudAccount] = CloudAccount.query.filter(CloudAccount.destroyed.is_(None)).all() - return render_template("list.html.j2", - title="Cloud Accounts", - item="cloud account", - items=accounts, - new_link=url_for("portal.cloud.cloud_account_new"), - **_SECTION_TEMPLATE_VARS) + accounts: List[CloudAccount] = CloudAccount.query.filter( + CloudAccount.destroyed.is_(None) + ).all() + return render_template( + "list.html.j2", + title="Cloud Accounts", + item="cloud account", + items=accounts, + new_link=url_for("portal.cloud.cloud_account_new"), + **_SECTION_TEMPLATE_VARS, + ) -@bp.route("/new", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) def cloud_account_new() -> ResponseReturnValue: form = NewCloudAccountForm() - form.provider.choices = sorted([ - (provider.name, provider.description) for provider in CloudProvider - ], key=lambda p: p[1].lower()) + form.provider.choices = sorted( + [(provider.name, provider.description) for provider in CloudProvider], + key=lambda p: p[1].lower(), + ) if form.validate_on_submit(): - return redirect(url_for("portal.cloud.cloud_account_new_for", provider=form.provider.data)) - return render_template("new.html.j2", - form=form, - **_SECTION_TEMPLATE_VARS) + return redirect( + url_for("portal.cloud.cloud_account_new_for", provider=form.provider.data) + ) + return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS) -@bp.route("/new/", methods=['GET', 'POST']) +@bp.route("/new/", methods=["GET", "POST"]) def cloud_account_new_for(provider: str) -> ResponseReturnValue: form = provider_forms[provider]() form.provider.data = CloudProvider[provider].description @@ -236,12 +304,10 @@ def cloud_account_new_for(provider: str) -> ResponseReturnValue: cloud_account_save(None, CloudProvider[provider], form) db.session.commit() return redirect(url_for("portal.cloud.cloud_account_list")) - return render_template("new.html.j2", - form=form, - **_SECTION_TEMPLATE_VARS) + return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS) -@bp.route("/edit/", methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def cloud_account_edit(account_id: int) -> ResponseReturnValue: account = CloudAccount.query.filter( CloudAccount.id == account_id, @@ -256,6 +322,4 @@ def cloud_account_edit(account_id: int) -> ResponseReturnValue: db.session.commit() return redirect(url_for("portal.cloud.cloud_account_list")) cloud_account_populate(form, account) - return render_template("new.html.j2", - form=form, - **_SECTION_TEMPLATE_VARS) + return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS) diff --git a/app/portal/country.py b/app/portal/country.py index c824208..1720819 100644 --- a/app/portal/country.py +++ b/app/portal/country.py @@ -13,7 +13,7 @@ bp = Blueprint("country", __name__) _SECTION_TEMPLATE_VARS = { "section": "country", - "help_url": "https://bypass.censorship.guide/user/countries.html" + "help_url": "https://bypass.censorship.guide/user/countries.html", } @@ -22,42 +22,51 @@ def filter_country_flag(country_code: str) -> str: country_code = country_code.upper() # Calculate the regional indicator symbol for each letter in the country code - base = ord('\U0001F1E6') - ord('A') - flag = ''.join([chr(ord(char) + base) for char in country_code]) + base = ord("\U0001F1E6") - ord("A") + flag = "".join([chr(ord(char) + base) for char in country_code]) return flag -@bp.route('/list') +@bp.route("/list") def country_list() -> ResponseReturnValue: countries = Country.query.filter(Country.destroyed.is_(None)).all() print(len(countries)) - return render_template("list.html.j2", - title="Countries", - item="country", - new_link=None, - items=sorted(countries, key=lambda x: x.country_code), - **_SECTION_TEMPLATE_VARS - ) + return render_template( + "list.html.j2", + title="Countries", + item="country", + new_link=None, + items=sorted(countries, key=lambda x: x.country_code), + **_SECTION_TEMPLATE_VARS + ) class EditCountryForm(FlaskForm): # type: ignore[misc] risk_level_override = BooleanField("Force Risk Level Override?") - risk_level_override_number = IntegerField("Forced Risk Level", description="Number from 0 to 20", default=0) - submit = SubmitField('Save Changes') + risk_level_override_number = IntegerField( + "Forced Risk Level", description="Number from 0 to 20", default=0 + ) + submit = SubmitField("Save Changes") -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def country_edit(country_id: int) -> ResponseReturnValue: country = Country.query.filter(Country.id == country_id).first() if country is None: - return Response(render_template("error.html.j2", - section="country", - header="404 Country Not Found", - message="The requested country could not be found."), - status=404) - form = EditCountryForm(risk_level_override=country.risk_level_override is not None, - risk_level_override_number=country.risk_level_override) + return Response( + render_template( + "error.html.j2", + section="country", + header="404 Country Not Found", + message="The requested country could not be found.", + ), + status=404, + ) + form = EditCountryForm( + risk_level_override=country.risk_level_override is not None, + risk_level_override_number=country.risk_level_override, + ) if form.validate_on_submit(): if form.risk_level_override.data: country.risk_level_override = form.risk_level_override_number.data @@ -69,6 +78,6 @@ def country_edit(country_id: int) -> ResponseReturnValue: flash("Saved changes to country.", "success") except sqlalchemy.exc.SQLAlchemyError: flash("An error occurred saving the changes to the country.", "danger") - return render_template("country.html.j2", - section="country", - country=country, form=form) + return render_template( + "country.html.j2", section="country", country=country, form=form + ) diff --git a/app/portal/eotk.py b/app/portal/eotk.py index ad40f5e..851030d 100644 --- a/app/portal/eotk.py +++ b/app/portal/eotk.py @@ -10,23 +10,32 @@ bp = Blueprint("eotk", __name__) _SECTION_TEMPLATE_VARS = { "section": "eotk", - "help_url": "https://bypass.censorship.guide/user/eotk.html" + "help_url": "https://bypass.censorship.guide/user/eotk.html", } @bp.route("/list") def eotk_list() -> ResponseReturnValue: - instances = Eotk.query.filter(Eotk.destroyed.is_(None)).order_by(desc(Eotk.added)).all() - return render_template("list.html.j2", - title="EOTK Instances", - item="eotk", - items=instances, - **_SECTION_TEMPLATE_VARS) + instances = ( + Eotk.query.filter(Eotk.destroyed.is_(None)).order_by(desc(Eotk.added)).all() + ) + return render_template( + "list.html.j2", + title="EOTK Instances", + item="eotk", + items=instances, + **_SECTION_TEMPLATE_VARS + ) @bp.route("/conf/") def eotk_conf(group_id: int) -> ResponseReturnValue: group = Group.query.filter(Group.id == group_id).first() - return Response(render_template("sites.conf.j2", - bypass_token=current_app.config["BYPASS_TOKEN"], - group=group), content_type="text/plain") + return Response( + render_template( + "sites.conf.j2", + bypass_token=current_app.config["BYPASS_TOKEN"], + group=group, + ), + content_type="text/plain", + ) diff --git a/app/portal/forms.py b/app/portal/forms.py index 2d52d33..e469230 100644 --- a/app/portal/forms.py +++ b/app/portal/forms.py @@ -3,6 +3,6 @@ from wtforms import SelectField, StringField, SubmitField class EditMirrorForm(FlaskForm): # type: ignore - origin = SelectField('Origin') - url = StringField('URL') - submit = SubmitField('Save Changes') + origin = SelectField("Origin") + url = StringField("URL") + submit = SubmitField("Save Changes") diff --git a/app/portal/group.py b/app/portal/group.py index a1aceed..5d9dd2e 100644 --- a/app/portal/group.py +++ b/app/portal/group.py @@ -1,8 +1,7 @@ from datetime import datetime, timezone import sqlalchemy -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm from wtforms import BooleanField, StringField, SubmitField @@ -18,27 +17,29 @@ class NewGroupForm(FlaskForm): # type: ignore group_name = StringField("Short Name", validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()]) eotk = BooleanField("Deploy EOTK instances?") - submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) + submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"}) class EditGroupForm(FlaskForm): # type: ignore - description = StringField('Description', validators=[DataRequired()]) + description = StringField("Description", validators=[DataRequired()]) eotk = BooleanField("Deploy EOTK instances?") - submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) + submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"}) @bp.route("/list") def group_list() -> ResponseReturnValue: groups = Group.query.order_by(Group.group_name).all() - return render_template("list.html.j2", - section="group", - title="Groups", - item="group", - items=groups, - new_link=url_for("portal.group.group_new")) + return render_template( + "list.html.j2", + section="group", + title="Groups", + item="group", + items=groups, + new_link=url_for("portal.group.group_new"), + ) -@bp.route("/new", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) def group_new() -> ResponseReturnValue: form = NewGroupForm() if form.validate_on_submit(): @@ -59,17 +60,20 @@ def group_new() -> ResponseReturnValue: return render_template("new.html.j2", section="group", form=form) -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def group_edit(group_id: int) -> ResponseReturnValue: group = Group.query.filter(Group.id == group_id).first() if group is None: - return Response(render_template("error.html.j2", - section="group", - header="404 Group Not Found", - message="The requested group could not be found."), - status=404) - form = EditGroupForm(description=group.description, - eotk=group.eotk) + return Response( + render_template( + "error.html.j2", + section="group", + header="404 Group Not Found", + message="The requested group could not be found.", + ), + status=404, + ) + form = EditGroupForm(description=group.description, eotk=group.eotk) if form.validate_on_submit(): group.description = form.description.data group.eotk = form.eotk.data @@ -79,6 +83,4 @@ def group_edit(group_id: int) -> ResponseReturnValue: flash("Saved changes to group.", "success") except sqlalchemy.exc.SQLAlchemyError: flash("An error occurred saving the changes to the group.", "danger") - return render_template("group.html.j2", - section="group", - group=group, form=form) + return render_template("group.html.j2", section="group", group=group, form=form) diff --git a/app/portal/list.py b/app/portal/list.py index a147a69..26d2863 100644 --- a/app/portal/list.py +++ b/app/portal/list.py @@ -2,8 +2,7 @@ import json from datetime import datetime, timezone from typing import Any, Optional -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm from sqlalchemy import exc @@ -23,7 +22,7 @@ bp = Blueprint("list", __name__) _SECTION_TEMPLATE_VARS = { "section": "list", - "help_url": "https://bypass.censorship.guide/user/lists.html" + "help_url": "https://bypass.censorship.guide/user/lists.html", } @@ -42,37 +41,44 @@ def list_encoding_name(key: str) -> str: return MirrorList.encodings_supported.get(key, "Unknown") -@bp.route('/list') +@bp.route("/list") def list_list() -> ResponseReturnValue: lists = MirrorList.query.filter(MirrorList.destroyed.is_(None)).all() - return render_template("list.html.j2", - title="Distribution Lists", - item="distribution list", - new_link=url_for("portal.list.list_new"), - items=lists, - **_SECTION_TEMPLATE_VARS - ) + return render_template( + "list.html.j2", + title="Distribution Lists", + item="distribution list", + new_link=url_for("portal.list.list_new"), + items=lists, + **_SECTION_TEMPLATE_VARS + ) -@bp.route('/preview//') +@bp.route("/preview//") def list_preview(format_: str, pool_id: int) -> ResponseReturnValue: pool = Pool.query.filter(Pool.id == pool_id).first() if not pool: return response_404(message="Pool not found") if format_ == "bca": - return Response(json.dumps(mirror_mapping(pool)), content_type="application/json") + return Response( + json.dumps(mirror_mapping(pool)), content_type="application/json" + ) if format_ == "bc2": return Response(json.dumps(mirror_sites(pool)), content_type="application/json") if format_ == "bridgelines": return Response(json.dumps(bridgelines(pool)), content_type="application/json") if format_ == "rdr": - return Response(json.dumps(redirector_data(pool)), content_type="application/json") + return Response( + json.dumps(redirector_data(pool)), content_type="application/json" + ) return response_404(message="Format not found") -@bp.route("/destroy/", methods=['GET', 'POST']) +@bp.route("/destroy/", methods=["GET", "POST"]) def list_destroy(list_id: int) -> ResponseReturnValue: - list_ = MirrorList.query.filter(MirrorList.id == list_id, MirrorList.destroyed.is_(None)).first() + list_ = MirrorList.query.filter( + MirrorList.id == list_id, MirrorList.destroyed.is_(None) + ).first() if list_ is None: return response_404("The requested bridge configuration could not be found.") return view_lifecycle( @@ -82,12 +88,12 @@ def list_destroy(list_id: int) -> ResponseReturnValue: success_message="This list will no longer be updated and may be deleted depending on the provider.", section="list", resource=list_, - action="destroy" + action="destroy", ) -@bp.route("/new", methods=['GET', 'POST']) -@bp.route("/new/", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) +@bp.route("/new/", methods=["GET", "POST"]) def list_new(group_id: Optional[int] = None) -> ResponseReturnValue: form = NewMirrorListForm() form.provider.choices = list(MirrorList.providers_supported.items()) @@ -116,43 +122,53 @@ def list_new(group_id: Optional[int] = None) -> ResponseReturnValue: return redirect(url_for("portal.list.list_list")) if group_id: form.group.data = group_id - return render_template("new.html.j2", - form=form, - **_SECTION_TEMPLATE_VARS) + return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS) class NewMirrorListForm(FlaskForm): # type: ignore - pool = SelectField('Resource Pool', validators=[DataRequired()]) - provider = SelectField('Provider', validators=[DataRequired()]) - format = SelectField('Distribution Method', validators=[DataRequired()]) - encoding = SelectField('Encoding', validators=[DataRequired()]) - description = StringField('Description', validators=[DataRequired()]) - container = StringField('Container', validators=[DataRequired()], - description="GitHub Project, GitLab Project or AWS S3 bucket name.") - branch = StringField('Git Branch/AWS Region', validators=[DataRequired()], - description="For GitHub/GitLab, set this to the desired branch name, e.g. main. For AWS S3, " - "set this field to the desired region, e.g. us-east-1.") - role = StringField('Role ARN', - description="(Optional) ARN for IAM role to assume for interaction with the S3 bucket.") - filename = StringField('Filename', validators=[DataRequired()]) - submit = SubmitField('Save Changes') + pool = SelectField("Resource Pool", validators=[DataRequired()]) + provider = SelectField("Provider", validators=[DataRequired()]) + format = SelectField("Distribution Method", validators=[DataRequired()]) + encoding = SelectField("Encoding", validators=[DataRequired()]) + description = StringField("Description", validators=[DataRequired()]) + container = StringField( + "Container", + validators=[DataRequired()], + description="GitHub Project, GitLab Project or AWS S3 bucket name.", + ) + branch = StringField( + "Git Branch/AWS Region", + validators=[DataRequired()], + description="For GitHub/GitLab, set this to the desired branch name, e.g. main. For AWS S3, " + "set this field to the desired region, e.g. us-east-1.", + ) + role = StringField( + "Role ARN", + description="(Optional) ARN for IAM role to assume for interaction with the S3 bucket.", + ) + filename = StringField("Filename", validators=[DataRequired()]) + submit = SubmitField("Save Changes") def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.pool.choices = [ - (pool.id, pool.pool_name) for pool in Pool.query.all() - ] + self.pool.choices = [(pool.id, pool.pool_name) for pool in Pool.query.all()] -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def list_edit(list_id: int) -> ResponseReturnValue: - list_: Optional[MirrorList] = MirrorList.query.filter(MirrorList.id == list_id).first() + list_: Optional[MirrorList] = MirrorList.query.filter( + MirrorList.id == list_id + ).first() if list_ is None: - return Response(render_template("error.html.j2", - header="404 Distribution List Not Found", - message="The requested distribution list could not be found.", - **_SECTION_TEMPLATE_VARS), - status=404) + return Response( + render_template( + "error.html.j2", + header="404 Distribution List Not Found", + message="The requested distribution list could not be found.", + **_SECTION_TEMPLATE_VARS + ), + status=404, + ) form = NewMirrorListForm( pool=list_.pool_id, provider=list_.provider, @@ -162,7 +178,7 @@ def list_edit(list_id: int) -> ResponseReturnValue: container=list_.container, branch=list_.branch, role=list_.role, - filename=list_.filename + filename=list_.filename, ) form.provider.choices = list(MirrorList.providers_supported.items()) form.format.choices = list(MirrorList.formats_supported.items()) @@ -182,7 +198,10 @@ def list_edit(list_id: int) -> ResponseReturnValue: db.session.commit() flash("Saved changes to group.", "success") except exc.SQLAlchemyError: - flash("An error occurred saving the changes to the distribution list.", "danger") - return render_template("distlist.html.j2", - list=list_, form=form, - **_SECTION_TEMPLATE_VARS) + flash( + "An error occurred saving the changes to the distribution list.", + "danger", + ) + return render_template( + "distlist.html.j2", list=list_, form=form, **_SECTION_TEMPLATE_VARS + ) diff --git a/app/portal/onion.py b/app/portal/onion.py index fa8f759..0e779ad 100644 --- a/app/portal/onion.py +++ b/app/portal/onion.py @@ -9,13 +9,13 @@ from app.portal.util import response_404, view_lifecycle bp = Blueprint("onion", __name__) -@bp.route("/new", methods=['GET', 'POST']) -@bp.route("/new/", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) +@bp.route("/new/", methods=["GET", "POST"]) def onion_new(group_id: Optional[int] = None) -> ResponseReturnValue: return redirect("/ui/web/onions/new") -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def onion_edit(onion_id: int) -> ResponseReturnValue: return redirect("/ui/web/onions/edit/{}".format(onion_id)) @@ -25,9 +25,11 @@ def onion_list() -> ResponseReturnValue: return redirect("/ui/web/onions") -@bp.route("/destroy/", methods=['GET', 'POST']) +@bp.route("/destroy/", methods=["GET", "POST"]) def onion_destroy(onion_id: str) -> ResponseReturnValue: - onion = Onion.query.filter(Onion.id == int(onion_id), Onion.destroyed.is_(None)).first() + onion = Onion.query.filter( + Onion.id == int(onion_id), Onion.destroyed.is_(None) + ).first() if onion is None: return response_404("The requested onion service could not be found.") return view_lifecycle( @@ -37,5 +39,5 @@ def onion_destroy(onion_id: str) -> ResponseReturnValue: success_view="portal.onion.onion_list", section="onion", resource=onion, - action="destroy" + action="destroy", ) diff --git a/app/portal/origin.py b/app/portal/origin.py index 2de0f50..4f52e04 100644 --- a/app/portal/origin.py +++ b/app/portal/origin.py @@ -4,13 +4,11 @@ from typing import List, Optional import requests import sqlalchemy -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm from sqlalchemy import exc -from wtforms import (BooleanField, IntegerField, SelectField, StringField, - SubmitField) +from wtforms import BooleanField, IntegerField, SelectField, StringField, SubmitField from wtforms.validators import DataRequired from app.extensions import db @@ -22,29 +20,31 @@ bp = Blueprint("origin", __name__) class NewOriginForm(FlaskForm): # type: ignore - domain_name = StringField('Domain Name', validators=[DataRequired()]) - description = StringField('Description', validators=[DataRequired()]) - group = SelectField('Group', validators=[DataRequired()]) + domain_name = StringField("Domain Name", validators=[DataRequired()]) + description = StringField("Description", validators=[DataRequired()]) + group = SelectField("Group", validators=[DataRequired()]) auto_rotate = BooleanField("Enable auto-rotation?", default=True) smart_proxy = BooleanField("Requires smart proxy?", default=False) asset_domain = BooleanField("Used to host assets for other domains?", default=False) - submit = SubmitField('Save Changes') + submit = SubmitField("Save Changes") class EditOriginForm(FlaskForm): # type: ignore[misc] - description = StringField('Description', validators=[DataRequired()]) - group = SelectField('Group', validators=[DataRequired()]) + description = StringField("Description", validators=[DataRequired()]) + group = SelectField("Group", validators=[DataRequired()]) auto_rotate = BooleanField("Enable auto-rotation?") smart_proxy = BooleanField("Requires smart proxy?") asset_domain = BooleanField("Used to host assets for other domains?", default=False) risk_level_override = BooleanField("Force Risk Level Override?") - risk_level_override_number = IntegerField("Forced Risk Level", description="Number from 0 to 20", default=0) - submit = SubmitField('Save Changes') + risk_level_override_number = IntegerField( + "Forced Risk Level", description="Number from 0 to 20", default=0 + ) + submit = SubmitField("Save Changes") class CountrySelectForm(FlaskForm): # type: ignore[misc] country = SelectField("Country", validators=[DataRequired()]) - submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) + submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"}) def final_domain_name(domain_name: str) -> str: @@ -53,8 +53,8 @@ def final_domain_name(domain_name: str) -> str: return urllib.parse.urlparse(r.url).netloc -@bp.route("/new", methods=['GET', 'POST']) -@bp.route("/new/", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) +@bp.route("/new/", methods=["GET", "POST"]) def origin_new(group_id: Optional[int] = None) -> ResponseReturnValue: form = NewOriginForm() form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] @@ -81,22 +81,28 @@ def origin_new(group_id: Optional[int] = None) -> ResponseReturnValue: return render_template("new.html.j2", section="origin", form=form) -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def origin_edit(origin_id: int) -> ResponseReturnValue: origin: Optional[Origin] = Origin.query.filter(Origin.id == origin_id).first() if origin is None: - return Response(render_template("error.html.j2", - section="origin", - header="404 Origin Not Found", - message="The requested origin could not be found."), - status=404) - form = EditOriginForm(group=origin.group_id, - description=origin.description, - auto_rotate=origin.auto_rotation, - smart_proxy=origin.smart, - asset_domain=origin.assets, - risk_level_override=origin.risk_level_override is not None, - risk_level_override_number=origin.risk_level_override) + return Response( + render_template( + "error.html.j2", + section="origin", + header="404 Origin Not Found", + message="The requested origin could not be found.", + ), + status=404, + ) + form = EditOriginForm( + group=origin.group_id, + description=origin.description, + auto_rotate=origin.auto_rotation, + smart_proxy=origin.smart, + asset_domain=origin.assets, + risk_level_override=origin.risk_level_override is not None, + risk_level_override_number=origin.risk_level_override, + ) form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] if form.validate_on_submit(): origin.group_id = form.group.data @@ -114,41 +120,47 @@ def origin_edit(origin_id: int) -> ResponseReturnValue: flash(f"Saved changes for origin {origin.domain_name}.", "success") except exc.SQLAlchemyError: flash("An error occurred saving the changes to the origin.", "danger") - return render_template("origin.html.j2", - section="origin", - origin=origin, form=form) + return render_template("origin.html.j2", section="origin", origin=origin, form=form) @bp.route("/list") def origin_list() -> ResponseReturnValue: origins: List[Origin] = Origin.query.order_by(Origin.domain_name).all() - return render_template("list.html.j2", - section="origin", - title="Web Origins", - item="origin", - new_link=url_for("portal.origin.origin_new"), - items=origins, - extra_buttons=[{ - "link": url_for("portal.origin.origin_onion"), - "text": "Onion services", - "style": "onion" - }]) + return render_template( + "list.html.j2", + section="origin", + title="Web Origins", + item="origin", + new_link=url_for("portal.origin.origin_new"), + items=origins, + extra_buttons=[ + { + "link": url_for("portal.origin.origin_onion"), + "text": "Onion services", + "style": "onion", + } + ], + ) @bp.route("/onion") def origin_onion() -> ResponseReturnValue: origins = Origin.query.order_by(Origin.domain_name).all() - return render_template("list.html.j2", - section="origin", - title="Onion Sites", - item="onion service", - new_link=url_for("portal.onion.onion_new"), - items=origins) + return render_template( + "list.html.j2", + section="origin", + title="Onion Sites", + item="onion service", + new_link=url_for("portal.onion.onion_new"), + items=origins, + ) -@bp.route("/destroy/", methods=['GET', 'POST']) +@bp.route("/destroy/", methods=["GET", "POST"]) def origin_destroy(origin_id: int) -> ResponseReturnValue: - origin = Origin.query.filter(Origin.id == origin_id, Origin.destroyed.is_(None)).first() + origin = Origin.query.filter( + Origin.id == origin_id, Origin.destroyed.is_(None) + ).first() if origin is None: return response_404("The requested origin could not be found.") return view_lifecycle( @@ -158,32 +170,44 @@ def origin_destroy(origin_id: int) -> ResponseReturnValue: success_view="portal.origin.origin_list", section="origin", resource=origin, - action="destroy" + action="destroy", ) -@bp.route('/country_remove//', methods=['GET', 'POST']) +@bp.route("/country_remove//", methods=["GET", "POST"]) def origin_country_remove(origin_id: int, country_id: int) -> ResponseReturnValue: origin = Origin.query.filter(Origin.id == origin_id).first() if origin is None: - return Response(render_template("error.html.j2", - section="origin", - header="404 Pool Not Found", - message="The requested origin could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="origin", + header="404 Pool Not Found", + message="The requested origin could not be found.", + ), + status=404, + ) country = Country.query.filter(Country.id == country_id).first() if country is None: - return Response(render_template("error.html.j2", - section="origin", - header="404 Country Not Found", - message="The requested country could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="origin", + header="404 Country Not Found", + message="The requested country could not be found.", + ), + status=404, + ) if country not in origin.countries: - return Response(render_template("error.html.j2", - section="origin", - header="404 Country Not In Pool", - message="The requested country could not be found in the specified origin."), - status=404) + return Response( + render_template( + "error.html.j2", + section="origin", + header="404 Country Not In Pool", + message="The requested country could not be found in the specified origin.", + ), + status=404, + ) form = LifecycleForm() if form.validate_on_submit(): origin.countries.remove(country) @@ -193,32 +217,45 @@ def origin_country_remove(origin_id: int, country_id: int) -> ResponseReturnValu return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id)) except sqlalchemy.exc.SQLAlchemyError: flash("An error occurred saving the changes to the origin.", "danger") - return render_template("lifecycle.html.j2", - header=f"Remove {country.description} from the {origin.domain_name} origin?", - message="Stop monitoring in this country.", - section="origin", - origin=origin, form=form) + return render_template( + "lifecycle.html.j2", + header=f"Remove {country.description} from the {origin.domain_name} origin?", + message="Stop monitoring in this country.", + section="origin", + origin=origin, + form=form, + ) -@bp.route('/country_add/', methods=['GET', 'POST']) +@bp.route("/country_add/", methods=["GET", "POST"]) def origin_country_add(origin_id: int) -> ResponseReturnValue: origin = Origin.query.filter(Origin.id == origin_id).first() if origin is None: - return Response(render_template("error.html.j2", - section="origin", - header="404 Origin Not Found", - message="The requested origin could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="origin", + header="404 Origin Not Found", + message="The requested origin could not be found.", + ), + status=404, + ) form = CountrySelectForm() - form.country.choices = [(x.id, f"{x.country_code} - {x.description}") for x in Country.query.all()] + form.country.choices = [ + (x.id, f"{x.country_code} - {x.description}") for x in Country.query.all() + ] if form.validate_on_submit(): country = Country.query.filter(Country.id == form.country.data).first() if country is None: - return Response(render_template("error.html.j2", - section="origin", - header="404 Country Not Found", - message="The requested country could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="origin", + header="404 Country Not Found", + message="The requested country could not be found.", + ), + status=404, + ) origin.countries.append(country) try: db.session.commit() @@ -226,8 +263,11 @@ def origin_country_add(origin_id: int) -> ResponseReturnValue: return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id)) except sqlalchemy.exc.SQLAlchemyError: flash("An error occurred saving the changes to the origin.", "danger") - return render_template("lifecycle.html.j2", - header=f"Add a country to {origin.domain_name}", - message="Enable monitoring from this country:", - section="origin", - origin=origin, form=form) + return render_template( + "lifecycle.html.j2", + header=f"Add a country to {origin.domain_name}", + message="Enable monitoring from this country:", + section="origin", + origin=origin, + form=form, + ) diff --git a/app/portal/pool.py b/app/portal/pool.py index 665ff11..bc39811 100644 --- a/app/portal/pool.py +++ b/app/portal/pool.py @@ -3,8 +3,7 @@ import secrets from datetime import datetime, timezone import sqlalchemy -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm from wtforms import SelectField, StringField, SubmitField @@ -21,41 +20,50 @@ class NewPoolForm(FlaskForm): # type: ignore[misc] group_name = StringField("Short Name", validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()]) redirector_domain = StringField("Redirector Domain") - submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) + submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"}) class EditPoolForm(FlaskForm): # type: ignore[misc] description = StringField("Description", validators=[DataRequired()]) redirector_domain = StringField("Redirector Domain") - api_key = StringField("API Key", description=("Any change to this field (e.g. clearing it) will result in the " - "API key being regenerated.")) - submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) + api_key = StringField( + "API Key", + description=( + "Any change to this field (e.g. clearing it) will result in the " + "API key being regenerated." + ), + ) + submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"}) class GroupSelectForm(FlaskForm): # type: ignore[misc] group = SelectField("Group", validators=[DataRequired()]) - submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) + submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"}) @bp.route("/list") def pool_list() -> ResponseReturnValue: pools = Pool.query.order_by(Pool.pool_name).all() - return render_template("list.html.j2", - section="pool", - title="Resource Pools", - item="pool", - items=pools, - new_link=url_for("portal.pool.pool_new")) + return render_template( + "list.html.j2", + section="pool", + title="Resource Pools", + item="pool", + items=pools, + new_link=url_for("portal.pool.pool_new"), + ) -@bp.route("/new", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) def pool_new() -> ResponseReturnValue: form = NewPoolForm() if form.validate_on_submit(): pool = Pool() pool.pool_name = form.group_name.data pool.description = form.description.data - pool.redirector_domain = form.redirector_domain.data if form.redirector_domain.data != "" else None + pool.redirector_domain = ( + form.redirector_domain.data if form.redirector_domain.data != "" else None + ) pool.api_key = secrets.token_urlsafe(nbytes=32) pool.added = datetime.now(timezone.utc) pool.updated = datetime.now(timezone.utc) @@ -71,21 +79,29 @@ def pool_new() -> ResponseReturnValue: return render_template("new.html.j2", section="pool", form=form) -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def pool_edit(pool_id: int) -> ResponseReturnValue: pool = Pool.query.filter(Pool.id == pool_id).first() if pool is None: - return Response(render_template("error.html.j2", - section="pool", - header="404 Pool Not Found", - message="The requested pool could not be found."), - status=404) - form = EditPoolForm(description=pool.description, - api_key=pool.api_key, - redirector_domain=pool.redirector_domain) + return Response( + render_template( + "error.html.j2", + section="pool", + header="404 Pool Not Found", + message="The requested pool could not be found.", + ), + status=404, + ) + form = EditPoolForm( + description=pool.description, + api_key=pool.api_key, + redirector_domain=pool.redirector_domain, + ) if form.validate_on_submit(): pool.description = form.description.data - pool.redirector_domain = form.redirector_domain.data if form.redirector_domain.data != "" else None + pool.redirector_domain = ( + form.redirector_domain.data if form.redirector_domain.data != "" else None + ) if form.api_key.data != pool.api_key: pool.api_key = secrets.token_urlsafe(nbytes=32) form.api_key.data = pool.api_key @@ -95,33 +111,43 @@ def pool_edit(pool_id: int) -> ResponseReturnValue: flash("Saved changes to pool.", "success") except sqlalchemy.exc.SQLAlchemyError: flash("An error occurred saving the changes to the pool.", "danger") - return render_template("pool.html.j2", - section="pool", - pool=pool, form=form) + return render_template("pool.html.j2", section="pool", pool=pool, form=form) -@bp.route('/group_remove//', methods=['GET', 'POST']) +@bp.route("/group_remove//", methods=["GET", "POST"]) def pool_group_remove(pool_id: int, group_id: int) -> ResponseReturnValue: pool = Pool.query.filter(Pool.id == pool_id).first() if pool is None: - return Response(render_template("error.html.j2", - section="pool", - header="404 Pool Not Found", - message="The requested pool could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="pool", + header="404 Pool Not Found", + message="The requested pool could not be found.", + ), + status=404, + ) group = Group.query.filter(Group.id == group_id).first() if group is None: - return Response(render_template("error.html.j2", - section="pool", - header="404 Group Not Found", - message="The requested group could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="pool", + header="404 Group Not Found", + message="The requested group could not be found.", + ), + status=404, + ) if group not in pool.groups: - return Response(render_template("error.html.j2", - section="pool", - header="404 Group Not In Pool", - message="The requested group could not be found in the specified pool."), - status=404) + return Response( + render_template( + "error.html.j2", + section="pool", + header="404 Group Not In Pool", + message="The requested group could not be found in the specified pool.", + ), + status=404, + ) form = LifecycleForm() if form.validate_on_submit(): pool.groups.remove(group) @@ -131,32 +157,43 @@ def pool_group_remove(pool_id: int, group_id: int) -> ResponseReturnValue: return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id)) except sqlalchemy.exc.SQLAlchemyError: flash("An error occurred saving the changes to the pool.", "danger") - return render_template("lifecycle.html.j2", - header=f"Remove {group.group_name} from the {pool.pool_name} pool?", - message="Resources deployed and available in the pool will be destroyed soon.", - section="pool", - pool=pool, form=form) + return render_template( + "lifecycle.html.j2", + header=f"Remove {group.group_name} from the {pool.pool_name} pool?", + message="Resources deployed and available in the pool will be destroyed soon.", + section="pool", + pool=pool, + form=form, + ) -@bp.route('/group_add/', methods=['GET', 'POST']) +@bp.route("/group_add/", methods=["GET", "POST"]) def pool_group_add(pool_id: int) -> ResponseReturnValue: pool = Pool.query.filter(Pool.id == pool_id).first() if pool is None: - return Response(render_template("error.html.j2", - section="pool", - header="404 Pool Not Found", - message="The requested pool could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="pool", + header="404 Pool Not Found", + message="The requested pool could not be found.", + ), + status=404, + ) form = GroupSelectForm() form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] if form.validate_on_submit(): group = Group.query.filter(Group.id == form.group.data).first() if group is None: - return Response(render_template("error.html.j2", - section="pool", - header="404 Group Not Found", - message="The requested group could not be found."), - status=404) + return Response( + render_template( + "error.html.j2", + section="pool", + header="404 Group Not Found", + message="The requested group could not be found.", + ), + status=404, + ) pool.groups.append(group) try: db.session.commit() @@ -164,8 +201,11 @@ def pool_group_add(pool_id: int) -> ResponseReturnValue: return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id)) except sqlalchemy.exc.SQLAlchemyError: flash("An error occurred saving the changes to the pool.", "danger") - return render_template("lifecycle.html.j2", - header=f"Add a group to {pool.pool_name}", - message="Resources will shortly be deployed and available for all origins in this group.", - section="pool", - pool=pool, form=form) + return render_template( + "lifecycle.html.j2", + header=f"Add a group to {pool.pool_name}", + message="Resources will shortly be deployed and available for all origins in this group.", + section="pool", + pool=pool, + form=form, + ) diff --git a/app/portal/proxy.py b/app/portal/proxy.py index 1efda70..7a28cc8 100644 --- a/app/portal/proxy.py +++ b/app/portal/proxy.py @@ -1,5 +1,4 @@ -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from sqlalchemy import desc @@ -12,51 +11,63 @@ bp = Blueprint("proxy", __name__) @bp.route("/list") def proxy_list() -> ResponseReturnValue: - proxies = Proxy.query.filter(Proxy.destroyed.is_(None)).order_by(desc(Proxy.added)).all() - return render_template("list.html.j2", - section="proxy", - title="Proxies", - item="proxy", - items=proxies) + proxies = ( + Proxy.query.filter(Proxy.destroyed.is_(None)).order_by(desc(Proxy.added)).all() + ) + return render_template( + "list.html.j2", section="proxy", title="Proxies", item="proxy", items=proxies + ) -@bp.route("/expire/", methods=['GET', 'POST']) +@bp.route("/expire/", methods=["GET", "POST"]) def proxy_expire(proxy_id: int) -> ResponseReturnValue: proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first() if proxy is None: - return Response(render_template("error.html.j2", - header="404 Proxy Not Found", - message="The requested proxy could not be found. It may have already been " - "destroyed.")) + return Response( + render_template( + "error.html.j2", + header="404 Proxy Not Found", + message="The requested proxy could not be found. It may have already been " + "destroyed.", + ) + ) form = LifecycleForm() if form.validate_on_submit(): proxy.destroy() db.session.commit() flash("Proxy will be shortly retired.", "success") return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id)) - return render_template("lifecycle.html.j2", - header=f"Expire proxy for {proxy.origin.domain_name} immediately?", - message=proxy.url, - section="proxy", - form=form) + return render_template( + "lifecycle.html.j2", + header=f"Expire proxy for {proxy.origin.domain_name} immediately?", + message=proxy.url, + section="proxy", + form=form, + ) -@bp.route("/block/", methods=['GET', 'POST']) +@bp.route("/block/", methods=["GET", "POST"]) def proxy_block(proxy_id: int) -> ResponseReturnValue: proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first() if proxy is None: - return Response(render_template("error.html.j2", - header="404 Proxy Not Found", - message="The requested proxy could not be found. It may have already been " - "destroyed.")) + return Response( + render_template( + "error.html.j2", + header="404 Proxy Not Found", + message="The requested proxy could not be found. It may have already been " + "destroyed.", + ) + ) form = LifecycleForm() if form.validate_on_submit(): proxy.deprecate(reason="manual") db.session.commit() flash("Proxy will be shortly replaced.", "success") return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id)) - return render_template("lifecycle.html.j2", - header=f"Mark proxy for {proxy.origin.domain_name} as blocked?", - message=proxy.url, - section="proxy", - form=form) + return render_template( + "lifecycle.html.j2", + header=f"Mark proxy for {proxy.origin.domain_name} as blocked?", + message=proxy.url, + section="proxy", + form=form, + ) diff --git a/app/portal/report.py b/app/portal/report.py index a8ec4bb..4d5a7ef 100644 --- a/app/portal/report.py +++ b/app/portal/report.py @@ -20,12 +20,12 @@ def generate_subqueries(): deprecations_24hr_subquery = ( db.session.query( DeprecationAlias.resource_id, - func.count(DeprecationAlias.resource_id).label('deprecations_24hr') + func.count(DeprecationAlias.resource_id).label("deprecations_24hr"), ) .filter( - DeprecationAlias.reason.like('block_%'), + DeprecationAlias.reason.like("block_%"), DeprecationAlias.deprecated_at >= now - timedelta(hours=24), - DeprecationAlias.resource_type == 'Proxy' + DeprecationAlias.resource_type == "Proxy", ) .group_by(DeprecationAlias.resource_id) .subquery() @@ -33,12 +33,12 @@ def generate_subqueries(): deprecations_72hr_subquery = ( db.session.query( DeprecationAlias.resource_id, - func.count(DeprecationAlias.resource_id).label('deprecations_72hr') + func.count(DeprecationAlias.resource_id).label("deprecations_72hr"), ) .filter( - DeprecationAlias.reason.like('block_%'), + DeprecationAlias.reason.like("block_%"), DeprecationAlias.deprecated_at >= now - timedelta(hours=72), - DeprecationAlias.resource_type == 'Proxy' + DeprecationAlias.resource_type == "Proxy", ) .group_by(DeprecationAlias.resource_id) .subquery() @@ -52,13 +52,23 @@ def countries_report(): return ( db.session.query( Country, - func.coalesce(func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0).label('total_deprecations_24hr'), - func.coalesce(func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0).label('total_deprecations_72hr') + func.coalesce( + func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0 + ).label("total_deprecations_24hr"), + func.coalesce( + func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0 + ).label("total_deprecations_72hr"), ) .join(Origin, Country.origins) .join(Proxy, Origin.proxies) - .outerjoin(deprecations_24hr_subquery, Proxy.id == deprecations_24hr_subquery.c.resource_id) - .outerjoin(deprecations_72hr_subquery, Proxy.id == deprecations_72hr_subquery.c.resource_id) + .outerjoin( + deprecations_24hr_subquery, + Proxy.id == deprecations_24hr_subquery.c.resource_id, + ) + .outerjoin( + deprecations_72hr_subquery, + Proxy.id == deprecations_72hr_subquery.c.resource_id, + ) .group_by(Country.id) .all() ) @@ -70,12 +80,22 @@ def origins_report(): return ( db.session.query( Origin, - func.coalesce(func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0).label('total_deprecations_24hr'), - func.coalesce(func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0).label('total_deprecations_72hr') + func.coalesce( + func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0 + ).label("total_deprecations_24hr"), + func.coalesce( + func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0 + ).label("total_deprecations_72hr"), ) .outerjoin(Proxy, Origin.proxies) - .outerjoin(deprecations_24hr_subquery, Proxy.id == deprecations_24hr_subquery.c.resource_id) - .outerjoin(deprecations_72hr_subquery, Proxy.id == deprecations_72hr_subquery.c.resource_id) + .outerjoin( + deprecations_24hr_subquery, + Proxy.id == deprecations_24hr_subquery.c.resource_id, + ) + .outerjoin( + deprecations_72hr_subquery, + Proxy.id == deprecations_72hr_subquery.c.resource_id, + ) .filter(Origin.destroyed.is_(None)) .group_by(Origin.id) .order_by(desc("total_deprecations_24hr")) @@ -83,26 +103,37 @@ def origins_report(): ) -@report.app_template_filter('country_name') +@report.app_template_filter("country_name") def country_description_filter(country_code): country = Country.query.filter_by(country_code=country_code).first() return country.description if country else None -@report.route("/blocks", methods=['GET']) +@report.route("/blocks", methods=["GET"]) def report_blocks() -> ResponseReturnValue: - blocked_today = db.session.query( # type: ignore[no-untyped-call] - Origin.domain_name, - Origin.description, - Proxy.added, - Proxy.deprecated, - Proxy.deprecation_reason - ).join(Origin, Origin.id == Proxy.origin_id - ).filter(and_(Proxy.deprecated > datetime.now(tz=timezone.utc) - timedelta(days=1), - Proxy.deprecation_reason.like('block_%'))).all() + blocked_today = ( + db.session.query( # type: ignore[no-untyped-call] + Origin.domain_name, + Origin.description, + Proxy.added, + Proxy.deprecated, + Proxy.deprecation_reason, + ) + .join(Origin, Origin.id == Proxy.origin_id) + .filter( + and_( + Proxy.deprecated > datetime.now(tz=timezone.utc) - timedelta(days=1), + Proxy.deprecation_reason.like("block_%"), + ) + ) + .all() + ) - return render_template("report_blocks.html.j2", - blocked_today=blocked_today, - origins=sorted(origins_report(), key=lambda o: o[1], reverse=True), - countries=sorted(countries_report(), key=lambda c: c[0].risk_level, reverse=True), - ) + return render_template( + "report_blocks.html.j2", + blocked_today=blocked_today, + origins=sorted(origins_report(), key=lambda o: o[1], reverse=True), + countries=sorted( + countries_report(), key=lambda c: c[0].risk_level, reverse=True + ), + ) diff --git a/app/portal/smart_proxy.py b/app/portal/smart_proxy.py index bfa7b15..2aa0988 100644 --- a/app/portal/smart_proxy.py +++ b/app/portal/smart_proxy.py @@ -9,9 +9,15 @@ bp = Blueprint("smart_proxy", __name__) @bp.route("/list") def smart_proxy_list() -> ResponseReturnValue: - instances = SmartProxy.query.filter(SmartProxy.destroyed.is_(None)).order_by(desc(SmartProxy.added)).all() - return render_template("list.html.j2", - section="smart_proxy", - title="Smart Proxy Instances", - item="smart proxy", - items=instances) + instances = ( + SmartProxy.query.filter(SmartProxy.destroyed.is_(None)) + .order_by(desc(SmartProxy.added)) + .all() + ) + return render_template( + "list.html.j2", + section="smart_proxy", + title="Smart Proxy Instances", + item="smart proxy", + items=instances, + ) diff --git a/app/portal/static.py b/app/portal/static.py index 3bf5b45..6cff742 100644 --- a/app/portal/static.py +++ b/app/portal/static.py @@ -2,13 +2,19 @@ import logging from typing import Any, List, Optional import sqlalchemy.exc -from flask import (Blueprint, Response, current_app, flash, redirect, - render_template, url_for) +from flask import ( + Blueprint, + Response, + current_app, + flash, + redirect, + render_template, + url_for, +) from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm from sqlalchemy import exc -from wtforms import (BooleanField, FileField, SelectField, StringField, - SubmitField) +from wtforms import BooleanField, FileField, SelectField, StringField, SubmitField from wtforms.validators import DataRequired from app.brm.static import create_static_origin @@ -22,87 +28,99 @@ bp = Blueprint("static", __name__) class StaticOriginForm(FlaskForm): # type: ignore description = StringField( - 'Description', + "Description", validators=[DataRequired()], - description='Enter a brief description of the static website that you are creating in this field. This is ' - 'also a required field.' + description="Enter a brief description of the static website that you are creating in this field. This is " + "also a required field.", ) group = SelectField( - 'Group', + "Group", validators=[DataRequired()], - description='Select the group that you want the origin to belong to from the drop-down menu in this field. ' - 'This is a required field.' + description="Select the group that you want the origin to belong to from the drop-down menu in this field. " + "This is a required field.", ) storage_cloud_account = SelectField( - 'Storage Cloud Account', + "Storage Cloud Account", validators=[DataRequired()], - description='Select the cloud account that you want the origin to be deployed to from the drop-down menu in ' - 'this field. This is a required field.' + description="Select the cloud account that you want the origin to be deployed to from the drop-down menu in " + "this field. This is a required field.", ) source_cloud_account = SelectField( - 'Source Cloud Account', + "Source Cloud Account", validators=[DataRequired()], - description='Select the cloud account that will be used to modify the source repository for the web content ' - 'for this static origin. This is a required field.' + description="Select the cloud account that will be used to modify the source repository for the web content " + "for this static origin. This is a required field.", ) source_project = StringField( - 'Source Project', + "Source Project", validators=[DataRequired()], - description='GitLab project path.' + description="GitLab project path.", ) auto_rotate = BooleanField( - 'Auto-Rotate', + "Auto-Rotate", default=True, - description='Select this field if you want to enable auto-rotation for the mirror. This means that the mirror ' - 'will automatically redeploy with a new domain name if it is detected to be blocked. This field ' - 'is optional and is enabled by default.' + description="Select this field if you want to enable auto-rotation for the mirror. This means that the mirror " + "will automatically redeploy with a new domain name if it is detected to be blocked. This field " + "is optional and is enabled by default.", ) matrix_homeserver = SelectField( - 'Matrix Homeserver', - description='Select the Matrix homeserver from the drop-down box to enable Keanu Convene on mirrors of this ' - 'static origin.' + "Matrix Homeserver", + description="Select the Matrix homeserver from the drop-down box to enable Keanu Convene on mirrors of this " + "static origin.", ) keanu_convene_path = StringField( - 'Keanu Convene Path', - default='talk', - description='Enter the subdirectory to present the Keanu Convene application at on the mirror. This defaults ' - 'to "talk".' + "Keanu Convene Path", + default="talk", + description="Enter the subdirectory to present the Keanu Convene application at on the mirror. This defaults " + 'to "talk".', ) keanu_convene_logo = FileField( - 'Keanu Convene Logo', - description='Logo to use for Keanu Convene' + "Keanu Convene Logo", description="Logo to use for Keanu Convene" ) keanu_convene_color = StringField( - 'Keanu Convene Accent Color', - default='#0047ab', - description='Accent color to use for Keanu Convene (HTML hex code)' + "Keanu Convene Accent Color", + default="#0047ab", + description="Accent color to use for Keanu Convene (HTML hex code)", ) enable_clean_insights = BooleanField( - 'Enable Clean Insights', - description='When enabled, a Clean Insights Measurement Proxy endpoint is deployed on the mirror to allow for ' - 'submission of results from any of the supported Clean Insights SDKs.' + "Enable Clean Insights", + description="When enabled, a Clean Insights Measurement Proxy endpoint is deployed on the mirror to allow for " + "submission of results from any of the supported Clean Insights SDKs.", ) - submit = SubmitField('Save Changes') + submit = SubmitField("Save Changes") def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.group.choices = [(x.id, x.group_name) for x in Group.query.all()] - self.storage_cloud_account.choices = [(x.id, f"{x.provider.description} - {x.description}") for x in - CloudAccount.query.filter( - CloudAccount.provider == CloudProvider.AWS).all()] - self.source_cloud_account.choices = [(x.id, f"{x.provider.description} - {x.description}") for x in - CloudAccount.query.filter( - CloudAccount.provider == CloudProvider.GITLAB).all()] - self.matrix_homeserver.choices = [(x, x) for x in current_app.config['MATRIX_HOMESERVERS']] + self.storage_cloud_account.choices = [ + (x.id, f"{x.provider.description} - {x.description}") + for x in CloudAccount.query.filter( + CloudAccount.provider == CloudProvider.AWS + ).all() + ] + self.source_cloud_account.choices = [ + (x.id, f"{x.provider.description} - {x.description}") + for x in CloudAccount.query.filter( + CloudAccount.provider == CloudProvider.GITLAB + ).all() + ] + self.matrix_homeserver.choices = [ + (x, x) for x in current_app.config["MATRIX_HOMESERVERS"] + ] -@bp.route("/new", methods=['GET', 'POST']) -@bp.route("/new/", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) +@bp.route("/new/", methods=["GET", "POST"]) def static_new(group_id: Optional[int] = None) -> ResponseReturnValue: form = StaticOriginForm() - if len(form.source_cloud_account.choices) == 0 or len(form.storage_cloud_account.choices) == 0: - flash("You must add at least one AWS account and at least one GitLab account before creating static origins.", - "warning") + if ( + len(form.source_cloud_account.choices) == 0 + or len(form.storage_cloud_account.choices) == 0 + ): + flash( + "You must add at least one AWS account and at least one GitLab account before creating static origins.", + "warning", + ) return redirect(url_for("portal.cloud.cloud_account_list")) if form.validate_on_submit(): try: @@ -118,16 +136,22 @@ def static_new(group_id: Optional[int] = None) -> ResponseReturnValue: form.keanu_convene_logo.data, form.keanu_convene_color.data, form.enable_clean_insights.data, - True + True, ) flash(f"Created new static origin #{static.id}.", "success") return redirect(url_for("portal.static.static_edit", static_id=static.id)) - except ValueError as e: # may be returned by create_static_origin and from the int conversion + except ( + ValueError + ) as e: # may be returned by create_static_origin and from the int conversion logging.warning(e) - flash("Failed to create new static origin due to an invalid input.", "danger") + flash( + "Failed to create new static origin due to an invalid input.", "danger" + ) return redirect(url_for("portal.static.static_list")) except exc.SQLAlchemyError as e: - flash("Failed to create new static origin due to a database error.", "danger") + flash( + "Failed to create new static origin due to a database error.", "danger" + ) logging.warning(e) return redirect(url_for("portal.static.static_list")) if group_id: @@ -135,24 +159,32 @@ def static_new(group_id: Optional[int] = None) -> ResponseReturnValue: return render_template("new.html.j2", section="static", form=form) -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def static_edit(static_id: int) -> ResponseReturnValue: - static_origin: Optional[StaticOrigin] = StaticOrigin.query.filter(StaticOrigin.id == static_id).first() + static_origin: Optional[StaticOrigin] = StaticOrigin.query.filter( + StaticOrigin.id == static_id + ).first() if static_origin is None: - return Response(render_template("error.html.j2", - section="static", - header="404 Origin Not Found", - message="The requested static origin could not be found."), - status=404) - form = StaticOriginForm(description=static_origin.description, - group=static_origin.group_id, - storage_cloud_account=static_origin.storage_cloud_account_id, - source_cloud_account=static_origin.source_cloud_account_id, - source_project=static_origin.source_project, - matrix_homeserver=static_origin.matrix_homeserver, - keanu_convene_path=static_origin.keanu_convene_path, - auto_rotate=static_origin.auto_rotate, - enable_clean_insights=bool(static_origin.clean_insights_backend)) + return Response( + render_template( + "error.html.j2", + section="static", + header="404 Origin Not Found", + message="The requested static origin could not be found.", + ), + status=404, + ) + form = StaticOriginForm( + description=static_origin.description, + group=static_origin.group_id, + storage_cloud_account=static_origin.storage_cloud_account_id, + source_cloud_account=static_origin.source_cloud_account_id, + source_project=static_origin.source_project, + matrix_homeserver=static_origin.matrix_homeserver, + keanu_convene_path=static_origin.keanu_convene_path, + auto_rotate=static_origin.auto_rotate, + enable_clean_insights=bool(static_origin.clean_insights_backend), + ) form.group.render_kw = {"disabled": ""} form.storage_cloud_account.render_kw = {"disabled": ""} form.source_cloud_account.render_kw = {"disabled": ""} @@ -167,50 +199,68 @@ def static_edit(static_id: int) -> ResponseReturnValue: form.keanu_convene_logo.data, form.keanu_convene_color.data, form.enable_clean_insights.data, - True + True, ) flash("Saved changes to group.", "success") - except ValueError as e: # may be returned by create_static_origin and from the int conversion + except ( + ValueError + ) as e: # may be returned by create_static_origin and from the int conversion logging.warning(e) - flash("An error occurred saving the changes to the static origin due to an invalid input.", "danger") + flash( + "An error occurred saving the changes to the static origin due to an invalid input.", + "danger", + ) except exc.SQLAlchemyError as e: logging.warning(e) - flash("An error occurred saving the changes to the static origin due to a database error.", "danger") + flash( + "An error occurred saving the changes to the static origin due to a database error.", + "danger", + ) try: - origin = Origin.query.filter_by(domain_name=static_origin.origin_domain_name).one() + origin = Origin.query.filter_by( + domain_name=static_origin.origin_domain_name + ).one() proxies = origin.proxies except sqlalchemy.exc.NoResultFound: proxies = [] - return render_template("static.html.j2", - section="static", - static=static_origin, form=form, - proxies=proxies) + return render_template( + "static.html.j2", + section="static", + static=static_origin, + form=form, + proxies=proxies, + ) @bp.route("/list") def static_list() -> ResponseReturnValue: - statics: List[StaticOrigin] = StaticOrigin.query.order_by(StaticOrigin.description).all() - return render_template("list.html.j2", - section="static", - title="Static Origins", - item="static", - new_link=url_for("portal.static.static_new"), - items=statics - ) + statics: List[StaticOrigin] = StaticOrigin.query.order_by( + StaticOrigin.description + ).all() + return render_template( + "list.html.j2", + section="static", + title="Static Origins", + item="static", + new_link=url_for("portal.static.static_new"), + items=statics, + ) -@bp.route("/destroy/", methods=['GET', 'POST']) +@bp.route("/destroy/", methods=["GET", "POST"]) def static_destroy(static_id: int) -> ResponseReturnValue: - static = StaticOrigin.query.filter(StaticOrigin.id == static_id, StaticOrigin.destroyed.is_(None)).first() + static = StaticOrigin.query.filter( + StaticOrigin.id == static_id, StaticOrigin.destroyed.is_(None) + ).first() if static is None: return response_404("The requested static origin could not be found.") return view_lifecycle( header=f"Destroy static origin {static.description}", message=static.description, success_message="All proxies from the destroyed static origin will shortly be destroyed at their providers, " - "and the static content will be removed from the cloud provider.", + "and the static content will be removed from the cloud provider.", success_view="portal.static.static_list", section="static", resource=static, - action="destroy" + action="destroy", ) diff --git a/app/portal/storage.py b/app/portal/storage.py index 8fe4b90..76a8396 100644 --- a/app/portal/storage.py +++ b/app/portal/storage.py @@ -17,24 +17,30 @@ bp = Blueprint("storage", __name__) _SECTION_TEMPLATE_VARS = { "section": "automation", - "help_url": "https://bypass.censorship.guide/user/automation.html" + "help_url": "https://bypass.censorship.guide/user/automation.html", } class EditStorageForm(FlaskForm): # type: ignore - force_unlock = BooleanField('Force Unlock') - submit = SubmitField('Save Changes') + force_unlock = BooleanField("Force Unlock") + submit = SubmitField("Save Changes") -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def storage_edit(storage_key: str) -> ResponseReturnValue: - storage: Optional[TerraformState] = TerraformState.query.filter(TerraformState.key == storage_key).first() + storage: Optional[TerraformState] = TerraformState.query.filter( + TerraformState.key == storage_key + ).first() if storage is None: - return Response(render_template("error.html.j2", - header="404 Storage Key Not Found", - message="The requested storage could not be found.", - **_SECTION_TEMPLATE_VARS), - status=404) + return Response( + render_template( + "error.html.j2", + header="404 Storage Key Not Found", + message="The requested storage could not be found.", + **_SECTION_TEMPLATE_VARS + ), + status=404, + ) form = EditStorageForm() if form.validate_on_submit(): if form.force_unlock.data: @@ -45,17 +51,16 @@ def storage_edit(storage_key: str) -> ResponseReturnValue: flash("Storage has been force unlocked.", "success") except exc.SQLAlchemyError: flash("An error occurred unlocking the storage.", "danger") - return render_template("storage.html.j2", - storage=storage, - form=form, - **_SECTION_TEMPLATE_VARS) + return render_template( + "storage.html.j2", storage=storage, form=form, **_SECTION_TEMPLATE_VARS + ) -@bp.route("/kick/", methods=['GET', 'POST']) +@bp.route("/kick/", methods=["GET", "POST"]) def automation_kick(automation_id: int) -> ResponseReturnValue: automation = Automation.query.filter( - Automation.id == automation_id, - Automation.destroyed.is_(None)).first() + Automation.id == automation_id, Automation.destroyed.is_(None) + ).first() if automation is None: return response_404("The requested bridge configuration could not be found.") return view_lifecycle( @@ -65,5 +70,5 @@ def automation_kick(automation_id: int) -> ResponseReturnValue: success_view="portal.automation.automation_list", success_message="This automation job will next run within 1 minute.", resource=automation, - action="kick" + action="kick", ) diff --git a/app/portal/util.py b/app/portal/util.py index 58c27ac..9a60c4e 100644 --- a/app/portal/util.py +++ b/app/portal/util.py @@ -9,19 +9,21 @@ from app.models.activity import Activity def response_404(message: str) -> ResponseReturnValue: - return Response(render_template("error.html.j2", - header="404 Not Found", - message=message)) + return Response( + render_template("error.html.j2", header="404 Not Found", message=message) + ) -def view_lifecycle(*, - header: str, - message: str, - success_message: str, - success_view: str, - section: str, - resource: AbstractResource, - action: str) -> ResponseReturnValue: +def view_lifecycle( + *, + header: str, + message: str, + success_message: str, + success_view: str, + section: str, + resource: AbstractResource, + action: str, +) -> ResponseReturnValue: form = LifecycleForm() if action == "destroy": form.submit.render_kw = {"class": "btn btn-danger"} @@ -41,19 +43,17 @@ def view_lifecycle(*, return redirect(url_for("portal.portal_home")) activity = Activity( activity_type="lifecycle", - text=f"Portal action: {message}. {success_message}" + text=f"Portal action: {message}. {success_message}", ) db.session.add(activity) db.session.commit() activity.notify() flash(success_message, "success") return redirect(url_for(success_view)) - return render_template("lifecycle.html.j2", - header=header, - message=message, - section=section, - form=form) + return render_template( + "lifecycle.html.j2", header=header, message=message, section=section, form=form + ) class LifecycleForm(FlaskForm): # type: ignore - submit = SubmitField('Confirm') + submit = SubmitField("Confirm") diff --git a/app/portal/webhook.py b/app/portal/webhook.py index 4568744..df870f8 100644 --- a/app/portal/webhook.py +++ b/app/portal/webhook.py @@ -1,8 +1,7 @@ from datetime import datetime, timezone from typing import Optional -from flask import (Blueprint, Response, flash, redirect, render_template, - url_for) +from flask import Blueprint, Response, flash, redirect, render_template, url_for from flask.typing import ResponseReturnValue from flask_wtf import FlaskForm from sqlalchemy import exc @@ -26,47 +25,54 @@ def webhook_format_name(key: str) -> str: class NewWebhookForm(FlaskForm): # type: ignore - description = StringField('Description', validators=[DataRequired()]) - format = SelectField('Format', choices=[ - ("telegram", "Telegram"), - ("matrix", "Matrix") - ], validators=[DataRequired()]) - url = StringField('URL', validators=[DataRequired()]) - submit = SubmitField('Save Changes') + description = StringField("Description", validators=[DataRequired()]) + format = SelectField( + "Format", + choices=[("telegram", "Telegram"), ("matrix", "Matrix")], + validators=[DataRequired()], + ) + url = StringField("URL", validators=[DataRequired()]) + submit = SubmitField("Save Changes") -@bp.route("/new", methods=['GET', 'POST']) +@bp.route("/new", methods=["GET", "POST"]) def webhook_new() -> ResponseReturnValue: form = NewWebhookForm() if form.validate_on_submit(): webhook = Webhook( description=form.description.data, format=form.format.data, - url=form.url.data + url=form.url.data, ) try: db.session.add(webhook) db.session.commit() flash(f"Created new webhook {webhook.url}.", "success") - return redirect(url_for("portal.webhook.webhook_edit", webhook_id=webhook.id)) + return redirect( + url_for("portal.webhook.webhook_edit", webhook_id=webhook.id) + ) except exc.SQLAlchemyError: flash("Failed to create new webhook.", "danger") return redirect(url_for("portal.webhook.webhook_list")) return render_template("new.html.j2", section="webhook", form=form) -@bp.route('/edit/', methods=['GET', 'POST']) +@bp.route("/edit/", methods=["GET", "POST"]) def webhook_edit(webhook_id: int) -> ResponseReturnValue: webhook = Webhook.query.filter(Webhook.id == webhook_id).first() if webhook is None: - return Response(render_template("error.html.j2", - section="webhook", - header="404 Webhook Not Found", - message="The requested webhook could not be found."), - status=404) - form = NewWebhookForm(description=webhook.description, - format=webhook.format, - url=webhook.url) + return Response( + render_template( + "error.html.j2", + section="webhook", + header="404 Webhook Not Found", + message="The requested webhook could not be found.", + ), + status=404, + ) + form = NewWebhookForm( + description=webhook.description, format=webhook.format, url=webhook.url + ) if form.validate_on_submit(): webhook.description = form.description.data webhook.format = form.description.data @@ -77,26 +83,29 @@ def webhook_edit(webhook_id: int) -> ResponseReturnValue: flash("Saved changes to webhook.", "success") except exc.SQLAlchemyError: flash("An error occurred saving the changes to the webhook.", "danger") - return render_template("edit.html.j2", - section="webhook", - title="Edit Webhook", - item=webhook, form=form) + return render_template( + "edit.html.j2", section="webhook", title="Edit Webhook", item=webhook, form=form + ) @bp.route("/list") def webhook_list() -> ResponseReturnValue: webhooks = Webhook.query.all() - return render_template("list.html.j2", - section="webhook", - title="Webhooks", - item="webhook", - new_link=url_for("portal.webhook.webhook_new"), - items=webhooks) + return render_template( + "list.html.j2", + section="webhook", + title="Webhooks", + item="webhook", + new_link=url_for("portal.webhook.webhook_new"), + items=webhooks, + ) -@bp.route("/destroy/", methods=['GET', 'POST']) +@bp.route("/destroy/", methods=["GET", "POST"]) def webhook_destroy(webhook_id: int) -> ResponseReturnValue: - webhook: Optional[Webhook] = Webhook.query.filter(Webhook.id == webhook_id, Webhook.destroyed.is_(None)).first() + webhook: Optional[Webhook] = Webhook.query.filter( + Webhook.id == webhook_id, Webhook.destroyed.is_(None) + ).first() if webhook is None: return response_404("The requested webhook could not be found.") return view_lifecycle( @@ -106,5 +115,5 @@ def webhook_destroy(webhook_id: int) -> ResponseReturnValue: success_view="portal.webhook.webhook_list", section="webhook", resource=webhook, - action="destroy" + action="destroy", ) diff --git a/app/terraform/__init__.py b/app/terraform/__init__.py index 09be884..c4a8b9a 100644 --- a/app/terraform/__init__.py +++ b/app/terraform/__init__.py @@ -12,6 +12,7 @@ class DeterministicZip: Heavily inspired by https://github.com/bboe/deterministic_zip. """ + zipfile: ZipFile def __init__(self, filename: str): @@ -67,15 +68,22 @@ class BaseAutomation: if not self.working_dir: raise RuntimeError("No working directory specified.") tmpl = jinja2.Template(template) - with open(os.path.join(self.working_dir, filename), 'w', encoding="utf-8") as tfconf: + with open( + os.path.join(self.working_dir, filename), "w", encoding="utf-8" + ) as tfconf: tfconf.write(tmpl.render(**kwargs)) - def bin_write(self, filename: str, data: bytes, group_id: Optional[int] = None) -> None: + def bin_write( + self, filename: str, data: bytes, group_id: Optional[int] = None + ) -> None: if not self.working_dir: raise RuntimeError("No working directory specified.") try: os.mkdir(os.path.join(self.working_dir, str(group_id))) except FileExistsError: pass - with open(os.path.join(self.working_dir, str(group_id) if group_id else "", filename), 'wb') as binfile: + with open( + os.path.join(self.working_dir, str(group_id) if group_id else "", filename), + "wb", + ) as binfile: binfile.write(data) diff --git a/app/terraform/alarms/eotk_aws.py b/app/terraform/alarms/eotk_aws.py index c5bc3aa..a7abc0c 100644 --- a/app/terraform/alarms/eotk_aws.py +++ b/app/terraform/alarms/eotk_aws.py @@ -13,33 +13,38 @@ from app.terraform import BaseAutomation def alarms_in_region(region: str, prefix: str, aspect: str) -> None: - cloudwatch = boto3.client('cloudwatch', - aws_access_key_id=app.config['AWS_ACCESS_KEY'], - aws_secret_access_key=app.config['AWS_SECRET_KEY'], - region_name=region) - dist_paginator = cloudwatch.get_paginator('describe_alarms') + cloudwatch = boto3.client( + "cloudwatch", + aws_access_key_id=app.config["AWS_ACCESS_KEY"], + aws_secret_access_key=app.config["AWS_SECRET_KEY"], + region_name=region, + ) + dist_paginator = cloudwatch.get_paginator("describe_alarms") page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix) for page in page_iterator: - for cw_alarm in page['MetricAlarms']: - eotk_id = cw_alarm["AlarmName"][len(prefix):].split("-") - group: Optional[Group] = Group.query.filter(func.lower(Group.group_name) == eotk_id[1]).first() + for cw_alarm in page["MetricAlarms"]: + eotk_id = cw_alarm["AlarmName"][len(prefix) :].split("-") + group: Optional[Group] = Group.query.filter( + func.lower(Group.group_name) == eotk_id[1] + ).first() if group is None: - print("Unable to find group for " + cw_alarm['AlarmName']) + print("Unable to find group for " + cw_alarm["AlarmName"]) continue eotk = Eotk.query.filter( - Eotk.group_id == group.id, - Eotk.region == region + Eotk.group_id == group.id, Eotk.region == region ).first() if eotk is None: - print("Skipping unknown instance " + cw_alarm['AlarmName']) + print("Skipping unknown instance " + cw_alarm["AlarmName"]) continue alarm = get_or_create_alarm(eotk.brn, aspect) - if cw_alarm['StateValue'] == "OK": + if cw_alarm["StateValue"] == "OK": alarm.update_state(AlarmState.OK, "CloudWatch alarm OK") - elif cw_alarm['StateValue'] == "ALARM": + elif cw_alarm["StateValue"] == "ALARM": alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM") else: - alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}") + alarm.update_state( + AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}" + ) class AlarmEotkAwsAutomation(BaseAutomation): diff --git a/app/terraform/alarms/proxy_azure_cdn.py b/app/terraform/alarms/proxy_azure_cdn.py index 2828f46..18ad2f8 100644 --- a/app/terraform/alarms/proxy_azure_cdn.py +++ b/app/terraform/alarms/proxy_azure_cdn.py @@ -16,20 +16,19 @@ class AlarmProxyAzureCdnAutomation(BaseAutomation): def automate(self, full: bool = False) -> Tuple[bool, str]: credential = ClientSecretCredential( - tenant_id=app.config['AZURE_TENANT_ID'], - client_id=app.config['AZURE_CLIENT_ID'], - client_secret=app.config['AZURE_CLIENT_SECRET']) - client = AlertsManagementClient( - credential, - app.config['AZURE_SUBSCRIPTION_ID'] + tenant_id=app.config["AZURE_TENANT_ID"], + client_id=app.config["AZURE_CLIENT_ID"], + client_secret=app.config["AZURE_CLIENT_SECRET"], ) - firing = [x.name[len("bandwidth-out-high-bc-"):] - for x in client.alerts.get_all() - if x.name.startswith("bandwidth-out-high-bc-") - and x.properties.essentials.monitor_condition == "Fired"] + client = AlertsManagementClient(credential, app.config["AZURE_SUBSCRIPTION_ID"]) + firing = [ + x.name[len("bandwidth-out-high-bc-") :] + for x in client.alerts.get_all() + if x.name.startswith("bandwidth-out-high-bc-") + and x.properties.essentials.monitor_condition == "Fired" + ] for proxy in Proxy.query.filter( - Proxy.provider == "azure_cdn", - Proxy.destroyed.is_(None) + Proxy.provider == "azure_cdn", Proxy.destroyed.is_(None) ): alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high") if proxy.origin.group.group_name.lower() not in firing: diff --git a/app/terraform/alarms/proxy_cloudfront.py b/app/terraform/alarms/proxy_cloudfront.py index f91ce40..e450bac 100644 --- a/app/terraform/alarms/proxy_cloudfront.py +++ b/app/terraform/alarms/proxy_cloudfront.py @@ -16,9 +16,8 @@ def _cloudfront_quota() -> None: # It would be nice to learn this from the Service Quotas API, however # at the time of writing this comment, the current value for this quota # is not available from the API. It just doesn't return anything. - max_count = int(current_app.config.get('AWS_CLOUDFRONT_MAX_DISTRIBUTIONS', 200)) - deployed_count = len(Proxy.query.filter( - Proxy.destroyed.is_(None)).all()) + max_count = int(current_app.config.get("AWS_CLOUDFRONT_MAX_DISTRIBUTIONS", 200)) + deployed_count = len(Proxy.query.filter(Proxy.destroyed.is_(None)).all()) message = f"{deployed_count} distributions deployed of {max_count} quota" alarm = get_or_create_alarm( BRN( @@ -26,9 +25,9 @@ def _cloudfront_quota() -> None: product="mirror", provider="cloudfront", resource_type="quota", - resource_id="distributions" + resource_id="distributions", ), - "quota-usage" + "quota-usage", ) if deployed_count > max_count * 0.9: alarm.update_state(AlarmState.CRITICAL, message) @@ -39,26 +38,30 @@ def _cloudfront_quota() -> None: def _proxy_alarms() -> None: - cloudwatch = boto3.client('cloudwatch', - aws_access_key_id=app.config['AWS_ACCESS_KEY'], - aws_secret_access_key=app.config['AWS_SECRET_KEY'], - region_name='us-east-2') - dist_paginator = cloudwatch.get_paginator('describe_alarms') + cloudwatch = boto3.client( + "cloudwatch", + aws_access_key_id=app.config["AWS_ACCESS_KEY"], + aws_secret_access_key=app.config["AWS_SECRET_KEY"], + region_name="us-east-2", + ) + dist_paginator = cloudwatch.get_paginator("describe_alarms") page_iterator = dist_paginator.paginate(AlarmNamePrefix="bandwidth-out-high-") for page in page_iterator: - for cw_alarm in page['MetricAlarms']: - dist_id = cw_alarm["AlarmName"][len("bandwidth-out-high-"):] + for cw_alarm in page["MetricAlarms"]: + dist_id = cw_alarm["AlarmName"][len("bandwidth-out-high-") :] proxy = Proxy.query.filter(Proxy.slug == dist_id).first() if proxy is None: print("Skipping unknown proxy " + dist_id) continue alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high") - if cw_alarm['StateValue'] == "OK": + if cw_alarm["StateValue"] == "OK": alarm.update_state(AlarmState.OK, "CloudWatch alarm OK") - elif cw_alarm['StateValue'] == "ALARM": + elif cw_alarm["StateValue"] == "ALARM": alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM") else: - alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}") + alarm.update_state( + AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}" + ) class AlarmProxyCloudfrontAutomation(BaseAutomation): diff --git a/app/terraform/alarms/proxy_http_status.py b/app/terraform/alarms/proxy_http_status.py index a98d0cb..919a4ae 100644 --- a/app/terraform/alarms/proxy_http_status.py +++ b/app/terraform/alarms/proxy_http_status.py @@ -16,39 +16,25 @@ class AlarmProxyHTTPStatusAutomation(BaseAutomation): frequency = 45 def automate(self, full: bool = False) -> Tuple[bool, str]: - proxies = Proxy.query.filter( - Proxy.destroyed.is_(None) - ) + proxies = Proxy.query.filter(Proxy.destroyed.is_(None)) for proxy in proxies: try: if proxy.url is None: continue - r = requests.get(proxy.url, - allow_redirects=False, - timeout=5) + r = requests.get(proxy.url, allow_redirects=False, timeout=5) r.raise_for_status() alarm = get_or_create_alarm(proxy.brn, "http-status") if r.is_redirect: alarm.update_state( - AlarmState.CRITICAL, - f"{r.status_code} {r.reason}" + AlarmState.CRITICAL, f"{r.status_code} {r.reason}" ) else: - alarm.update_state( - AlarmState.OK, - f"{r.status_code} {r.reason}" - ) + alarm.update_state(AlarmState.OK, f"{r.status_code} {r.reason}") except requests.HTTPError: alarm = get_or_create_alarm(proxy.brn, "http-status") - alarm.update_state( - AlarmState.CRITICAL, - f"{r.status_code} {r.reason}" - ) + alarm.update_state(AlarmState.CRITICAL, f"{r.status_code} {r.reason}") except RequestException as e: alarm = get_or_create_alarm(proxy.brn, "http-status") - alarm.update_state( - AlarmState.CRITICAL, - repr(e) - ) + alarm.update_state(AlarmState.CRITICAL, repr(e)) db.session.commit() return True, "" diff --git a/app/terraform/alarms/smart_aws.py b/app/terraform/alarms/smart_aws.py index 94737cc..4ad0457 100644 --- a/app/terraform/alarms/smart_aws.py +++ b/app/terraform/alarms/smart_aws.py @@ -13,33 +13,38 @@ from app.terraform import BaseAutomation def alarms_in_region(region: str, prefix: str, aspect: str) -> None: - cloudwatch = boto3.client('cloudwatch', - aws_access_key_id=app.config['AWS_ACCESS_KEY'], - aws_secret_access_key=app.config['AWS_SECRET_KEY'], - region_name=region) - dist_paginator = cloudwatch.get_paginator('describe_alarms') + cloudwatch = boto3.client( + "cloudwatch", + aws_access_key_id=app.config["AWS_ACCESS_KEY"], + aws_secret_access_key=app.config["AWS_SECRET_KEY"], + region_name=region, + ) + dist_paginator = cloudwatch.get_paginator("describe_alarms") page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix) for page in page_iterator: - for cw_alarm in page['MetricAlarms']: - smart_id = cw_alarm["AlarmName"][len(prefix):].split("-") - group: Optional[Group] = Group.query.filter(func.lower(Group.group_name) == smart_id[1]).first() + for cw_alarm in page["MetricAlarms"]: + smart_id = cw_alarm["AlarmName"][len(prefix) :].split("-") + group: Optional[Group] = Group.query.filter( + func.lower(Group.group_name) == smart_id[1] + ).first() if group is None: - print("Unable to find group for " + cw_alarm['AlarmName']) + print("Unable to find group for " + cw_alarm["AlarmName"]) continue smart_proxy = SmartProxy.query.filter( - SmartProxy.group_id == group.id, - SmartProxy.region == region + SmartProxy.group_id == group.id, SmartProxy.region == region ).first() if smart_proxy is None: - print("Skipping unknown instance " + cw_alarm['AlarmName']) + print("Skipping unknown instance " + cw_alarm["AlarmName"]) continue alarm = get_or_create_alarm(smart_proxy.brn, aspect) - if cw_alarm['StateValue'] == "OK": + if cw_alarm["StateValue"] == "OK": alarm.update_state(AlarmState.OK, "CloudWatch alarm OK") - elif cw_alarm['StateValue'] == "ALARM": + elif cw_alarm["StateValue"] == "ALARM": alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM") else: - alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}") + alarm.update_state( + AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}" + ) class AlarmSmartAwsAutomation(BaseAutomation): diff --git a/app/terraform/block/block_blocky.py b/app/terraform/block/block_blocky.py index ce9499e..3dcc6b2 100644 --- a/app/terraform/block/block_blocky.py +++ b/app/terraform/block/block_blocky.py @@ -16,7 +16,7 @@ def clean_json_response(raw_response: str) -> Dict[str, Any]: """ end_index = raw_response.rfind("}") if end_index != -1: - raw_response = raw_response[:end_index + 1] + raw_response = raw_response[: end_index + 1] response: Dict[str, Any] = json.loads(raw_response) return response @@ -27,20 +27,21 @@ def request_test_now(test_url: str) -> str: "User-Agent": "bypasscensorship.org", "Content-Type": "application/json;charset=utf-8", "Pragma": "no-cache", - "Cache-Control": "no-cache" + "Cache-Control": "no-cache", } request_count = 0 while request_count < 180: - params = { - "url": test_url, - "timestamp": str(int(time.time())) # unix timestamp - } - response = requests.post(api_url, params=params, headers=headers, json={}, timeout=30) + params = {"url": test_url, "timestamp": str(int(time.time()))} # unix timestamp + response = requests.post( + api_url, params=params, headers=headers, json={}, timeout=30 + ) response_data = clean_json_response(response.text) print(f"Response: {response_data}") if "url_test_id" in response_data.get("d", {}): url_test_id: str = response_data["d"]["url_test_id"] - logging.debug("Test result for %s has test result ID %s", test_url, url_test_id) + logging.debug( + "Test result for %s has test result ID %s", test_url, url_test_id + ) return url_test_id request_count += 1 time.sleep(2) @@ -52,13 +53,19 @@ def request_test_result(url_test_id: str) -> int: headers = { "User-Agent": "bypasscensorship.org", "Pragma": "no-cache", - "Cache-Control": "no-cache" + "Cache-Control": "no-cache", } response = requests.get(url, headers=headers, timeout=30) response_data = response.json() tests = response_data.get("d", []) - non_zero_curl_exit_count: int = sum(1 for test in tests if test.get("curl_exit_value") != "0") - logging.debug("Test result for %s has %s non-zero exit values", url_test_id, non_zero_curl_exit_count) + non_zero_curl_exit_count: int = sum( + 1 for test in tests if test.get("curl_exit_value") != "0" + ) + logging.debug( + "Test result for %s has %s non-zero exit values", + url_test_id, + non_zero_curl_exit_count, + ) return non_zero_curl_exit_count @@ -81,7 +88,7 @@ class BlockBlockyAutomation(BlockMirrorAutomation): Proxy.url.is_not(None), Proxy.deprecated.is_(None), Proxy.destroyed.is_(None), - Proxy.pool_id != -1 + Proxy.pool_id != -1, ) .all() ) diff --git a/app/terraform/block/block_scriptzteam.py b/app/terraform/block/block_scriptzteam.py index 94e853b..f163193 100644 --- a/app/terraform/block/block_scriptzteam.py +++ b/app/terraform/block/block_scriptzteam.py @@ -15,7 +15,8 @@ class BlockBridgeScriptzteamAutomation(BlockBridgelinesAutomation): def fetch(self) -> None: r = requests.get( "https://raw.githubusercontent.com/scriptzteam/Tor-Bridges-Collector/main/bridges-obfs4", - timeout=60) + timeout=60, + ) r.encoding = "utf-8" contents = r.text self._lines = contents.splitlines() diff --git a/app/terraform/block/bridge.py b/app/terraform/block/bridge.py index ed505f8..fe3245b 100644 --- a/app/terraform/block/bridge.py +++ b/app/terraform/block/bridge.py @@ -24,8 +24,9 @@ class BlockBridgeAutomation(BaseAutomation): self.hashed_fingerprints = [] super().__init__(*args, **kwargs) - def perform_deprecations(self, ids: List[str], bridge_select_func: Callable[[str], Optional[Bridge]] - ) -> List[Tuple[Optional[str], Any, Any]]: + def perform_deprecations( + self, ids: List[str], bridge_select_func: Callable[[str], Optional[Bridge]] + ) -> List[Tuple[Optional[str], Any, Any]]: rotated = [] for id_ in ids: bridge = bridge_select_func(id_) @@ -37,7 +38,13 @@ class BlockBridgeAutomation(BaseAutomation): continue if bridge.deprecate(reason=self.short_name): logging.info("Rotated %s", bridge.hashed_fingerprint) - rotated.append((bridge.fingerprint, bridge.cloud_account.provider, bridge.cloud_account.description)) + rotated.append( + ( + bridge.fingerprint, + bridge.cloud_account.provider, + bridge.cloud_account.description, + ) + ) else: logging.debug("Not rotating a bridge that is already deprecated") return rotated @@ -50,15 +57,28 @@ class BlockBridgeAutomation(BaseAutomation): rotated = [] rotated.extend(self.perform_deprecations(self.ips, get_bridge_by_ip)) logging.debug("Blocked by IP") - rotated.extend(self.perform_deprecations(self.fingerprints, get_bridge_by_fingerprint)) + rotated.extend( + self.perform_deprecations(self.fingerprints, get_bridge_by_fingerprint) + ) logging.debug("Blocked by fingerprint") - rotated.extend(self.perform_deprecations(self.hashed_fingerprints, get_bridge_by_hashed_fingerprint)) + rotated.extend( + self.perform_deprecations( + self.hashed_fingerprints, get_bridge_by_hashed_fingerprint + ) + ) logging.debug("Blocked by hashed fingerprint") if rotated: activity = Activity( activity_type="block", - text=(f"[{self.short_name}] ♻ Rotated {len(rotated)} bridges: \n" - + "\n".join([f"* {fingerprint} ({provider}: {provider_description})" for fingerprint, provider, provider_description in rotated])) + text=( + f"[{self.short_name}] ♻ Rotated {len(rotated)} bridges: \n" + + "\n".join( + [ + f"* {fingerprint} ({provider}: {provider_description})" + for fingerprint, provider, provider_description in rotated + ] + ) + ), ) db.session.add(activity) activity.notify() @@ -87,7 +107,7 @@ def get_bridge_by_ip(ip: str) -> Optional[Bridge]: return Bridge.query.filter( # type: ignore[no-any-return] Bridge.deprecated.is_(None), Bridge.destroyed.is_(None), - Bridge.bridgeline.contains(f" {ip} ") + Bridge.bridgeline.contains(f" {ip} "), ).first() @@ -95,7 +115,7 @@ def get_bridge_by_fingerprint(fingerprint: str) -> Optional[Bridge]: return Bridge.query.filter( # type: ignore[no-any-return] Bridge.deprecated.is_(None), Bridge.destroyed.is_(None), - Bridge.fingerprint == fingerprint + Bridge.fingerprint == fingerprint, ).first() @@ -103,5 +123,5 @@ def get_bridge_by_hashed_fingerprint(hashed_fingerprint: str) -> Optional[Bridge return Bridge.query.filter( # type: ignore[no-any-return] Bridge.deprecated.is_(None), Bridge.destroyed.is_(None), - Bridge.hashed_fingerprint == hashed_fingerprint + Bridge.hashed_fingerprint == hashed_fingerprint, ).first() diff --git a/app/terraform/block/bridge_bridgelines.py b/app/terraform/block/bridge_bridgelines.py index 5ad31d6..102f2ed 100644 --- a/app/terraform/block/bridge_bridgelines.py +++ b/app/terraform/block/bridge_bridgelines.py @@ -17,6 +17,8 @@ class BlockBridgelinesAutomation(BlockBridgeAutomation, ABC): fingerprint = parts[2] self.ips.append(ip_address) self.fingerprints.append(fingerprint) - logging.debug(f"Added blocked bridge with IP {ip_address} and fingerprint {fingerprint}") + logging.debug( + f"Added blocked bridge with IP {ip_address} and fingerprint {fingerprint}" + ) except IndexError: logging.warning("A parsing error occured.") diff --git a/app/terraform/block/bridge_github.py b/app/terraform/block/bridge_github.py index c1528da..26f5563 100644 --- a/app/terraform/block/bridge_github.py +++ b/app/terraform/block/bridge_github.py @@ -1,8 +1,7 @@ from flask import current_app from github import Github -from app.terraform.block.bridge_reachability import \ - BlockBridgeReachabilityAutomation +from app.terraform.block.bridge_reachability import BlockBridgeReachabilityAutomation class BlockBridgeGitHubAutomation(BlockBridgeReachabilityAutomation): @@ -15,12 +14,13 @@ class BlockBridgeGitHubAutomation(BlockBridgeReachabilityAutomation): frequency = 30 def fetch(self) -> None: - github = Github(current_app.config['GITHUB_API_KEY']) - repo = github.get_repo(current_app.config['GITHUB_BRIDGE_REPO']) - for vantage_point in current_app.config['GITHUB_BRIDGE_VANTAGE_POINTS']: + github = Github(current_app.config["GITHUB_API_KEY"]) + repo = github.get_repo(current_app.config["GITHUB_BRIDGE_REPO"]) + for vantage_point in current_app.config["GITHUB_BRIDGE_VANTAGE_POINTS"]: contents = repo.get_contents(f"recentResult_{vantage_point}") if isinstance(contents, list): raise RuntimeError( f"Expected a file at recentResult_{vantage_point}" - " but got a directory.") - self._lines = contents.decoded_content.decode('utf-8').splitlines() + " but got a directory." + ) + self._lines = contents.decoded_content.decode("utf-8").splitlines() diff --git a/app/terraform/block/bridge_gitlab.py b/app/terraform/block/bridge_gitlab.py index f1c1590..2902db9 100644 --- a/app/terraform/block/bridge_gitlab.py +++ b/app/terraform/block/bridge_gitlab.py @@ -1,8 +1,7 @@ from flask import current_app from gitlab import Gitlab -from app.terraform.block.bridge_reachability import \ - BlockBridgeReachabilityAutomation +from app.terraform.block.bridge_reachability import BlockBridgeReachabilityAutomation class BlockBridgeGitlabAutomation(BlockBridgeReachabilityAutomation): @@ -16,15 +15,15 @@ class BlockBridgeGitlabAutomation(BlockBridgeReachabilityAutomation): def fetch(self) -> None: self._lines = list() - credentials = {"private_token": current_app.config['GITLAB_TOKEN']} + credentials = {"private_token": current_app.config["GITLAB_TOKEN"]} if "GITLAB_URL" in current_app.config: - credentials['url'] = current_app.config['GITLAB_URL'] + credentials["url"] = current_app.config["GITLAB_URL"] gitlab = Gitlab(**credentials) - project = gitlab.projects.get(current_app.config['GITLAB_BRIDGE_PROJECT']) - for vantage_point in current_app.config['GITHUB_BRIDGE_VANTAGE_POINTS']: + project = gitlab.projects.get(current_app.config["GITLAB_BRIDGE_PROJECT"]) + for vantage_point in current_app.config["GITHUB_BRIDGE_VANTAGE_POINTS"]: contents = project.files.get( file_path=f"recentResult_{vantage_point}", - ref=current_app.config["GITLAB_BRIDGE_BRANCH"] + ref=current_app.config["GITLAB_BRIDGE_BRANCH"], ) # Decode the base64 first, then decode the UTF-8 string - self._lines.extend(contents.decode().decode('utf-8').splitlines()) + self._lines.extend(contents.decode().decode("utf-8").splitlines()) diff --git a/app/terraform/block/bridge_reachability.py b/app/terraform/block/bridge_reachability.py index 52d8789..3f0a979 100644 --- a/app/terraform/block/bridge_reachability.py +++ b/app/terraform/block/bridge_reachability.py @@ -14,8 +14,10 @@ class BlockBridgeReachabilityAutomation(BlockBridgeAutomation, ABC): def parse(self) -> None: for line in self._lines: parts = line.split("\t") - if isoparse(parts[2]) < (datetime.datetime.now(datetime.timezone.utc) - - datetime.timedelta(days=3)): + if isoparse(parts[2]) < ( + datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(days=3) + ): # Skip results older than 3 days continue if int(parts[1]) < 40: diff --git a/app/terraform/block/bridge_roskomsvoboda.py b/app/terraform/block/bridge_roskomsvoboda.py index 96b7dca..86cf6dc 100644 --- a/app/terraform/block/bridge_roskomsvoboda.py +++ b/app/terraform/block/bridge_roskomsvoboda.py @@ -13,7 +13,9 @@ class BlockBridgeRoskomsvobodaAutomation(BlockBridgeAutomation): _data: Any def fetch(self) -> None: - self._data = requests.get("https://reestr.rublacklist.net/api/v3/ips/", timeout=180).json() + self._data = requests.get( + "https://reestr.rublacklist.net/api/v3/ips/", timeout=180 + ).json() def parse(self) -> None: self.ips.extend(self._data) diff --git a/app/terraform/block_external.py b/app/terraform/block_external.py index 10a2bf2..2f96242 100644 --- a/app/terraform/block_external.py +++ b/app/terraform/block_external.py @@ -9,7 +9,7 @@ from app.terraform.block_mirror import BlockMirrorAutomation def _trim_prefix(s: str, prefix: str) -> str: if s.startswith(prefix): - return s[len(prefix):] + return s[len(prefix) :] return s @@ -20,30 +20,31 @@ def trim_http_https(s: str) -> str: :param s: String to modify. :return: Modified string. """ - return _trim_prefix( - _trim_prefix(s, "https://"), - "http://") + return _trim_prefix(_trim_prefix(s, "https://"), "http://") class BlockExternalAutomation(BlockMirrorAutomation): """ Automation task to import proxy reachability results from external source. """ + short_name = "block_external" description = "Import proxy reachability results from external source" _content: bytes def fetch(self) -> None: - user_agent = {'User-agent': 'BypassCensorship/1.0'} - check_urls_config = app.config.get('EXTERNAL_CHECK_URL', []) + user_agent = {"User-agent": "BypassCensorship/1.0"} + check_urls_config = app.config.get("EXTERNAL_CHECK_URL", []) if isinstance(check_urls_config, dict): # Config is already a dictionary, use as is. check_urls = check_urls_config elif isinstance(check_urls_config, list): # Convert list of strings to a dictionary with "external_N" keys. - check_urls = {f"external_{i}": url for i, url in enumerate(check_urls_config)} + check_urls = { + f"external_{i}": url for i, url in enumerate(check_urls_config) + } elif isinstance(check_urls_config, str): # Single string, convert to a dictionary with key "external". check_urls = {"external": check_urls_config} @@ -53,9 +54,13 @@ class BlockExternalAutomation(BlockMirrorAutomation): for source, check_url in check_urls.items(): if self._data is None: self._data = defaultdict(list) - self._data[source].extend(requests.get(check_url, headers=user_agent, timeout=30).json()) + self._data[source].extend( + requests.get(check_url, headers=user_agent, timeout=30).json() + ) def parse(self) -> None: for source, patterns in self._data.items(): - self.patterns[source].extend(["https://" + trim_http_https(pattern) for pattern in patterns]) + self.patterns[source].extend( + ["https://" + trim_http_https(pattern) for pattern in patterns] + ) logging.debug("Found URLs: %s", self.patterns) diff --git a/app/terraform/block_mirror.py b/app/terraform/block_mirror.py index 6322894..df96510 100644 --- a/app/terraform/block_mirror.py +++ b/app/terraform/block_mirror.py @@ -52,8 +52,15 @@ class BlockMirrorAutomation(BaseAutomation): if rotated: activity = Activity( activity_type="block", - text=(f"[{self.short_name}] ♻ Rotated {len(rotated)} proxies️️: \n" - + "\n".join([f"* {proxy_domain} ({origin_domain})" for proxy_domain, origin_domain in rotated])) + text=( + f"[{self.short_name}] ♻ Rotated {len(rotated)} proxies️️: \n" + + "\n".join( + [ + f"* {proxy_domain} ({origin_domain})" + for proxy_domain, origin_domain in rotated + ] + ) + ), ) db.session.add(activity) activity.notify() @@ -79,15 +86,15 @@ class BlockMirrorAutomation(BaseAutomation): def active_proxy_urls() -> List[str]: - return [proxy.url for proxy in Proxy.query.filter( - Proxy.deprecated.is_(None), - Proxy.destroyed.is_(None) - ).all()] + return [ + proxy.url + for proxy in Proxy.query.filter( + Proxy.deprecated.is_(None), Proxy.destroyed.is_(None) + ).all() + ] def proxy_by_url(url: str) -> Optional[Proxy]: return Proxy.query.filter( # type: ignore[no-any-return] - Proxy.deprecated.is_(None), - Proxy.destroyed.is_(None), - Proxy.url == url + Proxy.deprecated.is_(None), Proxy.destroyed.is_(None), Proxy.url == url ).first() diff --git a/app/terraform/block_ooni.py b/app/terraform/block_ooni.py index 87ca310..069d640 100644 --- a/app/terraform/block_ooni.py +++ b/app/terraform/block_ooni.py @@ -12,19 +12,23 @@ from app.terraform import BaseAutomation def check_origin(domain_name: str) -> Dict[str, Any]: - start_date = (datetime.now(tz=timezone.utc) - timedelta(days=1)).strftime("%Y-%m-%dT%H%%3A%M") + start_date = (datetime.now(tz=timezone.utc) - timedelta(days=1)).strftime( + "%Y-%m-%dT%H%%3A%M" + ) end_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H%%3A%M") api_url = f"https://api.ooni.io/api/v1/measurements?domain={domain_name}&since={start_date}&until={end_date}" - result: Dict[str, Dict[str, int]] = defaultdict(lambda: {"anomaly": 0, "confirmed": 0, "failure": 0, "ok": 0}) + result: Dict[str, Dict[str, int]] = defaultdict( + lambda: {"anomaly": 0, "confirmed": 0, "failure": 0, "ok": 0} + ) return _check_origin(api_url, result) def _check_origin(api_url: str, result: Dict[str, Any]) -> Dict[str, Any]: print(f"Processing {api_url}") req = requests.get(api_url, timeout=30).json() - if 'results' not in req or not req['results']: + if "results" not in req or not req["results"]: return result - for r in req['results']: + for r in req["results"]: not_ok = False for status in ["anomaly", "confirmed", "failure"]: if status in r and r[status]: @@ -33,27 +37,28 @@ def _check_origin(api_url: str, result: Dict[str, Any]) -> Dict[str, Any]: break if not not_ok: result[r["probe_cc"]]["ok"] += 1 - if req['metadata']['next_url']: - return _check_origin(req['metadata']['next_url'], result) + if req["metadata"]["next_url"]: + return _check_origin(req["metadata"]["next_url"], result) return result def threshold_origin(domain_name: str) -> Dict[str, Any]: ooni = check_origin(domain_name) for country in ooni: - total = sum([ - ooni[country]["anomaly"], - ooni[country]["confirmed"], - ooni[country]["failure"], - ooni[country]["ok"] - ]) - total_blocks = sum([ - ooni[country]["anomaly"], - ooni[country]["confirmed"] - ]) + total = sum( + [ + ooni[country]["anomaly"], + ooni[country]["confirmed"], + ooni[country]["failure"], + ooni[country]["ok"], + ] + ) + total_blocks = sum([ooni[country]["anomaly"], ooni[country]["confirmed"]]) block_perc = round((total_blocks / total * 100), 1) ooni[country]["block_perc"] = block_perc - ooni[country]["state"] = AlarmState.WARNING if block_perc > 20 else AlarmState.OK + ooni[country]["state"] = ( + AlarmState.WARNING if block_perc > 20 else AlarmState.OK + ) ooni[country]["message"] = f"Blocked in {block_perc}% of measurements" return ooni @@ -72,8 +77,9 @@ class BlockOONIAutomation(BaseAutomation): for origin in origins: ooni = threshold_origin(origin.domain_name) for country in ooni: - alarm = get_or_create_alarm(origin.brn, - f"origin-block-ooni-{country.lower()}") + alarm = get_or_create_alarm( + origin.brn, f"origin-block-ooni-{country.lower()}" + ) alarm.update_state(ooni[country]["state"], ooni[country]["message"]) db.session.commit() return True, "" diff --git a/app/terraform/block_roskomsvoboda.py b/app/terraform/block_roskomsvoboda.py index 0346b95..39bf61a 100644 --- a/app/terraform/block_roskomsvoboda.py +++ b/app/terraform/block_roskomsvoboda.py @@ -32,6 +32,7 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation): Where proxies are found to be blocked they will be rotated. """ + short_name = "block_roskomsvoboda" description = "Import Russian blocklist from RosKomSvoboda" frequency = 300 @@ -43,7 +44,11 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation): try: # This endpoint routinely has an expired certificate, and it's more useful that we are consuming the # data than that we are verifying the certificate. - r = requests.get(f"https://dumps.rublacklist.net/fetch/{latest_rev}", timeout=180, verify=False) # nosec: B501 + r = requests.get( + f"https://dumps.rublacklist.net/fetch/{latest_rev}", + timeout=180, + verify=False, + ) # nosec: B501 r.raise_for_status() zip_file = ZipFile(BytesIO(r.content)) self._data = zip_file.read("dump.xml") @@ -51,26 +56,33 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation): except requests.HTTPError: activity = Activity( activity_type="automation", - text=(f"[{self.short_name}] 🚨 Unable to download dump {latest_rev} due to HTTP error {r.status_code}. " - "The automation task has not been disabled and will attempt to download the next dump when the " - "latest dump revision is incremented at the server.")) + text=( + f"[{self.short_name}] 🚨 Unable to download dump {latest_rev} due to HTTP error {r.status_code}. " + "The automation task has not been disabled and will attempt to download the next dump when the " + "latest dump revision is incremented at the server." + ), + ) activity.notify() db.session.add(activity) db.session.commit() except BadZipFile: activity = Activity( activity_type="automation", - text=(f"[{self.short_name}] 🚨 Unable to extract zip file from dump {latest_rev}. There was an error " - "related to the format of the zip file. " - "The automation task has not been disabled and will attempt to download the next dump when the " - "latest dump revision is incremented at the server.")) + text=( + f"[{self.short_name}] 🚨 Unable to extract zip file from dump {latest_rev}. There was an error " + "related to the format of the zip file. " + "The automation task has not been disabled and will attempt to download the next dump when the " + "latest dump revision is incremented at the server." + ), + ) activity.notify() db.session.add(activity) db.session.commit() def fetch(self) -> None: state: Optional[TerraformState] = TerraformState.query.filter( - TerraformState.key == "block_roskomsvoboda").first() + TerraformState.key == "block_roskomsvoboda" + ).first() if state is None: state = TerraformState() state.key = "block_roskomsvoboda" @@ -80,8 +92,14 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation): latest_metadata = json.loads(state.state) # This endpoint routinely has an expired certificate, and it's more useful that we are consuming the # data than that we are verifying the certificate. - latest_rev = requests.get("https://dumps.rublacklist.net/fetch/latest", timeout=30, verify=False).text.strip() # nosec: B501 - logging.debug("Latest revision is %s, already got %s", latest_rev, latest_metadata["dump_rev"]) + latest_rev = requests.get( + "https://dumps.rublacklist.net/fetch/latest", timeout=30, verify=False + ).text.strip() # nosec: B501 + logging.debug( + "Latest revision is %s, already got %s", + latest_rev, + latest_metadata["dump_rev"], + ) if latest_rev != latest_metadata["dump_rev"]: state.state = json.dumps({"dump_rev": latest_rev}) db.session.commit() @@ -94,18 +112,24 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation): logging.debug("No new data to parse") return try: - for _event, element in lxml.etree.iterparse(BytesIO(self._data), - resolve_entities=False): + for _event, element in lxml.etree.iterparse( + BytesIO(self._data), resolve_entities=False + ): if element.tag == "domain": - self.patterns["roskomsvoboda"].append("https://" + element.text.strip()) + self.patterns["roskomsvoboda"].append( + "https://" + element.text.strip() + ) except XMLSyntaxError: activity = Activity( activity_type="automation", - text=(f"[{self.short_name}] 🚨 Unable to parse XML file from dump. There was an error " - "related to the format of the XML file within the zip file. Interestingly we were able to " - "extract the file from the zip file fine. " - "The automation task has not been disabled and will attempt to download the next dump when the " - "latest dump revision is incremented at the server.")) + text=( + f"[{self.short_name}] 🚨 Unable to parse XML file from dump. There was an error " + "related to the format of the XML file within the zip file. Interestingly we were able to " + "extract the file from the zip file fine. " + "The automation task has not been disabled and will attempt to download the next dump when the " + "latest dump revision is incremented at the server." + ), + ) activity.notify() db.session.add(activity) db.session.commit() diff --git a/app/terraform/bridge/__init__.py b/app/terraform/bridge/__init__.py index d432dd1..142b050 100644 --- a/app/terraform/bridge/__init__.py +++ b/app/terraform/bridge/__init__.py @@ -16,20 +16,32 @@ BridgeResourceRow = Row[Tuple[AbstractResource, BridgeConf, CloudAccount]] def active_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]: - stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where( - CloudAccount.provider == provider, - Bridge.destroyed.is_(None), + stmt = ( + select(Bridge, BridgeConf, CloudAccount) + .join_from(Bridge, BridgeConf) + .join_from(Bridge, CloudAccount) + .where( + CloudAccount.provider == provider, + Bridge.destroyed.is_(None), + ) ) bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all() return bridges -def recently_destroyed_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]: +def recently_destroyed_bridges_by_provider( + provider: CloudProvider, +) -> Sequence[BridgeResourceRow]: cutoff = datetime.now(tz=timezone.utc) - timedelta(hours=72) - stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where( - CloudAccount.provider == provider, - Bridge.destroyed.is_not(None), - Bridge.destroyed >= cutoff, + stmt = ( + select(Bridge, BridgeConf, CloudAccount) + .join_from(Bridge, BridgeConf) + .join_from(Bridge, CloudAccount) + .where( + CloudAccount.provider == provider, + Bridge.destroyed.is_not(None), + Bridge.destroyed >= cutoff, + ) ) bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all() return bridges @@ -60,35 +72,38 @@ class BridgeAutomation(TerraformAutomation): self.template, active_resources=active_bridges_by_provider(self.provider), destroyed_resources=recently_destroyed_bridges_by_provider(self.provider), - global_namespace=app.config['GLOBAL_NAMESPACE'], - terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), + global_namespace=app.config["GLOBAL_NAMESPACE"], + terraform_modules_path=os.path.join( + *list(os.path.split(app.root_path))[:-1], "terraform-modules" + ), backend_config=f"""backend "http" {{ lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" }}""", - **{ - k: app.config[k.upper()] - for k in self.template_parameters - } + **{k: app.config[k.upper()] for k in self.template_parameters}, ) def tf_posthook(self, *, prehook_result: Any = None) -> None: outputs = self.tf_output() for output in outputs: - if output.startswith('bridge_hashed_fingerprint_'): - parts = outputs[output]['value'].split(" ") + if output.startswith("bridge_hashed_fingerprint_"): + parts = outputs[output]["value"].split(" ") if len(parts) < 2: continue - bridge = Bridge.query.filter(Bridge.id == output[len('bridge_hashed_fingerprint_'):]).first() + bridge = Bridge.query.filter( + Bridge.id == output[len("bridge_hashed_fingerprint_") :] + ).first() bridge.nickname = parts[0] bridge.hashed_fingerprint = parts[1] bridge.terraform_updated = datetime.now(tz=timezone.utc) - if output.startswith('bridge_bridgeline_'): - parts = outputs[output]['value'].split(" ") + if output.startswith("bridge_bridgeline_"): + parts = outputs[output]["value"].split(" ") if len(parts) < 4: continue - bridge = Bridge.query.filter(Bridge.id == output[len('bridge_bridgeline_'):]).first() + bridge = Bridge.query.filter( + Bridge.id == output[len("bridge_bridgeline_") :] + ).first() del parts[3] bridge.bridgeline = " ".join(parts) bridge.terraform_updated = datetime.now(tz=timezone.utc) diff --git a/app/terraform/bridge/gandi.py b/app/terraform/bridge/gandi.py index 4a5eab1..d466456 100644 --- a/app/terraform/bridge/gandi.py +++ b/app/terraform/bridge/gandi.py @@ -7,10 +7,7 @@ class BridgeGandiAutomation(BridgeAutomation): description = "Deploy Tor bridges on GandiCloud VPS" provider = CloudProvider.GANDI - template_parameters = [ - "ssh_public_key_path", - "ssh_private_key_path" - ] + template_parameters = ["ssh_public_key_path", "ssh_private_key_path"] template = """ terraform { diff --git a/app/terraform/bridge/hcloud.py b/app/terraform/bridge/hcloud.py index b52bf5c..cc3b15f 100644 --- a/app/terraform/bridge/hcloud.py +++ b/app/terraform/bridge/hcloud.py @@ -7,10 +7,7 @@ class BridgeHcloudAutomation(BridgeAutomation): description = "Deploy Tor bridges on Hetzner Cloud" provider = CloudProvider.HCLOUD - template_parameters = [ - "ssh_private_key_path", - "ssh_public_key_path" - ] + template_parameters = ["ssh_private_key_path", "ssh_public_key_path"] template = """ terraform { diff --git a/app/terraform/bridge/meta.py b/app/terraform/bridge/meta.py index be46ef1..902270d 100644 --- a/app/terraform/bridge/meta.py +++ b/app/terraform/bridge/meta.py @@ -25,10 +25,17 @@ def active_bridges_in_account(account: CloudAccount) -> List[Bridge]: return bridges -def create_bridges_in_account(bridgeconf: BridgeConf, account: CloudAccount, count: int) -> int: +def create_bridges_in_account( + bridgeconf: BridgeConf, account: CloudAccount, count: int +) -> int: created = 0 - while created < count and len(active_bridges_in_account(account)) < account.max_instances: - logging.debug("Creating bridge for configuration %s in account %s", bridgeconf.id, account) + while ( + created < count + and len(active_bridges_in_account(account)) < account.max_instances + ): + logging.debug( + "Creating bridge for configuration %s in account %s", bridgeconf.id, account + ) bridge = Bridge() bridge.pool_id = bridgeconf.pool.id bridge.conf_id = bridgeconf.id @@ -45,16 +52,18 @@ def create_bridges_by_cost(bridgeconf: BridgeConf, count: int) -> int: """ Creates bridge resources for the given bridge configuration using the cheapest available provider. """ - logging.debug("Creating %s bridges by cost for configuration %s", count, bridgeconf.id) + logging.debug( + "Creating %s bridges by cost for configuration %s", count, bridgeconf.id + ) created = 0 for provider in BRIDGE_PROVIDERS: if created >= count: break logging.info("Creating bridges in %s accounts", provider.description) for account in CloudAccount.query.filter( - CloudAccount.destroyed.is_(None), - CloudAccount.enabled.is_(True), - CloudAccount.provider == provider, + CloudAccount.destroyed.is_(None), + CloudAccount.enabled.is_(True), + CloudAccount.provider == provider, ).all(): logging.info("Creating bridges in %s", account) created += create_bridges_in_account(bridgeconf, account, count - created) @@ -78,7 +87,9 @@ def create_bridges_by_random(bridgeconf: BridgeConf, count: int) -> int: """ Creates bridge resources for the given bridge configuration using random providers. """ - logging.debug("Creating %s bridges by random for configuration %s", count, bridgeconf.id) + logging.debug( + "Creating %s bridges by random for configuration %s", count, bridgeconf.id + ) created = 0 while candidate_accounts := _accounts_with_room(): # Not security-critical random number generation @@ -97,16 +108,24 @@ def create_bridges(bridgeconf: BridgeConf, count: int) -> int: return create_bridges_by_random(bridgeconf, count) -def deprecate_bridges(bridgeconf: BridgeConf, count: int, reason: str = "redundant") -> int: - logging.debug("Deprecating %s bridges (%s) for configuration %s", count, reason, bridgeconf.id) +def deprecate_bridges( + bridgeconf: BridgeConf, count: int, reason: str = "redundant" +) -> int: + logging.debug( + "Deprecating %s bridges (%s) for configuration %s", count, reason, bridgeconf.id + ) deprecated = 0 - active_conf_bridges = iter(Bridge.query.filter( - Bridge.conf_id == bridgeconf.id, - Bridge.deprecated.is_(None), - Bridge.destroyed.is_(None), - ).all()) + active_conf_bridges = iter( + Bridge.query.filter( + Bridge.conf_id == bridgeconf.id, + Bridge.deprecated.is_(None), + Bridge.destroyed.is_(None), + ).all() + ) while deprecated < count: - logging.debug("Deprecating bridge %s for configuration %s", deprecated + 1, bridgeconf.id) + logging.debug( + "Deprecating bridge %s for configuration %s", deprecated + 1, bridgeconf.id + ) bridge = next(active_conf_bridges) logging.debug("Bridge %r", bridge) bridge.deprecate(reason=reason) @@ -129,7 +148,9 @@ class BridgeMetaAutomation(BaseAutomation): for bridge in deprecated_bridges: if bridge.deprecated is None: continue # Possible due to SQLAlchemy lazy loading - cutoff = datetime.now(tz=timezone.utc) - timedelta(hours=bridge.conf.expiry_hours) + cutoff = datetime.now(tz=timezone.utc) - timedelta( + hours=bridge.conf.expiry_hours + ) if bridge.deprecated < cutoff: logging.debug("Destroying expired bridge") bridge.destroy() @@ -146,7 +167,9 @@ class BridgeMetaAutomation(BaseAutomation): activate_bridgeconfs = BridgeConf.query.filter( BridgeConf.destroyed.is_(None), ).all() - logging.debug("Found %s active bridge configurations", len(activate_bridgeconfs)) + logging.debug( + "Found %s active bridge configurations", len(activate_bridgeconfs) + ) for bridgeconf in activate_bridgeconfs: active_conf_bridges = Bridge.query.filter( Bridge.conf_id == bridgeconf.id, @@ -157,16 +180,18 @@ class BridgeMetaAutomation(BaseAutomation): Bridge.conf_id == bridgeconf.id, Bridge.destroyed.is_(None), ).all() - logging.debug("Generating new bridges for %s (active: %s, total: %s, target: %s, max: %s)", - bridgeconf.id, - len(active_conf_bridges), - len(total_conf_bridges), - bridgeconf.target_number, - bridgeconf.max_number - ) + logging.debug( + "Generating new bridges for %s (active: %s, total: %s, target: %s, max: %s)", + bridgeconf.id, + len(active_conf_bridges), + len(total_conf_bridges), + bridgeconf.target_number, + bridgeconf.max_number, + ) missing = min( bridgeconf.target_number - len(active_conf_bridges), - bridgeconf.max_number - len(total_conf_bridges)) + bridgeconf.max_number - len(total_conf_bridges), + ) if missing > 0: create_bridges(bridgeconf, missing) elif missing < 0: diff --git a/app/terraform/bridge/ovh.py b/app/terraform/bridge/ovh.py index 55f16c3..b98d5a9 100644 --- a/app/terraform/bridge/ovh.py +++ b/app/terraform/bridge/ovh.py @@ -7,10 +7,7 @@ class BridgeOvhAutomation(BridgeAutomation): description = "Deploy Tor bridges on OVH Public Cloud" provider = CloudProvider.OVH - template_parameters = [ - "ssh_public_key_path", - "ssh_private_key_path" - ] + template_parameters = ["ssh_public_key_path", "ssh_private_key_path"] template = """ terraform { diff --git a/app/terraform/eotk/aws.py b/app/terraform/eotk/aws.py index 7d4ab1a..e467782 100644 --- a/app/terraform/eotk/aws.py +++ b/app/terraform/eotk/aws.py @@ -11,14 +11,12 @@ from app.terraform.eotk import eotk_configuration from app.terraform.terraform import TerraformAutomation -def update_eotk_instance(group_id: int, - region: str, - instance_id: str) -> None: +def update_eotk_instance(group_id: int, region: str, instance_id: str) -> None: instance = Eotk.query.filter( Eotk.group_id == group_id, Eotk.region == region, Eotk.provider == "aws", - Eotk.destroyed.is_(None) + Eotk.destroyed.is_(None), ).first() if instance is None: instance = Eotk() @@ -35,10 +33,7 @@ class EotkAWSAutomation(TerraformAutomation): short_name = "eotk_aws" description = "Deploy EOTK instances to AWS" - template_parameters = [ - "aws_access_key", - "aws_secret_key" - ] + template_parameters = ["aws_access_key", "aws_secret_key"] template = """ terraform { @@ -81,32 +76,41 @@ class EotkAWSAutomation(TerraformAutomation): self.tf_write( self.template, groups=Group.query.filter( - Group.eotk.is_(True), - Group.destroyed.is_(None) + Group.eotk.is_(True), Group.destroyed.is_(None) ).all(), - global_namespace=app.config['GLOBAL_NAMESPACE'], - terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), + global_namespace=app.config["GLOBAL_NAMESPACE"], + terraform_modules_path=os.path.join( + *list(os.path.split(app.root_path))[:-1], "terraform-modules" + ), backend_config=f"""backend "http" {{ lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" }}""", - **{ - k: app.config[k.upper()] - for k in self.template_parameters - } + **{k: app.config[k.upper()] for k in self.template_parameters}, ) - for group in Group.query.filter( - Group.eotk.is_(True), - Group.destroyed.is_(None) - ).order_by(Group.id).all(): - with DeterministicZip(os.path.join(self.working_dir, f"{group.id}.zip")) as dzip: - dzip.add_file("sites.conf", eotk_configuration(group).encode('utf-8')) + for group in ( + Group.query.filter(Group.eotk.is_(True), Group.destroyed.is_(None)) + .order_by(Group.id) + .all() + ): + with DeterministicZip( + os.path.join(self.working_dir, f"{group.id}.zip") + ) as dzip: + dzip.add_file("sites.conf", eotk_configuration(group).encode("utf-8")) for onion in sorted(group.onions, key=lambda o: o.onion_name): - dzip.add_file(f"{onion.onion_name}.v3pub.key", onion.onion_public_key) - dzip.add_file(f"{onion.onion_name}.v3sec.key", onion.onion_private_key) - dzip.add_file(f"{onion.onion_name[:20]}-v3.cert", onion.tls_public_key) - dzip.add_file(f"{onion.onion_name[:20]}-v3.pem", onion.tls_private_key) + dzip.add_file( + f"{onion.onion_name}.v3pub.key", onion.onion_public_key + ) + dzip.add_file( + f"{onion.onion_name}.v3sec.key", onion.onion_private_key + ) + dzip.add_file( + f"{onion.onion_name[:20]}-v3.cert", onion.tls_public_key + ) + dzip.add_file( + f"{onion.onion_name[:20]}-v3.pem", onion.tls_private_key + ) def tf_posthook(self, *, prehook_result: Any = None) -> None: for e in Eotk.query.all(): @@ -115,9 +119,9 @@ class EotkAWSAutomation(TerraformAutomation): for output in outputs: if output.startswith("eotk_instances_"): try: - group_id = int(output[len("eotk_instance_") + 1:]) - for az in outputs[output]['value']: - update_eotk_instance(group_id, az, outputs[output]['value'][az]) + group_id = int(output[len("eotk_instance_") + 1 :]) + for az in outputs[output]["value"]: + update_eotk_instance(group_id, az, outputs[output]["value"][az]) except ValueError: pass db.session.commit() diff --git a/app/terraform/list/__init__.py b/app/terraform/list/__init__.py index 50c500f..0683944 100644 --- a/app/terraform/list/__init__.py +++ b/app/terraform/list/__init__.py @@ -55,26 +55,36 @@ class ListAutomation(TerraformAutomation): MirrorList.destroyed.is_(None), MirrorList.provider == self.provider, ).all(), - global_namespace=app.config['GLOBAL_NAMESPACE'], - terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), + global_namespace=app.config["GLOBAL_NAMESPACE"], + terraform_modules_path=os.path.join( + *list(os.path.split(app.root_path))[:-1], "terraform-modules" + ), backend_config=f"""backend "http" {{ lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" }}""", - **{ - k: app.config[k.upper()] - for k in self.template_parameters - } + **{k: app.config[k.upper()] for k in self.template_parameters}, ) for pool in Pool.query.filter(Pool.destroyed.is_(None)).all(): for key, formatter in lists.items(): formatted_pool = formatter(pool) for obfuscate in [True, False]: - with open(os.path.join( - self.working_dir, f"{key}.{pool.pool_name}{'.jsno' if obfuscate else '.json'}"), - 'w', encoding="utf-8") as out: + with open( + os.path.join( + self.working_dir, + f"{key}.{pool.pool_name}{'.jsno' if obfuscate else '.json'}", + ), + "w", + encoding="utf-8", + ) as out: out.write(json_encode(formatted_pool, obfuscate)) - with open(os.path.join(self.working_dir, f"{key}.{pool.pool_name}{'.jso' if obfuscate else '.js'}"), - 'w', encoding="utf-8") as out: + with open( + os.path.join( + self.working_dir, + f"{key}.{pool.pool_name}{'.jso' if obfuscate else '.js'}", + ), + "w", + encoding="utf-8", + ) as out: out.write(javascript_encode(formatted_pool, obfuscate)) diff --git a/app/terraform/list/github.py b/app/terraform/list/github.py index c83179f..1c3623e 100644 --- a/app/terraform/list/github.py +++ b/app/terraform/list/github.py @@ -11,9 +11,7 @@ class ListGithubAutomation(ListAutomation): # TODO: file an issue in the github about this, GitLab had a similar issue but fixed it parallelism = 1 - template_parameters = [ - "github_api_key" - ] + template_parameters = ["github_api_key"] template = """ terraform { diff --git a/app/terraform/list/gitlab.py b/app/terraform/list/gitlab.py index 0fd7633..757d88f 100644 --- a/app/terraform/list/gitlab.py +++ b/app/terraform/list/gitlab.py @@ -15,7 +15,7 @@ class ListGitlabAutomation(ListAutomation): "gitlab_token", "gitlab_author_email", "gitlab_author_name", - "gitlab_commit_message" + "gitlab_commit_message", ] template = """ @@ -56,5 +56,5 @@ class ListGitlabAutomation(ListAutomation): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - if 'GITLAB_URL' in current_app.config: + if "GITLAB_URL" in current_app.config: self.template_parameters.append("gitlab_url") diff --git a/app/terraform/list/s3.py b/app/terraform/list/s3.py index 639c117..bbe905d 100644 --- a/app/terraform/list/s3.py +++ b/app/terraform/list/s3.py @@ -6,10 +6,7 @@ class ListS3Automation(ListAutomation): description = "Update mirror lists in AWS S3 buckets" provider = "s3" - template_parameters = [ - "aws_access_key", - "aws_secret_key" - ] + template_parameters = ["aws_access_key", "aws_secret_key"] template = """ terraform { diff --git a/app/terraform/proxy/__init__.py b/app/terraform/proxy/__init__.py index 34c66d8..af15f4a 100644 --- a/app/terraform/proxy/__init__.py +++ b/app/terraform/proxy/__init__.py @@ -15,15 +15,14 @@ from app.models.mirrors import Origin, Proxy, SmartProxy from app.terraform.terraform import TerraformAutomation -def update_smart_proxy_instance(group_id: int, - provider: str, - region: str, - instance_id: str) -> None: +def update_smart_proxy_instance( + group_id: int, provider: str, region: str, instance_id: str +) -> None: instance = SmartProxy.query.filter( SmartProxy.group_id == group_id, SmartProxy.region == region, SmartProxy.provider == provider, - SmartProxy.destroyed.is_(None) + SmartProxy.destroyed.is_(None), ).first() if instance is None: instance = SmartProxy() @@ -93,16 +92,21 @@ class ProxyAutomation(TerraformAutomation): self.template, groups=groups, proxies=Proxy.query.filter( - Proxy.provider == self.provider, Proxy.destroyed.is_(None)).all(), + Proxy.provider == self.provider, Proxy.destroyed.is_(None) + ).all(), subgroups=self.get_subgroups(), - global_namespace=app.config['GLOBAL_NAMESPACE'], bypass_token=app.config['BYPASS_TOKEN'], - terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), + global_namespace=app.config["GLOBAL_NAMESPACE"], + bypass_token=app.config["BYPASS_TOKEN"], + terraform_modules_path=os.path.join( + *list(os.path.split(app.root_path))[:-1], "terraform-modules" + ), backend_config=f"""backend "http" {{ lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" }}""", - **{k: app.config[k.upper()] for k in self.template_parameters}) + **{k: app.config[k.upper()] for k in self.template_parameters}, + ) if self.smart_proxies: for group in groups: self.sp_config(group) @@ -111,9 +115,11 @@ class ProxyAutomation(TerraformAutomation): group_origins: List[Origin] = Origin.query.filter( Origin.group_id == group.id, Origin.destroyed.is_(None), - Origin.smart.is_(True) + Origin.smart.is_(True), ).all() - self.tmpl_write(f"smart_proxy.{group.id}.conf", """ + self.tmpl_write( + f"smart_proxy.{group.id}.conf", + """ {% for origin in origins %} server { listen 443 ssl; @@ -173,23 +179,28 @@ class ProxyAutomation(TerraformAutomation): } {% endfor %} """, - provider=self.provider, - origins=group_origins, - smart_zone=app.config['SMART_ZONE']) + provider=self.provider, + origins=group_origins, + smart_zone=app.config["SMART_ZONE"], + ) @classmethod def get_subgroups(cls) -> Dict[int, Dict[int, int]]: conn = db.engine.connect() - stmt = text(""" + stmt = text( + """ SELECT origin.group_id, proxy.psg, COUNT(proxy.id) FROM proxy, origin WHERE proxy.origin_id = origin.id AND proxy.destroyed IS NULL AND proxy.provider = :provider GROUP BY origin.group_id, proxy.psg; - """) + """ + ) stmt = stmt.bindparams(provider=cls.provider) result = conn.execute(stmt).all() - subgroups: Dict[int, Dict[int, int]] = defaultdict(lambda: defaultdict(lambda: 0)) + subgroups: Dict[int, Dict[int, int]] = defaultdict( + lambda: defaultdict(lambda: 0) + ) for row in result: subgroups[row[0]][row[1]] = row[2] return subgroups diff --git a/app/terraform/proxy/azure_cdn.py b/app/terraform/proxy/azure_cdn.py index b580aa5..916487e 100644 --- a/app/terraform/proxy/azure_cdn.py +++ b/app/terraform/proxy/azure_cdn.py @@ -21,7 +21,7 @@ class ProxyAzureCdnAutomation(ProxyAutomation): "azure_client_secret", "azure_subscription_id", "azure_tenant_id", - "smart_zone" + "smart_zone", ] template = """ @@ -162,8 +162,7 @@ class ProxyAzureCdnAutomation(ProxyAutomation): def import_state(self, state: Optional[Any]) -> None: proxies = Proxy.query.filter( - Proxy.provider == self.provider, - Proxy.destroyed.is_(None) + Proxy.provider == self.provider, Proxy.destroyed.is_(None) ).all() for proxy in proxies: proxy.url = f"https://{proxy.slug}.azureedge.net" diff --git a/app/terraform/proxy/cloudfront.py b/app/terraform/proxy/cloudfront.py index bba8172..436e54d 100644 --- a/app/terraform/proxy/cloudfront.py +++ b/app/terraform/proxy/cloudfront.py @@ -17,7 +17,7 @@ class ProxyCloudfrontAutomation(ProxyAutomation): "admin_email", "aws_access_key", "aws_secret_key", - "smart_zone" + "smart_zone", ] template = """ @@ -111,26 +111,35 @@ class ProxyCloudfrontAutomation(ProxyAutomation): def import_state(self, state: Any) -> None: if not isinstance(state, dict): raise RuntimeError("The Terraform state object returned was not a dict.") - if "child_modules" not in state['values']['root_module']: + if "child_modules" not in state["values"]["root_module"]: # There are no CloudFront proxies deployed to import state for return # CloudFront distributions (proxies) - for mod in state['values']['root_module']['child_modules']: - if mod['address'].startswith('module.cloudfront_'): - for res in mod['resources']: - if res['address'].endswith('aws_cloudfront_distribution.this'): - proxy = Proxy.query.filter(Proxy.id == mod['address'][len('module.cloudfront_'):]).first() - proxy.url = "https://" + res['values']['domain_name'] - proxy.slug = res['values']['id'] + for mod in state["values"]["root_module"]["child_modules"]: + if mod["address"].startswith("module.cloudfront_"): + for res in mod["resources"]: + if res["address"].endswith("aws_cloudfront_distribution.this"): + proxy = Proxy.query.filter( + Proxy.id == mod["address"][len("module.cloudfront_") :] + ).first() + proxy.url = "https://" + res["values"]["domain_name"] + proxy.slug = res["values"]["id"] proxy.terraform_updated = datetime.now(tz=timezone.utc) break # EC2 instances (smart proxies) for g in state["values"]["root_module"]["child_modules"]: if g["address"].startswith("module.smart_proxy_"): - group_id = int(g["address"][len("module.smart_proxy_"):]) + group_id = int(g["address"][len("module.smart_proxy_") :]) for s in g["child_modules"]: if s["address"].endswith(".module.instance"): for x in s["resources"]: - if x["address"].endswith(".module.instance.aws_instance.default[0]"): - update_smart_proxy_instance(group_id, self.provider, "us-east-2a", x['values']['id']) + if x["address"].endswith( + ".module.instance.aws_instance.default[0]" + ): + update_smart_proxy_instance( + group_id, + self.provider, + "us-east-2a", + x["values"]["id"], + ) db.session.commit() diff --git a/app/terraform/proxy/fastly.py b/app/terraform/proxy/fastly.py index 89f83bd..d923e64 100644 --- a/app/terraform/proxy/fastly.py +++ b/app/terraform/proxy/fastly.py @@ -14,11 +14,7 @@ class ProxyFastlyAutomation(ProxyAutomation): subgroup_members_max = 20 cloud_name = "fastly" - template_parameters = [ - "aws_access_key", - "aws_secret_key", - "fastly_api_key" - ] + template_parameters = ["aws_access_key", "aws_secret_key", "fastly_api_key"] template = """ terraform { @@ -125,13 +121,14 @@ class ProxyFastlyAutomation(ProxyAutomation): Constructor method. """ # Requires Flask application context to read configuration - self.subgroup_members_max = min(current_app.config.get("FASTLY_MAX_BACKENDS", 5), 20) + self.subgroup_members_max = min( + current_app.config.get("FASTLY_MAX_BACKENDS", 5), 20 + ) super().__init__(*args, **kwargs) def import_state(self, state: Optional[Any]) -> None: proxies = Proxy.query.filter( - Proxy.provider == self.provider, - Proxy.destroyed.is_(None) + Proxy.provider == self.provider, Proxy.destroyed.is_(None) ).all() for proxy in proxies: proxy.url = f"https://{proxy.slug}.global.ssl.fastly.net" diff --git a/app/terraform/proxy/meta.py b/app/terraform/proxy/meta.py index 55c8f36..62a79d9 100644 --- a/app/terraform/proxy/meta.py +++ b/app/terraform/proxy/meta.py @@ -18,12 +18,16 @@ from app.terraform.proxy.azure_cdn import ProxyAzureCdnAutomation from app.terraform.proxy.cloudfront import ProxyCloudfrontAutomation from app.terraform.proxy.fastly import ProxyFastlyAutomation -PROXY_PROVIDERS: Dict[str, Type[ProxyAutomation]] = {p.provider: p for p in [ # type: ignore[attr-defined] - # In order of preference - ProxyCloudfrontAutomation, - ProxyFastlyAutomation, - ProxyAzureCdnAutomation -] if p.enabled} # type: ignore[attr-defined] +PROXY_PROVIDERS: Dict[str, Type[ProxyAutomation]] = { + p.provider: p # type: ignore[attr-defined] + for p in [ + # In order of preference + ProxyCloudfrontAutomation, + ProxyFastlyAutomation, + ProxyAzureCdnAutomation, + ] + if p.enabled # type: ignore[attr-defined] +} SubgroupCount = OrderedDictT[str, OrderedDictT[int, OrderedDictT[int, int]]] @@ -61,8 +65,9 @@ def random_slug(origin_domain_name: str) -> str: "exampasdfghjkl" """ # The random slug doesn't need to be cryptographically secure, hence the use of `# nosec` - return tldextract.extract(origin_domain_name).domain[:5] + ''.join( - random.choices(string.ascii_lowercase, k=12)) # nosec + return tldextract.extract(origin_domain_name).domain[:5] + "".join( + random.choices(string.ascii_lowercase, k=12) # nosec: B311 + ) def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupCount: @@ -95,8 +100,13 @@ def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupC return subgroup_count -def next_subgroup(subgroup_count: SubgroupCount, provider: str, group_id: int, max_subgroup_count: int, - max_subgroup_members: int) -> Optional[int]: +def next_subgroup( + subgroup_count: SubgroupCount, + provider: str, + group_id: int, + max_subgroup_count: int, + max_subgroup_members: int, +) -> Optional[int]: """ Find the first available subgroup with less than the specified maximum count in the specified provider and group. If the last subgroup in the group is full, return the next subgroup number as long as it doesn't exceed @@ -137,27 +147,36 @@ def auto_deprecate_proxies() -> None: - The "max_age_reached" reason means the proxy has been in use for longer than the maximum allowed period. The maximum age cutoff is randomly set to a time between 24 and 48 hours. """ - origin_destroyed_proxies = (db.session.query(Proxy) - .join(Origin, Proxy.origin_id == Origin.id) - .filter(Proxy.destroyed.is_(None), - Proxy.deprecated.is_(None), - Origin.destroyed.is_not(None)) - .all()) + origin_destroyed_proxies = ( + db.session.query(Proxy) + .join(Origin, Proxy.origin_id == Origin.id) + .filter( + Proxy.destroyed.is_(None), + Proxy.deprecated.is_(None), + Origin.destroyed.is_not(None), + ) + .all() + ) logging.debug("Origin destroyed: %s", origin_destroyed_proxies) for proxy in origin_destroyed_proxies: proxy.deprecate(reason="origin_destroyed") - max_age_proxies = (db.session.query(Proxy) - .join(Origin, Proxy.origin_id == Origin.id) - .filter(Proxy.destroyed.is_(None), - Proxy.deprecated.is_(None), - Proxy.pool_id != -1, # do not rotate hotspare proxies - Origin.assets, - Origin.auto_rotation) - .all()) + max_age_proxies = ( + db.session.query(Proxy) + .join(Origin, Proxy.origin_id == Origin.id) + .filter( + Proxy.destroyed.is_(None), + Proxy.deprecated.is_(None), + Proxy.pool_id != -1, # do not rotate hotspare proxies + Origin.assets, + Origin.auto_rotation, + ) + .all() + ) logging.debug("Max age: %s", max_age_proxies) for proxy in max_age_proxies: max_age_cutoff = datetime.now(timezone.utc) - timedelta( - days=1, seconds=86400 * random.random()) # nosec: B311 + days=1, seconds=86400 * random.random() # nosec: B311 + ) if proxy.added < max_age_cutoff: proxy.deprecate(reason="max_age_reached") @@ -171,8 +190,7 @@ def destroy_expired_proxies() -> None: """ expiry_cutoff = datetime.now(timezone.utc) - timedelta(days=4) proxies = Proxy.query.filter( - Proxy.destroyed.is_(None), - Proxy.deprecated < expiry_cutoff + Proxy.destroyed.is_(None), Proxy.deprecated < expiry_cutoff ).all() for proxy in proxies: logging.debug("Destroying expired proxy") @@ -244,12 +262,17 @@ class ProxyMetaAutomation(BaseAutomation): if origin.destroyed is not None: continue proxies = [ - x for x in origin.proxies - if x.pool_id == pool.id and x.deprecated is None and x.destroyed is None + x + for x in origin.proxies + if x.pool_id == pool.id + and x.deprecated is None + and x.destroyed is None ] logging.debug("Proxies for group %s: %s", group.group_name, proxies) if not proxies: - logging.debug("Creating new proxy for %s in pool %s", origin, pool) + logging.debug( + "Creating new proxy for %s in pool %s", origin, pool + ) if not promote_hot_spare_proxy(pool.id, origin): # No "hot spare" available self.create_proxy(pool.id, origin) @@ -270,8 +293,13 @@ class ProxyMetaAutomation(BaseAutomation): """ for provider in PROXY_PROVIDERS.values(): logging.debug("Looking at provider %s", provider.provider) - subgroup = next_subgroup(self.subgroup_count, provider.provider, origin.group_id, - provider.subgroup_members_max, provider.subgroup_count_max) + subgroup = next_subgroup( + self.subgroup_count, + provider.provider, + origin.group_id, + provider.subgroup_members_max, + provider.subgroup_count_max, + ) if subgroup is None: continue # Exceeded maximum number of subgroups and last subgroup is full self.increment_subgroup(provider.provider, origin.group_id, subgroup) @@ -317,9 +345,7 @@ class ProxyMetaAutomation(BaseAutomation): If an origin is not destroyed and lacks active proxies (not deprecated and not destroyed), a new 'hot spare' proxy for this origin is created in the reserve pool (with pool_id = -1). """ - origins = Origin.query.filter( - Origin.destroyed.is_(None) - ).all() + origins = Origin.query.filter(Origin.destroyed.is_(None)).all() for origin in origins: if origin.countries: risk_levels = origin.risk_level.items() @@ -328,7 +354,10 @@ class ProxyMetaAutomation(BaseAutomation): if highest_risk_level < 4: for proxy in origin.proxies: if proxy.destroyed is None and proxy.pool_id == -1: - logging.debug("Destroying hot spare proxy for origin %s (low risk)", origin) + logging.debug( + "Destroying hot spare proxy for origin %s (low risk)", + origin, + ) proxy.destroy() continue if origin.destroyed is not None: diff --git a/app/terraform/static/aws.py b/app/terraform/static/aws.py index 3286580..dab4cc2 100644 --- a/app/terraform/static/aws.py +++ b/app/terraform/static/aws.py @@ -15,21 +15,26 @@ from app.terraform.terraform import TerraformAutomation def import_state(state: Any) -> None: if not isinstance(state, dict): raise RuntimeError("The Terraform state object returned was not a dict.") - if "values" not in state or "child_modules" not in state['values']['root_module']: + if "values" not in state or "child_modules" not in state["values"]["root_module"]: # There are no CloudFront origins deployed to import state for return # CloudFront distributions (origins) - for mod in state['values']['root_module']['child_modules']: - if mod['address'].startswith('module.static_'): - static_id = mod['address'][len('module.static_'):] + for mod in state["values"]["root_module"]["child_modules"]: + if mod["address"].startswith("module.static_"): + static_id = mod["address"][len("module.static_") :] logging.debug("Found static module in state: %s", static_id) - for res in mod['resources']: - if res['address'].endswith('aws_cloudfront_distribution.this'): + for res in mod["resources"]: + if res["address"].endswith("aws_cloudfront_distribution.this"): logging.debug("and found related cloudfront distribution") - static = StaticOrigin.query.filter(StaticOrigin.id == static_id).first() - static.origin_domain_name = res['values']['domain_name'] - logging.debug("and found static origin: %s to update with domain name: %s", static.id, - static.origin_domain_name) + static = StaticOrigin.query.filter( + StaticOrigin.id == static_id + ).first() + static.origin_domain_name = res["values"]["domain_name"] + logging.debug( + "and found static origin: %s to update with domain name: %s", + static.id, + static.origin_domain_name, + ) static.terraform_updated = datetime.now(tz=timezone.utc) break db.session.commit() @@ -128,14 +133,18 @@ class StaticAWSAutomation(TerraformAutomation): groups=groups, storage_cloud_accounts=storage_cloud_accounts, source_cloud_accounts=source_cloud_accounts, - global_namespace=current_app.config['GLOBAL_NAMESPACE'], bypass_token=current_app.config['BYPASS_TOKEN'], - terraform_modules_path=os.path.join(*list(os.path.split(current_app.root_path))[:-1], 'terraform-modules'), + global_namespace=current_app.config["GLOBAL_NAMESPACE"], + bypass_token=current_app.config["BYPASS_TOKEN"], + terraform_modules_path=os.path.join( + *list(os.path.split(current_app.root_path))[:-1], "terraform-modules" + ), backend_config=f"""backend "http" {{ lock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}" }}""", - **{k: current_app.config[k.upper()] for k in self.template_parameters}) + **{k: current_app.config[k.upper()] for k in self.template_parameters}, + ) def tf_posthook(self, *, prehook_result: Any = None) -> None: import_state(self.tf_show()) diff --git a/app/terraform/static/meta.py b/app/terraform/static/meta.py index 1efe61f..a0b6a14 100644 --- a/app/terraform/static/meta.py +++ b/app/terraform/static/meta.py @@ -27,7 +27,9 @@ class StaticMetaAutomation(BaseAutomation): if static_origin.origin_domain_name is not None: try: # Check if an Origin with the same domain name already exists - origin = Origin.query.filter_by(domain_name=static_origin.origin_domain_name).one() + origin = Origin.query.filter_by( + domain_name=static_origin.origin_domain_name + ).one() # Keep auto rotation value in sync origin.auto_rotation = static_origin.auto_rotate except NoResultFound: @@ -35,17 +37,21 @@ class StaticMetaAutomation(BaseAutomation): origin = Origin( group_id=static_origin.group_id, description=f"PORTAL !! DO NOT DELETE !! Automatically created web origin for static origin " - f"#{static_origin.id}", + f"#{static_origin.id}", domain_name=static_origin.origin_domain_name, auto_rotation=static_origin.auto_rotate, smart=False, assets=False, ) db.session.add(origin) - logging.debug(f"Created Origin with domain name {origin.domain_name}") + logging.debug( + f"Created Origin with domain name {origin.domain_name}" + ) # Step 2: Remove Origins for StaticOrigins with non-null destroy value - static_origins_with_destroyed = StaticOrigin.query.filter(StaticOrigin.destroyed.isnot(None)).all() + static_origins_with_destroyed = StaticOrigin.query.filter( + StaticOrigin.destroyed.isnot(None) + ).all() for static_origin in static_origins_with_destroyed: try: origin = Origin.query.filter_by( diff --git a/app/terraform/terraform.py b/app/terraform/terraform.py index aabf956..4b082b7 100644 --- a/app/terraform/terraform.py +++ b/app/terraform/terraform.py @@ -51,14 +51,20 @@ class TerraformAutomation(BaseAutomation): prehook_result = self.tf_prehook() # pylint: disable=assignment-from-no-return self.tf_generate() self.tf_init() - returncode, logs = self.tf_apply(self.working_dir, refresh=self.always_refresh or full) + returncode, logs = self.tf_apply( + self.working_dir, refresh=self.always_refresh or full + ) self.tf_posthook(prehook_result=prehook_result) return returncode == 0, logs - def tf_apply(self, working_dir: str, *, - refresh: bool = True, - parallelism: Optional[int] = None, - lock_timeout: int = 15) -> Tuple[int, str]: + def tf_apply( + self, + working_dir: str, + *, + refresh: bool = True, + parallelism: Optional[int] = None, + lock_timeout: int = 15, + ) -> Tuple[int, str]: if not parallelism: parallelism = self.parallelism if not self.working_dir: @@ -67,17 +73,19 @@ class TerraformAutomation(BaseAutomation): # the argument list as an array such that argument injection would be # ineffective. tfcmd = subprocess.run( # nosec - ['terraform', - 'apply', - '-auto-approve', - '-json', - f'-refresh={str(refresh).lower()}', - f'-parallelism={str(parallelism)}', - f'-lock-timeout={str(lock_timeout)}m', - ], + [ + "terraform", + "apply", + "-auto-approve", + "-json", + f"-refresh={str(refresh).lower()}", + f"-parallelism={str(parallelism)}", + f"-lock-timeout={str(lock_timeout)}m", + ], cwd=working_dir, - stdout=subprocess.PIPE) - return tfcmd.returncode, tfcmd.stdout.decode('utf-8') + stdout=subprocess.PIPE, + ) + return tfcmd.returncode, tfcmd.stdout.decode("utf-8") @abstractmethod def tf_generate(self) -> None: @@ -91,41 +99,49 @@ class TerraformAutomation(BaseAutomation): # the argument list as an array such that argument injection would be # ineffective. subprocess.run( # nosec - ['terraform', - 'init', - f'-lock-timeout={str(lock_timeout)}m', - ], - cwd=self.working_dir) + [ + "terraform", + "init", + f"-lock-timeout={str(lock_timeout)}m", + ], + cwd=self.working_dir, + ) def tf_output(self) -> Any: if not self.working_dir: raise RuntimeError("No working directory specified.") # The following subprocess call does not take any user input. tfcmd = subprocess.run( # nosec - ['terraform', 'output', '-json'], + ["terraform", "output", "-json"], cwd=self.working_dir, - stdout=subprocess.PIPE) + stdout=subprocess.PIPE, + ) return json.loads(tfcmd.stdout) - def tf_plan(self, *, - refresh: bool = True, - parallelism: Optional[int] = None, - lock_timeout: int = 15) -> Tuple[int, str]: + def tf_plan( + self, + *, + refresh: bool = True, + parallelism: Optional[int] = None, + lock_timeout: int = 15, + ) -> Tuple[int, str]: if not self.working_dir: raise RuntimeError("No working directory specified.") # The following subprocess call takes external input, but is providing # the argument list as an array such that argument injection would be # ineffective. tfcmd = subprocess.run( # nosec - ['terraform', - 'plan', - '-json', - f'-refresh={str(refresh).lower()}', - f'-parallelism={str(parallelism)}', - f'-lock-timeout={str(lock_timeout)}m', - ], - cwd=self.working_dir) - return tfcmd.returncode, tfcmd.stdout.decode('utf-8') + [ + "terraform", + "plan", + "-json", + f"-refresh={str(refresh).lower()}", + f"-parallelism={str(parallelism)}", + f"-lock-timeout={str(lock_timeout)}m", + ], + cwd=self.working_dir, + ) + return tfcmd.returncode, tfcmd.stdout.decode("utf-8") def tf_posthook(self, *, prehook_result: Any = None) -> None: """ @@ -154,9 +170,8 @@ class TerraformAutomation(BaseAutomation): raise RuntimeError("No working directory specified.") # This subprocess call doesn't take any user input. terraform = subprocess.run( # nosec - ['terraform', 'show', '-json'], - cwd=self.working_dir, - stdout=subprocess.PIPE) + ["terraform", "show", "-json"], cwd=self.working_dir, stdout=subprocess.PIPE + ) return json.loads(terraform.stdout) def tf_write(self, template: str, **kwargs: Any) -> None: diff --git a/app/tfstate.py b/app/tfstate.py index 5a6028f..b9f70d6 100644 --- a/app/tfstate.py +++ b/app/tfstate.py @@ -9,7 +9,7 @@ from app.models.tfstate import TerraformState tfstate = Blueprint("tfstate", __name__) -@tfstate.route("/", methods=['GET']) +@tfstate.route("/", methods=["GET"]) def handle_get(key: str) -> ResponseReturnValue: state = TerraformState.query.filter(TerraformState.key == key).first() if state is None or state.state is None: @@ -17,16 +17,18 @@ def handle_get(key: str) -> ResponseReturnValue: return Response(state.state, content_type="application/json") -@tfstate.route("/", methods=['POST', 'DELETE', 'UNLOCK']) +@tfstate.route("/", methods=["POST", "DELETE", "UNLOCK"]) def handle_update(key: str) -> ResponseReturnValue: state = TerraformState.query.filter(TerraformState.key == key).first() if not state: if request.method in ["DELETE", "UNLOCK"]: return "OK", 200 state = TerraformState(key=key) - if state.lock and not (request.method == "UNLOCK" and request.args.get('ID') is None): + if state.lock and not ( + request.method == "UNLOCK" and request.args.get("ID") is None + ): # force-unlock seems to not give an ID to verify so accept no ID being present - if json.loads(state.lock)['ID'] != request.args.get('ID'): + if json.loads(state.lock)["ID"] != request.args.get("ID"): return Response(state.lock, status=409, content_type="application/json") if request.method == "POST": state.state = json.dumps(request.json) @@ -38,9 +40,11 @@ def handle_update(key: str) -> ResponseReturnValue: return "OK", 200 -@tfstate.route("/", methods=['LOCK']) +@tfstate.route("/", methods=["LOCK"]) def handle_lock(key: str) -> ResponseReturnValue: - state = TerraformState.query.filter(TerraformState.key == key).with_for_update().first() + state = ( + TerraformState.query.filter(TerraformState.key == key).with_for_update().first() + ) if state is None: state = TerraformState(key=key, state="") db.session.add(state) diff --git a/app/util/onion.py b/app/util/onion.py index 3c180aa..8eef970 100644 --- a/app/util/onion.py +++ b/app/util/onion.py @@ -20,8 +20,9 @@ def onion_hostname(onion_public_key: bytes) -> str: return onion.lower() -def decode_onion_keys(onion_private_key_base64: str, onion_public_key_base64: str) -> Tuple[ - Optional[bytes], Optional[bytes], List[Dict[str, str]]]: +def decode_onion_keys( + onion_private_key_base64: str, onion_public_key_base64: str +) -> Tuple[Optional[bytes], Optional[bytes], List[Dict[str, str]]]: try: onion_private_key = base64.b64decode(onion_private_key_base64) onion_public_key = base64.b64decode(onion_public_key_base64) diff --git a/app/util/x509.py b/app/util/x509.py index cd005d4..e0f0650 100644 --- a/app/util/x509.py +++ b/app/util/x509.py @@ -22,14 +22,20 @@ def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]: return certificates -def build_certificate_chain(certificates: list[x509.Certificate]) -> list[x509.Certificate]: +def build_certificate_chain( + certificates: list[x509.Certificate], +) -> list[x509.Certificate]: if len(certificates) == 1: return certificates chain = [] cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates} end_entity = next( - (cert for cert in certificates if cert.subject.rfc4514_string() not in cert_map), - None + ( + cert + for cert in certificates + if cert.subject.rfc4514_string() not in cert_map + ), + None, ) if not end_entity: raise ValueError("Cannot identify the end-entity certificate.") @@ -51,7 +57,9 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool: for i in range(len(chain) - 1): next_public_key = chain[i + 1].public_key() if not (isinstance(next_public_key, RSAPublicKey)): - raise ValueError(f"Certificate using unsupported algorithm: {type(next_public_key)}") + raise ValueError( + f"Certificate using unsupported algorithm: {type(next_public_key)}" + ) hash_algorithm = chain[i].signature_hash_algorithm if hash_algorithm is None: raise ValueError("Certificate missing hash algorithm") @@ -59,23 +67,23 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool: chain[i].signature, chain[i].tbs_certificate_bytes, PKCS1v15(), - hash_algorithm + hash_algorithm, ) end_cert = chain[-1] if not any( - end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates + end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates ): raise ValueError("Certificate chain does not terminate at a trusted root CA.") return True def validate_tls_keys( - tls_private_key_pem: Optional[str], - tls_certificate_pem: Optional[str], - skip_chain_verification: Optional[bool], - skip_name_verification: Optional[bool], - hostname: str + tls_private_key_pem: Optional[str], + tls_certificate_pem: Optional[str], + skip_chain_verification: Optional[bool], + skip_name_verification: Optional[bool], + hostname: str, ) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]: errors = [] san_list = [] @@ -90,31 +98,55 @@ def validate_tls_keys( private_key = serialization.load_pem_private_key( tls_private_key_pem.encode("utf-8"), password=None, - backend=default_backend() + backend=default_backend(), ) if not isinstance(private_key, rsa.RSAPrivateKey): - errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be RSA."}) + errors.append( + { + "Error": "tls_private_key_invalid", + "Message": "Private key must be RSA.", + } + ) if tls_certificate_pem: - certificates = list(load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))) + certificates = list( + load_certificates_from_pem(tls_certificate_pem.encode("utf-8")) + ) if not certificates: - errors.append({"Error": "tls_certificate_invalid", "Message": "No valid certificate found."}) + errors.append( + { + "Error": "tls_certificate_invalid", + "Message": "No valid certificate found.", + } + ) else: chain = build_certificate_chain(certificates) end_entity_cert = chain[0] if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc): - errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."}) + errors.append( + { + "Error": "tls_public_key_expired", + "Message": "TLS public key is expired.", + } + ) if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc): - errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."}) + errors.append( + { + "Error": "tls_public_key_future", + "Message": "TLS public key is not yet valid.", + } + ) if private_key: public_key = end_entity_cert.public_key() if TYPE_CHECKING: assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101 assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101 - assert end_entity_cert.signature_hash_algorithm is not None # nosec: B101 + assert ( + end_entity_cert.signature_hash_algorithm is not None + ) # nosec: B101 try: test_message = b"test" signature = private_key.sign( @@ -130,20 +162,30 @@ def validate_tls_keys( ) except Exception: errors.append( - {"Error": "tls_key_mismatch", "Message": "Private key does not match certificate."}) + { + "Error": "tls_key_mismatch", + "Message": "Private key does not match certificate.", + } + ) if not skip_chain_verification: try: validate_certificate_chain(chain) except ValueError as e: - errors.append({"Error": "certificate_chain_invalid", "Message": str(e)}) + errors.append( + {"Error": "certificate_chain_invalid", "Message": str(e)} + ) if not skip_name_verification: san_list = extract_sans(end_entity_cert) for expected_hostname in [hostname, f"*.{hostname}"]: if expected_hostname not in san_list: errors.append( - {"Error": "hostname_not_in_san", "Message": f"{expected_hostname} not found in SANs."}) + { + "Error": "hostname_not_in_san", + "Message": f"{expected_hostname} not found in SANs.", + } + ) except Exception as e: errors.append({"Error": "tls_validation_error", "Message": str(e)}) @@ -153,7 +195,9 @@ def validate_tls_keys( def extract_sans(cert: x509.Certificate) -> List[str]: try: - san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + san_extension = cert.extensions.get_extension_for_oid( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ) sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined] return sans except Exception: diff --git a/setup.cfg b/setup.cfg index a1cce8e..692adb8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [flake8] -ignore = E501,W503 +ignore = E203,E501,W503