lint: reformat python code with black

This commit is contained in:
Iain Learmonth 2024-12-06 18:15:47 +00:00
parent 331beb01b4
commit a406a7974b
88 changed files with 2579 additions and 1608 deletions

View file

@ -6,8 +6,7 @@ import yaml
from flask import Flask, redirect, send_from_directory, url_for from flask import Flask, redirect, send_from_directory, url_for
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from prometheus_client import CollectorRegistry, Metric, make_wsgi_app from prometheus_client import CollectorRegistry, Metric, make_wsgi_app
from prometheus_client.metrics_core import (CounterMetricFamily, from prometheus_client.metrics_core import CounterMetricFamily, GaugeMetricFamily
GaugeMetricFamily)
from prometheus_client.registry import REGISTRY, Collector from prometheus_client.registry import REGISTRY, Collector
from prometheus_flask_exporter import PrometheusMetrics from prometheus_flask_exporter import PrometheusMetrics
from sqlalchemy import text from sqlalchemy import text
@ -28,9 +27,9 @@ app.config.from_file("../config.yaml", load=yaml.safe_load)
registry = CollectorRegistry() registry = CollectorRegistry()
metrics = PrometheusMetrics(app, registry=registry) metrics = PrometheusMetrics(app, registry=registry)
app.wsgi_app = DispatcherMiddleware(app.wsgi_app, { # type: ignore[method-assign] app.wsgi_app = DispatcherMiddleware( # type: ignore[method-assign]
'/metrics': make_wsgi_app(registry) app.wsgi_app, {"/metrics": make_wsgi_app(registry)}
}) )
# register default collectors to our new registry # register default collectors to our new registry
collectors = list(REGISTRY._collector_to_names.keys()) collectors = list(REGISTRY._collector_to_names.keys())
@ -54,12 +53,16 @@ def not_migrating() -> bool:
class DefinedProxiesCollector(Collector): class DefinedProxiesCollector(Collector):
def collect(self) -> Iterator[Metric]: def collect(self) -> Iterator[Metric]:
with app.app_context(): with app.app_context():
ok = GaugeMetricFamily("database_collector", ok = GaugeMetricFamily(
"Status of a database collector (0: bad, 1: good)", "database_collector",
labels=["collector"]) "Status of a database collector (0: bad, 1: good)",
labels=["collector"],
)
try: try:
with db.engine.connect() as conn: with db.engine.connect() as conn:
result = conn.execute(text(""" result = conn.execute(
text(
"""
SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name, SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name,
COUNT(proxy.id) FROM proxy, origin, pool, "group" COUNT(proxy.id) FROM proxy, origin, pool, "group"
WHERE proxy.origin_id = origin.id WHERE proxy.origin_id = origin.id
@ -67,13 +70,24 @@ class DefinedProxiesCollector(Collector):
AND proxy.pool_id = pool.id AND proxy.pool_id = pool.id
AND proxy.destroyed IS NULL AND proxy.destroyed IS NULL
GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name; GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name;
""")) """
c = GaugeMetricFamily("defined_proxies", "Number of proxies currently defined for deployment", )
labels=['group_id', 'group_name', 'provider', 'pool_id', )
'pool_name']) c = GaugeMetricFamily(
"defined_proxies",
"Number of proxies currently defined for deployment",
labels=[
"group_id",
"group_name",
"provider",
"pool_id",
"pool_name",
],
)
for row in result: for row in result:
c.add_metric([str(row[0]), row[1], row[2], str(row[3]), row[4]], c.add_metric(
row[5]) [str(row[0]), row[1], row[2], str(row[3]), row[4]], row[5]
)
yield c yield c
ok.add_metric(["defined_proxies"], 1) ok.add_metric(["defined_proxies"], 1)
except SQLAlchemyError: except SQLAlchemyError:
@ -84,12 +98,16 @@ class DefinedProxiesCollector(Collector):
class BlockedProxiesCollector(Collector): class BlockedProxiesCollector(Collector):
def collect(self) -> Iterator[Metric]: def collect(self) -> Iterator[Metric]:
with app.app_context(): with app.app_context():
ok = GaugeMetricFamily("database_collector", ok = GaugeMetricFamily(
"Status of a database collector (0: bad, 1: good)", "database_collector",
labels=["collector"]) "Status of a database collector (0: bad, 1: good)",
labels=["collector"],
)
try: try:
with db.engine.connect() as conn: with db.engine.connect() as conn:
result = conn.execute(text(""" result = conn.execute(
text(
"""
SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name, SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name,
proxy.deprecation_reason, COUNT(proxy.id) FROM proxy, origin, pool, "group" proxy.deprecation_reason, COUNT(proxy.id) FROM proxy, origin, pool, "group"
WHERE proxy.origin_id = origin.id WHERE proxy.origin_id = origin.id
@ -98,14 +116,26 @@ class BlockedProxiesCollector(Collector):
AND proxy.deprecated IS NOT NULL AND proxy.deprecated IS NOT NULL
GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name, GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name,
proxy.deprecation_reason; proxy.deprecation_reason;
""")) """
c = CounterMetricFamily("deprecated_proxies", )
"Number of proxies deprecated", )
labels=['group_id', 'group_name', 'provider', 'pool_id', 'pool_name', c = CounterMetricFamily(
'deprecation_reason']) "deprecated_proxies",
"Number of proxies deprecated",
labels=[
"group_id",
"group_name",
"provider",
"pool_id",
"pool_name",
"deprecation_reason",
],
)
for row in result: for row in result:
c.add_metric([str(row[0]), row[1], row[2], str(row[3]), row[4], row[5]], c.add_metric(
row[6]) [str(row[0]), row[1], row[2], str(row[3]), row[4], row[5]],
row[6],
)
yield c yield c
ok.add_metric(["deprecated_proxies"], 0) ok.add_metric(["deprecated_proxies"], 0)
except SQLAlchemyError: except SQLAlchemyError:
@ -116,24 +146,36 @@ class BlockedProxiesCollector(Collector):
class AutomationCollector(Collector): class AutomationCollector(Collector):
def collect(self) -> Iterator[Metric]: def collect(self) -> Iterator[Metric]:
with app.app_context(): with app.app_context():
ok = GaugeMetricFamily("database_collector", ok = GaugeMetricFamily(
"Status of a database collector (0: bad, 1: good)", "database_collector",
labels=["collector"]) "Status of a database collector (0: bad, 1: good)",
labels=["collector"],
)
try: try:
state = GaugeMetricFamily("automation_state", "The automation state (0: idle, 1: running, 2: error)", state = GaugeMetricFamily(
labels=['automation_name']) "automation_state",
enabled = GaugeMetricFamily("automation_enabled", "The automation state (0: idle, 1: running, 2: error)",
"Whether an automation is enabled (0: disabled, 1: enabled)", labels=["automation_name"],
labels=['automation_name']) )
next_run = GaugeMetricFamily("automation_next_run", "The timestamp of the next run of the automation", enabled = GaugeMetricFamily(
labels=['automation_name']) "automation_enabled",
last_run_start = GaugeMetricFamily("automation_last_run_start", "Whether an automation is enabled (0: disabled, 1: enabled)",
"The timestamp of the last run of the automation ", labels=["automation_name"],
labels=['automation_name']) )
next_run = GaugeMetricFamily(
"automation_next_run",
"The timestamp of the next run of the automation",
labels=["automation_name"],
)
last_run_start = GaugeMetricFamily(
"automation_last_run_start",
"The timestamp of the last run of the automation ",
labels=["automation_name"],
)
automations = Automation.query.all() automations = Automation.query.all()
for automation in automations: for automation in automations:
if automation.short_name in app.config['HIDDEN_AUTOMATIONS']: if automation.short_name in app.config["HIDDEN_AUTOMATIONS"]:
continue continue
if automation.state == AutomationState.IDLE: if automation.state == AutomationState.IDLE:
state.add_metric([automation.short_name], 0) state.add_metric([automation.short_name], 0)
@ -141,13 +183,19 @@ class AutomationCollector(Collector):
state.add_metric([automation.short_name], 1) state.add_metric([automation.short_name], 1)
else: else:
state.add_metric([automation.short_name], 2) state.add_metric([automation.short_name], 2)
enabled.add_metric([automation.short_name], 1 if automation.enabled else 0) enabled.add_metric(
[automation.short_name], 1 if automation.enabled else 0
)
if automation.next_run: if automation.next_run:
next_run.add_metric([automation.short_name], automation.next_run.timestamp()) next_run.add_metric(
[automation.short_name], automation.next_run.timestamp()
)
else: else:
next_run.add_metric([automation.short_name], 0) next_run.add_metric([automation.short_name], 0)
if automation.last_run: if automation.last_run:
last_run_start.add_metric([automation.short_name], automation.last_run.timestamp()) last_run_start.add_metric(
[automation.short_name], automation.last_run.timestamp()
)
else: else:
last_run_start.add_metric([automation.short_name], 0) last_run_start.add_metric([automation.short_name], 0)
yield state yield state
@ -161,31 +209,31 @@ class AutomationCollector(Collector):
# register all custom collectors to registry # register all custom collectors to registry
if not_migrating() and 'DISABLE_METRICS' not in os.environ: if not_migrating() and "DISABLE_METRICS" not in os.environ:
registry.register(DefinedProxiesCollector()) registry.register(DefinedProxiesCollector())
registry.register(BlockedProxiesCollector()) registry.register(BlockedProxiesCollector())
registry.register(AutomationCollector()) registry.register(AutomationCollector())
@app.route('/ui') @app.route("/ui")
def redirect_ui() -> ResponseReturnValue: def redirect_ui() -> ResponseReturnValue:
return redirect("/ui/") return redirect("/ui/")
@app.route('/ui/', defaults={'path': ''}) @app.route("/ui/", defaults={"path": ""})
@app.route('/ui/<path:path>') @app.route("/ui/<path:path>")
def serve_ui(path: str) -> ResponseReturnValue: def serve_ui(path: str) -> ResponseReturnValue:
if path != "" and os.path.exists("app/static/ui/" + path): if path != "" and os.path.exists("app/static/ui/" + path):
return send_from_directory('static/ui', path) return send_from_directory("static/ui", path)
else: else:
return send_from_directory('static/ui', 'index.html') return send_from_directory("static/ui", "index.html")
@app.route('/') @app.route("/")
def index() -> ResponseReturnValue: def index() -> ResponseReturnValue:
# TODO: update to point at new UI when ready # TODO: update to point at new UI when ready
return redirect(url_for("portal.portal_home")) return redirect(url_for("portal.portal_home"))
if __name__ == '__main__': if __name__ == "__main__":
app.run() app.run()

View file

@ -7,18 +7,15 @@ from app.models.alarms import Alarm
def alarms_for(target: BRN) -> List[Alarm]: def alarms_for(target: BRN) -> List[Alarm]:
return list(Alarm.query.filter( return list(Alarm.query.filter(Alarm.target == str(target)).all())
Alarm.target == str(target)
).all())
def _get_alarm(target: BRN, def _get_alarm(
aspect: str, target: BRN, aspect: str, create_if_missing: bool = True
create_if_missing: bool = True) -> Optional[Alarm]: ) -> Optional[Alarm]:
target_str = str(target) target_str = str(target)
alarm: Optional[Alarm] = Alarm.query.filter( alarm: Optional[Alarm] = Alarm.query.filter(
Alarm.aspect == aspect, Alarm.aspect == aspect, Alarm.target == target_str
Alarm.target == target_str
).first() ).first()
if create_if_missing and alarm is None: if create_if_missing and alarm is None:
alarm = Alarm() alarm = Alarm()

View file

@ -5,34 +5,38 @@ from werkzeug.exceptions import HTTPException
from app.api.onion import api_onion from app.api.onion import api_onion
from app.api.web import api_web from app.api.web import api_web
api = Blueprint('api', __name__) api = Blueprint("api", __name__)
api.register_blueprint(api_onion, url_prefix='/onion') api.register_blueprint(api_onion, url_prefix="/onion")
api.register_blueprint(api_web, url_prefix='/web') api.register_blueprint(api_web, url_prefix="/web")
@api.errorhandler(400) @api.errorhandler(400)
def bad_request(error: HTTPException) -> ResponseReturnValue: def bad_request(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Bad Request', 'message': error.description}) response = jsonify({"error": "Bad Request", "message": error.description})
response.status_code = 400 response.status_code = 400
return response return response
@api.errorhandler(401) @api.errorhandler(401)
def unauthorized(error: HTTPException) -> ResponseReturnValue: def unauthorized(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Unauthorized', 'message': error.description}) response = jsonify({"error": "Unauthorized", "message": error.description})
response.status_code = 401 response.status_code = 401
return response return response
@api.errorhandler(404) @api.errorhandler(404)
def not_found(_: HTTPException) -> ResponseReturnValue: def not_found(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'}) response = jsonify(
{"error": "Not found", "message": "Resource could not be found."}
)
response.status_code = 404 response.status_code = 404
return response return response
@api.errorhandler(500) @api.errorhandler(500)
def internal_server_error(_: HTTPException) -> ResponseReturnValue: def internal_server_error(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'}) response = jsonify(
{"error": "Internal Server Error", "message": "An unexpected error occurred."}
)
response.status_code = 500 response.status_code = 500
return response return response

View file

@ -7,31 +7,37 @@ from flask import Blueprint, abort, jsonify, request
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from sqlalchemy import exc from sqlalchemy import exc
from app.api.util import (DOMAIN_NAME_REGEX, MAX_ALLOWED_ITEMS, from app.api.util import (
MAX_DOMAIN_NAME_LENGTH, ListFilter, DOMAIN_NAME_REGEX,
get_single_resource, list_resources, MAX_ALLOWED_ITEMS,
validate_description) MAX_DOMAIN_NAME_LENGTH,
ListFilter,
get_single_resource,
list_resources,
validate_description,
)
from app.extensions import db from app.extensions import db
from app.models.base import Group from app.models.base import Group
from app.models.onions import Onion from app.models.onions import Onion
from app.util.onion import decode_onion_keys, onion_hostname from app.util.onion import decode_onion_keys, onion_hostname
from app.util.x509 import validate_tls_keys from app.util.x509 import validate_tls_keys
api_onion = Blueprint('api_onion', __name__) api_onion = Blueprint("api_onion", __name__)
@api_onion.route('/onion', methods=['GET']) @api_onion.route("/onion", methods=["GET"])
def list_onions() -> ResponseReturnValue: def list_onions() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName') domain_name_filter = request.args.get("DomainName")
group_id_filter = request.args.get('GroupId') group_id_filter = request.args.get("GroupId")
filters: List[ListFilter] = [ filters: List[ListFilter] = [(Onion.destroyed.is_(None))]
(Onion.destroyed.is_(None))
]
if domain_name_filter: if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") abort(
400,
description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.",
)
if not DOMAIN_NAME_REGEX.match(domain_name_filter): if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.") abort(400, description="DomainName contains invalid characters.")
filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%")) filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%"))
@ -46,9 +52,9 @@ def list_onions() -> ResponseReturnValue:
Onion, Onion,
lambda onion: onion.to_dict(), lambda onion: onion.to_dict(),
filters=filters, filters=filters,
resource_name='OnionsList', resource_name="OnionsList",
max_allowed_items=MAX_ALLOWED_ITEMS, max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber', protective_marking="amber",
) )
@ -71,13 +77,26 @@ def create_onion() -> ResponseReturnValue:
abort(400) abort(400)
errors = [] errors = []
for field in ["DomainName", "Description", "OnionPrivateKey", "OnionPublicKey", "GroupId", "TlsPrivateKey", for field in [
"TlsCertificate"]: "DomainName",
"Description",
"OnionPrivateKey",
"OnionPublicKey",
"GroupId",
"TlsPrivateKey",
"TlsCertificate",
]:
if not data.get(field): if not data.get(field):
errors.append({"Error": f"{field}_missing", "Message": f"Missing required field: {field}"}) errors.append(
{
"Error": f"{field}_missing",
"Message": f"Missing required field: {field}",
}
)
onion_private_key, onion_public_key, onion_errors = decode_onion_keys(data["OnionPrivateKey"], onion_private_key, onion_public_key, onion_errors = decode_onion_keys(
data["OnionPublicKey"]) data["OnionPrivateKey"], data["OnionPublicKey"]
)
if onion_errors: if onion_errors:
errors.extend(onion_errors) errors.extend(onion_errors)
@ -85,23 +104,35 @@ def create_onion() -> ResponseReturnValue:
return jsonify({"Errors": errors}), 400 return jsonify({"Errors": errors}), 400
if onion_private_key: if onion_private_key:
existing_onion = db.session.query(Onion).where( existing_onion = (
Onion.onion_private_key == onion_private_key, db.session.query(Onion)
Onion.destroyed.is_(None), .where(
).first() Onion.onion_private_key == onion_private_key,
Onion.destroyed.is_(None),
)
.first()
)
if existing_onion: if existing_onion:
errors.append( errors.append(
{"Error": "duplicate_onion_key", "Message": "An onion service with this private key already exists."}) {
"Error": "duplicate_onion_key",
"Message": "An onion service with this private key already exists.",
}
)
if "GroupId" in data: if "GroupId" in data:
group = Group.query.get(data["GroupId"]) group = Group.query.get(data["GroupId"])
if not group: if not group:
errors.append({"Error": "group_id_not_found", "Message": "Invalid group ID."}) errors.append(
{"Error": "group_id_not_found", "Message": "Invalid group ID."}
)
chain, san_list, tls_errors = validate_tls_keys( chain, san_list, tls_errors = validate_tls_keys(
data["TlsPrivateKey"], data["TlsCertificate"], data.get("SkipChainVerification"), data["TlsPrivateKey"],
data["TlsCertificate"],
data.get("SkipChainVerification"),
data.get("SkipNameVerification"), data.get("SkipNameVerification"),
f"{onion_hostname(onion_public_key)}.onion" f"{onion_hostname(onion_public_key)}.onion",
) )
if tls_errors: if tls_errors:
@ -123,15 +154,21 @@ def create_onion() -> ResponseReturnValue:
added=datetime.now(timezone.utc), added=datetime.now(timezone.utc),
updated=datetime.now(timezone.utc), updated=datetime.now(timezone.utc),
cert_expiry=cert_expiry_date, cert_expiry=cert_expiry_date,
cert_sans=",".join(san_list) cert_sans=",".join(san_list),
) )
try: try:
db.session.add(onion) db.session.add(onion)
db.session.commit() db.session.commit()
return jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), 201 return (
jsonify({"Message": "Onion service created successfully.", "Id": onion.id}),
201,
)
except exc.SQLAlchemyError as e: except exc.SQLAlchemyError as e:
return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500 return (
jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}),
500,
)
class UpdateOnionRequest(TypedDict): class UpdateOnionRequest(TypedDict):
@ -152,8 +189,19 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
onion = Onion.query.get(onion_id) onion = Onion.query.get(onion_id)
if not onion: if not onion:
return jsonify( return (
{"Errors": [{"Error": "onion_not_found", "Message": f"No Onion service found with ID {onion_id}"}]}), 404 jsonify(
{
"Errors": [
{
"Error": "onion_not_found",
"Message": f"No Onion service found with ID {onion_id}",
}
]
}
),
404,
)
if "Description" in data: if "Description" in data:
description = data["Description"] description = data["Description"]
@ -161,7 +209,12 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
if validate_description(description): if validate_description(description):
onion.description = description onion.description = description
else: else:
errors.append({"Error": "description_error", "Message": "Description field is invalid"}) errors.append(
{
"Error": "description_error",
"Message": "Description field is invalid",
}
)
tls_private_key_pem: Optional[str] = None tls_private_key_pem: Optional[str] = None
tls_certificate_pem: Optional[str] = None tls_certificate_pem: Optional[str] = None
@ -176,7 +229,9 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
tls_private_key_pem = onion.tls_private_key.decode("utf-8") tls_private_key_pem = onion.tls_private_key.decode("utf-8")
chain, san_list, tls_errors = validate_tls_keys( chain, san_list, tls_errors = validate_tls_keys(
tls_private_key_pem, tls_certificate_pem, data.get("SkipChainVerification", False), tls_private_key_pem,
tls_certificate_pem,
data.get("SkipChainVerification", False),
data.get("SkipNameVerification", False), data.get("SkipNameVerification", False),
f"{onion_hostname(onion.onion_public_key)}.onion", f"{onion_hostname(onion.onion_public_key)}.onion",
) )
@ -200,7 +255,10 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
db.session.commit() db.session.commit()
return jsonify({"Message": "Onion service updated successfully."}), 200 return jsonify({"Message": "Onion service updated successfully."}), 200
except exc.SQLAlchemyError as e: except exc.SQLAlchemyError as e:
return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500 return (
jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}),
500,
)
@api_onion.route("/onion/<int:onion_id>", methods=["GET"]) @api_onion.route("/onion/<int:onion_id>", methods=["GET"])

View file

@ -12,7 +12,7 @@ from app.extensions import db
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_DOMAIN_NAME_LENGTH = 255 MAX_DOMAIN_NAME_LENGTH = 255
DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$') DOMAIN_NAME_REGEX = re.compile(r"^[a-zA-Z0-9.\-]*$")
MAX_ALLOWED_ITEMS = 100 MAX_ALLOWED_ITEMS = 100
ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]] ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]]
@ -24,7 +24,10 @@ def validate_max_items(max_items_str: str, max_allowed: int) -> int:
raise ValueError() raise ValueError()
return max_items return max_items
except ValueError: except ValueError:
abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.") abort(
400,
description=f"MaxItems must be a positive integer not exceeding {max_allowed}.",
)
def validate_marker(marker_str: str) -> int: def validate_marker(marker_str: str) -> int:
@ -47,21 +50,22 @@ TlpMarkings = Union[
def list_resources( # pylint: disable=too-many-arguments,too-many-locals def list_resources( # pylint: disable=too-many-arguments,too-many-locals
model: Type[Any], model: Type[Any],
serialize_func: Callable[[Any], Dict[str, Any]], serialize_func: Callable[[Any], Dict[str, Any]],
*, *,
filters: Optional[List[ListFilter]] = None, filters: Optional[List[ListFilter]] = None,
order_by: Optional[ColumnElement[Any]] = None, order_by: Optional[ColumnElement[Any]] = None,
resource_name: str = 'ResourceList', resource_name: str = "ResourceList",
max_items_param: str = 'MaxItems', max_items_param: str = "MaxItems",
marker_param: str = 'Marker', marker_param: str = "Marker",
max_allowed_items: int = 100, max_allowed_items: int = 100,
protective_marking: TlpMarkings = 'default', protective_marking: TlpMarkings = "default",
) -> ResponseReturnValue: ) -> ResponseReturnValue:
try: try:
marker = request.args.get(marker_param) marker = request.args.get(marker_param)
max_items = validate_max_items( max_items = validate_max_items(
request.args.get(max_items_param, default='100'), max_allowed_items) request.args.get(max_items_param, default="100"), max_allowed_items
)
query = select(model) query = select(model)
if filters: if filters:
@ -101,14 +105,21 @@ def list_resources( # pylint: disable=too-many-arguments,too-many-locals
abort(500) abort(500)
def get_single_resource(model: Type[Any], id_: int, resource_name: str) -> ResponseReturnValue: def get_single_resource(
model: Type[Any], id_: int, resource_name: str
) -> ResponseReturnValue:
try: try:
resource = db.session.get(model, id_) resource = db.session.get(model, id_)
if not resource: if not resource:
return jsonify({ return (
"Error": "resource_not_found", jsonify(
"Message": f"No {resource_name} found with ID {id_}" {
}), 404 "Error": "resource_not_found",
"Message": f"No {resource_name} found with ID {id_}",
}
),
404,
)
return jsonify({resource_name: resource.to_dict()}), 200 return jsonify({resource_name: resource.to_dict()}), 200
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
logger.exception("An unexpected error occurred while retrieving the onion") logger.exception("An unexpected error occurred while retrieving the onion")

View file

@ -4,35 +4,43 @@ from typing import List
from flask import Blueprint, abort, request from flask import Blueprint, abort, request
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from app.api.util import (DOMAIN_NAME_REGEX, MAX_ALLOWED_ITEMS, from app.api.util import (
MAX_DOMAIN_NAME_LENGTH, ListFilter, list_resources) DOMAIN_NAME_REGEX,
MAX_ALLOWED_ITEMS,
MAX_DOMAIN_NAME_LENGTH,
ListFilter,
list_resources,
)
from app.models.base import Group from app.models.base import Group
from app.models.mirrors import Origin, Proxy from app.models.mirrors import Origin, Proxy
api_web = Blueprint('web', __name__) api_web = Blueprint("web", __name__)
@api_web.route('/group', methods=['GET']) @api_web.route("/group", methods=["GET"])
def list_groups() -> ResponseReturnValue: def list_groups() -> ResponseReturnValue:
return list_resources( return list_resources(
Group, Group,
lambda group: group.to_dict(), lambda group: group.to_dict(),
resource_name='OriginGroupList', resource_name="OriginGroupList",
max_allowed_items=MAX_ALLOWED_ITEMS, max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber', protective_marking="amber",
) )
@api_web.route('/origin', methods=['GET']) @api_web.route("/origin", methods=["GET"])
def list_origins() -> ResponseReturnValue: def list_origins() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName') domain_name_filter = request.args.get("DomainName")
group_id_filter = request.args.get('GroupId') group_id_filter = request.args.get("GroupId")
filters: List[ListFilter] = [] filters: List[ListFilter] = []
if domain_name_filter: if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH: if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.") abort(
400,
description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.",
)
if not DOMAIN_NAME_REGEX.match(domain_name_filter): if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.") abort(400, description="DomainName contains invalid characters.")
filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%")) filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%"))
@ -47,18 +55,18 @@ def list_origins() -> ResponseReturnValue:
Origin, Origin,
lambda origin: origin.to_dict(), lambda origin: origin.to_dict(),
filters=filters, filters=filters,
resource_name='OriginsList', resource_name="OriginsList",
max_allowed_items=MAX_ALLOWED_ITEMS, max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber', protective_marking="amber",
) )
@api_web.route('/mirror', methods=['GET']) @api_web.route("/mirror", methods=["GET"])
def list_mirrors() -> ResponseReturnValue: def list_mirrors() -> ResponseReturnValue:
filters = [] filters = []
twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24)
status_filter = request.args.get('Status') status_filter = request.args.get("Status")
if status_filter: if status_filter:
if status_filter == "pending": if status_filter == "pending":
filters.append(Proxy.url.is_(None)) filters.append(Proxy.url.is_(None))
@ -74,13 +82,15 @@ def list_mirrors() -> ResponseReturnValue:
if status_filter == "destroyed": if status_filter == "destroyed":
filters.append(Proxy.destroyed > twenty_four_hours_ago) filters.append(Proxy.destroyed > twenty_four_hours_ago)
else: else:
filters.append((Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago)) filters.append(
(Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago)
)
return list_resources( return list_resources(
Proxy, Proxy,
lambda proxy: proxy.to_dict(), lambda proxy: proxy.to_dict(),
filters=filters, filters=filters,
resource_name='MirrorsList', resource_name="MirrorsList",
max_allowed_items=MAX_ALLOWED_ITEMS, max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber', protective_marking="amber",
) )

View file

@ -29,31 +29,37 @@ class BRN:
def from_str(cls, string: str) -> BRN: def from_str(cls, string: str) -> BRN:
parts = string.split(":") parts = string.split(":")
if len(parts) != 6 or parts[0].lower() != "brn" or not is_integer(parts[2]): if len(parts) != 6 or parts[0].lower() != "brn" or not is_integer(parts[2]):
raise TypeError(f"Expected a valid BRN but got {repr(string)} (invalid parts).") raise TypeError(
f"Expected a valid BRN but got {repr(string)} (invalid parts)."
)
resource_parts = parts[5].split("/") resource_parts = parts[5].split("/")
if len(resource_parts) != 2: if len(resource_parts) != 2:
raise TypeError(f"Expected a valid BRN but got {repr(string)} (invalid resource parts).") raise TypeError(
f"Expected a valid BRN but got {repr(string)} (invalid resource parts)."
)
return cls( return cls(
global_namespace=parts[1], global_namespace=parts[1],
group_id=int(parts[2]), group_id=int(parts[2]),
product=parts[3], product=parts[3],
provider=parts[4], provider=parts[4],
resource_type=resource_parts[0], resource_type=resource_parts[0],
resource_id=resource_parts[1] resource_id=resource_parts[1],
) )
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return str(self) == str(other) return str(self) == str(other)
def __str__(self) -> str: def __str__(self) -> str:
return ":".join([ return ":".join(
"brn", [
self.global_namespace, "brn",
str(self.group_id), self.global_namespace,
self.product, str(self.group_id),
self.provider, self.product,
f"{self.resource_type}/{self.resource_id}" self.provider,
]) f"{self.resource_type}/{self.resource_id}",
]
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<BRN {str(self)}>" return f"<BRN {str(self)}>"

View file

@ -9,18 +9,18 @@ from app.models.mirrors import StaticOrigin
def create_static_origin( def create_static_origin(
description: str, description: str,
group_id: int, group_id: int,
storage_cloud_account_id: int, storage_cloud_account_id: int,
source_cloud_account_id: int, source_cloud_account_id: int,
source_project: str, source_project: str,
auto_rotate: bool, auto_rotate: bool,
matrix_homeserver: Optional[str], matrix_homeserver: Optional[str],
keanu_convene_path: Optional[str], keanu_convene_path: Optional[str],
keanu_convene_logo: Optional[FileStorage], keanu_convene_logo: Optional[FileStorage],
keanu_convene_color: Optional[str], keanu_convene_color: Optional[str],
clean_insights_backend: Optional[Union[str, bool]], clean_insights_backend: Optional[Union[str, bool]],
db_session_commit: bool = False, db_session_commit: bool = False,
) -> StaticOrigin: ) -> StaticOrigin:
""" """
Create a new static origin. Create a new static origin.
@ -47,14 +47,18 @@ def create_static_origin(
else: else:
raise ValueError("group_id must be an int") raise ValueError("group_id must be an int")
if isinstance(storage_cloud_account_id, int): if isinstance(storage_cloud_account_id, int):
cloud_account = CloudAccount.query.filter(CloudAccount.id == storage_cloud_account_id).first() cloud_account = CloudAccount.query.filter(
CloudAccount.id == storage_cloud_account_id
).first()
if cloud_account is None: if cloud_account is None:
raise ValueError("storage_cloud_account_id must match an existing provider") raise ValueError("storage_cloud_account_id must match an existing provider")
static_origin.storage_cloud_account_id = storage_cloud_account_id static_origin.storage_cloud_account_id = storage_cloud_account_id
else: else:
raise ValueError("storage_cloud_account_id must be an int") raise ValueError("storage_cloud_account_id must be an int")
if isinstance(source_cloud_account_id, int): if isinstance(source_cloud_account_id, int):
cloud_account = CloudAccount.query.filter(CloudAccount.id == source_cloud_account_id).first() cloud_account = CloudAccount.query.filter(
CloudAccount.id == source_cloud_account_id
).first()
if cloud_account is None: if cloud_account is None:
raise ValueError("source_cloud_account_id must match an existing provider") raise ValueError("source_cloud_account_id must match an existing provider")
static_origin.source_cloud_account_id = source_cloud_account_id static_origin.source_cloud_account_id = source_cloud_account_id
@ -69,7 +73,7 @@ def create_static_origin(
keanu_convene_logo, keanu_convene_logo,
keanu_convene_color, keanu_convene_color,
clean_insights_backend, clean_insights_backend,
False False,
) )
if db_session_commit: if db_session_commit:
db.session.add(static_origin) db.session.add(static_origin)

View file

@ -26,7 +26,9 @@ def is_integer(contender: Any) -> bool:
return float(contender).is_integer() return float(contender).is_integer()
def thumbnail_uploaded_image(file: FileStorage, max_size: Tuple[int, int] = (256, 256)) -> bytes: def thumbnail_uploaded_image(
file: FileStorage, max_size: Tuple[int, int] = (256, 256)
) -> bytes:
""" """
Process an uploaded image file into a resized image of a specific size. Process an uploaded image file into a resized image of a specific size.
@ -39,7 +41,9 @@ def thumbnail_uploaded_image(file: FileStorage, max_size: Tuple[int, int] = (256
img = Image.open(file) img = Image.open(file)
img.thumbnail(max_size) img.thumbnail(max_size)
byte_arr = BytesIO() byte_arr = BytesIO()
img.save(byte_arr, format='PNG' if file.filename.lower().endswith('.png') else 'JPEG') img.save(
byte_arr, format="PNG" if file.filename.lower().endswith(".png") else "JPEG"
)
return byte_arr.getvalue() return byte_arr.getvalue()
@ -52,9 +56,11 @@ def create_data_uri(bytes_data: bytes, file_extension: str) -> str:
:return: A data URI representing the image. :return: A data URI representing the image.
""" """
# base64 encode # base64 encode
encoded = base64.b64encode(bytes_data).decode('ascii') encoded = base64.b64encode(bytes_data).decode("ascii")
# create data URI # create data URI
data_uri = "data:image/{};base64,{}".format('jpeg' if file_extension == 'jpg' else file_extension, encoded) data_uri = "data:image/{};base64,{}".format(
"jpeg" if file_extension == "jpg" else file_extension, encoded
)
return data_uri return data_uri
@ -80,7 +86,7 @@ def normalize_color(color: str) -> str:
return webcolors.name_to_hex(color) # type: ignore[no-any-return] return webcolors.name_to_hex(color) # type: ignore[no-any-return]
except ValueError: except ValueError:
pass pass
if color.startswith('#'): if color.startswith("#"):
color = color[1:].lower() color = color[1:].lower()
if len(color) in [3, 6]: if len(color) in [3, 6]:
try: try:

View file

@ -3,7 +3,9 @@ from abc import abstractmethod
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
_SubparserType = argparse._SubParsersAction[argparse.ArgumentParser] # pylint: disable=protected-access _SubparserType = argparse._SubParsersAction[
argparse.ArgumentParser
] # pylint: disable=protected-access
else: else:
_SubparserType = Any _SubparserType = Any

View file

@ -13,7 +13,9 @@ def parse_args(argv: List[str]) -> None:
if basename(argv[0]) == "__main__.py": if basename(argv[0]) == "__main__.py":
argv[0] = "bypass" argv[0] = "bypass"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-v", "--verbose", help="increase logging verbosity", action="store_true") parser.add_argument(
"-v", "--verbose", help="increase logging verbosity", action="store_true"
)
subparsers = parser.add_subparsers(title="command", help="command to run") subparsers = parser.add_subparsers(title="command", help="command to run")
AutomateCliHandler.add_subparser_to(subparsers) AutomateCliHandler.add_subparser_to(subparsers)
DbCliHandler.add_subparser_to(subparsers) DbCliHandler.add_subparser_to(subparsers)
@ -28,7 +30,6 @@ def parse_args(argv: List[str]) -> None:
if __name__ == "__main__": if __name__ == "__main__":
VERBOSE = "-v" in sys.argv or "--verbose" in sys.argv VERBOSE = "-v" in sys.argv or "--verbose" in sys.argv
logging.basicConfig( logging.basicConfig(level=logging.DEBUG if VERBOSE else logging.INFO)
level=logging.DEBUG if VERBOSE else logging.INFO)
logging.debug("Arguments: %s", sys.argv) logging.debug("Arguments: %s", sys.argv)
parse_args(sys.argv) parse_args(sys.argv)

View file

@ -14,18 +14,14 @@ from app.models.automation import Automation, AutomationLogs, AutomationState
from app.terraform import BaseAutomation from app.terraform import BaseAutomation
from app.terraform.alarms.eotk_aws import AlarmEotkAwsAutomation from app.terraform.alarms.eotk_aws import AlarmEotkAwsAutomation
from app.terraform.alarms.proxy_azure_cdn import AlarmProxyAzureCdnAutomation from app.terraform.alarms.proxy_azure_cdn import AlarmProxyAzureCdnAutomation
from app.terraform.alarms.proxy_cloudfront import \ from app.terraform.alarms.proxy_cloudfront import AlarmProxyCloudfrontAutomation
AlarmProxyCloudfrontAutomation from app.terraform.alarms.proxy_http_status import AlarmProxyHTTPStatusAutomation
from app.terraform.alarms.proxy_http_status import \
AlarmProxyHTTPStatusAutomation
from app.terraform.alarms.smart_aws import AlarmSmartAwsAutomation from app.terraform.alarms.smart_aws import AlarmSmartAwsAutomation
from app.terraform.block.block_blocky import BlockBlockyAutomation from app.terraform.block.block_blocky import BlockBlockyAutomation
from app.terraform.block.block_scriptzteam import \ from app.terraform.block.block_scriptzteam import BlockBridgeScriptzteamAutomation
BlockBridgeScriptzteamAutomation
from app.terraform.block.bridge_github import BlockBridgeGitHubAutomation from app.terraform.block.bridge_github import BlockBridgeGitHubAutomation
from app.terraform.block.bridge_gitlab import BlockBridgeGitlabAutomation from app.terraform.block.bridge_gitlab import BlockBridgeGitlabAutomation
from app.terraform.block.bridge_roskomsvoboda import \ from app.terraform.block.bridge_roskomsvoboda import BlockBridgeRoskomsvobodaAutomation
BlockBridgeRoskomsvobodaAutomation
from app.terraform.block_external import BlockExternalAutomation from app.terraform.block_external import BlockExternalAutomation
from app.terraform.block_ooni import BlockOONIAutomation from app.terraform.block_ooni import BlockOONIAutomation
from app.terraform.block_roskomsvoboda import BlockRoskomsvobodaAutomation from app.terraform.block_roskomsvoboda import BlockRoskomsvobodaAutomation
@ -58,12 +54,10 @@ jobs = {
BlockExternalAutomation, BlockExternalAutomation,
BlockOONIAutomation, BlockOONIAutomation,
BlockRoskomsvobodaAutomation, BlockRoskomsvobodaAutomation,
# Create new resources # Create new resources
BridgeMetaAutomation, BridgeMetaAutomation,
StaticMetaAutomation, StaticMetaAutomation,
ProxyMetaAutomation, ProxyMetaAutomation,
# Terraform # Terraform
BridgeAWSAutomation, BridgeAWSAutomation,
BridgeGandiAutomation, BridgeGandiAutomation,
@ -74,14 +68,12 @@ jobs = {
ProxyAzureCdnAutomation, ProxyAzureCdnAutomation,
ProxyCloudfrontAutomation, ProxyCloudfrontAutomation,
ProxyFastlyAutomation, ProxyFastlyAutomation,
# Import alarms # Import alarms
AlarmEotkAwsAutomation, AlarmEotkAwsAutomation,
AlarmProxyAzureCdnAutomation, AlarmProxyAzureCdnAutomation,
AlarmProxyCloudfrontAutomation, AlarmProxyCloudfrontAutomation,
AlarmProxyHTTPStatusAutomation, AlarmProxyHTTPStatusAutomation,
AlarmSmartAwsAutomation, AlarmSmartAwsAutomation,
# Update lists # Update lists
ListGithubAutomation, ListGithubAutomation,
ListGitlabAutomation, ListGitlabAutomation,
@ -103,9 +95,12 @@ def run_all(**kwargs: bool) -> None:
run_job(job, **kwargs) run_job(job, **kwargs)
def run_job(job_cls: Type[BaseAutomation], *, def run_job(
force: bool = False, ignore_schedule: bool = False) -> None: job_cls: Type[BaseAutomation], *, force: bool = False, ignore_schedule: bool = False
automation = Automation.query.filter(Automation.short_name == job_cls.short_name).first() ) -> None:
automation = Automation.query.filter(
Automation.short_name == job_cls.short_name
).first()
if automation is None: if automation is None:
automation = Automation() automation = Automation()
automation.short_name = job_cls.short_name automation.short_name = job_cls.short_name
@ -121,18 +116,24 @@ def run_job(job_cls: Type[BaseAutomation], *,
logging.warning("Not running an already running automation") logging.warning("Not running an already running automation")
return return
if not ignore_schedule and not force: if not ignore_schedule and not force:
if automation.next_run is not None and automation.next_run > datetime.now(tz=timezone.utc): if automation.next_run is not None and automation.next_run > datetime.now(
tz=timezone.utc
):
logging.warning("Not time to run this job yet") logging.warning("Not time to run this job yet")
return return
if not automation.enabled and not force: if not automation.enabled and not force:
logging.warning("job %s is disabled and --force not specified", job_cls.short_name) logging.warning(
"job %s is disabled and --force not specified", job_cls.short_name
)
return return
automation.state = AutomationState.RUNNING automation.state = AutomationState.RUNNING
db.session.commit() db.session.commit()
try: try:
if 'TERRAFORM_DIRECTORY' in app.config: if "TERRAFORM_DIRECTORY" in app.config:
working_dir = os.path.join(app.config['TERRAFORM_DIRECTORY'], working_dir = os.path.join(
job_cls.short_name or job_cls.__class__.__name__.lower()) app.config["TERRAFORM_DIRECTORY"],
job_cls.short_name or job_cls.__class__.__name__.lower(),
)
else: else:
working_dir = tempfile.mkdtemp() working_dir = tempfile.mkdtemp()
job: BaseAutomation = job_cls(working_dir) job: BaseAutomation = job_cls(working_dir)
@ -150,8 +151,9 @@ def run_job(job_cls: Type[BaseAutomation], *,
if job is not None and success: if job is not None and success:
automation.state = AutomationState.IDLE automation.state = AutomationState.IDLE
automation.next_run = datetime.now(tz=timezone.utc) + timedelta( automation.next_run = datetime.now(tz=timezone.utc) + timedelta(
minutes=getattr(job, "frequency", 7)) minutes=getattr(job, "frequency", 7)
if 'TERRAFORM_DIRECTORY' not in app.config and working_dir is not None: )
if "TERRAFORM_DIRECTORY" not in app.config and working_dir is not None:
# We used a temporary working directory # We used a temporary working directory
shutil.rmtree(working_dir) shutil.rmtree(working_dir)
else: else:
@ -165,7 +167,7 @@ def run_job(job_cls: Type[BaseAutomation], *,
"list_gitlab", "list_gitlab",
"block_blocky", "block_blocky",
"block_external", "block_external",
"block_ooni" "block_ooni",
] ]
if job.short_name not in safe_jobs: if job.short_name not in safe_jobs:
automation.enabled = False automation.enabled = False
@ -179,10 +181,12 @@ def run_job(job_cls: Type[BaseAutomation], *,
db.session.commit() db.session.commit()
activity = Activity( activity = Activity(
activity_type="automation", activity_type="automation",
text=(f"[{automation.short_name}] 🚨 Automation failure: It was not possible to handle this failure safely " text=(
"and so the automation task has been automatically disabled. It may be possible to simply re-enable " f"[{automation.short_name}] 🚨 Automation failure: It was not possible to handle this failure safely "
"the task, but repeated failures will usually require deeper investigation. See logs for full " "and so the automation task has been automatically disabled. It may be possible to simply re-enable "
"details.") "the task, but repeated failures will usually require deeper investigation. See logs for full "
"details."
),
) )
db.session.add(activity) db.session.add(activity)
activity.notify() # Notify before commit because the failure occurred even if we can't commit. activity.notify() # Notify before commit because the failure occurred even if we can't commit.
@ -194,20 +198,43 @@ class AutomateCliHandler(BaseCliHandler):
@classmethod @classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None: def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("automate", help="automation operations") parser = subparsers.add_parser("automate", help="automation operations")
parser.add_argument("-a", "--all", dest="all", help="run all automation jobs", action="store_true") parser.add_argument(
parser.add_argument("-j", "--job", dest="job", choices=sorted(jobs.keys()), "-a",
help="run a specific automation job") "--all",
parser.add_argument("--force", help="run job even if disabled and it's not time yet", action="store_true") dest="all",
parser.add_argument("--ignore-schedule", help="run job even if it's not time yet", action="store_true") help="run all automation jobs",
action="store_true",
)
parser.add_argument(
"-j",
"--job",
dest="job",
choices=sorted(jobs.keys()),
help="run a specific automation job",
)
parser.add_argument(
"--force",
help="run job even if disabled and it's not time yet",
action="store_true",
)
parser.add_argument(
"--ignore-schedule",
help="run job even if it's not time yet",
action="store_true",
)
parser.set_defaults(cls=cls) parser.set_defaults(cls=cls)
def run(self) -> None: def run(self) -> None:
with app.app_context(): with app.app_context():
if self.args.job: if self.args.job:
run_job(jobs[self.args.job], run_job(
force=self.args.force, jobs[self.args.job],
ignore_schedule=self.args.ignore_schedule) force=self.args.force,
ignore_schedule=self.args.ignore_schedule,
)
elif self.args.all: elif self.args.all:
run_all(force=self.args.force, ignore_schedule=self.args.ignore_schedule) run_all(
force=self.args.force, ignore_schedule=self.args.ignore_schedule
)
else: else:
logging.error("No action requested") logging.error("No action requested")

View file

@ -40,7 +40,7 @@ models: List[Model] = [
Eotk, Eotk,
MirrorList, MirrorList,
TerraformState, TerraformState,
Webhook Webhook,
] ]
@ -53,7 +53,7 @@ class ExportEncoder(json.JSONEncoder):
if isinstance(o, AutomationState): if isinstance(o, AutomationState):
return o.name return o.name
if isinstance(o, bytes): if isinstance(o, bytes):
return base64.encodebytes(o).decode('utf-8') return base64.encodebytes(o).decode("utf-8")
if isinstance(o, (datetime.datetime, datetime.date, datetime.time)): if isinstance(o, (datetime.datetime, datetime.date, datetime.time)):
return o.isoformat() return o.isoformat()
return super().default(o) return super().default(o)
@ -82,7 +82,7 @@ def db_export() -> None:
decoder: Dict[str, Callable[[Any], Any]] = { decoder: Dict[str, Callable[[Any], Any]] = {
"AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x), "AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x),
"AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x), "AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x),
"bytes": lambda x: base64.decodebytes(x.encode('utf-8')), "bytes": lambda x: base64.decodebytes(x.encode("utf-8")),
"datetime": datetime.datetime.fromisoformat, "datetime": datetime.datetime.fromisoformat,
"int": int, "int": int,
"str": lambda x: x, "str": lambda x: x,
@ -110,8 +110,12 @@ class DbCliHandler(BaseCliHandler):
@classmethod @classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None: def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("db", help="database operations") parser = subparsers.add_parser("db", help="database operations")
parser.add_argument("--export", help="export data to JSON format", action="store_true") parser.add_argument(
parser.add_argument("--import", help="import data from JSON format", action="store_true") "--export", help="export data to JSON format", action="store_true"
)
parser.add_argument(
"--import", help="import data from JSON format", action="store_true"
)
parser.set_defaults(cls=cls) parser.set_defaults(cls=cls)
def run(self) -> None: def run(self) -> None:

View file

@ -17,8 +17,9 @@ class ListCliHandler(BaseCliHandler):
@classmethod @classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None: def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("list", help="list operations") parser = subparsers.add_parser("list", help="list operations")
parser.add_argument("--dump", choices=sorted(lists.keys()), parser.add_argument(
help="dump a list in JSON format") "--dump", choices=sorted(lists.keys()), help="dump a list in JSON format"
)
parser.set_defaults(cls=cls) parser.set_defaults(cls=cls)
def run(self) -> None: def run(self) -> None:

View file

@ -4,11 +4,11 @@ from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData from sqlalchemy import MetaData
convention = { convention = {
"ix": 'ix_%(column_0_label)s', "ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s", "uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s" "pk": "pk_%(table_name)s",
} }
metadata = MetaData(naming_convention=convention) metadata = MetaData(naming_convention=convention)

View file

@ -26,13 +26,15 @@ def onion_alternative(origin: Origin) -> List[BC2Alternative]:
url: Optional[str] = origin.onion() url: Optional[str] = origin.onion()
if url is None: if url is None:
return [] return []
return [{ return [
"proto": "tor", {
"type": "eotk", "proto": "tor",
"created_at": str(origin.added), "type": "eotk",
"updated_at": str(origin.updated), "created_at": str(origin.added),
"url": url "updated_at": str(origin.updated),
}] "url": url,
}
]
def proxy_alternative(proxy: Proxy) -> Optional[BC2Alternative]: def proxy_alternative(proxy: Proxy) -> Optional[BC2Alternative]:
@ -43,43 +45,51 @@ def proxy_alternative(proxy: Proxy) -> Optional[BC2Alternative]:
"type": "mirror", "type": "mirror",
"created_at": proxy.added.isoformat(), "created_at": proxy.added.isoformat(),
"updated_at": proxy.updated.isoformat(), "updated_at": proxy.updated.isoformat(),
"url": proxy.url "url": proxy.url,
} }
def main_domain(origin: Origin) -> str: def main_domain(origin: Origin) -> str:
description: str = origin.description description: str = origin.description
if description.startswith("proxy:"): if description.startswith("proxy:"):
return description[len("proxy:"):].replace("www.", "") return description[len("proxy:") :].replace("www.", "")
domain_name: str = origin.domain_name domain_name: str = origin.domain_name
return domain_name.replace("www.", "") return domain_name.replace("www.", "")
def active_proxies(origin: Origin, pool: Pool) -> List[Proxy]: def active_proxies(origin: Origin, pool: Pool) -> List[Proxy]:
return [ return [
proxy for proxy in origin.proxies proxy
if proxy.url is not None and not proxy.deprecated and not proxy.destroyed and proxy.pool_id == pool.id for proxy in origin.proxies
if proxy.url is not None
and not proxy.deprecated
and not proxy.destroyed
and proxy.pool_id == pool.id
] ]
def mirror_sites(pool: Pool) -> BypassCensorship2: def mirror_sites(pool: Pool) -> BypassCensorship2:
origins = Origin.query.filter(Origin.destroyed.is_(None)).order_by(Origin.domain_name).all() origins = (
Origin.query.filter(Origin.destroyed.is_(None))
.order_by(Origin.domain_name)
.all()
)
sites: List[BC2Site] = [] sites: List[BC2Site] = []
for origin in origins: for origin in origins:
# Gather alternatives, filtering out None values from proxy_alternative # Gather alternatives, filtering out None values from proxy_alternative
alternatives = onion_alternative(origin) + [ alternatives = onion_alternative(origin) + [
alt for proxy in active_proxies(origin, pool) alt
for proxy in active_proxies(origin, pool)
if (alt := proxy_alternative(proxy)) is not None if (alt := proxy_alternative(proxy)) is not None
] ]
# Add the site dictionary to the list # Add the site dictionary to the list
sites.append({ sites.append(
"main_domain": main_domain(origin), {
"available_alternatives": list(alternatives) "main_domain": main_domain(origin),
}) "available_alternatives": list(alternatives),
}
)
return { return {"version": "2.0", "sites": sites}
"version": "2.0",
"sites": sites
}

View file

@ -11,12 +11,14 @@ class BridgelinesDict(TypedDict):
bridgelines: List[str] bridgelines: List[str]
def bridgelines(pool: Pool, *, distribution_method: Optional[str] = None) -> BridgelinesDict: def bridgelines(
pool: Pool, *, distribution_method: Optional[str] = None
) -> BridgelinesDict:
# Fetch bridges with selectinload for related data # Fetch bridges with selectinload for related data
query = Bridge.query.options(selectinload(Bridge.conf)).filter( query = Bridge.query.options(selectinload(Bridge.conf)).filter(
Bridge.destroyed.is_(None), Bridge.destroyed.is_(None),
Bridge.deprecated.is_(None), Bridge.deprecated.is_(None),
Bridge.bridgeline.is_not(None) Bridge.bridgeline.is_not(None),
) )
if distribution_method is not None: if distribution_method is not None:
@ -26,7 +28,4 @@ def bridgelines(pool: Pool, *, distribution_method: Optional[str] = None) -> Bri
bridgelines = [b.bridgeline for b in query.all() if b.conf.pool_id == pool.id] bridgelines = [b.bridgeline for b in query.all() if b.conf.pool_id == pool.id]
# Return dictionary directly, inlining the previous `to_dict` functionality # Return dictionary directly, inlining the previous `to_dict` functionality
return { return {"version": "1.0", "bridgelines": bridgelines}
"version": "1.0",
"bridgelines": bridgelines
}

View file

@ -48,7 +48,9 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping:
countries = proxy.origin.risk_level countries = proxy.origin.risk_level
if countries: if countries:
highest_risk_country_code, highest_risk_level = max(countries.items(), key=lambda x: x[1]) highest_risk_country_code, highest_risk_level = max(
countries.items(), key=lambda x: x[1]
)
else: else:
highest_risk_country_code = "ZZ" highest_risk_country_code = "ZZ"
highest_risk_level = 0 highest_risk_level = 0
@ -61,7 +63,7 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping:
"valid_to": proxy.destroyed.isoformat() if proxy.destroyed else None, "valid_to": proxy.destroyed.isoformat() if proxy.destroyed else None,
"countries": countries, "countries": countries,
"country": highest_risk_country_code, "country": highest_risk_country_code,
"risk": highest_risk_level "risk": highest_risk_level,
} }
groups = db.session.query(Group).options(selectinload(Group.pools)) groups = db.session.query(Group).options(selectinload(Group.pools))
@ -70,8 +72,4 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping:
for g in groups.filter(Group.destroyed.is_(None)).all() for g in groups.filter(Group.destroyed.is_(None)).all()
] ]
return { return {"version": "1.2", "mappings": result, "s3_buckets": s3_buckets}
"version": "1.2",
"mappings": result,
"s3_buckets": s3_buckets
}

View file

@ -26,15 +26,17 @@ def redirector_pool_origins(pool: Pool) -> Dict[str, str]:
Proxy.deprecated.is_(None), Proxy.deprecated.is_(None),
Proxy.destroyed.is_(None), Proxy.destroyed.is_(None),
Proxy.url.is_not(None), Proxy.url.is_not(None),
Proxy.pool_id == pool.id Proxy.pool_id == pool.id,
) )
} }
def redirector_data(_: Optional[Pool]) -> RedirectorData: def redirector_data(_: Optional[Pool]) -> RedirectorData:
active_pools = Pool.query.options( active_pools = (
selectinload(Pool.proxies) Pool.query.options(selectinload(Pool.proxies))
).filter(Pool.destroyed.is_(None)).all() .filter(Pool.destroyed.is_(None))
.all()
)
pools: List[RedirectorPool] = [ pools: List[RedirectorPool] = [
{ {
@ -42,12 +44,9 @@ def redirector_data(_: Optional[Pool]) -> RedirectorData:
"description": pool.description, "description": pool.description,
"api_key": pool.api_key, "api_key": pool.api_key,
"redirector_domain": pool.redirector_domain, "redirector_domain": pool.redirector_domain,
"origins": redirector_pool_origins(pool) "origins": redirector_pool_origins(pool),
} }
for pool in active_pools for pool in active_pools
] ]
return { return {"version": "1.0", "pools": pools}
"version": "1.0",
"pools": pools
}

View file

@ -17,7 +17,9 @@ class AbstractConfiguration(db.Model): # type: ignore
description: Mapped[str] description: Mapped[str]
added: Mapped[datetime] = mapped_column(AwareDateTime()) added: Mapped[datetime] = mapped_column(AwareDateTime())
updated: Mapped[datetime] = mapped_column(AwareDateTime()) updated: Mapped[datetime] = mapped_column(AwareDateTime())
destroyed: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) destroyed: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
@property @property
@abstractmethod @abstractmethod
@ -30,14 +32,10 @@ class AbstractConfiguration(db.Model): # type: ignore
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return [ return ["id", "description", "added", "updated", "destroyed"]
"id", "description", "added", "updated", "destroyed"
]
def csv_row(self) -> List[Any]: def csv_row(self) -> List[Any]:
return [ return [getattr(self, x) for x in self.csv_header()]
getattr(self, x) for x in self.csv_header()
]
class Deprecation(db.Model): # type: ignore[name-defined,misc] class Deprecation(db.Model): # type: ignore[name-defined,misc]
@ -51,7 +49,8 @@ class Deprecation(db.Model): # type: ignore[name-defined,misc]
@property @property
def resource(self) -> "AbstractResource": def resource(self) -> "AbstractResource":
from app.models.mirrors import Proxy # pylint: disable=R0401 from app.models.mirrors import Proxy # pylint: disable=R0401
model = {'Proxy': Proxy}[self.resource_type]
model = {"Proxy": Proxy}[self.resource_type]
return model.query.get(self.resource_id) # type: ignore[no-any-return] return model.query.get(self.resource_id) # type: ignore[no-any-return]
@ -61,29 +60,38 @@ class AbstractResource(db.Model): # type: ignore
id: Mapped[int] = mapped_column(db.Integer, primary_key=True) id: Mapped[int] = mapped_column(db.Integer, primary_key=True)
added: Mapped[datetime] = mapped_column(AwareDateTime()) added: Mapped[datetime] = mapped_column(AwareDateTime())
updated: Mapped[datetime] = mapped_column(AwareDateTime()) updated: Mapped[datetime] = mapped_column(AwareDateTime())
deprecated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) deprecated: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
deprecation_reason: Mapped[Optional[str]] deprecation_reason: Mapped[Optional[str]]
destroyed: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) destroyed: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
def __init__(self, *, def __init__(
id: Optional[int] = None, self,
added: Optional[datetime] = None, *,
updated: Optional[datetime] = None, id: Optional[int] = None,
deprecated: Optional[datetime] = None, added: Optional[datetime] = None,
deprecation_reason: Optional[str] = None, updated: Optional[datetime] = None,
destroyed: Optional[datetime] = None, deprecated: Optional[datetime] = None,
**kwargs: Any) -> None: deprecation_reason: Optional[str] = None,
destroyed: Optional[datetime] = None,
**kwargs: Any
) -> None:
if added is None: if added is None:
added = datetime.now(tz=timezone.utc) added = datetime.now(tz=timezone.utc)
if updated is None: if updated is None:
updated = datetime.now(tz=timezone.utc) updated = datetime.now(tz=timezone.utc)
super().__init__(id=id, super().__init__(
added=added, id=id,
updated=updated, added=added,
deprecated=deprecated, updated=updated,
deprecation_reason=deprecation_reason, deprecated=deprecated,
destroyed=destroyed, deprecation_reason=deprecation_reason,
**kwargs) destroyed=destroyed,
**kwargs
)
@property @property
@abstractmethod @abstractmethod
@ -110,19 +118,21 @@ class AbstractResource(db.Model): # type: ignore
resource_type=type(self).__name__, resource_type=type(self).__name__,
resource_id=self.id, resource_id=self.id,
reason=reason, reason=reason,
meta=meta meta=meta,
) )
db.session.add(new_deprecation) db.session.add(new_deprecation)
return True return True
logging.info("Not deprecating %s (reason=%s) because it's already deprecated with that reason.", logging.info(
self.brn, reason) "Not deprecating %s (reason=%s) because it's already deprecated with that reason.",
self.brn,
reason,
)
return False return False
@property @property
def deprecations(self) -> List[Deprecation]: def deprecations(self) -> List[Deprecation]:
return Deprecation.query.filter_by( # type: ignore[no-any-return] return Deprecation.query.filter_by( # type: ignore[no-any-return]
resource_type='Proxy', resource_type="Proxy", resource_id=self.id
resource_id=self.id
).all() ).all()
def destroy(self) -> None: def destroy(self) -> None:
@ -139,10 +149,13 @@ class AbstractResource(db.Model): # type: ignore
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return [ return [
"id", "added", "updated", "deprecated", "deprecation_reason", "destroyed" "id",
"added",
"updated",
"deprecated",
"deprecation_reason",
"destroyed",
] ]
def csv_row(self) -> List[Union[datetime, bool, int, str]]: def csv_row(self) -> List[Union[datetime, bool, int, str]]:
return [ return [getattr(self, x) for x in self.csv_header()]
getattr(self, x) for x in self.csv_header()
]

View file

@ -17,31 +17,40 @@ class Activity(db.Model): # type: ignore
text: Mapped[str] text: Mapped[str]
added: Mapped[datetime] = mapped_column(AwareDateTime()) added: Mapped[datetime] = mapped_column(AwareDateTime())
def __init__(self, *, def __init__(
id: Optional[int] = None, self,
group_id: Optional[int] = None, *,
activity_type: str, id: Optional[int] = None,
text: str, group_id: Optional[int] = None,
added: Optional[datetime] = None, activity_type: str,
**kwargs: Any) -> None: text: str,
if not isinstance(activity_type, str) or len(activity_type) > 20 or activity_type == "": added: Optional[datetime] = None,
raise TypeError("expected string for activity type between 1 and 20 characters") **kwargs: Any
) -> None:
if (
not isinstance(activity_type, str)
or len(activity_type) > 20
or activity_type == ""
):
raise TypeError(
"expected string for activity type between 1 and 20 characters"
)
if not isinstance(text, str): if not isinstance(text, str):
raise TypeError("expected string for text") raise TypeError("expected string for text")
if added is None: if added is None:
added = datetime.now(tz=timezone.utc) added = datetime.now(tz=timezone.utc)
super().__init__(id=id, super().__init__(
group_id=group_id, id=id,
activity_type=activity_type, group_id=group_id,
text=text, activity_type=activity_type,
added=added, text=text,
**kwargs) added=added,
**kwargs
)
def notify(self) -> int: def notify(self) -> int:
count = 0 count = 0
hooks = Webhook.query.filter( hooks = Webhook.query.filter(Webhook.destroyed.is_(None))
Webhook.destroyed.is_(None)
)
for hook in hooks: for hook in hooks:
hook.send(self.text) hook.send(self.text)
count += 1 count += 1
@ -59,7 +68,7 @@ class Webhook(AbstractConfiguration):
product="notify", product="notify",
provider=self.format, provider=self.format,
resource_type="conf", resource_type="conf",
resource_id=str(self.id) resource_id=str(self.id),
) )
def send(self, text: str) -> None: def send(self, text: str) -> None:

View file

@ -37,7 +37,15 @@ class Alarm(db.Model): # type: ignore
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return ["id", "target", "alarm_type", "alarm_state", "state_changed", "last_updated", "text"] return [
"id",
"target",
"alarm_type",
"alarm_state",
"state_changed",
"last_updated",
"text",
]
def csv_row(self) -> List[Any]: def csv_row(self) -> List[Any]:
return [getattr(self, x) for x in self.csv_header()] return [getattr(self, x) for x in self.csv_header()]
@ -45,11 +53,15 @@ class Alarm(db.Model): # type: ignore
def update_state(self, state: AlarmState, text: str) -> None: def update_state(self, state: AlarmState, text: str) -> None:
if self.alarm_state != state or self.state_changed is None: if self.alarm_state != state or self.state_changed is None:
self.state_changed = datetime.now(tz=timezone.utc) self.state_changed = datetime.now(tz=timezone.utc)
activity = Activity(activity_type="alarm_state", activity = Activity(
text=f"[{self.aspect}] {state.emoji} Alarm state changed from " activity_type="alarm_state",
f"{self.alarm_state.name} to {state.name} on {self.target}: {text}.") text=f"[{self.aspect}] {state.emoji} Alarm state changed from "
if (self.alarm_state.name in ["WARNING", "CRITICAL"] f"{self.alarm_state.name} to {state.name} on {self.target}: {text}.",
or state.name in ["WARNING", "CRITICAL"]): )
if self.alarm_state.name in ["WARNING", "CRITICAL"] or state.name in [
"WARNING",
"CRITICAL",
]:
# Notifications are only sent on recovery from warning/critical state or on entry # Notifications are only sent on recovery from warning/critical state or on entry
# to warning/critical states. This should reduce alert fatigue. # to warning/critical states. This should reduce alert fatigue.
activity.notify() activity.notify()

View file

@ -33,7 +33,7 @@ class Automation(AbstractConfiguration):
product="core", product="core",
provider="", provider="",
resource_type="automation", resource_type="automation",
resource_id=self.short_name resource_id=self.short_name,
) )
def kick(self) -> None: def kick(self) -> None:
@ -55,5 +55,5 @@ class AutomationLogs(AbstractResource):
product="core", product="core",
provider="", provider="",
resource_type="automationlog", resource_type="automationlog",
resource_id=str(self.id) resource_id=str(self.id),
) )

View file

@ -26,17 +26,21 @@ class Group(AbstractConfiguration):
eotk: Mapped[bool] eotk: Mapped[bool]
origins: Mapped[List["Origin"]] = relationship("Origin", back_populates="group") origins: Mapped[List["Origin"]] = relationship("Origin", back_populates="group")
statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="group") statics: Mapped[List["StaticOrigin"]] = relationship(
"StaticOrigin", back_populates="group"
)
eotks: Mapped[List["Eotk"]] = relationship("Eotk", back_populates="group") eotks: Mapped[List["Eotk"]] = relationship("Eotk", back_populates="group")
onions: Mapped[List["Onion"]] = relationship("Onion", back_populates="group") onions: Mapped[List["Onion"]] = relationship("Onion", back_populates="group")
smart_proxies: Mapped[List["SmartProxy"]] = relationship("SmartProxy", back_populates="group") smart_proxies: Mapped[List["SmartProxy"]] = relationship(
pools: Mapped[List["Pool"]] = relationship("Pool", secondary="pool_group", back_populates="groups") "SmartProxy", back_populates="group"
)
pools: Mapped[List["Pool"]] = relationship(
"Pool", secondary="pool_group", back_populates="groups"
)
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return super().csv_header() + [ return super().csv_header() + ["group_name", "eotk"]
"group_name", "eotk"
]
@property @property
def brn(self) -> BRN: def brn(self) -> BRN:
@ -45,16 +49,15 @@ class Group(AbstractConfiguration):
product="group", product="group",
provider="", provider="",
resource_type="group", resource_type="group",
resource_id=str(self.id) resource_id=str(self.id),
) )
def to_dict(self) -> GroupDict: def to_dict(self) -> GroupDict:
if not TYPE_CHECKING: if not TYPE_CHECKING:
from app.models.mirrors import Origin # to prevent circular import from app.models.mirrors import Origin # to prevent circular import
active_origins_query = ( active_origins_query = db.session.query(aliased(Origin)).filter(
db.session.query(aliased(Origin)) and_(Origin.group_id == self.id, Origin.destroyed.is_(None))
.filter(and_(Origin.group_id == self.id, Origin.destroyed.is_(None)))
) )
active_origins_count = active_origins_query.count() active_origins_count = active_origins_query.count()
return { return {
@ -70,16 +73,20 @@ class Pool(AbstractConfiguration):
api_key: Mapped[str] api_key: Mapped[str]
redirector_domain: Mapped[Optional[str]] redirector_domain: Mapped[Optional[str]]
bridgeconfs: Mapped[List["BridgeConf"]] = relationship("BridgeConf", back_populates="pool") bridgeconfs: Mapped[List["BridgeConf"]] = relationship(
"BridgeConf", back_populates="pool"
)
proxies: Mapped[List["Proxy"]] = relationship("Proxy", back_populates="pool") proxies: Mapped[List["Proxy"]] = relationship("Proxy", back_populates="pool")
lists: Mapped[List["MirrorList"]] = relationship("MirrorList", back_populates="pool") lists: Mapped[List["MirrorList"]] = relationship(
groups: Mapped[List[Group]] = relationship("Group", secondary="pool_group", back_populates="pools") "MirrorList", back_populates="pool"
)
groups: Mapped[List[Group]] = relationship(
"Group", secondary="pool_group", back_populates="pools"
)
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return super().csv_header() + [ return super().csv_header() + ["pool_name"]
"pool_name"
]
@property @property
def brn(self) -> BRN: def brn(self) -> BRN:
@ -88,7 +95,7 @@ class Pool(AbstractConfiguration):
product="pool", product="pool",
provider="", provider="",
resource_type="pool", resource_type="pool",
resource_id=str(self.pool_name) resource_id=str(self.pool_name),
) )
@ -121,14 +128,14 @@ class MirrorList(AbstractConfiguration):
"bc3": "Bypass Censorship v3", "bc3": "Bypass Censorship v3",
"bca": "Bypass Censorship Analytics", "bca": "Bypass Censorship Analytics",
"bridgelines": "Tor Bridge Lines", "bridgelines": "Tor Bridge Lines",
"rdr": "Redirector Data" "rdr": "Redirector Data",
} }
encodings_supported = { encodings_supported = {
"json": "JSON (Plain)", "json": "JSON (Plain)",
"jsno": "JSON (Obfuscated)", "jsno": "JSON (Obfuscated)",
"js": "JavaScript (Plain)", "js": "JavaScript (Plain)",
"jso": "JavaScript (Obfuscated)" "jso": "JavaScript (Obfuscated)",
} }
def destroy(self) -> None: def destroy(self) -> None:
@ -149,7 +156,11 @@ class MirrorList(AbstractConfiguration):
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return super().csv_header() + [ return super().csv_header() + [
"provider", "format", "container", "branch", "filename" "provider",
"format",
"container",
"branch",
"filename",
] ]
@property @property
@ -159,5 +170,5 @@ class MirrorList(AbstractConfiguration):
product="list", product="list",
provider=self.provider, provider=self.provider,
resource_type="list", resource_type="list",
resource_id=str(self.id) resource_id=str(self.id),
) )

View file

@ -34,7 +34,7 @@ class BridgeConf(AbstractConfiguration):
product="bridge", product="bridge",
provider="", provider="",
resource_type="bridgeconf", resource_type="bridgeconf",
resource_id=str(self.id) resource_id=str(self.id),
) )
def destroy(self) -> None: def destroy(self) -> None:
@ -48,14 +48,22 @@ class BridgeConf(AbstractConfiguration):
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return super().csv_header() + [ return super().csv_header() + [
"pool_id", "provider", "method", "description", "target_number", "max_number", "expiry_hours" "pool_id",
"provider",
"method",
"description",
"target_number",
"max_number",
"expiry_hours",
] ]
class Bridge(AbstractResource): class Bridge(AbstractResource):
conf_id: Mapped[int] = mapped_column(db.ForeignKey("bridge_conf.id")) conf_id: Mapped[int] = mapped_column(db.ForeignKey("bridge_conf.id"))
cloud_account_id: Mapped[int] = mapped_column(db.ForeignKey("cloud_account.id")) cloud_account_id: Mapped[int] = mapped_column(db.ForeignKey("cloud_account.id"))
terraform_updated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) terraform_updated: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
nickname: Mapped[Optional[str]] nickname: Mapped[Optional[str]]
fingerprint: Mapped[Optional[str]] fingerprint: Mapped[Optional[str]]
hashed_fingerprint: Mapped[Optional[str]] hashed_fingerprint: Mapped[Optional[str]]
@ -71,11 +79,16 @@ class Bridge(AbstractResource):
product="bridge", product="bridge",
provider=self.cloud_account.provider.key, provider=self.cloud_account.provider.key,
resource_type="bridge", resource_type="bridge",
resource_id=str(self.id) resource_id=str(self.id),
) )
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return super().csv_header() + [ return super().csv_header() + [
"conf_id", "terraform_updated", "nickname", "fingerprint", "hashed_fingerprint", "bridgeline" "conf_id",
"terraform_updated",
"nickname",
"fingerprint",
"hashed_fingerprint",
"bridgeline",
] ]

View file

@ -42,9 +42,14 @@ class CloudAccount(AbstractConfiguration):
# Compute Quotas # Compute Quotas
max_instances: Mapped[int] max_instances: Mapped[int]
bridges: Mapped[List["Bridge"]] = relationship("Bridge", back_populates="cloud_account") bridges: Mapped[List["Bridge"]] = relationship(
statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="storage_cloud_account", foreign_keys=[ "Bridge", back_populates="cloud_account"
StaticOrigin.storage_cloud_account_id]) )
statics: Mapped[List["StaticOrigin"]] = relationship(
"StaticOrigin",
back_populates="storage_cloud_account",
foreign_keys=[StaticOrigin.storage_cloud_account_id],
)
@property @property
def brn(self) -> BRN: def brn(self) -> BRN:

View file

@ -10,8 +10,7 @@ from tldextract import extract
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from app.brm.brn import BRN from app.brm.brn import BRN
from app.brm.utils import (create_data_uri, normalize_color, from app.brm.utils import create_data_uri, normalize_color, thumbnail_uploaded_image
thumbnail_uploaded_image)
from app.extensions import db from app.extensions import db
from app.models import AbstractConfiguration, AbstractResource, Deprecation from app.models import AbstractConfiguration, AbstractResource, Deprecation
from app.models.base import Group, Pool from app.models.base import Group, Pool
@ -19,10 +18,10 @@ from app.models.onions import Onion
from app.models.types import AwareDateTime from app.models.types import AwareDateTime
country_origin = db.Table( country_origin = db.Table(
'country_origin', "country_origin",
db.metadata, db.metadata,
db.Column('country_id', db.ForeignKey('country.id'), primary_key=True), db.Column("country_id", db.ForeignKey("country.id"), primary_key=True),
db.Column('origin_id', db.ForeignKey('origin.id'), primary_key=True), db.Column("origin_id", db.ForeignKey("origin.id"), primary_key=True),
extend_existing=True, extend_existing=True,
) )
@ -45,7 +44,9 @@ class Origin(AbstractConfiguration):
group: Mapped[Group] = relationship("Group", back_populates="origins") group: Mapped[Group] = relationship("Group", back_populates="origins")
proxies: Mapped[List[Proxy]] = relationship("Proxy", back_populates="origin") proxies: Mapped[List[Proxy]] = relationship("Proxy", back_populates="origin")
countries: Mapped[List[Country]] = relationship("Country", secondary=country_origin, back_populates='origins') countries: Mapped[List[Country]] = relationship(
"Country", secondary=country_origin, back_populates="origins"
)
@property @property
def brn(self) -> BRN: def brn(self) -> BRN:
@ -54,13 +55,18 @@ class Origin(AbstractConfiguration):
product="mirror", product="mirror",
provider="conf", provider="conf",
resource_type="origin", resource_type="origin",
resource_id=self.domain_name resource_id=self.domain_name,
) )
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return super().csv_header() + [ return super().csv_header() + [
"group_id", "domain_name", "auto_rotation", "smart", "assets", "country" "group_id",
"domain_name",
"auto_rotation",
"smart",
"assets",
"country",
] ]
def destroy(self) -> None: def destroy(self) -> None:
@ -84,30 +90,41 @@ class Origin(AbstractConfiguration):
@property @property
def risk_level(self) -> Dict[str, int]: def risk_level(self) -> Dict[str, int]:
if self.risk_level_override: if self.risk_level_override:
return {country.country_code: self.risk_level_override for country in self.countries} return {
country.country_code: self.risk_level_override
for country in self.countries
}
frequency_factor = 0.0 frequency_factor = 0.0
recency_factor = 0.0 recency_factor = 0.0
recent_deprecations = ( recent_deprecations = (
db.session.query(Deprecation) db.session.query(Deprecation)
.join(Proxy, .join(Proxy, Deprecation.resource_id == Proxy.id)
Deprecation.resource_id == Proxy.id)
.join(Origin, Origin.id == Proxy.origin_id) .join(Origin, Origin.id == Proxy.origin_id)
.filter( .filter(
Origin.id == self.id, Origin.id == self.id,
Deprecation.resource_type == 'Proxy', Deprecation.resource_type == "Proxy",
Deprecation.deprecated_at >= datetime.now(tz=timezone.utc) - timedelta(hours=168), Deprecation.deprecated_at
Deprecation.reason != "destroyed" >= datetime.now(tz=timezone.utc) - timedelta(hours=168),
Deprecation.reason != "destroyed",
) )
.distinct(Proxy.id) .distinct(Proxy.id)
.all() .all()
) )
for deprecation in recent_deprecations: for deprecation in recent_deprecations:
recency_factor += 1 / max((datetime.now(tz=timezone.utc) - deprecation.deprecated_at).total_seconds() // 3600, 1) recency_factor += 1 / max(
(
datetime.now(tz=timezone.utc) - deprecation.deprecated_at
).total_seconds()
// 3600,
1,
)
frequency_factor += 1 frequency_factor += 1
risk_levels: Dict[str, int] = {} risk_levels: Dict[str, int] = {}
for country in self.countries: for country in self.countries:
risk_levels[country.country_code.upper()] = int( risk_levels[country.country_code.upper()] = (
max(1, min(10, frequency_factor * recency_factor))) + country.risk_level int(max(1, min(10, frequency_factor * recency_factor)))
+ country.risk_level
)
return risk_levels return risk_levels
def to_dict(self) -> OriginDict: def to_dict(self) -> OriginDict:
@ -128,13 +145,15 @@ class Country(AbstractConfiguration):
product="country", product="country",
provider="iso3166-1", provider="iso3166-1",
resource_type="alpha2", resource_type="alpha2",
resource_id=self.country_code resource_id=self.country_code,
) )
country_code: Mapped[str] country_code: Mapped[str]
risk_level_override: Mapped[Optional[int]] risk_level_override: Mapped[Optional[int]]
origins = db.relationship("Origin", secondary=country_origin, back_populates='countries') origins = db.relationship(
"Origin", secondary=country_origin, back_populates="countries"
)
@property @property
def risk_level(self) -> int: def risk_level(self) -> int:
@ -144,29 +163,39 @@ class Country(AbstractConfiguration):
recency_factor = 0.0 recency_factor = 0.0
recent_deprecations = ( recent_deprecations = (
db.session.query(Deprecation) db.session.query(Deprecation)
.join(Proxy, .join(Proxy, Deprecation.resource_id == Proxy.id)
Deprecation.resource_id == Proxy.id)
.join(Origin, Origin.id == Proxy.origin_id) .join(Origin, Origin.id == Proxy.origin_id)
.join(Origin.countries) .join(Origin.countries)
.filter( .filter(
Country.id == self.id, Country.id == self.id,
Deprecation.resource_type == 'Proxy', Deprecation.resource_type == "Proxy",
Deprecation.deprecated_at >= datetime.now(tz=timezone.utc) - timedelta(hours=168), Deprecation.deprecated_at
Deprecation.reason != "destroyed" >= datetime.now(tz=timezone.utc) - timedelta(hours=168),
Deprecation.reason != "destroyed",
) )
.distinct(Proxy.id) .distinct(Proxy.id)
.all() .all()
) )
for deprecation in recent_deprecations: for deprecation in recent_deprecations:
recency_factor += 1 / max((datetime.now(tz=timezone.utc) - deprecation.deprecated_at).total_seconds() // 3600, 1) recency_factor += 1 / max(
(
datetime.now(tz=timezone.utc) - deprecation.deprecated_at
).total_seconds()
// 3600,
1,
)
frequency_factor += 1 frequency_factor += 1
return int(max(1, min(10, frequency_factor * recency_factor))) return int(max(1, min(10, frequency_factor * recency_factor)))
class StaticOrigin(AbstractConfiguration): class StaticOrigin(AbstractConfiguration):
group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False) group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False)
storage_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) storage_cloud_account_id = mapped_column(
source_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False) db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False
)
source_cloud_account_id = mapped_column(
db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False
)
source_project = mapped_column(db.String(255), nullable=False) source_project = mapped_column(db.String(255), nullable=False)
auto_rotate = mapped_column(db.Boolean, nullable=False) auto_rotate = mapped_column(db.Boolean, nullable=False)
matrix_homeserver = mapped_column(db.String(255), nullable=True) matrix_homeserver = mapped_column(db.String(255), nullable=True)
@ -182,30 +211,34 @@ class StaticOrigin(AbstractConfiguration):
product="mirror", product="mirror",
provider="aws", provider="aws",
resource_type="static", resource_type="static",
resource_id=self.domain_name resource_id=self.domain_name,
) )
group = db.relationship("Group", back_populates="statics") group = db.relationship("Group", back_populates="statics")
storage_cloud_account = db.relationship("CloudAccount", back_populates="statics", storage_cloud_account = db.relationship(
foreign_keys=[storage_cloud_account_id]) "CloudAccount",
source_cloud_account = db.relationship("CloudAccount", back_populates="statics", back_populates="statics",
foreign_keys=[source_cloud_account_id]) foreign_keys=[storage_cloud_account_id],
)
source_cloud_account = db.relationship(
"CloudAccount", back_populates="statics", foreign_keys=[source_cloud_account_id]
)
def destroy(self) -> None: def destroy(self) -> None:
# TODO: The StaticMetaAutomation will clean up for now, but it should probably happen here for consistency # TODO: The StaticMetaAutomation will clean up for now, but it should probably happen here for consistency
super().destroy() super().destroy()
def update( def update(
self, self,
source_project: str, source_project: str,
description: str, description: str,
auto_rotate: bool, auto_rotate: bool,
matrix_homeserver: Optional[str], matrix_homeserver: Optional[str],
keanu_convene_path: Optional[str], keanu_convene_path: Optional[str],
keanu_convene_logo: Optional[FileStorage], keanu_convene_logo: Optional[FileStorage],
keanu_convene_color: Optional[str], keanu_convene_color: Optional[str],
clean_insights_backend: Optional[Union[str, bool]], clean_insights_backend: Optional[Union[str, bool]],
db_session_commit: bool, db_session_commit: bool,
) -> None: ) -> None:
if isinstance(source_project, str): if isinstance(source_project, str):
self.source_project = source_project self.source_project = source_project
@ -235,19 +268,29 @@ class StaticOrigin(AbstractConfiguration):
elif isinstance(keanu_convene_logo, FileStorage): elif isinstance(keanu_convene_logo, FileStorage):
if keanu_convene_logo.filename: # if False, no file was uploaded if keanu_convene_logo.filename: # if False, no file was uploaded
keanu_convene_config["logo"] = create_data_uri( keanu_convene_config["logo"] = create_data_uri(
thumbnail_uploaded_image(keanu_convene_logo), keanu_convene_logo.filename) thumbnail_uploaded_image(keanu_convene_logo),
keanu_convene_logo.filename,
)
else: else:
raise ValueError("keanu_convene_logo must be a FileStorage") raise ValueError("keanu_convene_logo must be a FileStorage")
try: try:
if isinstance(keanu_convene_color, str): if isinstance(keanu_convene_color, str):
keanu_convene_config["color"] = normalize_color(keanu_convene_color) # can raise ValueError keanu_convene_config["color"] = normalize_color(
keanu_convene_color
) # can raise ValueError
else: else:
raise ValueError() # re-raised below with message raise ValueError() # re-raised below with message
except ValueError: except ValueError:
raise ValueError("keanu_convene_path must be a str containing an HTML color (CSS name or hex)") raise ValueError(
self.keanu_convene_config = json.dumps(keanu_convene_config, separators=(',', ':')) "keanu_convene_path must be a str containing an HTML color (CSS name or hex)"
)
self.keanu_convene_config = json.dumps(
keanu_convene_config, separators=(",", ":")
)
del keanu_convene_config # done with this temporary variable del keanu_convene_config # done with this temporary variable
if clean_insights_backend is None or (isinstance(clean_insights_backend, bool) and not clean_insights_backend): if clean_insights_backend is None or (
isinstance(clean_insights_backend, bool) and not clean_insights_backend
):
self.clean_insights_backend = None self.clean_insights_backend = None
elif isinstance(clean_insights_backend, bool) and clean_insights_backend: elif isinstance(clean_insights_backend, bool) and clean_insights_backend:
self.clean_insights_backend = "metrics.cleaninsights.org" self.clean_insights_backend = "metrics.cleaninsights.org"
@ -260,7 +303,9 @@ class StaticOrigin(AbstractConfiguration):
self.updated = datetime.now(tz=timezone.utc) self.updated = datetime.now(tz=timezone.utc)
ResourceStatus = Union[Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"]] ResourceStatus = Union[
Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"]
]
class ProxyDict(TypedDict): class ProxyDict(TypedDict):
@ -271,12 +316,16 @@ class ProxyDict(TypedDict):
class Proxy(AbstractResource): class Proxy(AbstractResource):
origin_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("origin.id"), nullable=False) origin_id: Mapped[int] = mapped_column(
db.Integer, db.ForeignKey("origin.id"), nullable=False
)
pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id")) pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id"))
provider: Mapped[str] = mapped_column(db.String(20), nullable=False) provider: Mapped[str] = mapped_column(db.String(20), nullable=False)
psg: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) psg: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
slug: Mapped[Optional[str]] = mapped_column(db.String(20), nullable=True) slug: Mapped[Optional[str]] = mapped_column(db.String(20), nullable=True)
terraform_updated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True) terraform_updated: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
origin: Mapped[Origin] = relationship("Origin", back_populates="proxies") origin: Mapped[Origin] = relationship("Origin", back_populates="proxies")
@ -289,13 +338,18 @@ class Proxy(AbstractResource):
product="mirror", product="mirror",
provider=self.provider, provider=self.provider,
resource_type="proxy", resource_type="proxy",
resource_id=str(self.id) resource_id=str(self.id),
) )
@classmethod @classmethod
def csv_header(cls) -> List[str]: def csv_header(cls) -> List[str]:
return super().csv_header() + [ return super().csv_header() + [
"origin_id", "provider", "psg", "slug", "terraform_updated", "url" "origin_id",
"provider",
"psg",
"slug",
"terraform_updated",
"url",
] ]
def to_dict(self) -> ProxyDict: def to_dict(self) -> ProxyDict:
@ -329,5 +383,5 @@ class SmartProxy(AbstractResource):
product="mirror", product="mirror",
provider=self.provider, provider=self.provider,
resource_type="smart_proxy", resource_type="smart_proxy",
resource_id=str(1) resource_id=str(1),
) )

View file

@ -32,7 +32,7 @@ class Onion(AbstractConfiguration):
product="eotk", product="eotk",
provider="*", provider="*",
resource_type="onion", resource_type="onion",
resource_id=self.onion_name resource_id=self.onion_name,
) )
group_id: Mapped[int] = mapped_column(db.ForeignKey("group.id")) group_id: Mapped[int] = mapped_column(db.ForeignKey("group.id"))
@ -80,5 +80,5 @@ class Eotk(AbstractResource):
provider=self.provider, provider=self.provider,
product="eotk", product="eotk",
resource_type="instance", resource_type="instance",
resource_id=self.region resource_id=self.region,
) )

View file

@ -32,7 +32,9 @@ from app.portal.static import bp as static
from app.portal.storage import bp as storage from app.portal.storage import bp as storage
from app.portal.webhook import bp as webhook from app.portal.webhook import bp as webhook
portal = Blueprint("portal", __name__, template_folder="templates", static_folder="static") portal = Blueprint(
"portal", __name__, template_folder="templates", static_folder="static"
)
portal.register_blueprint(automation, url_prefix="/automation") portal.register_blueprint(automation, url_prefix="/automation")
portal.register_blueprint(bridgeconf, url_prefix="/bridgeconf") portal.register_blueprint(bridgeconf, url_prefix="/bridgeconf")
portal.register_blueprint(bridge, url_prefix="/bridge") portal.register_blueprint(bridge, url_prefix="/bridge")
@ -54,7 +56,10 @@ portal.register_blueprint(webhook, url_prefix="/webhook")
@portal.app_template_filter("bridge_expiry") @portal.app_template_filter("bridge_expiry")
def calculate_bridge_expiry(b: Bridge) -> str: def calculate_bridge_expiry(b: Bridge) -> str:
if b.deprecated is None: if b.deprecated is None:
logging.warning("Bridge expiry requested by template for a bridge %s that was not expiring.", b.id) logging.warning(
"Bridge expiry requested by template for a bridge %s that was not expiring.",
b.id,
)
return "Not expiring" return "Not expiring"
expiry = b.deprecated + timedelta(hours=b.conf.expiry_hours) expiry = b.deprecated + timedelta(hours=b.conf.expiry_hours)
countdown = expiry - datetime.now(tz=timezone.utc) countdown = expiry - datetime.now(tz=timezone.utc)
@ -85,27 +90,27 @@ def describe_brn(s: str) -> ResponseReturnValue:
if parts[3] == "mirror": if parts[3] == "mirror":
if parts[5].startswith("origin/"): if parts[5].startswith("origin/"):
origin = Origin.query.filter( origin = Origin.query.filter(
Origin.domain_name == parts[5][len("origin/"):] Origin.domain_name == parts[5][len("origin/") :]
).first() ).first()
if not origin: if not origin:
return s return s
return f"Origin: {origin.domain_name} ({origin.group.group_name})" return f"Origin: {origin.domain_name} ({origin.group.group_name})"
if parts[5].startswith("proxy/"): if parts[5].startswith("proxy/"):
proxy = Proxy.query.filter( proxy = Proxy.query.filter(
Proxy.id == int(parts[5][len("proxy/"):]) Proxy.id == int(parts[5][len("proxy/") :])
).first() ).first()
if not proxy: if not proxy:
return s return s
return Markup( return Markup(
f"Proxy: {proxy.url}<br>({proxy.origin.group.group_name}: {proxy.origin.domain_name})") f"Proxy: {proxy.url}<br>({proxy.origin.group.group_name}: {proxy.origin.domain_name})"
)
if parts[5].startswith("quota/"): if parts[5].startswith("quota/"):
if parts[4] == "cloudfront": if parts[4] == "cloudfront":
return f"Quota: CloudFront {parts[5][len('quota/'):]}" return f"Quota: CloudFront {parts[5][len('quota/'):]}"
if parts[3] == "eotk": if parts[3] == "eotk":
if parts[5].startswith("instance/"): if parts[5].startswith("instance/"):
eotk = Eotk.query.filter( eotk = Eotk.query.filter(
Eotk.group_id == parts[2], Eotk.group_id == parts[2], Eotk.region == parts[5][len("instance/") :]
Eotk.region == parts[5][len("instance/"):]
).first() ).first()
if not eotk: if not eotk:
return s return s
@ -138,9 +143,16 @@ def portal_home() -> ResponseReturnValue:
proxies = Proxy.query.filter(Proxy.destroyed.is_(None)).all() proxies = Proxy.query.filter(Proxy.destroyed.is_(None)).all()
last24 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=1))).all()) last24 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=1))).all())
last72 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=3))).all()) last72 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=3))).all())
lastweek = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=7))).all()) lastweek = len(
Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=7))).all()
)
alarms = { alarms = {
s: len(Alarm.query.filter(Alarm.alarm_state == s.upper(), Alarm.last_updated > (now - timedelta(days=1))).all()) s: len(
Alarm.query.filter(
Alarm.alarm_state == s.upper(),
Alarm.last_updated > (now - timedelta(days=1)),
).all()
)
for s in ["critical", "warning", "ok", "unknown"] for s in ["critical", "warning", "ok", "unknown"]
} }
bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all() bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all()
@ -148,13 +160,36 @@ def portal_home() -> ResponseReturnValue:
d: len(Bridge.query.filter(Bridge.deprecated > (now - timedelta(days=d))).all()) d: len(Bridge.query.filter(Bridge.deprecated > (now - timedelta(days=d))).all())
for d in [1, 3, 7] for d in [1, 3, 7]
} }
activity = Activity.query.filter(Activity.added > (now - timedelta(days=2))).order_by(desc(Activity.added)).all() activity = (
onionified = len([o for o in Origin.query.filter(Origin.destroyed.is_(None)).all() if o.onion() is not None]) Activity.query.filter(Activity.added > (now - timedelta(days=2)))
.order_by(desc(Activity.added))
.all()
)
onionified = len(
[
o
for o in Origin.query.filter(Origin.destroyed.is_(None)).all()
if o.onion() is not None
]
)
ooni_blocked = total_origins_blocked() ooni_blocked = total_origins_blocked()
total_origins = len(Origin.query.filter(Origin.destroyed.is_(None)).all()) total_origins = len(Origin.query.filter(Origin.destroyed.is_(None)).all())
return render_template("home.html.j2", section="home", groups=groups, last24=last24, last72=last72, return render_template(
lastweek=lastweek, proxies=proxies, **alarms, activity=activity, total_origins=total_origins, "home.html.j2",
onionified=onionified, br_last=br_last, ooni_blocked=ooni_blocked, bridges=bridges) section="home",
groups=groups,
last24=last24,
last72=last72,
lastweek=lastweek,
proxies=proxies,
**alarms,
activity=activity,
total_origins=total_origins,
onionified=onionified,
br_last=br_last,
ooni_blocked=ooni_blocked,
bridges=bridges,
)
@portal.route("/search") @portal.route("/search")
@ -163,19 +198,27 @@ def search() -> ResponseReturnValue:
if query is None: if query is None:
return redirect(url_for("portal.portal_home")) return redirect(url_for("portal.portal_home"))
proxies = Proxy.query.filter( proxies = Proxy.query.filter(
or_(func.lower(Proxy.url).contains(query.lower())), Proxy.destroyed.is_(None)).all() or_(func.lower(Proxy.url).contains(query.lower())), Proxy.destroyed.is_(None)
).all()
origins = Origin.query.filter( origins = Origin.query.filter(
or_(func.lower(Origin.description).contains(query.lower()), or_(
func.lower(Origin.domain_name).contains(query.lower()))).all() func.lower(Origin.description).contains(query.lower()),
return render_template("search.html.j2", section="home", proxies=proxies, origins=origins) func.lower(Origin.domain_name).contains(query.lower()),
)
).all()
return render_template(
"search.html.j2", section="home", proxies=proxies, origins=origins
)
@portal.route('/alarms') @portal.route("/alarms")
def view_alarms() -> ResponseReturnValue: def view_alarms() -> ResponseReturnValue:
one_day_ago = datetime.now(timezone.utc) - timedelta(days=1) one_day_ago = datetime.now(timezone.utc) - timedelta(days=1)
alarms = Alarm.query.filter(Alarm.last_updated >= one_day_ago).order_by( alarms = (
desc(Alarm.alarm_state), desc(Alarm.state_changed)).all() Alarm.query.filter(Alarm.last_updated >= one_day_ago)
return render_template("list.html.j2", .order_by(desc(Alarm.alarm_state), desc(Alarm.state_changed))
section="alarm", .all()
title="Alarms", )
items=alarms) return render_template(
"list.html.j2", section="alarm", title="Alarms", items=alarms
)

View file

@ -17,40 +17,52 @@ bp = Blueprint("automation", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "automation", "section": "automation",
"help_url": "https://bypass.censorship.guide/user/automation.html" "help_url": "https://bypass.censorship.guide/user/automation.html",
} }
class EditAutomationForm(FlaskForm): # type: ignore class EditAutomationForm(FlaskForm): # type: ignore
enabled = BooleanField('Enabled') enabled = BooleanField("Enabled")
submit = SubmitField('Save Changes') submit = SubmitField("Save Changes")
@bp.route("/list") @bp.route("/list")
def automation_list() -> ResponseReturnValue: def automation_list() -> ResponseReturnValue:
automations = list(filter( automations = list(
lambda a: a.short_name not in current_app.config.get('HIDDEN_AUTOMATIONS', []), filter(
Automation.query.filter( lambda a: a.short_name
Automation.destroyed.is_(None)).order_by(Automation.description).all() not in current_app.config.get("HIDDEN_AUTOMATIONS", []),
)) Automation.query.filter(Automation.destroyed.is_(None))
.order_by(Automation.description)
.all(),
)
)
states = {tfs.key: tfs for tfs in TerraformState.query.all()} states = {tfs.key: tfs for tfs in TerraformState.query.all()}
return render_template("list.html.j2", return render_template(
title="Automation Jobs", "list.html.j2",
item="automation", title="Automation Jobs",
items=automations, item="automation",
states=states, items=automations,
**_SECTION_TEMPLATE_VARS) states=states,
**_SECTION_TEMPLATE_VARS
)
@bp.route('/edit/<automation_id>', methods=['GET', 'POST']) @bp.route("/edit/<automation_id>", methods=["GET", "POST"])
def automation_edit(automation_id: int) -> ResponseReturnValue: def automation_edit(automation_id: int) -> ResponseReturnValue:
automation: Optional[Automation] = Automation.query.filter(Automation.id == automation_id).first() automation: Optional[Automation] = Automation.query.filter(
Automation.id == automation_id
).first()
if automation is None: if automation is None:
return Response(render_template("error.html.j2", return Response(
header="404 Automation Job Not Found", render_template(
message="The requested automation job could not be found.", "error.html.j2",
**_SECTION_TEMPLATE_VARS), header="404 Automation Job Not Found",
status=404) message="The requested automation job could not be found.",
**_SECTION_TEMPLATE_VARS
),
status=404,
)
form = EditAutomationForm(enabled=automation.enabled) form = EditAutomationForm(enabled=automation.enabled)
if form.validate_on_submit(): if form.validate_on_submit():
automation.enabled = form.enabled.data automation.enabled = form.enabled.data
@ -59,21 +71,30 @@ def automation_edit(automation_id: int) -> ResponseReturnValue:
db.session.commit() db.session.commit()
flash("Saved changes to bridge configuration.", "success") flash("Saved changes to bridge configuration.", "success")
except exc.SQLAlchemyError: except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the bridge configuration.", "danger") flash(
logs = AutomationLogs.query.filter(AutomationLogs.automation_id == automation.id).order_by( "An error occurred saving the changes to the bridge configuration.",
desc(AutomationLogs.added)).limit(5).all() "danger",
return render_template("automation.html.j2", )
automation=automation, logs = (
logs=logs, AutomationLogs.query.filter(AutomationLogs.automation_id == automation.id)
form=form, .order_by(desc(AutomationLogs.added))
**_SECTION_TEMPLATE_VARS) .limit(5)
.all()
)
return render_template(
"automation.html.j2",
automation=automation,
logs=logs,
form=form,
**_SECTION_TEMPLATE_VARS
)
@bp.route("/kick/<automation_id>", methods=['GET', 'POST']) @bp.route("/kick/<automation_id>", methods=["GET", "POST"])
def automation_kick(automation_id: int) -> ResponseReturnValue: def automation_kick(automation_id: int) -> ResponseReturnValue:
automation = Automation.query.filter( automation = Automation.query.filter(
Automation.id == automation_id, Automation.id == automation_id, Automation.destroyed.is_(None)
Automation.destroyed.is_(None)).first() ).first()
if automation is None: if automation is None:
return response_404("The requested bridge configuration could not be found.") return response_404("The requested bridge configuration could not be found.")
return view_lifecycle( return view_lifecycle(
@ -83,5 +104,5 @@ def automation_kick(automation_id: int) -> ResponseReturnValue:
success_view="portal.automation.automation_list", success_view="portal.automation.automation_list",
success_message="This automation job will next run within 1 minute.", success_message="This automation job will next run within 1 minute.",
resource=automation, resource=automation,
action="kick" action="kick",
) )

View file

@ -1,7 +1,6 @@
from typing import Optional from typing import Optional
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from app.extensions import db from app.extensions import db
@ -12,57 +11,79 @@ bp = Blueprint("bridge", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "bridge", "section": "bridge",
"help_url": "https://bypass.censorship.guide/user/bridges.html" "help_url": "https://bypass.censorship.guide/user/bridges.html",
} }
@bp.route("/list") @bp.route("/list")
def bridge_list() -> ResponseReturnValue: def bridge_list() -> ResponseReturnValue:
bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all() bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all()
return render_template("list.html.j2", return render_template(
title="Tor Bridges", "list.html.j2",
item="bridge", title="Tor Bridges",
items=bridges, item="bridge",
**_SECTION_TEMPLATE_VARS) items=bridges,
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/block/<bridge_id>", methods=['GET', 'POST']) @bp.route("/block/<bridge_id>", methods=["GET", "POST"])
def bridge_blocked(bridge_id: int) -> ResponseReturnValue: def bridge_blocked(bridge_id: int) -> ResponseReturnValue:
bridge: Optional[Bridge] = Bridge.query.filter(Bridge.id == bridge_id, Bridge.destroyed.is_(None)).first() bridge: Optional[Bridge] = Bridge.query.filter(
Bridge.id == bridge_id, Bridge.destroyed.is_(None)
).first()
if bridge is None: if bridge is None:
return Response(render_template("error.html.j2", return Response(
header="404 Proxy Not Found", render_template(
message="The requested bridge could not be found.", "error.html.j2",
**_SECTION_TEMPLATE_VARS)) header="404 Proxy Not Found",
message="The requested bridge could not be found.",
**_SECTION_TEMPLATE_VARS,
)
)
form = LifecycleForm() form = LifecycleForm()
if form.validate_on_submit(): if form.validate_on_submit():
bridge.deprecate(reason="manual") bridge.deprecate(reason="manual")
db.session.commit() db.session.commit()
flash("Bridge will be shortly replaced.", "success") flash("Bridge will be shortly replaced.", "success")
return redirect(url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)) return redirect(
return render_template("lifecycle.html.j2", url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)
header=f"Mark bridge {bridge.hashed_fingerprint} as blocked?", )
message=bridge.hashed_fingerprint, return render_template(
form=form, "lifecycle.html.j2",
**_SECTION_TEMPLATE_VARS) header=f"Mark bridge {bridge.hashed_fingerprint} as blocked?",
message=bridge.hashed_fingerprint,
form=form,
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/expire/<bridge_id>", methods=['GET', 'POST']) @bp.route("/expire/<bridge_id>", methods=["GET", "POST"])
def bridge_expire(bridge_id: int) -> ResponseReturnValue: def bridge_expire(bridge_id: int) -> ResponseReturnValue:
bridge: Optional[Bridge] = Bridge.query.filter(Bridge.id == bridge_id, Bridge.destroyed.is_(None)).first() bridge: Optional[Bridge] = Bridge.query.filter(
Bridge.id == bridge_id, Bridge.destroyed.is_(None)
).first()
if bridge is None: if bridge is None:
return Response(render_template("error.html.j2", return Response(
header="404 Proxy Not Found", render_template(
message="The requested bridge could not be found.", "error.html.j2",
**_SECTION_TEMPLATE_VARS)) header="404 Proxy Not Found",
message="The requested bridge could not be found.",
**_SECTION_TEMPLATE_VARS,
)
)
form = LifecycleForm() form = LifecycleForm()
if form.validate_on_submit(): if form.validate_on_submit():
bridge.destroy() bridge.destroy()
db.session.commit() db.session.commit()
flash("Bridge will be shortly destroyed.", "success") flash("Bridge will be shortly destroyed.", "success")
return redirect(url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)) return redirect(
return render_template("lifecycle.html.j2", url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)
header=f"Destroy bridge {bridge.hashed_fingerprint}?", )
message=bridge.hashed_fingerprint, return render_template(
form=form, "lifecycle.html.j2",
**_SECTION_TEMPLATE_VARS) header=f"Destroy bridge {bridge.hashed_fingerprint}?",
message=bridge.hashed_fingerprint,
form=form,
**_SECTION_TEMPLATE_VARS,
)

View file

@ -1,8 +1,7 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional from typing import List, Optional
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from sqlalchemy import exc from sqlalchemy import exc
@ -19,77 +18,109 @@ bp = Blueprint("bridgeconf", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "bridgeconf", "section": "bridgeconf",
"help_url": "https://bypass.censorship.guide/user/bridges.html" "help_url": "https://bypass.censorship.guide/user/bridges.html",
} }
class NewBridgeConfForm(FlaskForm): # type: ignore class NewBridgeConfForm(FlaskForm): # type: ignore
method = SelectField('Distribution Method', validators=[DataRequired()]) method = SelectField("Distribution Method", validators=[DataRequired()])
description = StringField('Description') description = StringField("Description")
pool = SelectField('Pool', validators=[DataRequired()]) pool = SelectField("Pool", validators=[DataRequired()])
target_number = IntegerField('Target Number', target_number = IntegerField(
description="The number of active bridges to deploy (excluding deprecated bridges).", "Target Number",
validators=[NumberRange(1, message="One or more bridges must be created.")]) description="The number of active bridges to deploy (excluding deprecated bridges).",
max_number = IntegerField('Maximum Number', validators=[NumberRange(1, message="One or more bridges must be created.")],
description="The maximum number of bridges to deploy (including deprecated bridges).", )
validators=[ max_number = IntegerField(
NumberRange(1, message="Must be at least 1, ideally greater than target number.")]) "Maximum Number",
expiry_hours = IntegerField('Expiry Timer (hours)', description="The maximum number of bridges to deploy (including deprecated bridges).",
description=("The number of hours to wait after a bridge is deprecated before its " validators=[
"destruction.")) NumberRange(
provider_allocation = SelectField('Provider Allocation Method', 1, message="Must be at least 1, ideally greater than target number."
description="How to allocate new bridges to providers.", )
choices=[ ],
("COST", "Use cheapest provider first"), )
("RANDOM", "Use providers randomly"), expiry_hours = IntegerField(
]) "Expiry Timer (hours)",
submit = SubmitField('Save Changes') description=(
"The number of hours to wait after a bridge is deprecated before its "
"destruction."
),
)
provider_allocation = SelectField(
"Provider Allocation Method",
description="How to allocate new bridges to providers.",
choices=[
("COST", "Use cheapest provider first"),
("RANDOM", "Use providers randomly"),
],
)
submit = SubmitField("Save Changes")
class EditBridgeConfForm(FlaskForm): # type: ignore class EditBridgeConfForm(FlaskForm): # type: ignore
description = StringField('Description') description = StringField("Description")
target_number = IntegerField('Target Number', target_number = IntegerField(
description="The number of active bridges to deploy (excluding deprecated bridges).", "Target Number",
validators=[NumberRange(1, message="One or more bridges must be created.")]) description="The number of active bridges to deploy (excluding deprecated bridges).",
max_number = IntegerField('Maximum Number', validators=[NumberRange(1, message="One or more bridges must be created.")],
description="The maximum number of bridges to deploy (including deprecated bridges).", )
validators=[ max_number = IntegerField(
NumberRange(1, message="Must be at least 1, ideally greater than target number.")]) "Maximum Number",
expiry_hours = IntegerField('Expiry Timer (hours)', description="The maximum number of bridges to deploy (including deprecated bridges).",
description=("The number of hours to wait after a bridge is deprecated before its " validators=[
"destruction.")) NumberRange(
provider_allocation = SelectField('Provider Allocation Method', 1, message="Must be at least 1, ideally greater than target number."
description="How to allocate new bridges to providers.", )
choices=[ ],
("COST", "Use cheapest provider first"), )
("RANDOM", "Use providers randomly"), expiry_hours = IntegerField(
]) "Expiry Timer (hours)",
submit = SubmitField('Save Changes') description=(
"The number of hours to wait after a bridge is deprecated before its "
"destruction."
),
)
provider_allocation = SelectField(
"Provider Allocation Method",
description="How to allocate new bridges to providers.",
choices=[
("COST", "Use cheapest provider first"),
("RANDOM", "Use providers randomly"),
],
)
submit = SubmitField("Save Changes")
@bp.route("/list") @bp.route("/list")
def bridgeconf_list() -> ResponseReturnValue: def bridgeconf_list() -> ResponseReturnValue:
bridgeconfs: List[BridgeConf] = BridgeConf.query.filter(BridgeConf.destroyed.is_(None)).all() bridgeconfs: List[BridgeConf] = BridgeConf.query.filter(
return render_template("list.html.j2", BridgeConf.destroyed.is_(None)
title="Tor Bridge Configurations", ).all()
item="bridge configuration", return render_template(
items=bridgeconfs, "list.html.j2",
new_link=url_for("portal.bridgeconf.bridgeconf_new"), title="Tor Bridge Configurations",
**_SECTION_TEMPLATE_VARS) item="bridge configuration",
items=bridgeconfs,
new_link=url_for("portal.bridgeconf.bridgeconf_new"),
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=['GET', 'POST']) @bp.route("/new/<group_id>", methods=["GET", "POST"])
def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue: def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = NewBridgeConfForm() form = NewBridgeConfForm()
form.pool.choices = [(x.id, x.pool_name) for x in Pool.query.filter(Pool.destroyed.is_(None)).all()] form.pool.choices = [
(x.id, x.pool_name) for x in Pool.query.filter(Pool.destroyed.is_(None)).all()
]
form.method.choices = [ form.method.choices = [
("any", "Any (BridgeDB)"), ("any", "Any (BridgeDB)"),
("email", "E-Mail (BridgeDB)"), ("email", "E-Mail (BridgeDB)"),
("moat", "Moat (BridgeDB)"), ("moat", "Moat (BridgeDB)"),
("settings", "Settings (BridgeDB)"), ("settings", "Settings (BridgeDB)"),
("https", "HTTPS (BridgeDB)"), ("https", "HTTPS (BridgeDB)"),
("none", "None (Private)") ("none", "None (Private)"),
] ]
if form.validate_on_submit(): if form.validate_on_submit():
bridgeconf = BridgeConf() bridgeconf = BridgeConf()
@ -99,7 +130,9 @@ def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue:
bridgeconf.target_number = form.target_number.data bridgeconf.target_number = form.target_number.data
bridgeconf.max_number = form.max_number.data bridgeconf.max_number = form.max_number.data
bridgeconf.expiry_hours = form.expiry_hours.data bridgeconf.expiry_hours = form.expiry_hours.data
bridgeconf.provider_allocation = ProviderAllocation[form.provider_allocation.data] bridgeconf.provider_allocation = ProviderAllocation[
form.provider_allocation.data
]
bridgeconf.added = datetime.now(tz=timezone.utc) bridgeconf.added = datetime.now(tz=timezone.utc)
bridgeconf.updated = datetime.now(tz=timezone.utc) bridgeconf.updated = datetime.now(tz=timezone.utc)
try: try:
@ -112,47 +145,56 @@ def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return redirect(url_for("portal.bridgeconf.bridgeconf_list")) return redirect(url_for("portal.bridgeconf.bridgeconf_list"))
if group_id: if group_id:
form.group.data = group_id form.group.data = group_id
return render_template("new.html.j2", return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
form=form,
**_SECTION_TEMPLATE_VARS)
@bp.route('/edit/<bridgeconf_id>', methods=['GET', 'POST']) @bp.route("/edit/<bridgeconf_id>", methods=["GET", "POST"])
def bridgeconf_edit(bridgeconf_id: int) -> ResponseReturnValue: def bridgeconf_edit(bridgeconf_id: int) -> ResponseReturnValue:
bridgeconf = BridgeConf.query.filter(BridgeConf.id == bridgeconf_id).first() bridgeconf = BridgeConf.query.filter(BridgeConf.id == bridgeconf_id).first()
if bridgeconf is None: if bridgeconf is None:
return Response(render_template("error.html.j2", return Response(
header="404 Bridge Configuration Not Found", render_template(
message="The requested bridge configuration could not be found.", "error.html.j2",
**_SECTION_TEMPLATE_VARS), header="404 Bridge Configuration Not Found",
status=404) message="The requested bridge configuration could not be found.",
form = EditBridgeConfForm(description=bridgeconf.description, **_SECTION_TEMPLATE_VARS,
target_number=bridgeconf.target_number, ),
max_number=bridgeconf.max_number, status=404,
expiry_hours=bridgeconf.expiry_hours, )
provider_allocation=bridgeconf.provider_allocation.name, form = EditBridgeConfForm(
) description=bridgeconf.description,
target_number=bridgeconf.target_number,
max_number=bridgeconf.max_number,
expiry_hours=bridgeconf.expiry_hours,
provider_allocation=bridgeconf.provider_allocation.name,
)
if form.validate_on_submit(): if form.validate_on_submit():
bridgeconf.description = form.description.data bridgeconf.description = form.description.data
bridgeconf.target_number = form.target_number.data bridgeconf.target_number = form.target_number.data
bridgeconf.max_number = form.max_number.data bridgeconf.max_number = form.max_number.data
bridgeconf.expiry_hours = form.expiry_hours.data bridgeconf.expiry_hours = form.expiry_hours.data
bridgeconf.provider_allocation = ProviderAllocation[form.provider_allocation.data] bridgeconf.provider_allocation = ProviderAllocation[
form.provider_allocation.data
]
bridgeconf.updated = datetime.now(tz=timezone.utc) bridgeconf.updated = datetime.now(tz=timezone.utc)
try: try:
db.session.commit() db.session.commit()
flash("Saved changes to bridge configuration.", "success") flash("Saved changes to bridge configuration.", "success")
except exc.SQLAlchemyError: except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the bridge configuration.", "danger") flash(
return render_template("bridgeconf.html.j2", "An error occurred saving the changes to the bridge configuration.",
bridgeconf=bridgeconf, "danger",
form=form, )
**_SECTION_TEMPLATE_VARS) return render_template(
"bridgeconf.html.j2", bridgeconf=bridgeconf, form=form, **_SECTION_TEMPLATE_VARS
)
@bp.route("/destroy/<bridgeconf_id>", methods=['GET', 'POST']) @bp.route("/destroy/<bridgeconf_id>", methods=["GET", "POST"])
def bridgeconf_destroy(bridgeconf_id: int) -> ResponseReturnValue: def bridgeconf_destroy(bridgeconf_id: int) -> ResponseReturnValue:
bridgeconf = BridgeConf.query.filter(BridgeConf.id == bridgeconf_id, BridgeConf.destroyed.is_(None)).first() bridgeconf = BridgeConf.query.filter(
BridgeConf.id == bridgeconf_id, BridgeConf.destroyed.is_(None)
).first()
if bridgeconf is None: if bridgeconf is None:
return response_404("The requested bridge configuration could not be found.") return response_404("The requested bridge configuration could not be found.")
return view_lifecycle( return view_lifecycle(
@ -162,5 +204,5 @@ def bridgeconf_destroy(bridgeconf_id: int) -> ResponseReturnValue:
success_message="All bridges from the destroyed configuration will shortly be destroyed at their providers.", success_message="All bridges from the destroyed configuration will shortly be destroyed at their providers.",
section="bridgeconf", section="bridgeconf",
resource=bridgeconf, resource=bridgeconf,
action="destroy" action="destroy",
) )

View file

@ -3,8 +3,15 @@ from typing import Dict, List, Optional, Type, Union
from flask import Blueprint, redirect, render_template, url_for from flask import Blueprint, redirect, render_template, url_for
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import (BooleanField, Form, FormField, IntegerField, SelectField, from wtforms import (
StringField, SubmitField) BooleanField,
Form,
FormField,
IntegerField,
SelectField,
StringField,
SubmitField,
)
from wtforms.validators import InputRequired from wtforms.validators import InputRequired
from app.extensions import db from app.extensions import db
@ -14,54 +21,72 @@ bp = Blueprint("cloud", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "cloud", "section": "cloud",
"help_url": "https://bypass.censorship.guide/user/cloud.html" "help_url": "https://bypass.censorship.guide/user/cloud.html",
} }
class NewCloudAccountForm(FlaskForm): # type: ignore class NewCloudAccountForm(FlaskForm): # type: ignore
provider = SelectField('Cloud Provider', validators=[InputRequired()]) provider = SelectField("Cloud Provider", validators=[InputRequired()])
submit = SubmitField('Next') submit = SubmitField("Next")
class AWSAccountForm(FlaskForm): # type: ignore class AWSAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""}) provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()]) description = StringField("Description", validators=[InputRequired()])
aws_access_key = StringField('AWS Access Key', validators=[InputRequired()]) aws_access_key = StringField("AWS Access Key", validators=[InputRequired()])
aws_secret_key = StringField('AWS Secret Key', validators=[InputRequired()]) aws_secret_key = StringField("AWS Secret Key", validators=[InputRequired()])
aws_region = StringField('AWS Region', default='us-east-2', validators=[InputRequired()]) aws_region = StringField(
max_distributions = IntegerField('Cloudfront Distributions Quota', default=200, "AWS Region", default="us-east-2", validators=[InputRequired()]
description="This is the quota for number of distributions per account.", )
validators=[InputRequired()]) max_distributions = IntegerField(
max_instances = IntegerField('EC2 Instance Quota', default=2, "Cloudfront Distributions Quota",
description="This can be impacted by a number of quotas including instance limits " default=200,
"and IP address limits.", description="This is the quota for number of distributions per account.",
validators=[InputRequired()]) validators=[InputRequired()],
enabled = BooleanField('Enable this account', default=True, )
description="New resources will not be deployed to disabled accounts, however existing " max_instances = IntegerField(
"resources will persist until destroyed at the end of their lifecycle.") "EC2 Instance Quota",
submit = SubmitField('Save Changes') default=2,
description="This can be impacted by a number of quotas including instance limits "
"and IP address limits.",
validators=[InputRequired()],
)
enabled = BooleanField(
"Enable this account",
default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class HcloudAccountForm(FlaskForm): # type: ignore class HcloudAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""}) provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()]) description = StringField("Description", validators=[InputRequired()])
hcloud_token = StringField('Hetzner Cloud Token', validators=[InputRequired()]) hcloud_token = StringField("Hetzner Cloud Token", validators=[InputRequired()])
max_instances = IntegerField('Server Limit', default=10, max_instances = IntegerField(
validators=[InputRequired()]) "Server Limit", default=10, validators=[InputRequired()]
enabled = BooleanField('Enable this account', default=True, )
description="New resources will not be deployed to disabled accounts, however existing " enabled = BooleanField(
"resources will persist until destroyed at the end of their lifecycle.") "Enable this account",
submit = SubmitField('Save Changes') default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class GitlabAccountForm(FlaskForm): # type: ignore class GitlabAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""}) provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()]) description = StringField("Description", validators=[InputRequired()])
gitlab_token = StringField('GitLab Access Token', validators=[InputRequired()]) gitlab_token = StringField("GitLab Access Token", validators=[InputRequired()])
enabled = BooleanField('Enable this account', default=True, enabled = BooleanField(
description="New resources will not be deployed to disabled accounts, however existing " "Enable this account",
"resources will persist until destroyed at the end of their lifecycle.") default=True,
submit = SubmitField('Save Changes') description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class OvhHorizonForm(Form): # type: ignore[misc] class OvhHorizonForm(Form): # type: ignore[misc]
@ -77,16 +102,20 @@ class OvhApiForm(Form): # type: ignore[misc]
class OvhAccountForm(FlaskForm): # type: ignore class OvhAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""}) provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()]) description = StringField("Description", validators=[InputRequired()])
horizon = FormField(OvhHorizonForm, 'OpenStack Horizon API') horizon = FormField(OvhHorizonForm, "OpenStack Horizon API")
ovh_api = FormField(OvhApiForm, 'OVH API') ovh_api = FormField(OvhApiForm, "OVH API")
max_instances = IntegerField('Server Limit', default=10, max_instances = IntegerField(
validators=[InputRequired()]) "Server Limit", default=10, validators=[InputRequired()]
enabled = BooleanField('Enable this account', default=True, )
description="New resources will not be deployed to disabled accounts, however existing " enabled = BooleanField(
"resources will persist until destroyed at the end of their lifecycle.") "Enable this account",
submit = SubmitField('Save Changes') default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class GandiHorizonForm(Form): # type: ignore[misc] class GandiHorizonForm(Form): # type: ignore[misc]
@ -96,18 +125,24 @@ class GandiHorizonForm(Form): # type: ignore[misc]
class GandiAccountForm(FlaskForm): # type: ignore class GandiAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""}) provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()]) description = StringField("Description", validators=[InputRequired()])
horizon = FormField(GandiHorizonForm, 'OpenStack Horizon API') horizon = FormField(GandiHorizonForm, "OpenStack Horizon API")
max_instances = IntegerField('Server Limit', default=10, max_instances = IntegerField(
validators=[InputRequired()]) "Server Limit", default=10, validators=[InputRequired()]
enabled = BooleanField('Enable this account', default=True, )
description="New resources will not be deployed to disabled accounts, however existing " enabled = BooleanField(
"resources will persist until destroyed at the end of their lifecycle.") "Enable this account",
submit = SubmitField('Save Changes') default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
CloudAccountForm = Union[AWSAccountForm, HcloudAccountForm, GandiAccountForm, OvhAccountForm] CloudAccountForm = Union[
AWSAccountForm, HcloudAccountForm, GandiAccountForm, OvhAccountForm
]
provider_forms: Dict[str, Type[CloudAccountForm]] = { provider_forms: Dict[str, Type[CloudAccountForm]] = {
CloudProvider.AWS.name: AWSAccountForm, CloudProvider.AWS.name: AWSAccountForm,
@ -118,7 +153,9 @@ provider_forms: Dict[str, Type[CloudAccountForm]] = {
} }
def cloud_account_save(account: Optional[CloudAccount], provider: CloudProvider, form: CloudAccountForm) -> None: def cloud_account_save(
account: Optional[CloudAccount], provider: CloudProvider, form: CloudAccountForm
) -> None:
if not account: if not account:
account = CloudAccount() account = CloudAccount()
account.provider = provider account.provider = provider
@ -162,7 +199,9 @@ def cloud_account_save(account: Optional[CloudAccount], provider: CloudProvider,
"ovh_openstack_password": form.horizon.data["ovh_openstack_password"], "ovh_openstack_password": form.horizon.data["ovh_openstack_password"],
"ovh_openstack_tenant_id": form.horizon.data["ovh_openstack_tenant_id"], "ovh_openstack_tenant_id": form.horizon.data["ovh_openstack_tenant_id"],
"ovh_cloud_application_key": form.ovh_api.data["ovh_cloud_application_key"], "ovh_cloud_application_key": form.ovh_api.data["ovh_cloud_application_key"],
"ovh_cloud_application_secret": form.ovh_api.data["ovh_cloud_application_secret"], "ovh_cloud_application_secret": form.ovh_api.data[
"ovh_cloud_application_secret"
],
"ovh_cloud_consumer_key": form.ovh_api.data["ovh_cloud_consumer_key"], "ovh_cloud_consumer_key": form.ovh_api.data["ovh_cloud_consumer_key"],
} }
account.max_distributions = 0 account.max_distributions = 0
@ -182,53 +221,82 @@ def cloud_account_populate(form: CloudAccountForm, account: CloudAccount) -> Non
form.aws_region.data = account.credentials["aws_region"] form.aws_region.data = account.credentials["aws_region"]
form.max_distributions.data = account.max_distributions form.max_distributions.data = account.max_distributions
form.max_instances.data = account.max_instances form.max_instances.data = account.max_instances
elif account.provider == CloudProvider.HCLOUD and isinstance(form, HcloudAccountForm): elif account.provider == CloudProvider.HCLOUD and isinstance(
form, HcloudAccountForm
):
form.hcloud_token.data = account.credentials["hcloud_token"] form.hcloud_token.data = account.credentials["hcloud_token"]
form.max_instances.data = account.max_instances form.max_instances.data = account.max_instances
elif account.provider == CloudProvider.GANDI and isinstance(form, GandiAccountForm): elif account.provider == CloudProvider.GANDI and isinstance(form, GandiAccountForm):
form.horizon.form.gandi_openstack_user.data = account.credentials["gandi_openstack_user"] form.horizon.form.gandi_openstack_user.data = account.credentials[
form.horizon.form.gandi_openstack_password.data = account.credentials["gandi_openstack_password"] "gandi_openstack_user"
form.horizon.form.gandi_openstack_tenant_id.data = account.credentials["gandi_openstack_tenant_id"] ]
form.horizon.form.gandi_openstack_password.data = account.credentials[
"gandi_openstack_password"
]
form.horizon.form.gandi_openstack_tenant_id.data = account.credentials[
"gandi_openstack_tenant_id"
]
form.max_instances.data = account.max_instances form.max_instances.data = account.max_instances
elif account.provider == CloudProvider.GITLAB and isinstance(form, GitlabAccountForm): elif account.provider == CloudProvider.GITLAB and isinstance(
form, GitlabAccountForm
):
form.gitlab_token.data = account.credentials["gitlab_token"] form.gitlab_token.data = account.credentials["gitlab_token"]
elif account.provider == CloudProvider.OVH and isinstance(form, OvhAccountForm): elif account.provider == CloudProvider.OVH and isinstance(form, OvhAccountForm):
form.horizon.form.ovh_openstack_user.data = account.credentials["ovh_openstack_user"] form.horizon.form.ovh_openstack_user.data = account.credentials[
form.horizon.form.ovh_openstack_password.data = account.credentials["ovh_openstack_password"] "ovh_openstack_user"
form.horizon.form.ovh_openstack_tenant_id.data = account.credentials["ovh_openstack_tenant_id"] ]
form.ovh_api.form.ovh_cloud_application_key.data = account.credentials["ovh_cloud_application_key"] form.horizon.form.ovh_openstack_password.data = account.credentials[
form.ovh_api.form.ovh_cloud_application_secret.data = account.credentials["ovh_cloud_application_secret"] "ovh_openstack_password"
form.ovh_api.form.ovh_cloud_consumer_key.data = account.credentials["ovh_cloud_consumer_key"] ]
form.horizon.form.ovh_openstack_tenant_id.data = account.credentials[
"ovh_openstack_tenant_id"
]
form.ovh_api.form.ovh_cloud_application_key.data = account.credentials[
"ovh_cloud_application_key"
]
form.ovh_api.form.ovh_cloud_application_secret.data = account.credentials[
"ovh_cloud_application_secret"
]
form.ovh_api.form.ovh_cloud_consumer_key.data = account.credentials[
"ovh_cloud_consumer_key"
]
form.max_instances.data = account.max_instances form.max_instances.data = account.max_instances
else: else:
raise RuntimeError(f"Unknown provider {account.provider} or form data {type(form)} did not match provider.") raise RuntimeError(
f"Unknown provider {account.provider} or form data {type(form)} did not match provider."
)
@bp.route("/list") @bp.route("/list")
def cloud_account_list() -> ResponseReturnValue: def cloud_account_list() -> ResponseReturnValue:
accounts: List[CloudAccount] = CloudAccount.query.filter(CloudAccount.destroyed.is_(None)).all() accounts: List[CloudAccount] = CloudAccount.query.filter(
return render_template("list.html.j2", CloudAccount.destroyed.is_(None)
title="Cloud Accounts", ).all()
item="cloud account", return render_template(
items=accounts, "list.html.j2",
new_link=url_for("portal.cloud.cloud_account_new"), title="Cloud Accounts",
**_SECTION_TEMPLATE_VARS) item="cloud account",
items=accounts,
new_link=url_for("portal.cloud.cloud_account_new"),
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
def cloud_account_new() -> ResponseReturnValue: def cloud_account_new() -> ResponseReturnValue:
form = NewCloudAccountForm() form = NewCloudAccountForm()
form.provider.choices = sorted([ form.provider.choices = sorted(
(provider.name, provider.description) for provider in CloudProvider [(provider.name, provider.description) for provider in CloudProvider],
], key=lambda p: p[1].lower()) key=lambda p: p[1].lower(),
)
if form.validate_on_submit(): if form.validate_on_submit():
return redirect(url_for("portal.cloud.cloud_account_new_for", provider=form.provider.data)) return redirect(
return render_template("new.html.j2", url_for("portal.cloud.cloud_account_new_for", provider=form.provider.data)
form=form, )
**_SECTION_TEMPLATE_VARS) return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
@bp.route("/new/<provider>", methods=['GET', 'POST']) @bp.route("/new/<provider>", methods=["GET", "POST"])
def cloud_account_new_for(provider: str) -> ResponseReturnValue: def cloud_account_new_for(provider: str) -> ResponseReturnValue:
form = provider_forms[provider]() form = provider_forms[provider]()
form.provider.data = CloudProvider[provider].description form.provider.data = CloudProvider[provider].description
@ -236,12 +304,10 @@ def cloud_account_new_for(provider: str) -> ResponseReturnValue:
cloud_account_save(None, CloudProvider[provider], form) cloud_account_save(None, CloudProvider[provider], form)
db.session.commit() db.session.commit()
return redirect(url_for("portal.cloud.cloud_account_list")) return redirect(url_for("portal.cloud.cloud_account_list"))
return render_template("new.html.j2", return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
form=form,
**_SECTION_TEMPLATE_VARS)
@bp.route("/edit/<account_id>", methods=['GET', 'POST']) @bp.route("/edit/<account_id>", methods=["GET", "POST"])
def cloud_account_edit(account_id: int) -> ResponseReturnValue: def cloud_account_edit(account_id: int) -> ResponseReturnValue:
account = CloudAccount.query.filter( account = CloudAccount.query.filter(
CloudAccount.id == account_id, CloudAccount.id == account_id,
@ -256,6 +322,4 @@ def cloud_account_edit(account_id: int) -> ResponseReturnValue:
db.session.commit() db.session.commit()
return redirect(url_for("portal.cloud.cloud_account_list")) return redirect(url_for("portal.cloud.cloud_account_list"))
cloud_account_populate(form, account) cloud_account_populate(form, account)
return render_template("new.html.j2", return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
form=form,
**_SECTION_TEMPLATE_VARS)

View file

@ -13,7 +13,7 @@ bp = Blueprint("country", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "country", "section": "country",
"help_url": "https://bypass.censorship.guide/user/countries.html" "help_url": "https://bypass.censorship.guide/user/countries.html",
} }
@ -22,42 +22,51 @@ def filter_country_flag(country_code: str) -> str:
country_code = country_code.upper() country_code = country_code.upper()
# Calculate the regional indicator symbol for each letter in the country code # Calculate the regional indicator symbol for each letter in the country code
base = ord('\U0001F1E6') - ord('A') base = ord("\U0001F1E6") - ord("A")
flag = ''.join([chr(ord(char) + base) for char in country_code]) flag = "".join([chr(ord(char) + base) for char in country_code])
return flag return flag
@bp.route('/list') @bp.route("/list")
def country_list() -> ResponseReturnValue: def country_list() -> ResponseReturnValue:
countries = Country.query.filter(Country.destroyed.is_(None)).all() countries = Country.query.filter(Country.destroyed.is_(None)).all()
print(len(countries)) print(len(countries))
return render_template("list.html.j2", return render_template(
title="Countries", "list.html.j2",
item="country", title="Countries",
new_link=None, item="country",
items=sorted(countries, key=lambda x: x.country_code), new_link=None,
**_SECTION_TEMPLATE_VARS items=sorted(countries, key=lambda x: x.country_code),
) **_SECTION_TEMPLATE_VARS
)
class EditCountryForm(FlaskForm): # type: ignore[misc] class EditCountryForm(FlaskForm): # type: ignore[misc]
risk_level_override = BooleanField("Force Risk Level Override?") risk_level_override = BooleanField("Force Risk Level Override?")
risk_level_override_number = IntegerField("Forced Risk Level", description="Number from 0 to 20", default=0) risk_level_override_number = IntegerField(
submit = SubmitField('Save Changes') "Forced Risk Level", description="Number from 0 to 20", default=0
)
submit = SubmitField("Save Changes")
@bp.route('/edit/<country_id>', methods=['GET', 'POST']) @bp.route("/edit/<country_id>", methods=["GET", "POST"])
def country_edit(country_id: int) -> ResponseReturnValue: def country_edit(country_id: int) -> ResponseReturnValue:
country = Country.query.filter(Country.id == country_id).first() country = Country.query.filter(Country.id == country_id).first()
if country is None: if country is None:
return Response(render_template("error.html.j2", return Response(
section="country", render_template(
header="404 Country Not Found", "error.html.j2",
message="The requested country could not be found."), section="country",
status=404) header="404 Country Not Found",
form = EditCountryForm(risk_level_override=country.risk_level_override is not None, message="The requested country could not be found.",
risk_level_override_number=country.risk_level_override) ),
status=404,
)
form = EditCountryForm(
risk_level_override=country.risk_level_override is not None,
risk_level_override_number=country.risk_level_override,
)
if form.validate_on_submit(): if form.validate_on_submit():
if form.risk_level_override.data: if form.risk_level_override.data:
country.risk_level_override = form.risk_level_override_number.data country.risk_level_override = form.risk_level_override_number.data
@ -69,6 +78,6 @@ def country_edit(country_id: int) -> ResponseReturnValue:
flash("Saved changes to country.", "success") flash("Saved changes to country.", "success")
except sqlalchemy.exc.SQLAlchemyError: except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the country.", "danger") flash("An error occurred saving the changes to the country.", "danger")
return render_template("country.html.j2", return render_template(
section="country", "country.html.j2", section="country", country=country, form=form
country=country, form=form) )

View file

@ -10,23 +10,32 @@ bp = Blueprint("eotk", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "eotk", "section": "eotk",
"help_url": "https://bypass.censorship.guide/user/eotk.html" "help_url": "https://bypass.censorship.guide/user/eotk.html",
} }
@bp.route("/list") @bp.route("/list")
def eotk_list() -> ResponseReturnValue: def eotk_list() -> ResponseReturnValue:
instances = Eotk.query.filter(Eotk.destroyed.is_(None)).order_by(desc(Eotk.added)).all() instances = (
return render_template("list.html.j2", Eotk.query.filter(Eotk.destroyed.is_(None)).order_by(desc(Eotk.added)).all()
title="EOTK Instances", )
item="eotk", return render_template(
items=instances, "list.html.j2",
**_SECTION_TEMPLATE_VARS) title="EOTK Instances",
item="eotk",
items=instances,
**_SECTION_TEMPLATE_VARS
)
@bp.route("/conf/<group_id>") @bp.route("/conf/<group_id>")
def eotk_conf(group_id: int) -> ResponseReturnValue: def eotk_conf(group_id: int) -> ResponseReturnValue:
group = Group.query.filter(Group.id == group_id).first() group = Group.query.filter(Group.id == group_id).first()
return Response(render_template("sites.conf.j2", return Response(
bypass_token=current_app.config["BYPASS_TOKEN"], render_template(
group=group), content_type="text/plain") "sites.conf.j2",
bypass_token=current_app.config["BYPASS_TOKEN"],
group=group,
),
content_type="text/plain",
)

View file

@ -3,6 +3,6 @@ from wtforms import SelectField, StringField, SubmitField
class EditMirrorForm(FlaskForm): # type: ignore class EditMirrorForm(FlaskForm): # type: ignore
origin = SelectField('Origin') origin = SelectField("Origin")
url = StringField('URL') url = StringField("URL")
submit = SubmitField('Save Changes') submit = SubmitField("Save Changes")

View file

@ -1,8 +1,7 @@
from datetime import datetime, timezone from datetime import datetime, timezone
import sqlalchemy import sqlalchemy
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import BooleanField, StringField, SubmitField from wtforms import BooleanField, StringField, SubmitField
@ -18,27 +17,29 @@ class NewGroupForm(FlaskForm): # type: ignore
group_name = StringField("Short Name", validators=[DataRequired()]) group_name = StringField("Short Name", validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
eotk = BooleanField("Deploy EOTK instances?") eotk = BooleanField("Deploy EOTK instances?")
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
class EditGroupForm(FlaskForm): # type: ignore class EditGroupForm(FlaskForm): # type: ignore
description = StringField('Description', validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
eotk = BooleanField("Deploy EOTK instances?") eotk = BooleanField("Deploy EOTK instances?")
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
@bp.route("/list") @bp.route("/list")
def group_list() -> ResponseReturnValue: def group_list() -> ResponseReturnValue:
groups = Group.query.order_by(Group.group_name).all() groups = Group.query.order_by(Group.group_name).all()
return render_template("list.html.j2", return render_template(
section="group", "list.html.j2",
title="Groups", section="group",
item="group", title="Groups",
items=groups, item="group",
new_link=url_for("portal.group.group_new")) items=groups,
new_link=url_for("portal.group.group_new"),
)
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
def group_new() -> ResponseReturnValue: def group_new() -> ResponseReturnValue:
form = NewGroupForm() form = NewGroupForm()
if form.validate_on_submit(): if form.validate_on_submit():
@ -59,17 +60,20 @@ def group_new() -> ResponseReturnValue:
return render_template("new.html.j2", section="group", form=form) return render_template("new.html.j2", section="group", form=form)
@bp.route('/edit/<group_id>', methods=['GET', 'POST']) @bp.route("/edit/<group_id>", methods=["GET", "POST"])
def group_edit(group_id: int) -> ResponseReturnValue: def group_edit(group_id: int) -> ResponseReturnValue:
group = Group.query.filter(Group.id == group_id).first() group = Group.query.filter(Group.id == group_id).first()
if group is None: if group is None:
return Response(render_template("error.html.j2", return Response(
section="group", render_template(
header="404 Group Not Found", "error.html.j2",
message="The requested group could not be found."), section="group",
status=404) header="404 Group Not Found",
form = EditGroupForm(description=group.description, message="The requested group could not be found.",
eotk=group.eotk) ),
status=404,
)
form = EditGroupForm(description=group.description, eotk=group.eotk)
if form.validate_on_submit(): if form.validate_on_submit():
group.description = form.description.data group.description = form.description.data
group.eotk = form.eotk.data group.eotk = form.eotk.data
@ -79,6 +83,4 @@ def group_edit(group_id: int) -> ResponseReturnValue:
flash("Saved changes to group.", "success") flash("Saved changes to group.", "success")
except sqlalchemy.exc.SQLAlchemyError: except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the group.", "danger") flash("An error occurred saving the changes to the group.", "danger")
return render_template("group.html.j2", return render_template("group.html.j2", section="group", group=group, form=form)
section="group",
group=group, form=form)

View file

@ -2,8 +2,7 @@ import json
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from sqlalchemy import exc from sqlalchemy import exc
@ -23,7 +22,7 @@ bp = Blueprint("list", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "list", "section": "list",
"help_url": "https://bypass.censorship.guide/user/lists.html" "help_url": "https://bypass.censorship.guide/user/lists.html",
} }
@ -42,37 +41,44 @@ def list_encoding_name(key: str) -> str:
return MirrorList.encodings_supported.get(key, "Unknown") return MirrorList.encodings_supported.get(key, "Unknown")
@bp.route('/list') @bp.route("/list")
def list_list() -> ResponseReturnValue: def list_list() -> ResponseReturnValue:
lists = MirrorList.query.filter(MirrorList.destroyed.is_(None)).all() lists = MirrorList.query.filter(MirrorList.destroyed.is_(None)).all()
return render_template("list.html.j2", return render_template(
title="Distribution Lists", "list.html.j2",
item="distribution list", title="Distribution Lists",
new_link=url_for("portal.list.list_new"), item="distribution list",
items=lists, new_link=url_for("portal.list.list_new"),
**_SECTION_TEMPLATE_VARS items=lists,
) **_SECTION_TEMPLATE_VARS
)
@bp.route('/preview/<format_>/<pool_id>') @bp.route("/preview/<format_>/<pool_id>")
def list_preview(format_: str, pool_id: int) -> ResponseReturnValue: def list_preview(format_: str, pool_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first() pool = Pool.query.filter(Pool.id == pool_id).first()
if not pool: if not pool:
return response_404(message="Pool not found") return response_404(message="Pool not found")
if format_ == "bca": if format_ == "bca":
return Response(json.dumps(mirror_mapping(pool)), content_type="application/json") return Response(
json.dumps(mirror_mapping(pool)), content_type="application/json"
)
if format_ == "bc2": if format_ == "bc2":
return Response(json.dumps(mirror_sites(pool)), content_type="application/json") return Response(json.dumps(mirror_sites(pool)), content_type="application/json")
if format_ == "bridgelines": if format_ == "bridgelines":
return Response(json.dumps(bridgelines(pool)), content_type="application/json") return Response(json.dumps(bridgelines(pool)), content_type="application/json")
if format_ == "rdr": if format_ == "rdr":
return Response(json.dumps(redirector_data(pool)), content_type="application/json") return Response(
json.dumps(redirector_data(pool)), content_type="application/json"
)
return response_404(message="Format not found") return response_404(message="Format not found")
@bp.route("/destroy/<list_id>", methods=['GET', 'POST']) @bp.route("/destroy/<list_id>", methods=["GET", "POST"])
def list_destroy(list_id: int) -> ResponseReturnValue: def list_destroy(list_id: int) -> ResponseReturnValue:
list_ = MirrorList.query.filter(MirrorList.id == list_id, MirrorList.destroyed.is_(None)).first() list_ = MirrorList.query.filter(
MirrorList.id == list_id, MirrorList.destroyed.is_(None)
).first()
if list_ is None: if list_ is None:
return response_404("The requested bridge configuration could not be found.") return response_404("The requested bridge configuration could not be found.")
return view_lifecycle( return view_lifecycle(
@ -82,12 +88,12 @@ def list_destroy(list_id: int) -> ResponseReturnValue:
success_message="This list will no longer be updated and may be deleted depending on the provider.", success_message="This list will no longer be updated and may be deleted depending on the provider.",
section="list", section="list",
resource=list_, resource=list_,
action="destroy" action="destroy",
) )
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=['GET', 'POST']) @bp.route("/new/<group_id>", methods=["GET", "POST"])
def list_new(group_id: Optional[int] = None) -> ResponseReturnValue: def list_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = NewMirrorListForm() form = NewMirrorListForm()
form.provider.choices = list(MirrorList.providers_supported.items()) form.provider.choices = list(MirrorList.providers_supported.items())
@ -116,43 +122,53 @@ def list_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return redirect(url_for("portal.list.list_list")) return redirect(url_for("portal.list.list_list"))
if group_id: if group_id:
form.group.data = group_id form.group.data = group_id
return render_template("new.html.j2", return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
form=form,
**_SECTION_TEMPLATE_VARS)
class NewMirrorListForm(FlaskForm): # type: ignore class NewMirrorListForm(FlaskForm): # type: ignore
pool = SelectField('Resource Pool', validators=[DataRequired()]) pool = SelectField("Resource Pool", validators=[DataRequired()])
provider = SelectField('Provider', validators=[DataRequired()]) provider = SelectField("Provider", validators=[DataRequired()])
format = SelectField('Distribution Method', validators=[DataRequired()]) format = SelectField("Distribution Method", validators=[DataRequired()])
encoding = SelectField('Encoding', validators=[DataRequired()]) encoding = SelectField("Encoding", validators=[DataRequired()])
description = StringField('Description', validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
container = StringField('Container', validators=[DataRequired()], container = StringField(
description="GitHub Project, GitLab Project or AWS S3 bucket name.") "Container",
branch = StringField('Git Branch/AWS Region', validators=[DataRequired()], validators=[DataRequired()],
description="For GitHub/GitLab, set this to the desired branch name, e.g. main. For AWS S3, " description="GitHub Project, GitLab Project or AWS S3 bucket name.",
"set this field to the desired region, e.g. us-east-1.") )
role = StringField('Role ARN', branch = StringField(
description="(Optional) ARN for IAM role to assume for interaction with the S3 bucket.") "Git Branch/AWS Region",
filename = StringField('Filename', validators=[DataRequired()]) validators=[DataRequired()],
submit = SubmitField('Save Changes') description="For GitHub/GitLab, set this to the desired branch name, e.g. main. For AWS S3, "
"set this field to the desired region, e.g. us-east-1.",
)
role = StringField(
"Role ARN",
description="(Optional) ARN for IAM role to assume for interaction with the S3 bucket.",
)
filename = StringField("Filename", validators=[DataRequired()])
submit = SubmitField("Save Changes")
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.pool.choices = [ self.pool.choices = [(pool.id, pool.pool_name) for pool in Pool.query.all()]
(pool.id, pool.pool_name) for pool in Pool.query.all()
]
@bp.route('/edit/<list_id>', methods=['GET', 'POST']) @bp.route("/edit/<list_id>", methods=["GET", "POST"])
def list_edit(list_id: int) -> ResponseReturnValue: def list_edit(list_id: int) -> ResponseReturnValue:
list_: Optional[MirrorList] = MirrorList.query.filter(MirrorList.id == list_id).first() list_: Optional[MirrorList] = MirrorList.query.filter(
MirrorList.id == list_id
).first()
if list_ is None: if list_ is None:
return Response(render_template("error.html.j2", return Response(
header="404 Distribution List Not Found", render_template(
message="The requested distribution list could not be found.", "error.html.j2",
**_SECTION_TEMPLATE_VARS), header="404 Distribution List Not Found",
status=404) message="The requested distribution list could not be found.",
**_SECTION_TEMPLATE_VARS
),
status=404,
)
form = NewMirrorListForm( form = NewMirrorListForm(
pool=list_.pool_id, pool=list_.pool_id,
provider=list_.provider, provider=list_.provider,
@ -162,7 +178,7 @@ def list_edit(list_id: int) -> ResponseReturnValue:
container=list_.container, container=list_.container,
branch=list_.branch, branch=list_.branch,
role=list_.role, role=list_.role,
filename=list_.filename filename=list_.filename,
) )
form.provider.choices = list(MirrorList.providers_supported.items()) form.provider.choices = list(MirrorList.providers_supported.items())
form.format.choices = list(MirrorList.formats_supported.items()) form.format.choices = list(MirrorList.formats_supported.items())
@ -182,7 +198,10 @@ def list_edit(list_id: int) -> ResponseReturnValue:
db.session.commit() db.session.commit()
flash("Saved changes to group.", "success") flash("Saved changes to group.", "success")
except exc.SQLAlchemyError: except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the distribution list.", "danger") flash(
return render_template("distlist.html.j2", "An error occurred saving the changes to the distribution list.",
list=list_, form=form, "danger",
**_SECTION_TEMPLATE_VARS) )
return render_template(
"distlist.html.j2", list=list_, form=form, **_SECTION_TEMPLATE_VARS
)

View file

@ -9,13 +9,13 @@ from app.portal.util import response_404, view_lifecycle
bp = Blueprint("onion", __name__) bp = Blueprint("onion", __name__)
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=['GET', 'POST']) @bp.route("/new/<group_id>", methods=["GET", "POST"])
def onion_new(group_id: Optional[int] = None) -> ResponseReturnValue: def onion_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return redirect("/ui/web/onions/new") return redirect("/ui/web/onions/new")
@bp.route('/edit/<onion_id>', methods=['GET', 'POST']) @bp.route("/edit/<onion_id>", methods=["GET", "POST"])
def onion_edit(onion_id: int) -> ResponseReturnValue: def onion_edit(onion_id: int) -> ResponseReturnValue:
return redirect("/ui/web/onions/edit/{}".format(onion_id)) return redirect("/ui/web/onions/edit/{}".format(onion_id))
@ -25,9 +25,11 @@ def onion_list() -> ResponseReturnValue:
return redirect("/ui/web/onions") return redirect("/ui/web/onions")
@bp.route("/destroy/<onion_id>", methods=['GET', 'POST']) @bp.route("/destroy/<onion_id>", methods=["GET", "POST"])
def onion_destroy(onion_id: str) -> ResponseReturnValue: def onion_destroy(onion_id: str) -> ResponseReturnValue:
onion = Onion.query.filter(Onion.id == int(onion_id), Onion.destroyed.is_(None)).first() onion = Onion.query.filter(
Onion.id == int(onion_id), Onion.destroyed.is_(None)
).first()
if onion is None: if onion is None:
return response_404("The requested onion service could not be found.") return response_404("The requested onion service could not be found.")
return view_lifecycle( return view_lifecycle(
@ -37,5 +39,5 @@ def onion_destroy(onion_id: str) -> ResponseReturnValue:
success_view="portal.onion.onion_list", success_view="portal.onion.onion_list",
section="onion", section="onion",
resource=onion, resource=onion,
action="destroy" action="destroy",
) )

View file

@ -4,13 +4,11 @@ from typing import List, Optional
import requests import requests
import sqlalchemy import sqlalchemy
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from sqlalchemy import exc from sqlalchemy import exc
from wtforms import (BooleanField, IntegerField, SelectField, StringField, from wtforms import BooleanField, IntegerField, SelectField, StringField, SubmitField
SubmitField)
from wtforms.validators import DataRequired from wtforms.validators import DataRequired
from app.extensions import db from app.extensions import db
@ -22,29 +20,31 @@ bp = Blueprint("origin", __name__)
class NewOriginForm(FlaskForm): # type: ignore class NewOriginForm(FlaskForm): # type: ignore
domain_name = StringField('Domain Name', validators=[DataRequired()]) domain_name = StringField("Domain Name", validators=[DataRequired()])
description = StringField('Description', validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
group = SelectField('Group', validators=[DataRequired()]) group = SelectField("Group", validators=[DataRequired()])
auto_rotate = BooleanField("Enable auto-rotation?", default=True) auto_rotate = BooleanField("Enable auto-rotation?", default=True)
smart_proxy = BooleanField("Requires smart proxy?", default=False) smart_proxy = BooleanField("Requires smart proxy?", default=False)
asset_domain = BooleanField("Used to host assets for other domains?", default=False) asset_domain = BooleanField("Used to host assets for other domains?", default=False)
submit = SubmitField('Save Changes') submit = SubmitField("Save Changes")
class EditOriginForm(FlaskForm): # type: ignore[misc] class EditOriginForm(FlaskForm): # type: ignore[misc]
description = StringField('Description', validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
group = SelectField('Group', validators=[DataRequired()]) group = SelectField("Group", validators=[DataRequired()])
auto_rotate = BooleanField("Enable auto-rotation?") auto_rotate = BooleanField("Enable auto-rotation?")
smart_proxy = BooleanField("Requires smart proxy?") smart_proxy = BooleanField("Requires smart proxy?")
asset_domain = BooleanField("Used to host assets for other domains?", default=False) asset_domain = BooleanField("Used to host assets for other domains?", default=False)
risk_level_override = BooleanField("Force Risk Level Override?") risk_level_override = BooleanField("Force Risk Level Override?")
risk_level_override_number = IntegerField("Forced Risk Level", description="Number from 0 to 20", default=0) risk_level_override_number = IntegerField(
submit = SubmitField('Save Changes') "Forced Risk Level", description="Number from 0 to 20", default=0
)
submit = SubmitField("Save Changes")
class CountrySelectForm(FlaskForm): # type: ignore[misc] class CountrySelectForm(FlaskForm): # type: ignore[misc]
country = SelectField("Country", validators=[DataRequired()]) country = SelectField("Country", validators=[DataRequired()])
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
def final_domain_name(domain_name: str) -> str: def final_domain_name(domain_name: str) -> str:
@ -53,8 +53,8 @@ def final_domain_name(domain_name: str) -> str:
return urllib.parse.urlparse(r.url).netloc return urllib.parse.urlparse(r.url).netloc
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=['GET', 'POST']) @bp.route("/new/<group_id>", methods=["GET", "POST"])
def origin_new(group_id: Optional[int] = None) -> ResponseReturnValue: def origin_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = NewOriginForm() form = NewOriginForm()
form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] form.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
@ -81,22 +81,28 @@ def origin_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return render_template("new.html.j2", section="origin", form=form) return render_template("new.html.j2", section="origin", form=form)
@bp.route('/edit/<origin_id>', methods=['GET', 'POST']) @bp.route("/edit/<origin_id>", methods=["GET", "POST"])
def origin_edit(origin_id: int) -> ResponseReturnValue: def origin_edit(origin_id: int) -> ResponseReturnValue:
origin: Optional[Origin] = Origin.query.filter(Origin.id == origin_id).first() origin: Optional[Origin] = Origin.query.filter(Origin.id == origin_id).first()
if origin is None: if origin is None:
return Response(render_template("error.html.j2", return Response(
section="origin", render_template(
header="404 Origin Not Found", "error.html.j2",
message="The requested origin could not be found."), section="origin",
status=404) header="404 Origin Not Found",
form = EditOriginForm(group=origin.group_id, message="The requested origin could not be found.",
description=origin.description, ),
auto_rotate=origin.auto_rotation, status=404,
smart_proxy=origin.smart, )
asset_domain=origin.assets, form = EditOriginForm(
risk_level_override=origin.risk_level_override is not None, group=origin.group_id,
risk_level_override_number=origin.risk_level_override) description=origin.description,
auto_rotate=origin.auto_rotation,
smart_proxy=origin.smart,
asset_domain=origin.assets,
risk_level_override=origin.risk_level_override is not None,
risk_level_override_number=origin.risk_level_override,
)
form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] form.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
if form.validate_on_submit(): if form.validate_on_submit():
origin.group_id = form.group.data origin.group_id = form.group.data
@ -114,41 +120,47 @@ def origin_edit(origin_id: int) -> ResponseReturnValue:
flash(f"Saved changes for origin {origin.domain_name}.", "success") flash(f"Saved changes for origin {origin.domain_name}.", "success")
except exc.SQLAlchemyError: except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the origin.", "danger") flash("An error occurred saving the changes to the origin.", "danger")
return render_template("origin.html.j2", return render_template("origin.html.j2", section="origin", origin=origin, form=form)
section="origin",
origin=origin, form=form)
@bp.route("/list") @bp.route("/list")
def origin_list() -> ResponseReturnValue: def origin_list() -> ResponseReturnValue:
origins: List[Origin] = Origin.query.order_by(Origin.domain_name).all() origins: List[Origin] = Origin.query.order_by(Origin.domain_name).all()
return render_template("list.html.j2", return render_template(
section="origin", "list.html.j2",
title="Web Origins", section="origin",
item="origin", title="Web Origins",
new_link=url_for("portal.origin.origin_new"), item="origin",
items=origins, new_link=url_for("portal.origin.origin_new"),
extra_buttons=[{ items=origins,
"link": url_for("portal.origin.origin_onion"), extra_buttons=[
"text": "Onion services", {
"style": "onion" "link": url_for("portal.origin.origin_onion"),
}]) "text": "Onion services",
"style": "onion",
}
],
)
@bp.route("/onion") @bp.route("/onion")
def origin_onion() -> ResponseReturnValue: def origin_onion() -> ResponseReturnValue:
origins = Origin.query.order_by(Origin.domain_name).all() origins = Origin.query.order_by(Origin.domain_name).all()
return render_template("list.html.j2", return render_template(
section="origin", "list.html.j2",
title="Onion Sites", section="origin",
item="onion service", title="Onion Sites",
new_link=url_for("portal.onion.onion_new"), item="onion service",
items=origins) new_link=url_for("portal.onion.onion_new"),
items=origins,
)
@bp.route("/destroy/<origin_id>", methods=['GET', 'POST']) @bp.route("/destroy/<origin_id>", methods=["GET", "POST"])
def origin_destroy(origin_id: int) -> ResponseReturnValue: def origin_destroy(origin_id: int) -> ResponseReturnValue:
origin = Origin.query.filter(Origin.id == origin_id, Origin.destroyed.is_(None)).first() origin = Origin.query.filter(
Origin.id == origin_id, Origin.destroyed.is_(None)
).first()
if origin is None: if origin is None:
return response_404("The requested origin could not be found.") return response_404("The requested origin could not be found.")
return view_lifecycle( return view_lifecycle(
@ -158,32 +170,44 @@ def origin_destroy(origin_id: int) -> ResponseReturnValue:
success_view="portal.origin.origin_list", success_view="portal.origin.origin_list",
section="origin", section="origin",
resource=origin, resource=origin,
action="destroy" action="destroy",
) )
@bp.route('/country_remove/<origin_id>/<country_id>', methods=['GET', 'POST']) @bp.route("/country_remove/<origin_id>/<country_id>", methods=["GET", "POST"])
def origin_country_remove(origin_id: int, country_id: int) -> ResponseReturnValue: def origin_country_remove(origin_id: int, country_id: int) -> ResponseReturnValue:
origin = Origin.query.filter(Origin.id == origin_id).first() origin = Origin.query.filter(Origin.id == origin_id).first()
if origin is None: if origin is None:
return Response(render_template("error.html.j2", return Response(
section="origin", render_template(
header="404 Pool Not Found", "error.html.j2",
message="The requested origin could not be found."), section="origin",
status=404) header="404 Pool Not Found",
message="The requested origin could not be found.",
),
status=404,
)
country = Country.query.filter(Country.id == country_id).first() country = Country.query.filter(Country.id == country_id).first()
if country is None: if country is None:
return Response(render_template("error.html.j2", return Response(
section="origin", render_template(
header="404 Country Not Found", "error.html.j2",
message="The requested country could not be found."), section="origin",
status=404) header="404 Country Not Found",
message="The requested country could not be found.",
),
status=404,
)
if country not in origin.countries: if country not in origin.countries:
return Response(render_template("error.html.j2", return Response(
section="origin", render_template(
header="404 Country Not In Pool", "error.html.j2",
message="The requested country could not be found in the specified origin."), section="origin",
status=404) header="404 Country Not In Pool",
message="The requested country could not be found in the specified origin.",
),
status=404,
)
form = LifecycleForm() form = LifecycleForm()
if form.validate_on_submit(): if form.validate_on_submit():
origin.countries.remove(country) origin.countries.remove(country)
@ -193,32 +217,45 @@ def origin_country_remove(origin_id: int, country_id: int) -> ResponseReturnValu
return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id)) return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id))
except sqlalchemy.exc.SQLAlchemyError: except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the origin.", "danger") flash("An error occurred saving the changes to the origin.", "danger")
return render_template("lifecycle.html.j2", return render_template(
header=f"Remove {country.description} from the {origin.domain_name} origin?", "lifecycle.html.j2",
message="Stop monitoring in this country.", header=f"Remove {country.description} from the {origin.domain_name} origin?",
section="origin", message="Stop monitoring in this country.",
origin=origin, form=form) section="origin",
origin=origin,
form=form,
)
@bp.route('/country_add/<origin_id>', methods=['GET', 'POST']) @bp.route("/country_add/<origin_id>", methods=["GET", "POST"])
def origin_country_add(origin_id: int) -> ResponseReturnValue: def origin_country_add(origin_id: int) -> ResponseReturnValue:
origin = Origin.query.filter(Origin.id == origin_id).first() origin = Origin.query.filter(Origin.id == origin_id).first()
if origin is None: if origin is None:
return Response(render_template("error.html.j2", return Response(
section="origin", render_template(
header="404 Origin Not Found", "error.html.j2",
message="The requested origin could not be found."), section="origin",
status=404) header="404 Origin Not Found",
message="The requested origin could not be found.",
),
status=404,
)
form = CountrySelectForm() form = CountrySelectForm()
form.country.choices = [(x.id, f"{x.country_code} - {x.description}") for x in Country.query.all()] form.country.choices = [
(x.id, f"{x.country_code} - {x.description}") for x in Country.query.all()
]
if form.validate_on_submit(): if form.validate_on_submit():
country = Country.query.filter(Country.id == form.country.data).first() country = Country.query.filter(Country.id == form.country.data).first()
if country is None: if country is None:
return Response(render_template("error.html.j2", return Response(
section="origin", render_template(
header="404 Country Not Found", "error.html.j2",
message="The requested country could not be found."), section="origin",
status=404) header="404 Country Not Found",
message="The requested country could not be found.",
),
status=404,
)
origin.countries.append(country) origin.countries.append(country)
try: try:
db.session.commit() db.session.commit()
@ -226,8 +263,11 @@ def origin_country_add(origin_id: int) -> ResponseReturnValue:
return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id)) return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id))
except sqlalchemy.exc.SQLAlchemyError: except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the origin.", "danger") flash("An error occurred saving the changes to the origin.", "danger")
return render_template("lifecycle.html.j2", return render_template(
header=f"Add a country to {origin.domain_name}", "lifecycle.html.j2",
message="Enable monitoring from this country:", header=f"Add a country to {origin.domain_name}",
section="origin", message="Enable monitoring from this country:",
origin=origin, form=form) section="origin",
origin=origin,
form=form,
)

View file

@ -3,8 +3,7 @@ import secrets
from datetime import datetime, timezone from datetime import datetime, timezone
import sqlalchemy import sqlalchemy
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import SelectField, StringField, SubmitField from wtforms import SelectField, StringField, SubmitField
@ -21,41 +20,50 @@ class NewPoolForm(FlaskForm): # type: ignore[misc]
group_name = StringField("Short Name", validators=[DataRequired()]) group_name = StringField("Short Name", validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
redirector_domain = StringField("Redirector Domain") redirector_domain = StringField("Redirector Domain")
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
class EditPoolForm(FlaskForm): # type: ignore[misc] class EditPoolForm(FlaskForm): # type: ignore[misc]
description = StringField("Description", validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
redirector_domain = StringField("Redirector Domain") redirector_domain = StringField("Redirector Domain")
api_key = StringField("API Key", description=("Any change to this field (e.g. clearing it) will result in the " api_key = StringField(
"API key being regenerated.")) "API Key",
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) description=(
"Any change to this field (e.g. clearing it) will result in the "
"API key being regenerated."
),
)
submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
class GroupSelectForm(FlaskForm): # type: ignore[misc] class GroupSelectForm(FlaskForm): # type: ignore[misc]
group = SelectField("Group", validators=[DataRequired()]) group = SelectField("Group", validators=[DataRequired()])
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"}) submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
@bp.route("/list") @bp.route("/list")
def pool_list() -> ResponseReturnValue: def pool_list() -> ResponseReturnValue:
pools = Pool.query.order_by(Pool.pool_name).all() pools = Pool.query.order_by(Pool.pool_name).all()
return render_template("list.html.j2", return render_template(
section="pool", "list.html.j2",
title="Resource Pools", section="pool",
item="pool", title="Resource Pools",
items=pools, item="pool",
new_link=url_for("portal.pool.pool_new")) items=pools,
new_link=url_for("portal.pool.pool_new"),
)
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
def pool_new() -> ResponseReturnValue: def pool_new() -> ResponseReturnValue:
form = NewPoolForm() form = NewPoolForm()
if form.validate_on_submit(): if form.validate_on_submit():
pool = Pool() pool = Pool()
pool.pool_name = form.group_name.data pool.pool_name = form.group_name.data
pool.description = form.description.data pool.description = form.description.data
pool.redirector_domain = form.redirector_domain.data if form.redirector_domain.data != "" else None pool.redirector_domain = (
form.redirector_domain.data if form.redirector_domain.data != "" else None
)
pool.api_key = secrets.token_urlsafe(nbytes=32) pool.api_key = secrets.token_urlsafe(nbytes=32)
pool.added = datetime.now(timezone.utc) pool.added = datetime.now(timezone.utc)
pool.updated = datetime.now(timezone.utc) pool.updated = datetime.now(timezone.utc)
@ -71,21 +79,29 @@ def pool_new() -> ResponseReturnValue:
return render_template("new.html.j2", section="pool", form=form) return render_template("new.html.j2", section="pool", form=form)
@bp.route('/edit/<pool_id>', methods=['GET', 'POST']) @bp.route("/edit/<pool_id>", methods=["GET", "POST"])
def pool_edit(pool_id: int) -> ResponseReturnValue: def pool_edit(pool_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first() pool = Pool.query.filter(Pool.id == pool_id).first()
if pool is None: if pool is None:
return Response(render_template("error.html.j2", return Response(
section="pool", render_template(
header="404 Pool Not Found", "error.html.j2",
message="The requested pool could not be found."), section="pool",
status=404) header="404 Pool Not Found",
form = EditPoolForm(description=pool.description, message="The requested pool could not be found.",
api_key=pool.api_key, ),
redirector_domain=pool.redirector_domain) status=404,
)
form = EditPoolForm(
description=pool.description,
api_key=pool.api_key,
redirector_domain=pool.redirector_domain,
)
if form.validate_on_submit(): if form.validate_on_submit():
pool.description = form.description.data pool.description = form.description.data
pool.redirector_domain = form.redirector_domain.data if form.redirector_domain.data != "" else None pool.redirector_domain = (
form.redirector_domain.data if form.redirector_domain.data != "" else None
)
if form.api_key.data != pool.api_key: if form.api_key.data != pool.api_key:
pool.api_key = secrets.token_urlsafe(nbytes=32) pool.api_key = secrets.token_urlsafe(nbytes=32)
form.api_key.data = pool.api_key form.api_key.data = pool.api_key
@ -95,33 +111,43 @@ def pool_edit(pool_id: int) -> ResponseReturnValue:
flash("Saved changes to pool.", "success") flash("Saved changes to pool.", "success")
except sqlalchemy.exc.SQLAlchemyError: except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the pool.", "danger") flash("An error occurred saving the changes to the pool.", "danger")
return render_template("pool.html.j2", return render_template("pool.html.j2", section="pool", pool=pool, form=form)
section="pool",
pool=pool, form=form)
@bp.route('/group_remove/<pool_id>/<group_id>', methods=['GET', 'POST']) @bp.route("/group_remove/<pool_id>/<group_id>", methods=["GET", "POST"])
def pool_group_remove(pool_id: int, group_id: int) -> ResponseReturnValue: def pool_group_remove(pool_id: int, group_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first() pool = Pool.query.filter(Pool.id == pool_id).first()
if pool is None: if pool is None:
return Response(render_template("error.html.j2", return Response(
section="pool", render_template(
header="404 Pool Not Found", "error.html.j2",
message="The requested pool could not be found."), section="pool",
status=404) header="404 Pool Not Found",
message="The requested pool could not be found.",
),
status=404,
)
group = Group.query.filter(Group.id == group_id).first() group = Group.query.filter(Group.id == group_id).first()
if group is None: if group is None:
return Response(render_template("error.html.j2", return Response(
section="pool", render_template(
header="404 Group Not Found", "error.html.j2",
message="The requested group could not be found."), section="pool",
status=404) header="404 Group Not Found",
message="The requested group could not be found.",
),
status=404,
)
if group not in pool.groups: if group not in pool.groups:
return Response(render_template("error.html.j2", return Response(
section="pool", render_template(
header="404 Group Not In Pool", "error.html.j2",
message="The requested group could not be found in the specified pool."), section="pool",
status=404) header="404 Group Not In Pool",
message="The requested group could not be found in the specified pool.",
),
status=404,
)
form = LifecycleForm() form = LifecycleForm()
if form.validate_on_submit(): if form.validate_on_submit():
pool.groups.remove(group) pool.groups.remove(group)
@ -131,32 +157,43 @@ def pool_group_remove(pool_id: int, group_id: int) -> ResponseReturnValue:
return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id)) return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id))
except sqlalchemy.exc.SQLAlchemyError: except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the pool.", "danger") flash("An error occurred saving the changes to the pool.", "danger")
return render_template("lifecycle.html.j2", return render_template(
header=f"Remove {group.group_name} from the {pool.pool_name} pool?", "lifecycle.html.j2",
message="Resources deployed and available in the pool will be destroyed soon.", header=f"Remove {group.group_name} from the {pool.pool_name} pool?",
section="pool", message="Resources deployed and available in the pool will be destroyed soon.",
pool=pool, form=form) section="pool",
pool=pool,
form=form,
)
@bp.route('/group_add/<pool_id>', methods=['GET', 'POST']) @bp.route("/group_add/<pool_id>", methods=["GET", "POST"])
def pool_group_add(pool_id: int) -> ResponseReturnValue: def pool_group_add(pool_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first() pool = Pool.query.filter(Pool.id == pool_id).first()
if pool is None: if pool is None:
return Response(render_template("error.html.j2", return Response(
section="pool", render_template(
header="404 Pool Not Found", "error.html.j2",
message="The requested pool could not be found."), section="pool",
status=404) header="404 Pool Not Found",
message="The requested pool could not be found.",
),
status=404,
)
form = GroupSelectForm() form = GroupSelectForm()
form.group.choices = [(x.id, x.group_name) for x in Group.query.all()] form.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
if form.validate_on_submit(): if form.validate_on_submit():
group = Group.query.filter(Group.id == form.group.data).first() group = Group.query.filter(Group.id == form.group.data).first()
if group is None: if group is None:
return Response(render_template("error.html.j2", return Response(
section="pool", render_template(
header="404 Group Not Found", "error.html.j2",
message="The requested group could not be found."), section="pool",
status=404) header="404 Group Not Found",
message="The requested group could not be found.",
),
status=404,
)
pool.groups.append(group) pool.groups.append(group)
try: try:
db.session.commit() db.session.commit()
@ -164,8 +201,11 @@ def pool_group_add(pool_id: int) -> ResponseReturnValue:
return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id)) return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id))
except sqlalchemy.exc.SQLAlchemyError: except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the pool.", "danger") flash("An error occurred saving the changes to the pool.", "danger")
return render_template("lifecycle.html.j2", return render_template(
header=f"Add a group to {pool.pool_name}", "lifecycle.html.j2",
message="Resources will shortly be deployed and available for all origins in this group.", header=f"Add a group to {pool.pool_name}",
section="pool", message="Resources will shortly be deployed and available for all origins in this group.",
pool=pool, form=form) section="pool",
pool=pool,
form=form,
)

View file

@ -1,5 +1,4 @@
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from sqlalchemy import desc from sqlalchemy import desc
@ -12,51 +11,63 @@ bp = Blueprint("proxy", __name__)
@bp.route("/list") @bp.route("/list")
def proxy_list() -> ResponseReturnValue: def proxy_list() -> ResponseReturnValue:
proxies = Proxy.query.filter(Proxy.destroyed.is_(None)).order_by(desc(Proxy.added)).all() proxies = (
return render_template("list.html.j2", Proxy.query.filter(Proxy.destroyed.is_(None)).order_by(desc(Proxy.added)).all()
section="proxy", )
title="Proxies", return render_template(
item="proxy", "list.html.j2", section="proxy", title="Proxies", item="proxy", items=proxies
items=proxies) )
@bp.route("/expire/<proxy_id>", methods=['GET', 'POST']) @bp.route("/expire/<proxy_id>", methods=["GET", "POST"])
def proxy_expire(proxy_id: int) -> ResponseReturnValue: def proxy_expire(proxy_id: int) -> ResponseReturnValue:
proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first() proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first()
if proxy is None: if proxy is None:
return Response(render_template("error.html.j2", return Response(
header="404 Proxy Not Found", render_template(
message="The requested proxy could not be found. It may have already been " "error.html.j2",
"destroyed.")) header="404 Proxy Not Found",
message="The requested proxy could not be found. It may have already been "
"destroyed.",
)
)
form = LifecycleForm() form = LifecycleForm()
if form.validate_on_submit(): if form.validate_on_submit():
proxy.destroy() proxy.destroy()
db.session.commit() db.session.commit()
flash("Proxy will be shortly retired.", "success") flash("Proxy will be shortly retired.", "success")
return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id)) return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id))
return render_template("lifecycle.html.j2", return render_template(
header=f"Expire proxy for {proxy.origin.domain_name} immediately?", "lifecycle.html.j2",
message=proxy.url, header=f"Expire proxy for {proxy.origin.domain_name} immediately?",
section="proxy", message=proxy.url,
form=form) section="proxy",
form=form,
)
@bp.route("/block/<proxy_id>", methods=['GET', 'POST']) @bp.route("/block/<proxy_id>", methods=["GET", "POST"])
def proxy_block(proxy_id: int) -> ResponseReturnValue: def proxy_block(proxy_id: int) -> ResponseReturnValue:
proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first() proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first()
if proxy is None: if proxy is None:
return Response(render_template("error.html.j2", return Response(
header="404 Proxy Not Found", render_template(
message="The requested proxy could not be found. It may have already been " "error.html.j2",
"destroyed.")) header="404 Proxy Not Found",
message="The requested proxy could not be found. It may have already been "
"destroyed.",
)
)
form = LifecycleForm() form = LifecycleForm()
if form.validate_on_submit(): if form.validate_on_submit():
proxy.deprecate(reason="manual") proxy.deprecate(reason="manual")
db.session.commit() db.session.commit()
flash("Proxy will be shortly replaced.", "success") flash("Proxy will be shortly replaced.", "success")
return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id)) return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id))
return render_template("lifecycle.html.j2", return render_template(
header=f"Mark proxy for {proxy.origin.domain_name} as blocked?", "lifecycle.html.j2",
message=proxy.url, header=f"Mark proxy for {proxy.origin.domain_name} as blocked?",
section="proxy", message=proxy.url,
form=form) section="proxy",
form=form,
)

View file

@ -20,12 +20,12 @@ def generate_subqueries():
deprecations_24hr_subquery = ( deprecations_24hr_subquery = (
db.session.query( db.session.query(
DeprecationAlias.resource_id, DeprecationAlias.resource_id,
func.count(DeprecationAlias.resource_id).label('deprecations_24hr') func.count(DeprecationAlias.resource_id).label("deprecations_24hr"),
) )
.filter( .filter(
DeprecationAlias.reason.like('block_%'), DeprecationAlias.reason.like("block_%"),
DeprecationAlias.deprecated_at >= now - timedelta(hours=24), DeprecationAlias.deprecated_at >= now - timedelta(hours=24),
DeprecationAlias.resource_type == 'Proxy' DeprecationAlias.resource_type == "Proxy",
) )
.group_by(DeprecationAlias.resource_id) .group_by(DeprecationAlias.resource_id)
.subquery() .subquery()
@ -33,12 +33,12 @@ def generate_subqueries():
deprecations_72hr_subquery = ( deprecations_72hr_subquery = (
db.session.query( db.session.query(
DeprecationAlias.resource_id, DeprecationAlias.resource_id,
func.count(DeprecationAlias.resource_id).label('deprecations_72hr') func.count(DeprecationAlias.resource_id).label("deprecations_72hr"),
) )
.filter( .filter(
DeprecationAlias.reason.like('block_%'), DeprecationAlias.reason.like("block_%"),
DeprecationAlias.deprecated_at >= now - timedelta(hours=72), DeprecationAlias.deprecated_at >= now - timedelta(hours=72),
DeprecationAlias.resource_type == 'Proxy' DeprecationAlias.resource_type == "Proxy",
) )
.group_by(DeprecationAlias.resource_id) .group_by(DeprecationAlias.resource_id)
.subquery() .subquery()
@ -52,13 +52,23 @@ def countries_report():
return ( return (
db.session.query( db.session.query(
Country, Country,
func.coalesce(func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0).label('total_deprecations_24hr'), func.coalesce(
func.coalesce(func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0).label('total_deprecations_72hr') func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0
).label("total_deprecations_24hr"),
func.coalesce(
func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0
).label("total_deprecations_72hr"),
) )
.join(Origin, Country.origins) .join(Origin, Country.origins)
.join(Proxy, Origin.proxies) .join(Proxy, Origin.proxies)
.outerjoin(deprecations_24hr_subquery, Proxy.id == deprecations_24hr_subquery.c.resource_id) .outerjoin(
.outerjoin(deprecations_72hr_subquery, Proxy.id == deprecations_72hr_subquery.c.resource_id) deprecations_24hr_subquery,
Proxy.id == deprecations_24hr_subquery.c.resource_id,
)
.outerjoin(
deprecations_72hr_subquery,
Proxy.id == deprecations_72hr_subquery.c.resource_id,
)
.group_by(Country.id) .group_by(Country.id)
.all() .all()
) )
@ -70,12 +80,22 @@ def origins_report():
return ( return (
db.session.query( db.session.query(
Origin, Origin,
func.coalesce(func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0).label('total_deprecations_24hr'), func.coalesce(
func.coalesce(func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0).label('total_deprecations_72hr') func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0
).label("total_deprecations_24hr"),
func.coalesce(
func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0
).label("total_deprecations_72hr"),
) )
.outerjoin(Proxy, Origin.proxies) .outerjoin(Proxy, Origin.proxies)
.outerjoin(deprecations_24hr_subquery, Proxy.id == deprecations_24hr_subquery.c.resource_id) .outerjoin(
.outerjoin(deprecations_72hr_subquery, Proxy.id == deprecations_72hr_subquery.c.resource_id) deprecations_24hr_subquery,
Proxy.id == deprecations_24hr_subquery.c.resource_id,
)
.outerjoin(
deprecations_72hr_subquery,
Proxy.id == deprecations_72hr_subquery.c.resource_id,
)
.filter(Origin.destroyed.is_(None)) .filter(Origin.destroyed.is_(None))
.group_by(Origin.id) .group_by(Origin.id)
.order_by(desc("total_deprecations_24hr")) .order_by(desc("total_deprecations_24hr"))
@ -83,26 +103,37 @@ def origins_report():
) )
@report.app_template_filter('country_name') @report.app_template_filter("country_name")
def country_description_filter(country_code): def country_description_filter(country_code):
country = Country.query.filter_by(country_code=country_code).first() country = Country.query.filter_by(country_code=country_code).first()
return country.description if country else None return country.description if country else None
@report.route("/blocks", methods=['GET']) @report.route("/blocks", methods=["GET"])
def report_blocks() -> ResponseReturnValue: def report_blocks() -> ResponseReturnValue:
blocked_today = db.session.query( # type: ignore[no-untyped-call] blocked_today = (
Origin.domain_name, db.session.query( # type: ignore[no-untyped-call]
Origin.description, Origin.domain_name,
Proxy.added, Origin.description,
Proxy.deprecated, Proxy.added,
Proxy.deprecation_reason Proxy.deprecated,
).join(Origin, Origin.id == Proxy.origin_id Proxy.deprecation_reason,
).filter(and_(Proxy.deprecated > datetime.now(tz=timezone.utc) - timedelta(days=1), )
Proxy.deprecation_reason.like('block_%'))).all() .join(Origin, Origin.id == Proxy.origin_id)
.filter(
and_(
Proxy.deprecated > datetime.now(tz=timezone.utc) - timedelta(days=1),
Proxy.deprecation_reason.like("block_%"),
)
)
.all()
)
return render_template("report_blocks.html.j2", return render_template(
blocked_today=blocked_today, "report_blocks.html.j2",
origins=sorted(origins_report(), key=lambda o: o[1], reverse=True), blocked_today=blocked_today,
countries=sorted(countries_report(), key=lambda c: c[0].risk_level, reverse=True), origins=sorted(origins_report(), key=lambda o: o[1], reverse=True),
) countries=sorted(
countries_report(), key=lambda c: c[0].risk_level, reverse=True
),
)

View file

@ -9,9 +9,15 @@ bp = Blueprint("smart_proxy", __name__)
@bp.route("/list") @bp.route("/list")
def smart_proxy_list() -> ResponseReturnValue: def smart_proxy_list() -> ResponseReturnValue:
instances = SmartProxy.query.filter(SmartProxy.destroyed.is_(None)).order_by(desc(SmartProxy.added)).all() instances = (
return render_template("list.html.j2", SmartProxy.query.filter(SmartProxy.destroyed.is_(None))
section="smart_proxy", .order_by(desc(SmartProxy.added))
title="Smart Proxy Instances", .all()
item="smart proxy", )
items=instances) return render_template(
"list.html.j2",
section="smart_proxy",
title="Smart Proxy Instances",
item="smart proxy",
items=instances,
)

View file

@ -2,13 +2,19 @@ import logging
from typing import Any, List, Optional from typing import Any, List, Optional
import sqlalchemy.exc import sqlalchemy.exc
from flask import (Blueprint, Response, current_app, flash, redirect, from flask import (
render_template, url_for) Blueprint,
Response,
current_app,
flash,
redirect,
render_template,
url_for,
)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from sqlalchemy import exc from sqlalchemy import exc
from wtforms import (BooleanField, FileField, SelectField, StringField, from wtforms import BooleanField, FileField, SelectField, StringField, SubmitField
SubmitField)
from wtforms.validators import DataRequired from wtforms.validators import DataRequired
from app.brm.static import create_static_origin from app.brm.static import create_static_origin
@ -22,87 +28,99 @@ bp = Blueprint("static", __name__)
class StaticOriginForm(FlaskForm): # type: ignore class StaticOriginForm(FlaskForm): # type: ignore
description = StringField( description = StringField(
'Description', "Description",
validators=[DataRequired()], validators=[DataRequired()],
description='Enter a brief description of the static website that you are creating in this field. This is ' description="Enter a brief description of the static website that you are creating in this field. This is "
'also a required field.' "also a required field.",
) )
group = SelectField( group = SelectField(
'Group', "Group",
validators=[DataRequired()], validators=[DataRequired()],
description='Select the group that you want the origin to belong to from the drop-down menu in this field. ' description="Select the group that you want the origin to belong to from the drop-down menu in this field. "
'This is a required field.' "This is a required field.",
) )
storage_cloud_account = SelectField( storage_cloud_account = SelectField(
'Storage Cloud Account', "Storage Cloud Account",
validators=[DataRequired()], validators=[DataRequired()],
description='Select the cloud account that you want the origin to be deployed to from the drop-down menu in ' description="Select the cloud account that you want the origin to be deployed to from the drop-down menu in "
'this field. This is a required field.' "this field. This is a required field.",
) )
source_cloud_account = SelectField( source_cloud_account = SelectField(
'Source Cloud Account', "Source Cloud Account",
validators=[DataRequired()], validators=[DataRequired()],
description='Select the cloud account that will be used to modify the source repository for the web content ' description="Select the cloud account that will be used to modify the source repository for the web content "
'for this static origin. This is a required field.' "for this static origin. This is a required field.",
) )
source_project = StringField( source_project = StringField(
'Source Project', "Source Project",
validators=[DataRequired()], validators=[DataRequired()],
description='GitLab project path.' description="GitLab project path.",
) )
auto_rotate = BooleanField( auto_rotate = BooleanField(
'Auto-Rotate', "Auto-Rotate",
default=True, default=True,
description='Select this field if you want to enable auto-rotation for the mirror. This means that the mirror ' description="Select this field if you want to enable auto-rotation for the mirror. This means that the mirror "
'will automatically redeploy with a new domain name if it is detected to be blocked. This field ' "will automatically redeploy with a new domain name if it is detected to be blocked. This field "
'is optional and is enabled by default.' "is optional and is enabled by default.",
) )
matrix_homeserver = SelectField( matrix_homeserver = SelectField(
'Matrix Homeserver', "Matrix Homeserver",
description='Select the Matrix homeserver from the drop-down box to enable Keanu Convene on mirrors of this ' description="Select the Matrix homeserver from the drop-down box to enable Keanu Convene on mirrors of this "
'static origin.' "static origin.",
) )
keanu_convene_path = StringField( keanu_convene_path = StringField(
'Keanu Convene Path', "Keanu Convene Path",
default='talk', default="talk",
description='Enter the subdirectory to present the Keanu Convene application at on the mirror. This defaults ' description="Enter the subdirectory to present the Keanu Convene application at on the mirror. This defaults "
'to "talk".' 'to "talk".',
) )
keanu_convene_logo = FileField( keanu_convene_logo = FileField(
'Keanu Convene Logo', "Keanu Convene Logo", description="Logo to use for Keanu Convene"
description='Logo to use for Keanu Convene'
) )
keanu_convene_color = StringField( keanu_convene_color = StringField(
'Keanu Convene Accent Color', "Keanu Convene Accent Color",
default='#0047ab', default="#0047ab",
description='Accent color to use for Keanu Convene (HTML hex code)' description="Accent color to use for Keanu Convene (HTML hex code)",
) )
enable_clean_insights = BooleanField( enable_clean_insights = BooleanField(
'Enable Clean Insights', "Enable Clean Insights",
description='When enabled, a Clean Insights Measurement Proxy endpoint is deployed on the mirror to allow for ' description="When enabled, a Clean Insights Measurement Proxy endpoint is deployed on the mirror to allow for "
'submission of results from any of the supported Clean Insights SDKs.' "submission of results from any of the supported Clean Insights SDKs.",
) )
submit = SubmitField('Save Changes') submit = SubmitField("Save Changes")
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.group.choices = [(x.id, x.group_name) for x in Group.query.all()] self.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
self.storage_cloud_account.choices = [(x.id, f"{x.provider.description} - {x.description}") for x in self.storage_cloud_account.choices = [
CloudAccount.query.filter( (x.id, f"{x.provider.description} - {x.description}")
CloudAccount.provider == CloudProvider.AWS).all()] for x in CloudAccount.query.filter(
self.source_cloud_account.choices = [(x.id, f"{x.provider.description} - {x.description}") for x in CloudAccount.provider == CloudProvider.AWS
CloudAccount.query.filter( ).all()
CloudAccount.provider == CloudProvider.GITLAB).all()] ]
self.matrix_homeserver.choices = [(x, x) for x in current_app.config['MATRIX_HOMESERVERS']] self.source_cloud_account.choices = [
(x.id, f"{x.provider.description} - {x.description}")
for x in CloudAccount.query.filter(
CloudAccount.provider == CloudProvider.GITLAB
).all()
]
self.matrix_homeserver.choices = [
(x, x) for x in current_app.config["MATRIX_HOMESERVERS"]
]
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=['GET', 'POST']) @bp.route("/new/<group_id>", methods=["GET", "POST"])
def static_new(group_id: Optional[int] = None) -> ResponseReturnValue: def static_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = StaticOriginForm() form = StaticOriginForm()
if len(form.source_cloud_account.choices) == 0 or len(form.storage_cloud_account.choices) == 0: if (
flash("You must add at least one AWS account and at least one GitLab account before creating static origins.", len(form.source_cloud_account.choices) == 0
"warning") or len(form.storage_cloud_account.choices) == 0
):
flash(
"You must add at least one AWS account and at least one GitLab account before creating static origins.",
"warning",
)
return redirect(url_for("portal.cloud.cloud_account_list")) return redirect(url_for("portal.cloud.cloud_account_list"))
if form.validate_on_submit(): if form.validate_on_submit():
try: try:
@ -118,16 +136,22 @@ def static_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form.keanu_convene_logo.data, form.keanu_convene_logo.data,
form.keanu_convene_color.data, form.keanu_convene_color.data,
form.enable_clean_insights.data, form.enable_clean_insights.data,
True True,
) )
flash(f"Created new static origin #{static.id}.", "success") flash(f"Created new static origin #{static.id}.", "success")
return redirect(url_for("portal.static.static_edit", static_id=static.id)) return redirect(url_for("portal.static.static_edit", static_id=static.id))
except ValueError as e: # may be returned by create_static_origin and from the int conversion except (
ValueError
) as e: # may be returned by create_static_origin and from the int conversion
logging.warning(e) logging.warning(e)
flash("Failed to create new static origin due to an invalid input.", "danger") flash(
"Failed to create new static origin due to an invalid input.", "danger"
)
return redirect(url_for("portal.static.static_list")) return redirect(url_for("portal.static.static_list"))
except exc.SQLAlchemyError as e: except exc.SQLAlchemyError as e:
flash("Failed to create new static origin due to a database error.", "danger") flash(
"Failed to create new static origin due to a database error.", "danger"
)
logging.warning(e) logging.warning(e)
return redirect(url_for("portal.static.static_list")) return redirect(url_for("portal.static.static_list"))
if group_id: if group_id:
@ -135,24 +159,32 @@ def static_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return render_template("new.html.j2", section="static", form=form) return render_template("new.html.j2", section="static", form=form)
@bp.route('/edit/<static_id>', methods=['GET', 'POST']) @bp.route("/edit/<static_id>", methods=["GET", "POST"])
def static_edit(static_id: int) -> ResponseReturnValue: def static_edit(static_id: int) -> ResponseReturnValue:
static_origin: Optional[StaticOrigin] = StaticOrigin.query.filter(StaticOrigin.id == static_id).first() static_origin: Optional[StaticOrigin] = StaticOrigin.query.filter(
StaticOrigin.id == static_id
).first()
if static_origin is None: if static_origin is None:
return Response(render_template("error.html.j2", return Response(
section="static", render_template(
header="404 Origin Not Found", "error.html.j2",
message="The requested static origin could not be found."), section="static",
status=404) header="404 Origin Not Found",
form = StaticOriginForm(description=static_origin.description, message="The requested static origin could not be found.",
group=static_origin.group_id, ),
storage_cloud_account=static_origin.storage_cloud_account_id, status=404,
source_cloud_account=static_origin.source_cloud_account_id, )
source_project=static_origin.source_project, form = StaticOriginForm(
matrix_homeserver=static_origin.matrix_homeserver, description=static_origin.description,
keanu_convene_path=static_origin.keanu_convene_path, group=static_origin.group_id,
auto_rotate=static_origin.auto_rotate, storage_cloud_account=static_origin.storage_cloud_account_id,
enable_clean_insights=bool(static_origin.clean_insights_backend)) source_cloud_account=static_origin.source_cloud_account_id,
source_project=static_origin.source_project,
matrix_homeserver=static_origin.matrix_homeserver,
keanu_convene_path=static_origin.keanu_convene_path,
auto_rotate=static_origin.auto_rotate,
enable_clean_insights=bool(static_origin.clean_insights_backend),
)
form.group.render_kw = {"disabled": ""} form.group.render_kw = {"disabled": ""}
form.storage_cloud_account.render_kw = {"disabled": ""} form.storage_cloud_account.render_kw = {"disabled": ""}
form.source_cloud_account.render_kw = {"disabled": ""} form.source_cloud_account.render_kw = {"disabled": ""}
@ -167,50 +199,68 @@ def static_edit(static_id: int) -> ResponseReturnValue:
form.keanu_convene_logo.data, form.keanu_convene_logo.data,
form.keanu_convene_color.data, form.keanu_convene_color.data,
form.enable_clean_insights.data, form.enable_clean_insights.data,
True True,
) )
flash("Saved changes to group.", "success") flash("Saved changes to group.", "success")
except ValueError as e: # may be returned by create_static_origin and from the int conversion except (
ValueError
) as e: # may be returned by create_static_origin and from the int conversion
logging.warning(e) logging.warning(e)
flash("An error occurred saving the changes to the static origin due to an invalid input.", "danger") flash(
"An error occurred saving the changes to the static origin due to an invalid input.",
"danger",
)
except exc.SQLAlchemyError as e: except exc.SQLAlchemyError as e:
logging.warning(e) logging.warning(e)
flash("An error occurred saving the changes to the static origin due to a database error.", "danger") flash(
"An error occurred saving the changes to the static origin due to a database error.",
"danger",
)
try: try:
origin = Origin.query.filter_by(domain_name=static_origin.origin_domain_name).one() origin = Origin.query.filter_by(
domain_name=static_origin.origin_domain_name
).one()
proxies = origin.proxies proxies = origin.proxies
except sqlalchemy.exc.NoResultFound: except sqlalchemy.exc.NoResultFound:
proxies = [] proxies = []
return render_template("static.html.j2", return render_template(
section="static", "static.html.j2",
static=static_origin, form=form, section="static",
proxies=proxies) static=static_origin,
form=form,
proxies=proxies,
)
@bp.route("/list") @bp.route("/list")
def static_list() -> ResponseReturnValue: def static_list() -> ResponseReturnValue:
statics: List[StaticOrigin] = StaticOrigin.query.order_by(StaticOrigin.description).all() statics: List[StaticOrigin] = StaticOrigin.query.order_by(
return render_template("list.html.j2", StaticOrigin.description
section="static", ).all()
title="Static Origins", return render_template(
item="static", "list.html.j2",
new_link=url_for("portal.static.static_new"), section="static",
items=statics title="Static Origins",
) item="static",
new_link=url_for("portal.static.static_new"),
items=statics,
)
@bp.route("/destroy/<static_id>", methods=['GET', 'POST']) @bp.route("/destroy/<static_id>", methods=["GET", "POST"])
def static_destroy(static_id: int) -> ResponseReturnValue: def static_destroy(static_id: int) -> ResponseReturnValue:
static = StaticOrigin.query.filter(StaticOrigin.id == static_id, StaticOrigin.destroyed.is_(None)).first() static = StaticOrigin.query.filter(
StaticOrigin.id == static_id, StaticOrigin.destroyed.is_(None)
).first()
if static is None: if static is None:
return response_404("The requested static origin could not be found.") return response_404("The requested static origin could not be found.")
return view_lifecycle( return view_lifecycle(
header=f"Destroy static origin {static.description}", header=f"Destroy static origin {static.description}",
message=static.description, message=static.description,
success_message="All proxies from the destroyed static origin will shortly be destroyed at their providers, " success_message="All proxies from the destroyed static origin will shortly be destroyed at their providers, "
"and the static content will be removed from the cloud provider.", "and the static content will be removed from the cloud provider.",
success_view="portal.static.static_list", success_view="portal.static.static_list",
section="static", section="static",
resource=static, resource=static,
action="destroy" action="destroy",
) )

View file

@ -17,24 +17,30 @@ bp = Blueprint("storage", __name__)
_SECTION_TEMPLATE_VARS = { _SECTION_TEMPLATE_VARS = {
"section": "automation", "section": "automation",
"help_url": "https://bypass.censorship.guide/user/automation.html" "help_url": "https://bypass.censorship.guide/user/automation.html",
} }
class EditStorageForm(FlaskForm): # type: ignore class EditStorageForm(FlaskForm): # type: ignore
force_unlock = BooleanField('Force Unlock') force_unlock = BooleanField("Force Unlock")
submit = SubmitField('Save Changes') submit = SubmitField("Save Changes")
@bp.route('/edit/<storage_key>', methods=['GET', 'POST']) @bp.route("/edit/<storage_key>", methods=["GET", "POST"])
def storage_edit(storage_key: str) -> ResponseReturnValue: def storage_edit(storage_key: str) -> ResponseReturnValue:
storage: Optional[TerraformState] = TerraformState.query.filter(TerraformState.key == storage_key).first() storage: Optional[TerraformState] = TerraformState.query.filter(
TerraformState.key == storage_key
).first()
if storage is None: if storage is None:
return Response(render_template("error.html.j2", return Response(
header="404 Storage Key Not Found", render_template(
message="The requested storage could not be found.", "error.html.j2",
**_SECTION_TEMPLATE_VARS), header="404 Storage Key Not Found",
status=404) message="The requested storage could not be found.",
**_SECTION_TEMPLATE_VARS
),
status=404,
)
form = EditStorageForm() form = EditStorageForm()
if form.validate_on_submit(): if form.validate_on_submit():
if form.force_unlock.data: if form.force_unlock.data:
@ -45,17 +51,16 @@ def storage_edit(storage_key: str) -> ResponseReturnValue:
flash("Storage has been force unlocked.", "success") flash("Storage has been force unlocked.", "success")
except exc.SQLAlchemyError: except exc.SQLAlchemyError:
flash("An error occurred unlocking the storage.", "danger") flash("An error occurred unlocking the storage.", "danger")
return render_template("storage.html.j2", return render_template(
storage=storage, "storage.html.j2", storage=storage, form=form, **_SECTION_TEMPLATE_VARS
form=form, )
**_SECTION_TEMPLATE_VARS)
@bp.route("/kick/<automation_id>", methods=['GET', 'POST']) @bp.route("/kick/<automation_id>", methods=["GET", "POST"])
def automation_kick(automation_id: int) -> ResponseReturnValue: def automation_kick(automation_id: int) -> ResponseReturnValue:
automation = Automation.query.filter( automation = Automation.query.filter(
Automation.id == automation_id, Automation.id == automation_id, Automation.destroyed.is_(None)
Automation.destroyed.is_(None)).first() ).first()
if automation is None: if automation is None:
return response_404("The requested bridge configuration could not be found.") return response_404("The requested bridge configuration could not be found.")
return view_lifecycle( return view_lifecycle(
@ -65,5 +70,5 @@ def automation_kick(automation_id: int) -> ResponseReturnValue:
success_view="portal.automation.automation_list", success_view="portal.automation.automation_list",
success_message="This automation job will next run within 1 minute.", success_message="This automation job will next run within 1 minute.",
resource=automation, resource=automation,
action="kick" action="kick",
) )

View file

@ -9,19 +9,21 @@ from app.models.activity import Activity
def response_404(message: str) -> ResponseReturnValue: def response_404(message: str) -> ResponseReturnValue:
return Response(render_template("error.html.j2", return Response(
header="404 Not Found", render_template("error.html.j2", header="404 Not Found", message=message)
message=message)) )
def view_lifecycle(*, def view_lifecycle(
header: str, *,
message: str, header: str,
success_message: str, message: str,
success_view: str, success_message: str,
section: str, success_view: str,
resource: AbstractResource, section: str,
action: str) -> ResponseReturnValue: resource: AbstractResource,
action: str,
) -> ResponseReturnValue:
form = LifecycleForm() form = LifecycleForm()
if action == "destroy": if action == "destroy":
form.submit.render_kw = {"class": "btn btn-danger"} form.submit.render_kw = {"class": "btn btn-danger"}
@ -41,19 +43,17 @@ def view_lifecycle(*,
return redirect(url_for("portal.portal_home")) return redirect(url_for("portal.portal_home"))
activity = Activity( activity = Activity(
activity_type="lifecycle", activity_type="lifecycle",
text=f"Portal action: {message}. {success_message}" text=f"Portal action: {message}. {success_message}",
) )
db.session.add(activity) db.session.add(activity)
db.session.commit() db.session.commit()
activity.notify() activity.notify()
flash(success_message, "success") flash(success_message, "success")
return redirect(url_for(success_view)) return redirect(url_for(success_view))
return render_template("lifecycle.html.j2", return render_template(
header=header, "lifecycle.html.j2", header=header, message=message, section=section, form=form
message=message, )
section=section,
form=form)
class LifecycleForm(FlaskForm): # type: ignore class LifecycleForm(FlaskForm): # type: ignore
submit = SubmitField('Confirm') submit = SubmitField("Confirm")

View file

@ -1,8 +1,7 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Optional
from flask import (Blueprint, Response, flash, redirect, render_template, from flask import Blueprint, Response, flash, redirect, render_template, url_for
url_for)
from flask.typing import ResponseReturnValue from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from sqlalchemy import exc from sqlalchemy import exc
@ -26,47 +25,54 @@ def webhook_format_name(key: str) -> str:
class NewWebhookForm(FlaskForm): # type: ignore class NewWebhookForm(FlaskForm): # type: ignore
description = StringField('Description', validators=[DataRequired()]) description = StringField("Description", validators=[DataRequired()])
format = SelectField('Format', choices=[ format = SelectField(
("telegram", "Telegram"), "Format",
("matrix", "Matrix") choices=[("telegram", "Telegram"), ("matrix", "Matrix")],
], validators=[DataRequired()]) validators=[DataRequired()],
url = StringField('URL', validators=[DataRequired()]) )
submit = SubmitField('Save Changes') url = StringField("URL", validators=[DataRequired()])
submit = SubmitField("Save Changes")
@bp.route("/new", methods=['GET', 'POST']) @bp.route("/new", methods=["GET", "POST"])
def webhook_new() -> ResponseReturnValue: def webhook_new() -> ResponseReturnValue:
form = NewWebhookForm() form = NewWebhookForm()
if form.validate_on_submit(): if form.validate_on_submit():
webhook = Webhook( webhook = Webhook(
description=form.description.data, description=form.description.data,
format=form.format.data, format=form.format.data,
url=form.url.data url=form.url.data,
) )
try: try:
db.session.add(webhook) db.session.add(webhook)
db.session.commit() db.session.commit()
flash(f"Created new webhook {webhook.url}.", "success") flash(f"Created new webhook {webhook.url}.", "success")
return redirect(url_for("portal.webhook.webhook_edit", webhook_id=webhook.id)) return redirect(
url_for("portal.webhook.webhook_edit", webhook_id=webhook.id)
)
except exc.SQLAlchemyError: except exc.SQLAlchemyError:
flash("Failed to create new webhook.", "danger") flash("Failed to create new webhook.", "danger")
return redirect(url_for("portal.webhook.webhook_list")) return redirect(url_for("portal.webhook.webhook_list"))
return render_template("new.html.j2", section="webhook", form=form) return render_template("new.html.j2", section="webhook", form=form)
@bp.route('/edit/<webhook_id>', methods=['GET', 'POST']) @bp.route("/edit/<webhook_id>", methods=["GET", "POST"])
def webhook_edit(webhook_id: int) -> ResponseReturnValue: def webhook_edit(webhook_id: int) -> ResponseReturnValue:
webhook = Webhook.query.filter(Webhook.id == webhook_id).first() webhook = Webhook.query.filter(Webhook.id == webhook_id).first()
if webhook is None: if webhook is None:
return Response(render_template("error.html.j2", return Response(
section="webhook", render_template(
header="404 Webhook Not Found", "error.html.j2",
message="The requested webhook could not be found."), section="webhook",
status=404) header="404 Webhook Not Found",
form = NewWebhookForm(description=webhook.description, message="The requested webhook could not be found.",
format=webhook.format, ),
url=webhook.url) status=404,
)
form = NewWebhookForm(
description=webhook.description, format=webhook.format, url=webhook.url
)
if form.validate_on_submit(): if form.validate_on_submit():
webhook.description = form.description.data webhook.description = form.description.data
webhook.format = form.description.data webhook.format = form.description.data
@ -77,26 +83,29 @@ def webhook_edit(webhook_id: int) -> ResponseReturnValue:
flash("Saved changes to webhook.", "success") flash("Saved changes to webhook.", "success")
except exc.SQLAlchemyError: except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the webhook.", "danger") flash("An error occurred saving the changes to the webhook.", "danger")
return render_template("edit.html.j2", return render_template(
section="webhook", "edit.html.j2", section="webhook", title="Edit Webhook", item=webhook, form=form
title="Edit Webhook", )
item=webhook, form=form)
@bp.route("/list") @bp.route("/list")
def webhook_list() -> ResponseReturnValue: def webhook_list() -> ResponseReturnValue:
webhooks = Webhook.query.all() webhooks = Webhook.query.all()
return render_template("list.html.j2", return render_template(
section="webhook", "list.html.j2",
title="Webhooks", section="webhook",
item="webhook", title="Webhooks",
new_link=url_for("portal.webhook.webhook_new"), item="webhook",
items=webhooks) new_link=url_for("portal.webhook.webhook_new"),
items=webhooks,
)
@bp.route("/destroy/<webhook_id>", methods=['GET', 'POST']) @bp.route("/destroy/<webhook_id>", methods=["GET", "POST"])
def webhook_destroy(webhook_id: int) -> ResponseReturnValue: def webhook_destroy(webhook_id: int) -> ResponseReturnValue:
webhook: Optional[Webhook] = Webhook.query.filter(Webhook.id == webhook_id, Webhook.destroyed.is_(None)).first() webhook: Optional[Webhook] = Webhook.query.filter(
Webhook.id == webhook_id, Webhook.destroyed.is_(None)
).first()
if webhook is None: if webhook is None:
return response_404("The requested webhook could not be found.") return response_404("The requested webhook could not be found.")
return view_lifecycle( return view_lifecycle(
@ -106,5 +115,5 @@ def webhook_destroy(webhook_id: int) -> ResponseReturnValue:
success_view="portal.webhook.webhook_list", success_view="portal.webhook.webhook_list",
section="webhook", section="webhook",
resource=webhook, resource=webhook,
action="destroy" action="destroy",
) )

View file

@ -12,6 +12,7 @@ class DeterministicZip:
Heavily inspired by https://github.com/bboe/deterministic_zip. Heavily inspired by https://github.com/bboe/deterministic_zip.
""" """
zipfile: ZipFile zipfile: ZipFile
def __init__(self, filename: str): def __init__(self, filename: str):
@ -67,15 +68,22 @@ class BaseAutomation:
if not self.working_dir: if not self.working_dir:
raise RuntimeError("No working directory specified.") raise RuntimeError("No working directory specified.")
tmpl = jinja2.Template(template) tmpl = jinja2.Template(template)
with open(os.path.join(self.working_dir, filename), 'w', encoding="utf-8") as tfconf: with open(
os.path.join(self.working_dir, filename), "w", encoding="utf-8"
) as tfconf:
tfconf.write(tmpl.render(**kwargs)) tfconf.write(tmpl.render(**kwargs))
def bin_write(self, filename: str, data: bytes, group_id: Optional[int] = None) -> None: def bin_write(
self, filename: str, data: bytes, group_id: Optional[int] = None
) -> None:
if not self.working_dir: if not self.working_dir:
raise RuntimeError("No working directory specified.") raise RuntimeError("No working directory specified.")
try: try:
os.mkdir(os.path.join(self.working_dir, str(group_id))) os.mkdir(os.path.join(self.working_dir, str(group_id)))
except FileExistsError: except FileExistsError:
pass pass
with open(os.path.join(self.working_dir, str(group_id) if group_id else "", filename), 'wb') as binfile: with open(
os.path.join(self.working_dir, str(group_id) if group_id else "", filename),
"wb",
) as binfile:
binfile.write(data) binfile.write(data)

View file

@ -13,33 +13,38 @@ from app.terraform import BaseAutomation
def alarms_in_region(region: str, prefix: str, aspect: str) -> None: def alarms_in_region(region: str, prefix: str, aspect: str) -> None:
cloudwatch = boto3.client('cloudwatch', cloudwatch = boto3.client(
aws_access_key_id=app.config['AWS_ACCESS_KEY'], "cloudwatch",
aws_secret_access_key=app.config['AWS_SECRET_KEY'], aws_access_key_id=app.config["AWS_ACCESS_KEY"],
region_name=region) aws_secret_access_key=app.config["AWS_SECRET_KEY"],
dist_paginator = cloudwatch.get_paginator('describe_alarms') region_name=region,
)
dist_paginator = cloudwatch.get_paginator("describe_alarms")
page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix) page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix)
for page in page_iterator: for page in page_iterator:
for cw_alarm in page['MetricAlarms']: for cw_alarm in page["MetricAlarms"]:
eotk_id = cw_alarm["AlarmName"][len(prefix):].split("-") eotk_id = cw_alarm["AlarmName"][len(prefix) :].split("-")
group: Optional[Group] = Group.query.filter(func.lower(Group.group_name) == eotk_id[1]).first() group: Optional[Group] = Group.query.filter(
func.lower(Group.group_name) == eotk_id[1]
).first()
if group is None: if group is None:
print("Unable to find group for " + cw_alarm['AlarmName']) print("Unable to find group for " + cw_alarm["AlarmName"])
continue continue
eotk = Eotk.query.filter( eotk = Eotk.query.filter(
Eotk.group_id == group.id, Eotk.group_id == group.id, Eotk.region == region
Eotk.region == region
).first() ).first()
if eotk is None: if eotk is None:
print("Skipping unknown instance " + cw_alarm['AlarmName']) print("Skipping unknown instance " + cw_alarm["AlarmName"])
continue continue
alarm = get_or_create_alarm(eotk.brn, aspect) alarm = get_or_create_alarm(eotk.brn, aspect)
if cw_alarm['StateValue'] == "OK": if cw_alarm["StateValue"] == "OK":
alarm.update_state(AlarmState.OK, "CloudWatch alarm OK") alarm.update_state(AlarmState.OK, "CloudWatch alarm OK")
elif cw_alarm['StateValue'] == "ALARM": elif cw_alarm["StateValue"] == "ALARM":
alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM") alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM")
else: else:
alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}") alarm.update_state(
AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}"
)
class AlarmEotkAwsAutomation(BaseAutomation): class AlarmEotkAwsAutomation(BaseAutomation):

View file

@ -16,20 +16,19 @@ class AlarmProxyAzureCdnAutomation(BaseAutomation):
def automate(self, full: bool = False) -> Tuple[bool, str]: def automate(self, full: bool = False) -> Tuple[bool, str]:
credential = ClientSecretCredential( credential = ClientSecretCredential(
tenant_id=app.config['AZURE_TENANT_ID'], tenant_id=app.config["AZURE_TENANT_ID"],
client_id=app.config['AZURE_CLIENT_ID'], client_id=app.config["AZURE_CLIENT_ID"],
client_secret=app.config['AZURE_CLIENT_SECRET']) client_secret=app.config["AZURE_CLIENT_SECRET"],
client = AlertsManagementClient(
credential,
app.config['AZURE_SUBSCRIPTION_ID']
) )
firing = [x.name[len("bandwidth-out-high-bc-"):] client = AlertsManagementClient(credential, app.config["AZURE_SUBSCRIPTION_ID"])
for x in client.alerts.get_all() firing = [
if x.name.startswith("bandwidth-out-high-bc-") x.name[len("bandwidth-out-high-bc-") :]
and x.properties.essentials.monitor_condition == "Fired"] for x in client.alerts.get_all()
if x.name.startswith("bandwidth-out-high-bc-")
and x.properties.essentials.monitor_condition == "Fired"
]
for proxy in Proxy.query.filter( for proxy in Proxy.query.filter(
Proxy.provider == "azure_cdn", Proxy.provider == "azure_cdn", Proxy.destroyed.is_(None)
Proxy.destroyed.is_(None)
): ):
alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high") alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high")
if proxy.origin.group.group_name.lower() not in firing: if proxy.origin.group.group_name.lower() not in firing:

View file

@ -16,9 +16,8 @@ def _cloudfront_quota() -> None:
# It would be nice to learn this from the Service Quotas API, however # It would be nice to learn this from the Service Quotas API, however
# at the time of writing this comment, the current value for this quota # at the time of writing this comment, the current value for this quota
# is not available from the API. It just doesn't return anything. # is not available from the API. It just doesn't return anything.
max_count = int(current_app.config.get('AWS_CLOUDFRONT_MAX_DISTRIBUTIONS', 200)) max_count = int(current_app.config.get("AWS_CLOUDFRONT_MAX_DISTRIBUTIONS", 200))
deployed_count = len(Proxy.query.filter( deployed_count = len(Proxy.query.filter(Proxy.destroyed.is_(None)).all())
Proxy.destroyed.is_(None)).all())
message = f"{deployed_count} distributions deployed of {max_count} quota" message = f"{deployed_count} distributions deployed of {max_count} quota"
alarm = get_or_create_alarm( alarm = get_or_create_alarm(
BRN( BRN(
@ -26,9 +25,9 @@ def _cloudfront_quota() -> None:
product="mirror", product="mirror",
provider="cloudfront", provider="cloudfront",
resource_type="quota", resource_type="quota",
resource_id="distributions" resource_id="distributions",
), ),
"quota-usage" "quota-usage",
) )
if deployed_count > max_count * 0.9: if deployed_count > max_count * 0.9:
alarm.update_state(AlarmState.CRITICAL, message) alarm.update_state(AlarmState.CRITICAL, message)
@ -39,26 +38,30 @@ def _cloudfront_quota() -> None:
def _proxy_alarms() -> None: def _proxy_alarms() -> None:
cloudwatch = boto3.client('cloudwatch', cloudwatch = boto3.client(
aws_access_key_id=app.config['AWS_ACCESS_KEY'], "cloudwatch",
aws_secret_access_key=app.config['AWS_SECRET_KEY'], aws_access_key_id=app.config["AWS_ACCESS_KEY"],
region_name='us-east-2') aws_secret_access_key=app.config["AWS_SECRET_KEY"],
dist_paginator = cloudwatch.get_paginator('describe_alarms') region_name="us-east-2",
)
dist_paginator = cloudwatch.get_paginator("describe_alarms")
page_iterator = dist_paginator.paginate(AlarmNamePrefix="bandwidth-out-high-") page_iterator = dist_paginator.paginate(AlarmNamePrefix="bandwidth-out-high-")
for page in page_iterator: for page in page_iterator:
for cw_alarm in page['MetricAlarms']: for cw_alarm in page["MetricAlarms"]:
dist_id = cw_alarm["AlarmName"][len("bandwidth-out-high-"):] dist_id = cw_alarm["AlarmName"][len("bandwidth-out-high-") :]
proxy = Proxy.query.filter(Proxy.slug == dist_id).first() proxy = Proxy.query.filter(Proxy.slug == dist_id).first()
if proxy is None: if proxy is None:
print("Skipping unknown proxy " + dist_id) print("Skipping unknown proxy " + dist_id)
continue continue
alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high") alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high")
if cw_alarm['StateValue'] == "OK": if cw_alarm["StateValue"] == "OK":
alarm.update_state(AlarmState.OK, "CloudWatch alarm OK") alarm.update_state(AlarmState.OK, "CloudWatch alarm OK")
elif cw_alarm['StateValue'] == "ALARM": elif cw_alarm["StateValue"] == "ALARM":
alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM") alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM")
else: else:
alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}") alarm.update_state(
AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}"
)
class AlarmProxyCloudfrontAutomation(BaseAutomation): class AlarmProxyCloudfrontAutomation(BaseAutomation):

View file

@ -16,39 +16,25 @@ class AlarmProxyHTTPStatusAutomation(BaseAutomation):
frequency = 45 frequency = 45
def automate(self, full: bool = False) -> Tuple[bool, str]: def automate(self, full: bool = False) -> Tuple[bool, str]:
proxies = Proxy.query.filter( proxies = Proxy.query.filter(Proxy.destroyed.is_(None))
Proxy.destroyed.is_(None)
)
for proxy in proxies: for proxy in proxies:
try: try:
if proxy.url is None: if proxy.url is None:
continue continue
r = requests.get(proxy.url, r = requests.get(proxy.url, allow_redirects=False, timeout=5)
allow_redirects=False,
timeout=5)
r.raise_for_status() r.raise_for_status()
alarm = get_or_create_alarm(proxy.brn, "http-status") alarm = get_or_create_alarm(proxy.brn, "http-status")
if r.is_redirect: if r.is_redirect:
alarm.update_state( alarm.update_state(
AlarmState.CRITICAL, AlarmState.CRITICAL, f"{r.status_code} {r.reason}"
f"{r.status_code} {r.reason}"
) )
else: else:
alarm.update_state( alarm.update_state(AlarmState.OK, f"{r.status_code} {r.reason}")
AlarmState.OK,
f"{r.status_code} {r.reason}"
)
except requests.HTTPError: except requests.HTTPError:
alarm = get_or_create_alarm(proxy.brn, "http-status") alarm = get_or_create_alarm(proxy.brn, "http-status")
alarm.update_state( alarm.update_state(AlarmState.CRITICAL, f"{r.status_code} {r.reason}")
AlarmState.CRITICAL,
f"{r.status_code} {r.reason}"
)
except RequestException as e: except RequestException as e:
alarm = get_or_create_alarm(proxy.brn, "http-status") alarm = get_or_create_alarm(proxy.brn, "http-status")
alarm.update_state( alarm.update_state(AlarmState.CRITICAL, repr(e))
AlarmState.CRITICAL,
repr(e)
)
db.session.commit() db.session.commit()
return True, "" return True, ""

View file

@ -13,33 +13,38 @@ from app.terraform import BaseAutomation
def alarms_in_region(region: str, prefix: str, aspect: str) -> None: def alarms_in_region(region: str, prefix: str, aspect: str) -> None:
cloudwatch = boto3.client('cloudwatch', cloudwatch = boto3.client(
aws_access_key_id=app.config['AWS_ACCESS_KEY'], "cloudwatch",
aws_secret_access_key=app.config['AWS_SECRET_KEY'], aws_access_key_id=app.config["AWS_ACCESS_KEY"],
region_name=region) aws_secret_access_key=app.config["AWS_SECRET_KEY"],
dist_paginator = cloudwatch.get_paginator('describe_alarms') region_name=region,
)
dist_paginator = cloudwatch.get_paginator("describe_alarms")
page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix) page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix)
for page in page_iterator: for page in page_iterator:
for cw_alarm in page['MetricAlarms']: for cw_alarm in page["MetricAlarms"]:
smart_id = cw_alarm["AlarmName"][len(prefix):].split("-") smart_id = cw_alarm["AlarmName"][len(prefix) :].split("-")
group: Optional[Group] = Group.query.filter(func.lower(Group.group_name) == smart_id[1]).first() group: Optional[Group] = Group.query.filter(
func.lower(Group.group_name) == smart_id[1]
).first()
if group is None: if group is None:
print("Unable to find group for " + cw_alarm['AlarmName']) print("Unable to find group for " + cw_alarm["AlarmName"])
continue continue
smart_proxy = SmartProxy.query.filter( smart_proxy = SmartProxy.query.filter(
SmartProxy.group_id == group.id, SmartProxy.group_id == group.id, SmartProxy.region == region
SmartProxy.region == region
).first() ).first()
if smart_proxy is None: if smart_proxy is None:
print("Skipping unknown instance " + cw_alarm['AlarmName']) print("Skipping unknown instance " + cw_alarm["AlarmName"])
continue continue
alarm = get_or_create_alarm(smart_proxy.brn, aspect) alarm = get_or_create_alarm(smart_proxy.brn, aspect)
if cw_alarm['StateValue'] == "OK": if cw_alarm["StateValue"] == "OK":
alarm.update_state(AlarmState.OK, "CloudWatch alarm OK") alarm.update_state(AlarmState.OK, "CloudWatch alarm OK")
elif cw_alarm['StateValue'] == "ALARM": elif cw_alarm["StateValue"] == "ALARM":
alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM") alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM")
else: else:
alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}") alarm.update_state(
AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}"
)
class AlarmSmartAwsAutomation(BaseAutomation): class AlarmSmartAwsAutomation(BaseAutomation):

View file

@ -16,7 +16,7 @@ def clean_json_response(raw_response: str) -> Dict[str, Any]:
""" """
end_index = raw_response.rfind("}") end_index = raw_response.rfind("}")
if end_index != -1: if end_index != -1:
raw_response = raw_response[:end_index + 1] raw_response = raw_response[: end_index + 1]
response: Dict[str, Any] = json.loads(raw_response) response: Dict[str, Any] = json.loads(raw_response)
return response return response
@ -27,20 +27,21 @@ def request_test_now(test_url: str) -> str:
"User-Agent": "bypasscensorship.org", "User-Agent": "bypasscensorship.org",
"Content-Type": "application/json;charset=utf-8", "Content-Type": "application/json;charset=utf-8",
"Pragma": "no-cache", "Pragma": "no-cache",
"Cache-Control": "no-cache" "Cache-Control": "no-cache",
} }
request_count = 0 request_count = 0
while request_count < 180: while request_count < 180:
params = { params = {"url": test_url, "timestamp": str(int(time.time()))} # unix timestamp
"url": test_url, response = requests.post(
"timestamp": str(int(time.time())) # unix timestamp api_url, params=params, headers=headers, json={}, timeout=30
} )
response = requests.post(api_url, params=params, headers=headers, json={}, timeout=30)
response_data = clean_json_response(response.text) response_data = clean_json_response(response.text)
print(f"Response: {response_data}") print(f"Response: {response_data}")
if "url_test_id" in response_data.get("d", {}): if "url_test_id" in response_data.get("d", {}):
url_test_id: str = response_data["d"]["url_test_id"] url_test_id: str = response_data["d"]["url_test_id"]
logging.debug("Test result for %s has test result ID %s", test_url, url_test_id) logging.debug(
"Test result for %s has test result ID %s", test_url, url_test_id
)
return url_test_id return url_test_id
request_count += 1 request_count += 1
time.sleep(2) time.sleep(2)
@ -52,13 +53,19 @@ def request_test_result(url_test_id: str) -> int:
headers = { headers = {
"User-Agent": "bypasscensorship.org", "User-Agent": "bypasscensorship.org",
"Pragma": "no-cache", "Pragma": "no-cache",
"Cache-Control": "no-cache" "Cache-Control": "no-cache",
} }
response = requests.get(url, headers=headers, timeout=30) response = requests.get(url, headers=headers, timeout=30)
response_data = response.json() response_data = response.json()
tests = response_data.get("d", []) tests = response_data.get("d", [])
non_zero_curl_exit_count: int = sum(1 for test in tests if test.get("curl_exit_value") != "0") non_zero_curl_exit_count: int = sum(
logging.debug("Test result for %s has %s non-zero exit values", url_test_id, non_zero_curl_exit_count) 1 for test in tests if test.get("curl_exit_value") != "0"
)
logging.debug(
"Test result for %s has %s non-zero exit values",
url_test_id,
non_zero_curl_exit_count,
)
return non_zero_curl_exit_count return non_zero_curl_exit_count
@ -81,7 +88,7 @@ class BlockBlockyAutomation(BlockMirrorAutomation):
Proxy.url.is_not(None), Proxy.url.is_not(None),
Proxy.deprecated.is_(None), Proxy.deprecated.is_(None),
Proxy.destroyed.is_(None), Proxy.destroyed.is_(None),
Proxy.pool_id != -1 Proxy.pool_id != -1,
) )
.all() .all()
) )

View file

@ -15,7 +15,8 @@ class BlockBridgeScriptzteamAutomation(BlockBridgelinesAutomation):
def fetch(self) -> None: def fetch(self) -> None:
r = requests.get( r = requests.get(
"https://raw.githubusercontent.com/scriptzteam/Tor-Bridges-Collector/main/bridges-obfs4", "https://raw.githubusercontent.com/scriptzteam/Tor-Bridges-Collector/main/bridges-obfs4",
timeout=60) timeout=60,
)
r.encoding = "utf-8" r.encoding = "utf-8"
contents = r.text contents = r.text
self._lines = contents.splitlines() self._lines = contents.splitlines()

View file

@ -24,8 +24,9 @@ class BlockBridgeAutomation(BaseAutomation):
self.hashed_fingerprints = [] self.hashed_fingerprints = []
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def perform_deprecations(self, ids: List[str], bridge_select_func: Callable[[str], Optional[Bridge]] def perform_deprecations(
) -> List[Tuple[Optional[str], Any, Any]]: self, ids: List[str], bridge_select_func: Callable[[str], Optional[Bridge]]
) -> List[Tuple[Optional[str], Any, Any]]:
rotated = [] rotated = []
for id_ in ids: for id_ in ids:
bridge = bridge_select_func(id_) bridge = bridge_select_func(id_)
@ -37,7 +38,13 @@ class BlockBridgeAutomation(BaseAutomation):
continue continue
if bridge.deprecate(reason=self.short_name): if bridge.deprecate(reason=self.short_name):
logging.info("Rotated %s", bridge.hashed_fingerprint) logging.info("Rotated %s", bridge.hashed_fingerprint)
rotated.append((bridge.fingerprint, bridge.cloud_account.provider, bridge.cloud_account.description)) rotated.append(
(
bridge.fingerprint,
bridge.cloud_account.provider,
bridge.cloud_account.description,
)
)
else: else:
logging.debug("Not rotating a bridge that is already deprecated") logging.debug("Not rotating a bridge that is already deprecated")
return rotated return rotated
@ -50,15 +57,28 @@ class BlockBridgeAutomation(BaseAutomation):
rotated = [] rotated = []
rotated.extend(self.perform_deprecations(self.ips, get_bridge_by_ip)) rotated.extend(self.perform_deprecations(self.ips, get_bridge_by_ip))
logging.debug("Blocked by IP") logging.debug("Blocked by IP")
rotated.extend(self.perform_deprecations(self.fingerprints, get_bridge_by_fingerprint)) rotated.extend(
self.perform_deprecations(self.fingerprints, get_bridge_by_fingerprint)
)
logging.debug("Blocked by fingerprint") logging.debug("Blocked by fingerprint")
rotated.extend(self.perform_deprecations(self.hashed_fingerprints, get_bridge_by_hashed_fingerprint)) rotated.extend(
self.perform_deprecations(
self.hashed_fingerprints, get_bridge_by_hashed_fingerprint
)
)
logging.debug("Blocked by hashed fingerprint") logging.debug("Blocked by hashed fingerprint")
if rotated: if rotated:
activity = Activity( activity = Activity(
activity_type="block", activity_type="block",
text=(f"[{self.short_name}] ♻ Rotated {len(rotated)} bridges: \n" text=(
+ "\n".join([f"* {fingerprint} ({provider}: {provider_description})" for fingerprint, provider, provider_description in rotated])) f"[{self.short_name}] ♻ Rotated {len(rotated)} bridges: \n"
+ "\n".join(
[
f"* {fingerprint} ({provider}: {provider_description})"
for fingerprint, provider, provider_description in rotated
]
)
),
) )
db.session.add(activity) db.session.add(activity)
activity.notify() activity.notify()
@ -87,7 +107,7 @@ def get_bridge_by_ip(ip: str) -> Optional[Bridge]:
return Bridge.query.filter( # type: ignore[no-any-return] return Bridge.query.filter( # type: ignore[no-any-return]
Bridge.deprecated.is_(None), Bridge.deprecated.is_(None),
Bridge.destroyed.is_(None), Bridge.destroyed.is_(None),
Bridge.bridgeline.contains(f" {ip} ") Bridge.bridgeline.contains(f" {ip} "),
).first() ).first()
@ -95,7 +115,7 @@ def get_bridge_by_fingerprint(fingerprint: str) -> Optional[Bridge]:
return Bridge.query.filter( # type: ignore[no-any-return] return Bridge.query.filter( # type: ignore[no-any-return]
Bridge.deprecated.is_(None), Bridge.deprecated.is_(None),
Bridge.destroyed.is_(None), Bridge.destroyed.is_(None),
Bridge.fingerprint == fingerprint Bridge.fingerprint == fingerprint,
).first() ).first()
@ -103,5 +123,5 @@ def get_bridge_by_hashed_fingerprint(hashed_fingerprint: str) -> Optional[Bridge
return Bridge.query.filter( # type: ignore[no-any-return] return Bridge.query.filter( # type: ignore[no-any-return]
Bridge.deprecated.is_(None), Bridge.deprecated.is_(None),
Bridge.destroyed.is_(None), Bridge.destroyed.is_(None),
Bridge.hashed_fingerprint == hashed_fingerprint Bridge.hashed_fingerprint == hashed_fingerprint,
).first() ).first()

View file

@ -17,6 +17,8 @@ class BlockBridgelinesAutomation(BlockBridgeAutomation, ABC):
fingerprint = parts[2] fingerprint = parts[2]
self.ips.append(ip_address) self.ips.append(ip_address)
self.fingerprints.append(fingerprint) self.fingerprints.append(fingerprint)
logging.debug(f"Added blocked bridge with IP {ip_address} and fingerprint {fingerprint}") logging.debug(
f"Added blocked bridge with IP {ip_address} and fingerprint {fingerprint}"
)
except IndexError: except IndexError:
logging.warning("A parsing error occured.") logging.warning("A parsing error occured.")

View file

@ -1,8 +1,7 @@
from flask import current_app from flask import current_app
from github import Github from github import Github
from app.terraform.block.bridge_reachability import \ from app.terraform.block.bridge_reachability import BlockBridgeReachabilityAutomation
BlockBridgeReachabilityAutomation
class BlockBridgeGitHubAutomation(BlockBridgeReachabilityAutomation): class BlockBridgeGitHubAutomation(BlockBridgeReachabilityAutomation):
@ -15,12 +14,13 @@ class BlockBridgeGitHubAutomation(BlockBridgeReachabilityAutomation):
frequency = 30 frequency = 30
def fetch(self) -> None: def fetch(self) -> None:
github = Github(current_app.config['GITHUB_API_KEY']) github = Github(current_app.config["GITHUB_API_KEY"])
repo = github.get_repo(current_app.config['GITHUB_BRIDGE_REPO']) repo = github.get_repo(current_app.config["GITHUB_BRIDGE_REPO"])
for vantage_point in current_app.config['GITHUB_BRIDGE_VANTAGE_POINTS']: for vantage_point in current_app.config["GITHUB_BRIDGE_VANTAGE_POINTS"]:
contents = repo.get_contents(f"recentResult_{vantage_point}") contents = repo.get_contents(f"recentResult_{vantage_point}")
if isinstance(contents, list): if isinstance(contents, list):
raise RuntimeError( raise RuntimeError(
f"Expected a file at recentResult_{vantage_point}" f"Expected a file at recentResult_{vantage_point}"
" but got a directory.") " but got a directory."
self._lines = contents.decoded_content.decode('utf-8').splitlines() )
self._lines = contents.decoded_content.decode("utf-8").splitlines()

View file

@ -1,8 +1,7 @@
from flask import current_app from flask import current_app
from gitlab import Gitlab from gitlab import Gitlab
from app.terraform.block.bridge_reachability import \ from app.terraform.block.bridge_reachability import BlockBridgeReachabilityAutomation
BlockBridgeReachabilityAutomation
class BlockBridgeGitlabAutomation(BlockBridgeReachabilityAutomation): class BlockBridgeGitlabAutomation(BlockBridgeReachabilityAutomation):
@ -16,15 +15,15 @@ class BlockBridgeGitlabAutomation(BlockBridgeReachabilityAutomation):
def fetch(self) -> None: def fetch(self) -> None:
self._lines = list() self._lines = list()
credentials = {"private_token": current_app.config['GITLAB_TOKEN']} credentials = {"private_token": current_app.config["GITLAB_TOKEN"]}
if "GITLAB_URL" in current_app.config: if "GITLAB_URL" in current_app.config:
credentials['url'] = current_app.config['GITLAB_URL'] credentials["url"] = current_app.config["GITLAB_URL"]
gitlab = Gitlab(**credentials) gitlab = Gitlab(**credentials)
project = gitlab.projects.get(current_app.config['GITLAB_BRIDGE_PROJECT']) project = gitlab.projects.get(current_app.config["GITLAB_BRIDGE_PROJECT"])
for vantage_point in current_app.config['GITHUB_BRIDGE_VANTAGE_POINTS']: for vantage_point in current_app.config["GITHUB_BRIDGE_VANTAGE_POINTS"]:
contents = project.files.get( contents = project.files.get(
file_path=f"recentResult_{vantage_point}", file_path=f"recentResult_{vantage_point}",
ref=current_app.config["GITLAB_BRIDGE_BRANCH"] ref=current_app.config["GITLAB_BRIDGE_BRANCH"],
) )
# Decode the base64 first, then decode the UTF-8 string # Decode the base64 first, then decode the UTF-8 string
self._lines.extend(contents.decode().decode('utf-8').splitlines()) self._lines.extend(contents.decode().decode("utf-8").splitlines())

View file

@ -14,8 +14,10 @@ class BlockBridgeReachabilityAutomation(BlockBridgeAutomation, ABC):
def parse(self) -> None: def parse(self) -> None:
for line in self._lines: for line in self._lines:
parts = line.split("\t") parts = line.split("\t")
if isoparse(parts[2]) < (datetime.datetime.now(datetime.timezone.utc) if isoparse(parts[2]) < (
- datetime.timedelta(days=3)): datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(days=3)
):
# Skip results older than 3 days # Skip results older than 3 days
continue continue
if int(parts[1]) < 40: if int(parts[1]) < 40:

View file

@ -13,7 +13,9 @@ class BlockBridgeRoskomsvobodaAutomation(BlockBridgeAutomation):
_data: Any _data: Any
def fetch(self) -> None: def fetch(self) -> None:
self._data = requests.get("https://reestr.rublacklist.net/api/v3/ips/", timeout=180).json() self._data = requests.get(
"https://reestr.rublacklist.net/api/v3/ips/", timeout=180
).json()
def parse(self) -> None: def parse(self) -> None:
self.ips.extend(self._data) self.ips.extend(self._data)

View file

@ -9,7 +9,7 @@ from app.terraform.block_mirror import BlockMirrorAutomation
def _trim_prefix(s: str, prefix: str) -> str: def _trim_prefix(s: str, prefix: str) -> str:
if s.startswith(prefix): if s.startswith(prefix):
return s[len(prefix):] return s[len(prefix) :]
return s return s
@ -20,30 +20,31 @@ def trim_http_https(s: str) -> str:
:param s: String to modify. :param s: String to modify.
:return: Modified string. :return: Modified string.
""" """
return _trim_prefix( return _trim_prefix(_trim_prefix(s, "https://"), "http://")
_trim_prefix(s, "https://"),
"http://")
class BlockExternalAutomation(BlockMirrorAutomation): class BlockExternalAutomation(BlockMirrorAutomation):
""" """
Automation task to import proxy reachability results from external source. Automation task to import proxy reachability results from external source.
""" """
short_name = "block_external" short_name = "block_external"
description = "Import proxy reachability results from external source" description = "Import proxy reachability results from external source"
_content: bytes _content: bytes
def fetch(self) -> None: def fetch(self) -> None:
user_agent = {'User-agent': 'BypassCensorship/1.0'} user_agent = {"User-agent": "BypassCensorship/1.0"}
check_urls_config = app.config.get('EXTERNAL_CHECK_URL', []) check_urls_config = app.config.get("EXTERNAL_CHECK_URL", [])
if isinstance(check_urls_config, dict): if isinstance(check_urls_config, dict):
# Config is already a dictionary, use as is. # Config is already a dictionary, use as is.
check_urls = check_urls_config check_urls = check_urls_config
elif isinstance(check_urls_config, list): elif isinstance(check_urls_config, list):
# Convert list of strings to a dictionary with "external_N" keys. # Convert list of strings to a dictionary with "external_N" keys.
check_urls = {f"external_{i}": url for i, url in enumerate(check_urls_config)} check_urls = {
f"external_{i}": url for i, url in enumerate(check_urls_config)
}
elif isinstance(check_urls_config, str): elif isinstance(check_urls_config, str):
# Single string, convert to a dictionary with key "external". # Single string, convert to a dictionary with key "external".
check_urls = {"external": check_urls_config} check_urls = {"external": check_urls_config}
@ -53,9 +54,13 @@ class BlockExternalAutomation(BlockMirrorAutomation):
for source, check_url in check_urls.items(): for source, check_url in check_urls.items():
if self._data is None: if self._data is None:
self._data = defaultdict(list) self._data = defaultdict(list)
self._data[source].extend(requests.get(check_url, headers=user_agent, timeout=30).json()) self._data[source].extend(
requests.get(check_url, headers=user_agent, timeout=30).json()
)
def parse(self) -> None: def parse(self) -> None:
for source, patterns in self._data.items(): for source, patterns in self._data.items():
self.patterns[source].extend(["https://" + trim_http_https(pattern) for pattern in patterns]) self.patterns[source].extend(
["https://" + trim_http_https(pattern) for pattern in patterns]
)
logging.debug("Found URLs: %s", self.patterns) logging.debug("Found URLs: %s", self.patterns)

View file

@ -52,8 +52,15 @@ class BlockMirrorAutomation(BaseAutomation):
if rotated: if rotated:
activity = Activity( activity = Activity(
activity_type="block", activity_type="block",
text=(f"[{self.short_name}] ♻ Rotated {len(rotated)} proxies: \n" text=(
+ "\n".join([f"* {proxy_domain} ({origin_domain})" for proxy_domain, origin_domain in rotated])) f"[{self.short_name}] ♻ Rotated {len(rotated)} proxies: \n"
+ "\n".join(
[
f"* {proxy_domain} ({origin_domain})"
for proxy_domain, origin_domain in rotated
]
)
),
) )
db.session.add(activity) db.session.add(activity)
activity.notify() activity.notify()
@ -79,15 +86,15 @@ class BlockMirrorAutomation(BaseAutomation):
def active_proxy_urls() -> List[str]: def active_proxy_urls() -> List[str]:
return [proxy.url for proxy in Proxy.query.filter( return [
Proxy.deprecated.is_(None), proxy.url
Proxy.destroyed.is_(None) for proxy in Proxy.query.filter(
).all()] Proxy.deprecated.is_(None), Proxy.destroyed.is_(None)
).all()
]
def proxy_by_url(url: str) -> Optional[Proxy]: def proxy_by_url(url: str) -> Optional[Proxy]:
return Proxy.query.filter( # type: ignore[no-any-return] return Proxy.query.filter( # type: ignore[no-any-return]
Proxy.deprecated.is_(None), Proxy.deprecated.is_(None), Proxy.destroyed.is_(None), Proxy.url == url
Proxy.destroyed.is_(None),
Proxy.url == url
).first() ).first()

View file

@ -12,19 +12,23 @@ from app.terraform import BaseAutomation
def check_origin(domain_name: str) -> Dict[str, Any]: def check_origin(domain_name: str) -> Dict[str, Any]:
start_date = (datetime.now(tz=timezone.utc) - timedelta(days=1)).strftime("%Y-%m-%dT%H%%3A%M") start_date = (datetime.now(tz=timezone.utc) - timedelta(days=1)).strftime(
"%Y-%m-%dT%H%%3A%M"
)
end_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H%%3A%M") end_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H%%3A%M")
api_url = f"https://api.ooni.io/api/v1/measurements?domain={domain_name}&since={start_date}&until={end_date}" api_url = f"https://api.ooni.io/api/v1/measurements?domain={domain_name}&since={start_date}&until={end_date}"
result: Dict[str, Dict[str, int]] = defaultdict(lambda: {"anomaly": 0, "confirmed": 0, "failure": 0, "ok": 0}) result: Dict[str, Dict[str, int]] = defaultdict(
lambda: {"anomaly": 0, "confirmed": 0, "failure": 0, "ok": 0}
)
return _check_origin(api_url, result) return _check_origin(api_url, result)
def _check_origin(api_url: str, result: Dict[str, Any]) -> Dict[str, Any]: def _check_origin(api_url: str, result: Dict[str, Any]) -> Dict[str, Any]:
print(f"Processing {api_url}") print(f"Processing {api_url}")
req = requests.get(api_url, timeout=30).json() req = requests.get(api_url, timeout=30).json()
if 'results' not in req or not req['results']: if "results" not in req or not req["results"]:
return result return result
for r in req['results']: for r in req["results"]:
not_ok = False not_ok = False
for status in ["anomaly", "confirmed", "failure"]: for status in ["anomaly", "confirmed", "failure"]:
if status in r and r[status]: if status in r and r[status]:
@ -33,27 +37,28 @@ def _check_origin(api_url: str, result: Dict[str, Any]) -> Dict[str, Any]:
break break
if not not_ok: if not not_ok:
result[r["probe_cc"]]["ok"] += 1 result[r["probe_cc"]]["ok"] += 1
if req['metadata']['next_url']: if req["metadata"]["next_url"]:
return _check_origin(req['metadata']['next_url'], result) return _check_origin(req["metadata"]["next_url"], result)
return result return result
def threshold_origin(domain_name: str) -> Dict[str, Any]: def threshold_origin(domain_name: str) -> Dict[str, Any]:
ooni = check_origin(domain_name) ooni = check_origin(domain_name)
for country in ooni: for country in ooni:
total = sum([ total = sum(
ooni[country]["anomaly"], [
ooni[country]["confirmed"], ooni[country]["anomaly"],
ooni[country]["failure"], ooni[country]["confirmed"],
ooni[country]["ok"] ooni[country]["failure"],
]) ooni[country]["ok"],
total_blocks = sum([ ]
ooni[country]["anomaly"], )
ooni[country]["confirmed"] total_blocks = sum([ooni[country]["anomaly"], ooni[country]["confirmed"]])
])
block_perc = round((total_blocks / total * 100), 1) block_perc = round((total_blocks / total * 100), 1)
ooni[country]["block_perc"] = block_perc ooni[country]["block_perc"] = block_perc
ooni[country]["state"] = AlarmState.WARNING if block_perc > 20 else AlarmState.OK ooni[country]["state"] = (
AlarmState.WARNING if block_perc > 20 else AlarmState.OK
)
ooni[country]["message"] = f"Blocked in {block_perc}% of measurements" ooni[country]["message"] = f"Blocked in {block_perc}% of measurements"
return ooni return ooni
@ -72,8 +77,9 @@ class BlockOONIAutomation(BaseAutomation):
for origin in origins: for origin in origins:
ooni = threshold_origin(origin.domain_name) ooni = threshold_origin(origin.domain_name)
for country in ooni: for country in ooni:
alarm = get_or_create_alarm(origin.brn, alarm = get_or_create_alarm(
f"origin-block-ooni-{country.lower()}") origin.brn, f"origin-block-ooni-{country.lower()}"
)
alarm.update_state(ooni[country]["state"], ooni[country]["message"]) alarm.update_state(ooni[country]["state"], ooni[country]["message"])
db.session.commit() db.session.commit()
return True, "" return True, ""

View file

@ -32,6 +32,7 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
Where proxies are found to be blocked they will be rotated. Where proxies are found to be blocked they will be rotated.
""" """
short_name = "block_roskomsvoboda" short_name = "block_roskomsvoboda"
description = "Import Russian blocklist from RosKomSvoboda" description = "Import Russian blocklist from RosKomSvoboda"
frequency = 300 frequency = 300
@ -43,7 +44,11 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
try: try:
# This endpoint routinely has an expired certificate, and it's more useful that we are consuming the # This endpoint routinely has an expired certificate, and it's more useful that we are consuming the
# data than that we are verifying the certificate. # data than that we are verifying the certificate.
r = requests.get(f"https://dumps.rublacklist.net/fetch/{latest_rev}", timeout=180, verify=False) # nosec: B501 r = requests.get(
f"https://dumps.rublacklist.net/fetch/{latest_rev}",
timeout=180,
verify=False,
) # nosec: B501
r.raise_for_status() r.raise_for_status()
zip_file = ZipFile(BytesIO(r.content)) zip_file = ZipFile(BytesIO(r.content))
self._data = zip_file.read("dump.xml") self._data = zip_file.read("dump.xml")
@ -51,26 +56,33 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
except requests.HTTPError: except requests.HTTPError:
activity = Activity( activity = Activity(
activity_type="automation", activity_type="automation",
text=(f"[{self.short_name}] 🚨 Unable to download dump {latest_rev} due to HTTP error {r.status_code}. " text=(
"The automation task has not been disabled and will attempt to download the next dump when the " f"[{self.short_name}] 🚨 Unable to download dump {latest_rev} due to HTTP error {r.status_code}. "
"latest dump revision is incremented at the server.")) "The automation task has not been disabled and will attempt to download the next dump when the "
"latest dump revision is incremented at the server."
),
)
activity.notify() activity.notify()
db.session.add(activity) db.session.add(activity)
db.session.commit() db.session.commit()
except BadZipFile: except BadZipFile:
activity = Activity( activity = Activity(
activity_type="automation", activity_type="automation",
text=(f"[{self.short_name}] 🚨 Unable to extract zip file from dump {latest_rev}. There was an error " text=(
"related to the format of the zip file. " f"[{self.short_name}] 🚨 Unable to extract zip file from dump {latest_rev}. There was an error "
"The automation task has not been disabled and will attempt to download the next dump when the " "related to the format of the zip file. "
"latest dump revision is incremented at the server.")) "The automation task has not been disabled and will attempt to download the next dump when the "
"latest dump revision is incremented at the server."
),
)
activity.notify() activity.notify()
db.session.add(activity) db.session.add(activity)
db.session.commit() db.session.commit()
def fetch(self) -> None: def fetch(self) -> None:
state: Optional[TerraformState] = TerraformState.query.filter( state: Optional[TerraformState] = TerraformState.query.filter(
TerraformState.key == "block_roskomsvoboda").first() TerraformState.key == "block_roskomsvoboda"
).first()
if state is None: if state is None:
state = TerraformState() state = TerraformState()
state.key = "block_roskomsvoboda" state.key = "block_roskomsvoboda"
@ -80,8 +92,14 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
latest_metadata = json.loads(state.state) latest_metadata = json.loads(state.state)
# This endpoint routinely has an expired certificate, and it's more useful that we are consuming the # This endpoint routinely has an expired certificate, and it's more useful that we are consuming the
# data than that we are verifying the certificate. # data than that we are verifying the certificate.
latest_rev = requests.get("https://dumps.rublacklist.net/fetch/latest", timeout=30, verify=False).text.strip() # nosec: B501 latest_rev = requests.get(
logging.debug("Latest revision is %s, already got %s", latest_rev, latest_metadata["dump_rev"]) "https://dumps.rublacklist.net/fetch/latest", timeout=30, verify=False
).text.strip() # nosec: B501
logging.debug(
"Latest revision is %s, already got %s",
latest_rev,
latest_metadata["dump_rev"],
)
if latest_rev != latest_metadata["dump_rev"]: if latest_rev != latest_metadata["dump_rev"]:
state.state = json.dumps({"dump_rev": latest_rev}) state.state = json.dumps({"dump_rev": latest_rev})
db.session.commit() db.session.commit()
@ -94,18 +112,24 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
logging.debug("No new data to parse") logging.debug("No new data to parse")
return return
try: try:
for _event, element in lxml.etree.iterparse(BytesIO(self._data), for _event, element in lxml.etree.iterparse(
resolve_entities=False): BytesIO(self._data), resolve_entities=False
):
if element.tag == "domain": if element.tag == "domain":
self.patterns["roskomsvoboda"].append("https://" + element.text.strip()) self.patterns["roskomsvoboda"].append(
"https://" + element.text.strip()
)
except XMLSyntaxError: except XMLSyntaxError:
activity = Activity( activity = Activity(
activity_type="automation", activity_type="automation",
text=(f"[{self.short_name}] 🚨 Unable to parse XML file from dump. There was an error " text=(
"related to the format of the XML file within the zip file. Interestingly we were able to " f"[{self.short_name}] 🚨 Unable to parse XML file from dump. There was an error "
"extract the file from the zip file fine. " "related to the format of the XML file within the zip file. Interestingly we were able to "
"The automation task has not been disabled and will attempt to download the next dump when the " "extract the file from the zip file fine. "
"latest dump revision is incremented at the server.")) "The automation task has not been disabled and will attempt to download the next dump when the "
"latest dump revision is incremented at the server."
),
)
activity.notify() activity.notify()
db.session.add(activity) db.session.add(activity)
db.session.commit() db.session.commit()

View file

@ -16,20 +16,32 @@ BridgeResourceRow = Row[Tuple[AbstractResource, BridgeConf, CloudAccount]]
def active_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]: def active_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]:
stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where( stmt = (
CloudAccount.provider == provider, select(Bridge, BridgeConf, CloudAccount)
Bridge.destroyed.is_(None), .join_from(Bridge, BridgeConf)
.join_from(Bridge, CloudAccount)
.where(
CloudAccount.provider == provider,
Bridge.destroyed.is_(None),
)
) )
bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all() bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all()
return bridges return bridges
def recently_destroyed_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]: def recently_destroyed_bridges_by_provider(
provider: CloudProvider,
) -> Sequence[BridgeResourceRow]:
cutoff = datetime.now(tz=timezone.utc) - timedelta(hours=72) cutoff = datetime.now(tz=timezone.utc) - timedelta(hours=72)
stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where( stmt = (
CloudAccount.provider == provider, select(Bridge, BridgeConf, CloudAccount)
Bridge.destroyed.is_not(None), .join_from(Bridge, BridgeConf)
Bridge.destroyed >= cutoff, .join_from(Bridge, CloudAccount)
.where(
CloudAccount.provider == provider,
Bridge.destroyed.is_not(None),
Bridge.destroyed >= cutoff,
)
) )
bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all() bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all()
return bridges return bridges
@ -60,35 +72,38 @@ class BridgeAutomation(TerraformAutomation):
self.template, self.template,
active_resources=active_bridges_by_provider(self.provider), active_resources=active_bridges_by_provider(self.provider),
destroyed_resources=recently_destroyed_bridges_by_provider(self.provider), destroyed_resources=recently_destroyed_bridges_by_provider(self.provider),
global_namespace=app.config['GLOBAL_NAMESPACE'], global_namespace=app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{ backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""", }}""",
**{ **{k: app.config[k.upper()] for k in self.template_parameters},
k: app.config[k.upper()]
for k in self.template_parameters
}
) )
def tf_posthook(self, *, prehook_result: Any = None) -> None: def tf_posthook(self, *, prehook_result: Any = None) -> None:
outputs = self.tf_output() outputs = self.tf_output()
for output in outputs: for output in outputs:
if output.startswith('bridge_hashed_fingerprint_'): if output.startswith("bridge_hashed_fingerprint_"):
parts = outputs[output]['value'].split(" ") parts = outputs[output]["value"].split(" ")
if len(parts) < 2: if len(parts) < 2:
continue continue
bridge = Bridge.query.filter(Bridge.id == output[len('bridge_hashed_fingerprint_'):]).first() bridge = Bridge.query.filter(
Bridge.id == output[len("bridge_hashed_fingerprint_") :]
).first()
bridge.nickname = parts[0] bridge.nickname = parts[0]
bridge.hashed_fingerprint = parts[1] bridge.hashed_fingerprint = parts[1]
bridge.terraform_updated = datetime.now(tz=timezone.utc) bridge.terraform_updated = datetime.now(tz=timezone.utc)
if output.startswith('bridge_bridgeline_'): if output.startswith("bridge_bridgeline_"):
parts = outputs[output]['value'].split(" ") parts = outputs[output]["value"].split(" ")
if len(parts) < 4: if len(parts) < 4:
continue continue
bridge = Bridge.query.filter(Bridge.id == output[len('bridge_bridgeline_'):]).first() bridge = Bridge.query.filter(
Bridge.id == output[len("bridge_bridgeline_") :]
).first()
del parts[3] del parts[3]
bridge.bridgeline = " ".join(parts) bridge.bridgeline = " ".join(parts)
bridge.terraform_updated = datetime.now(tz=timezone.utc) bridge.terraform_updated = datetime.now(tz=timezone.utc)

View file

@ -7,10 +7,7 @@ class BridgeGandiAutomation(BridgeAutomation):
description = "Deploy Tor bridges on GandiCloud VPS" description = "Deploy Tor bridges on GandiCloud VPS"
provider = CloudProvider.GANDI provider = CloudProvider.GANDI
template_parameters = [ template_parameters = ["ssh_public_key_path", "ssh_private_key_path"]
"ssh_public_key_path",
"ssh_private_key_path"
]
template = """ template = """
terraform { terraform {

View file

@ -7,10 +7,7 @@ class BridgeHcloudAutomation(BridgeAutomation):
description = "Deploy Tor bridges on Hetzner Cloud" description = "Deploy Tor bridges on Hetzner Cloud"
provider = CloudProvider.HCLOUD provider = CloudProvider.HCLOUD
template_parameters = [ template_parameters = ["ssh_private_key_path", "ssh_public_key_path"]
"ssh_private_key_path",
"ssh_public_key_path"
]
template = """ template = """
terraform { terraform {

View file

@ -25,10 +25,17 @@ def active_bridges_in_account(account: CloudAccount) -> List[Bridge]:
return bridges return bridges
def create_bridges_in_account(bridgeconf: BridgeConf, account: CloudAccount, count: int) -> int: def create_bridges_in_account(
bridgeconf: BridgeConf, account: CloudAccount, count: int
) -> int:
created = 0 created = 0
while created < count and len(active_bridges_in_account(account)) < account.max_instances: while (
logging.debug("Creating bridge for configuration %s in account %s", bridgeconf.id, account) created < count
and len(active_bridges_in_account(account)) < account.max_instances
):
logging.debug(
"Creating bridge for configuration %s in account %s", bridgeconf.id, account
)
bridge = Bridge() bridge = Bridge()
bridge.pool_id = bridgeconf.pool.id bridge.pool_id = bridgeconf.pool.id
bridge.conf_id = bridgeconf.id bridge.conf_id = bridgeconf.id
@ -45,16 +52,18 @@ def create_bridges_by_cost(bridgeconf: BridgeConf, count: int) -> int:
""" """
Creates bridge resources for the given bridge configuration using the cheapest available provider. Creates bridge resources for the given bridge configuration using the cheapest available provider.
""" """
logging.debug("Creating %s bridges by cost for configuration %s", count, bridgeconf.id) logging.debug(
"Creating %s bridges by cost for configuration %s", count, bridgeconf.id
)
created = 0 created = 0
for provider in BRIDGE_PROVIDERS: for provider in BRIDGE_PROVIDERS:
if created >= count: if created >= count:
break break
logging.info("Creating bridges in %s accounts", provider.description) logging.info("Creating bridges in %s accounts", provider.description)
for account in CloudAccount.query.filter( for account in CloudAccount.query.filter(
CloudAccount.destroyed.is_(None), CloudAccount.destroyed.is_(None),
CloudAccount.enabled.is_(True), CloudAccount.enabled.is_(True),
CloudAccount.provider == provider, CloudAccount.provider == provider,
).all(): ).all():
logging.info("Creating bridges in %s", account) logging.info("Creating bridges in %s", account)
created += create_bridges_in_account(bridgeconf, account, count - created) created += create_bridges_in_account(bridgeconf, account, count - created)
@ -78,7 +87,9 @@ def create_bridges_by_random(bridgeconf: BridgeConf, count: int) -> int:
""" """
Creates bridge resources for the given bridge configuration using random providers. Creates bridge resources for the given bridge configuration using random providers.
""" """
logging.debug("Creating %s bridges by random for configuration %s", count, bridgeconf.id) logging.debug(
"Creating %s bridges by random for configuration %s", count, bridgeconf.id
)
created = 0 created = 0
while candidate_accounts := _accounts_with_room(): while candidate_accounts := _accounts_with_room():
# Not security-critical random number generation # Not security-critical random number generation
@ -97,16 +108,24 @@ def create_bridges(bridgeconf: BridgeConf, count: int) -> int:
return create_bridges_by_random(bridgeconf, count) return create_bridges_by_random(bridgeconf, count)
def deprecate_bridges(bridgeconf: BridgeConf, count: int, reason: str = "redundant") -> int: def deprecate_bridges(
logging.debug("Deprecating %s bridges (%s) for configuration %s", count, reason, bridgeconf.id) bridgeconf: BridgeConf, count: int, reason: str = "redundant"
) -> int:
logging.debug(
"Deprecating %s bridges (%s) for configuration %s", count, reason, bridgeconf.id
)
deprecated = 0 deprecated = 0
active_conf_bridges = iter(Bridge.query.filter( active_conf_bridges = iter(
Bridge.conf_id == bridgeconf.id, Bridge.query.filter(
Bridge.deprecated.is_(None), Bridge.conf_id == bridgeconf.id,
Bridge.destroyed.is_(None), Bridge.deprecated.is_(None),
).all()) Bridge.destroyed.is_(None),
).all()
)
while deprecated < count: while deprecated < count:
logging.debug("Deprecating bridge %s for configuration %s", deprecated + 1, bridgeconf.id) logging.debug(
"Deprecating bridge %s for configuration %s", deprecated + 1, bridgeconf.id
)
bridge = next(active_conf_bridges) bridge = next(active_conf_bridges)
logging.debug("Bridge %r", bridge) logging.debug("Bridge %r", bridge)
bridge.deprecate(reason=reason) bridge.deprecate(reason=reason)
@ -129,7 +148,9 @@ class BridgeMetaAutomation(BaseAutomation):
for bridge in deprecated_bridges: for bridge in deprecated_bridges:
if bridge.deprecated is None: if bridge.deprecated is None:
continue # Possible due to SQLAlchemy lazy loading continue # Possible due to SQLAlchemy lazy loading
cutoff = datetime.now(tz=timezone.utc) - timedelta(hours=bridge.conf.expiry_hours) cutoff = datetime.now(tz=timezone.utc) - timedelta(
hours=bridge.conf.expiry_hours
)
if bridge.deprecated < cutoff: if bridge.deprecated < cutoff:
logging.debug("Destroying expired bridge") logging.debug("Destroying expired bridge")
bridge.destroy() bridge.destroy()
@ -146,7 +167,9 @@ class BridgeMetaAutomation(BaseAutomation):
activate_bridgeconfs = BridgeConf.query.filter( activate_bridgeconfs = BridgeConf.query.filter(
BridgeConf.destroyed.is_(None), BridgeConf.destroyed.is_(None),
).all() ).all()
logging.debug("Found %s active bridge configurations", len(activate_bridgeconfs)) logging.debug(
"Found %s active bridge configurations", len(activate_bridgeconfs)
)
for bridgeconf in activate_bridgeconfs: for bridgeconf in activate_bridgeconfs:
active_conf_bridges = Bridge.query.filter( active_conf_bridges = Bridge.query.filter(
Bridge.conf_id == bridgeconf.id, Bridge.conf_id == bridgeconf.id,
@ -157,16 +180,18 @@ class BridgeMetaAutomation(BaseAutomation):
Bridge.conf_id == bridgeconf.id, Bridge.conf_id == bridgeconf.id,
Bridge.destroyed.is_(None), Bridge.destroyed.is_(None),
).all() ).all()
logging.debug("Generating new bridges for %s (active: %s, total: %s, target: %s, max: %s)", logging.debug(
bridgeconf.id, "Generating new bridges for %s (active: %s, total: %s, target: %s, max: %s)",
len(active_conf_bridges), bridgeconf.id,
len(total_conf_bridges), len(active_conf_bridges),
bridgeconf.target_number, len(total_conf_bridges),
bridgeconf.max_number bridgeconf.target_number,
) bridgeconf.max_number,
)
missing = min( missing = min(
bridgeconf.target_number - len(active_conf_bridges), bridgeconf.target_number - len(active_conf_bridges),
bridgeconf.max_number - len(total_conf_bridges)) bridgeconf.max_number - len(total_conf_bridges),
)
if missing > 0: if missing > 0:
create_bridges(bridgeconf, missing) create_bridges(bridgeconf, missing)
elif missing < 0: elif missing < 0:

View file

@ -7,10 +7,7 @@ class BridgeOvhAutomation(BridgeAutomation):
description = "Deploy Tor bridges on OVH Public Cloud" description = "Deploy Tor bridges on OVH Public Cloud"
provider = CloudProvider.OVH provider = CloudProvider.OVH
template_parameters = [ template_parameters = ["ssh_public_key_path", "ssh_private_key_path"]
"ssh_public_key_path",
"ssh_private_key_path"
]
template = """ template = """
terraform { terraform {

View file

@ -11,14 +11,12 @@ from app.terraform.eotk import eotk_configuration
from app.terraform.terraform import TerraformAutomation from app.terraform.terraform import TerraformAutomation
def update_eotk_instance(group_id: int, def update_eotk_instance(group_id: int, region: str, instance_id: str) -> None:
region: str,
instance_id: str) -> None:
instance = Eotk.query.filter( instance = Eotk.query.filter(
Eotk.group_id == group_id, Eotk.group_id == group_id,
Eotk.region == region, Eotk.region == region,
Eotk.provider == "aws", Eotk.provider == "aws",
Eotk.destroyed.is_(None) Eotk.destroyed.is_(None),
).first() ).first()
if instance is None: if instance is None:
instance = Eotk() instance = Eotk()
@ -35,10 +33,7 @@ class EotkAWSAutomation(TerraformAutomation):
short_name = "eotk_aws" short_name = "eotk_aws"
description = "Deploy EOTK instances to AWS" description = "Deploy EOTK instances to AWS"
template_parameters = [ template_parameters = ["aws_access_key", "aws_secret_key"]
"aws_access_key",
"aws_secret_key"
]
template = """ template = """
terraform { terraform {
@ -81,32 +76,41 @@ class EotkAWSAutomation(TerraformAutomation):
self.tf_write( self.tf_write(
self.template, self.template,
groups=Group.query.filter( groups=Group.query.filter(
Group.eotk.is_(True), Group.eotk.is_(True), Group.destroyed.is_(None)
Group.destroyed.is_(None)
).all(), ).all(),
global_namespace=app.config['GLOBAL_NAMESPACE'], global_namespace=app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{ backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""", }}""",
**{ **{k: app.config[k.upper()] for k in self.template_parameters},
k: app.config[k.upper()]
for k in self.template_parameters
}
) )
for group in Group.query.filter( for group in (
Group.eotk.is_(True), Group.query.filter(Group.eotk.is_(True), Group.destroyed.is_(None))
Group.destroyed.is_(None) .order_by(Group.id)
).order_by(Group.id).all(): .all()
with DeterministicZip(os.path.join(self.working_dir, f"{group.id}.zip")) as dzip: ):
dzip.add_file("sites.conf", eotk_configuration(group).encode('utf-8')) with DeterministicZip(
os.path.join(self.working_dir, f"{group.id}.zip")
) as dzip:
dzip.add_file("sites.conf", eotk_configuration(group).encode("utf-8"))
for onion in sorted(group.onions, key=lambda o: o.onion_name): for onion in sorted(group.onions, key=lambda o: o.onion_name):
dzip.add_file(f"{onion.onion_name}.v3pub.key", onion.onion_public_key) dzip.add_file(
dzip.add_file(f"{onion.onion_name}.v3sec.key", onion.onion_private_key) f"{onion.onion_name}.v3pub.key", onion.onion_public_key
dzip.add_file(f"{onion.onion_name[:20]}-v3.cert", onion.tls_public_key) )
dzip.add_file(f"{onion.onion_name[:20]}-v3.pem", onion.tls_private_key) dzip.add_file(
f"{onion.onion_name}.v3sec.key", onion.onion_private_key
)
dzip.add_file(
f"{onion.onion_name[:20]}-v3.cert", onion.tls_public_key
)
dzip.add_file(
f"{onion.onion_name[:20]}-v3.pem", onion.tls_private_key
)
def tf_posthook(self, *, prehook_result: Any = None) -> None: def tf_posthook(self, *, prehook_result: Any = None) -> None:
for e in Eotk.query.all(): for e in Eotk.query.all():
@ -115,9 +119,9 @@ class EotkAWSAutomation(TerraformAutomation):
for output in outputs: for output in outputs:
if output.startswith("eotk_instances_"): if output.startswith("eotk_instances_"):
try: try:
group_id = int(output[len("eotk_instance_") + 1:]) group_id = int(output[len("eotk_instance_") + 1 :])
for az in outputs[output]['value']: for az in outputs[output]["value"]:
update_eotk_instance(group_id, az, outputs[output]['value'][az]) update_eotk_instance(group_id, az, outputs[output]["value"][az])
except ValueError: except ValueError:
pass pass
db.session.commit() db.session.commit()

View file

@ -55,26 +55,36 @@ class ListAutomation(TerraformAutomation):
MirrorList.destroyed.is_(None), MirrorList.destroyed.is_(None),
MirrorList.provider == self.provider, MirrorList.provider == self.provider,
).all(), ).all(),
global_namespace=app.config['GLOBAL_NAMESPACE'], global_namespace=app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{ backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""", }}""",
**{ **{k: app.config[k.upper()] for k in self.template_parameters},
k: app.config[k.upper()]
for k in self.template_parameters
}
) )
for pool in Pool.query.filter(Pool.destroyed.is_(None)).all(): for pool in Pool.query.filter(Pool.destroyed.is_(None)).all():
for key, formatter in lists.items(): for key, formatter in lists.items():
formatted_pool = formatter(pool) formatted_pool = formatter(pool)
for obfuscate in [True, False]: for obfuscate in [True, False]:
with open(os.path.join( with open(
self.working_dir, f"{key}.{pool.pool_name}{'.jsno' if obfuscate else '.json'}"), os.path.join(
'w', encoding="utf-8") as out: self.working_dir,
f"{key}.{pool.pool_name}{'.jsno' if obfuscate else '.json'}",
),
"w",
encoding="utf-8",
) as out:
out.write(json_encode(formatted_pool, obfuscate)) out.write(json_encode(formatted_pool, obfuscate))
with open(os.path.join(self.working_dir, f"{key}.{pool.pool_name}{'.jso' if obfuscate else '.js'}"), with open(
'w', encoding="utf-8") as out: os.path.join(
self.working_dir,
f"{key}.{pool.pool_name}{'.jso' if obfuscate else '.js'}",
),
"w",
encoding="utf-8",
) as out:
out.write(javascript_encode(formatted_pool, obfuscate)) out.write(javascript_encode(formatted_pool, obfuscate))

View file

@ -11,9 +11,7 @@ class ListGithubAutomation(ListAutomation):
# TODO: file an issue in the github about this, GitLab had a similar issue but fixed it # TODO: file an issue in the github about this, GitLab had a similar issue but fixed it
parallelism = 1 parallelism = 1
template_parameters = [ template_parameters = ["github_api_key"]
"github_api_key"
]
template = """ template = """
terraform { terraform {

View file

@ -15,7 +15,7 @@ class ListGitlabAutomation(ListAutomation):
"gitlab_token", "gitlab_token",
"gitlab_author_email", "gitlab_author_email",
"gitlab_author_name", "gitlab_author_name",
"gitlab_commit_message" "gitlab_commit_message",
] ]
template = """ template = """
@ -56,5 +56,5 @@ class ListGitlabAutomation(ListAutomation):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if 'GITLAB_URL' in current_app.config: if "GITLAB_URL" in current_app.config:
self.template_parameters.append("gitlab_url") self.template_parameters.append("gitlab_url")

View file

@ -6,10 +6,7 @@ class ListS3Automation(ListAutomation):
description = "Update mirror lists in AWS S3 buckets" description = "Update mirror lists in AWS S3 buckets"
provider = "s3" provider = "s3"
template_parameters = [ template_parameters = ["aws_access_key", "aws_secret_key"]
"aws_access_key",
"aws_secret_key"
]
template = """ template = """
terraform { terraform {

View file

@ -15,15 +15,14 @@ from app.models.mirrors import Origin, Proxy, SmartProxy
from app.terraform.terraform import TerraformAutomation from app.terraform.terraform import TerraformAutomation
def update_smart_proxy_instance(group_id: int, def update_smart_proxy_instance(
provider: str, group_id: int, provider: str, region: str, instance_id: str
region: str, ) -> None:
instance_id: str) -> None:
instance = SmartProxy.query.filter( instance = SmartProxy.query.filter(
SmartProxy.group_id == group_id, SmartProxy.group_id == group_id,
SmartProxy.region == region, SmartProxy.region == region,
SmartProxy.provider == provider, SmartProxy.provider == provider,
SmartProxy.destroyed.is_(None) SmartProxy.destroyed.is_(None),
).first() ).first()
if instance is None: if instance is None:
instance = SmartProxy() instance = SmartProxy()
@ -93,16 +92,21 @@ class ProxyAutomation(TerraformAutomation):
self.template, self.template,
groups=groups, groups=groups,
proxies=Proxy.query.filter( proxies=Proxy.query.filter(
Proxy.provider == self.provider, Proxy.destroyed.is_(None)).all(), Proxy.provider == self.provider, Proxy.destroyed.is_(None)
).all(),
subgroups=self.get_subgroups(), subgroups=self.get_subgroups(),
global_namespace=app.config['GLOBAL_NAMESPACE'], bypass_token=app.config['BYPASS_TOKEN'], global_namespace=app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'), bypass_token=app.config["BYPASS_TOKEN"],
terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{ backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""", }}""",
**{k: app.config[k.upper()] for k in self.template_parameters}) **{k: app.config[k.upper()] for k in self.template_parameters},
)
if self.smart_proxies: if self.smart_proxies:
for group in groups: for group in groups:
self.sp_config(group) self.sp_config(group)
@ -111,9 +115,11 @@ class ProxyAutomation(TerraformAutomation):
group_origins: List[Origin] = Origin.query.filter( group_origins: List[Origin] = Origin.query.filter(
Origin.group_id == group.id, Origin.group_id == group.id,
Origin.destroyed.is_(None), Origin.destroyed.is_(None),
Origin.smart.is_(True) Origin.smart.is_(True),
).all() ).all()
self.tmpl_write(f"smart_proxy.{group.id}.conf", """ self.tmpl_write(
f"smart_proxy.{group.id}.conf",
"""
{% for origin in origins %} {% for origin in origins %}
server { server {
listen 443 ssl; listen 443 ssl;
@ -173,23 +179,28 @@ class ProxyAutomation(TerraformAutomation):
} }
{% endfor %} {% endfor %}
""", """,
provider=self.provider, provider=self.provider,
origins=group_origins, origins=group_origins,
smart_zone=app.config['SMART_ZONE']) smart_zone=app.config["SMART_ZONE"],
)
@classmethod @classmethod
def get_subgroups(cls) -> Dict[int, Dict[int, int]]: def get_subgroups(cls) -> Dict[int, Dict[int, int]]:
conn = db.engine.connect() conn = db.engine.connect()
stmt = text(""" stmt = text(
"""
SELECT origin.group_id, proxy.psg, COUNT(proxy.id) FROM proxy, origin SELECT origin.group_id, proxy.psg, COUNT(proxy.id) FROM proxy, origin
WHERE proxy.origin_id = origin.id WHERE proxy.origin_id = origin.id
AND proxy.destroyed IS NULL AND proxy.destroyed IS NULL
AND proxy.provider = :provider AND proxy.provider = :provider
GROUP BY origin.group_id, proxy.psg; GROUP BY origin.group_id, proxy.psg;
""") """
)
stmt = stmt.bindparams(provider=cls.provider) stmt = stmt.bindparams(provider=cls.provider)
result = conn.execute(stmt).all() result = conn.execute(stmt).all()
subgroups: Dict[int, Dict[int, int]] = defaultdict(lambda: defaultdict(lambda: 0)) subgroups: Dict[int, Dict[int, int]] = defaultdict(
lambda: defaultdict(lambda: 0)
)
for row in result: for row in result:
subgroups[row[0]][row[1]] = row[2] subgroups[row[0]][row[1]] = row[2]
return subgroups return subgroups

View file

@ -21,7 +21,7 @@ class ProxyAzureCdnAutomation(ProxyAutomation):
"azure_client_secret", "azure_client_secret",
"azure_subscription_id", "azure_subscription_id",
"azure_tenant_id", "azure_tenant_id",
"smart_zone" "smart_zone",
] ]
template = """ template = """
@ -162,8 +162,7 @@ class ProxyAzureCdnAutomation(ProxyAutomation):
def import_state(self, state: Optional[Any]) -> None: def import_state(self, state: Optional[Any]) -> None:
proxies = Proxy.query.filter( proxies = Proxy.query.filter(
Proxy.provider == self.provider, Proxy.provider == self.provider, Proxy.destroyed.is_(None)
Proxy.destroyed.is_(None)
).all() ).all()
for proxy in proxies: for proxy in proxies:
proxy.url = f"https://{proxy.slug}.azureedge.net" proxy.url = f"https://{proxy.slug}.azureedge.net"

View file

@ -17,7 +17,7 @@ class ProxyCloudfrontAutomation(ProxyAutomation):
"admin_email", "admin_email",
"aws_access_key", "aws_access_key",
"aws_secret_key", "aws_secret_key",
"smart_zone" "smart_zone",
] ]
template = """ template = """
@ -111,26 +111,35 @@ class ProxyCloudfrontAutomation(ProxyAutomation):
def import_state(self, state: Any) -> None: def import_state(self, state: Any) -> None:
if not isinstance(state, dict): if not isinstance(state, dict):
raise RuntimeError("The Terraform state object returned was not a dict.") raise RuntimeError("The Terraform state object returned was not a dict.")
if "child_modules" not in state['values']['root_module']: if "child_modules" not in state["values"]["root_module"]:
# There are no CloudFront proxies deployed to import state for # There are no CloudFront proxies deployed to import state for
return return
# CloudFront distributions (proxies) # CloudFront distributions (proxies)
for mod in state['values']['root_module']['child_modules']: for mod in state["values"]["root_module"]["child_modules"]:
if mod['address'].startswith('module.cloudfront_'): if mod["address"].startswith("module.cloudfront_"):
for res in mod['resources']: for res in mod["resources"]:
if res['address'].endswith('aws_cloudfront_distribution.this'): if res["address"].endswith("aws_cloudfront_distribution.this"):
proxy = Proxy.query.filter(Proxy.id == mod['address'][len('module.cloudfront_'):]).first() proxy = Proxy.query.filter(
proxy.url = "https://" + res['values']['domain_name'] Proxy.id == mod["address"][len("module.cloudfront_") :]
proxy.slug = res['values']['id'] ).first()
proxy.url = "https://" + res["values"]["domain_name"]
proxy.slug = res["values"]["id"]
proxy.terraform_updated = datetime.now(tz=timezone.utc) proxy.terraform_updated = datetime.now(tz=timezone.utc)
break break
# EC2 instances (smart proxies) # EC2 instances (smart proxies)
for g in state["values"]["root_module"]["child_modules"]: for g in state["values"]["root_module"]["child_modules"]:
if g["address"].startswith("module.smart_proxy_"): if g["address"].startswith("module.smart_proxy_"):
group_id = int(g["address"][len("module.smart_proxy_"):]) group_id = int(g["address"][len("module.smart_proxy_") :])
for s in g["child_modules"]: for s in g["child_modules"]:
if s["address"].endswith(".module.instance"): if s["address"].endswith(".module.instance"):
for x in s["resources"]: for x in s["resources"]:
if x["address"].endswith(".module.instance.aws_instance.default[0]"): if x["address"].endswith(
update_smart_proxy_instance(group_id, self.provider, "us-east-2a", x['values']['id']) ".module.instance.aws_instance.default[0]"
):
update_smart_proxy_instance(
group_id,
self.provider,
"us-east-2a",
x["values"]["id"],
)
db.session.commit() db.session.commit()

View file

@ -14,11 +14,7 @@ class ProxyFastlyAutomation(ProxyAutomation):
subgroup_members_max = 20 subgroup_members_max = 20
cloud_name = "fastly" cloud_name = "fastly"
template_parameters = [ template_parameters = ["aws_access_key", "aws_secret_key", "fastly_api_key"]
"aws_access_key",
"aws_secret_key",
"fastly_api_key"
]
template = """ template = """
terraform { terraform {
@ -125,13 +121,14 @@ class ProxyFastlyAutomation(ProxyAutomation):
Constructor method. Constructor method.
""" """
# Requires Flask application context to read configuration # Requires Flask application context to read configuration
self.subgroup_members_max = min(current_app.config.get("FASTLY_MAX_BACKENDS", 5), 20) self.subgroup_members_max = min(
current_app.config.get("FASTLY_MAX_BACKENDS", 5), 20
)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def import_state(self, state: Optional[Any]) -> None: def import_state(self, state: Optional[Any]) -> None:
proxies = Proxy.query.filter( proxies = Proxy.query.filter(
Proxy.provider == self.provider, Proxy.provider == self.provider, Proxy.destroyed.is_(None)
Proxy.destroyed.is_(None)
).all() ).all()
for proxy in proxies: for proxy in proxies:
proxy.url = f"https://{proxy.slug}.global.ssl.fastly.net" proxy.url = f"https://{proxy.slug}.global.ssl.fastly.net"

View file

@ -18,12 +18,16 @@ from app.terraform.proxy.azure_cdn import ProxyAzureCdnAutomation
from app.terraform.proxy.cloudfront import ProxyCloudfrontAutomation from app.terraform.proxy.cloudfront import ProxyCloudfrontAutomation
from app.terraform.proxy.fastly import ProxyFastlyAutomation from app.terraform.proxy.fastly import ProxyFastlyAutomation
PROXY_PROVIDERS: Dict[str, Type[ProxyAutomation]] = {p.provider: p for p in [ # type: ignore[attr-defined] PROXY_PROVIDERS: Dict[str, Type[ProxyAutomation]] = {
# In order of preference p.provider: p # type: ignore[attr-defined]
ProxyCloudfrontAutomation, for p in [
ProxyFastlyAutomation, # In order of preference
ProxyAzureCdnAutomation ProxyCloudfrontAutomation,
] if p.enabled} # type: ignore[attr-defined] ProxyFastlyAutomation,
ProxyAzureCdnAutomation,
]
if p.enabled # type: ignore[attr-defined]
}
SubgroupCount = OrderedDictT[str, OrderedDictT[int, OrderedDictT[int, int]]] SubgroupCount = OrderedDictT[str, OrderedDictT[int, OrderedDictT[int, int]]]
@ -61,8 +65,9 @@ def random_slug(origin_domain_name: str) -> str:
"exampasdfghjkl" "exampasdfghjkl"
""" """
# The random slug doesn't need to be cryptographically secure, hence the use of `# nosec` # The random slug doesn't need to be cryptographically secure, hence the use of `# nosec`
return tldextract.extract(origin_domain_name).domain[:5] + ''.join( return tldextract.extract(origin_domain_name).domain[:5] + "".join(
random.choices(string.ascii_lowercase, k=12)) # nosec random.choices(string.ascii_lowercase, k=12) # nosec: B311
)
def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupCount: def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupCount:
@ -95,8 +100,13 @@ def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupC
return subgroup_count return subgroup_count
def next_subgroup(subgroup_count: SubgroupCount, provider: str, group_id: int, max_subgroup_count: int, def next_subgroup(
max_subgroup_members: int) -> Optional[int]: subgroup_count: SubgroupCount,
provider: str,
group_id: int,
max_subgroup_count: int,
max_subgroup_members: int,
) -> Optional[int]:
""" """
Find the first available subgroup with less than the specified maximum count in the specified provider and group. Find the first available subgroup with less than the specified maximum count in the specified provider and group.
If the last subgroup in the group is full, return the next subgroup number as long as it doesn't exceed If the last subgroup in the group is full, return the next subgroup number as long as it doesn't exceed
@ -137,27 +147,36 @@ def auto_deprecate_proxies() -> None:
- The "max_age_reached" reason means the proxy has been in use for longer than the maximum allowed period. - The "max_age_reached" reason means the proxy has been in use for longer than the maximum allowed period.
The maximum age cutoff is randomly set to a time between 24 and 48 hours. The maximum age cutoff is randomly set to a time between 24 and 48 hours.
""" """
origin_destroyed_proxies = (db.session.query(Proxy) origin_destroyed_proxies = (
.join(Origin, Proxy.origin_id == Origin.id) db.session.query(Proxy)
.filter(Proxy.destroyed.is_(None), .join(Origin, Proxy.origin_id == Origin.id)
Proxy.deprecated.is_(None), .filter(
Origin.destroyed.is_not(None)) Proxy.destroyed.is_(None),
.all()) Proxy.deprecated.is_(None),
Origin.destroyed.is_not(None),
)
.all()
)
logging.debug("Origin destroyed: %s", origin_destroyed_proxies) logging.debug("Origin destroyed: %s", origin_destroyed_proxies)
for proxy in origin_destroyed_proxies: for proxy in origin_destroyed_proxies:
proxy.deprecate(reason="origin_destroyed") proxy.deprecate(reason="origin_destroyed")
max_age_proxies = (db.session.query(Proxy) max_age_proxies = (
.join(Origin, Proxy.origin_id == Origin.id) db.session.query(Proxy)
.filter(Proxy.destroyed.is_(None), .join(Origin, Proxy.origin_id == Origin.id)
Proxy.deprecated.is_(None), .filter(
Proxy.pool_id != -1, # do not rotate hotspare proxies Proxy.destroyed.is_(None),
Origin.assets, Proxy.deprecated.is_(None),
Origin.auto_rotation) Proxy.pool_id != -1, # do not rotate hotspare proxies
.all()) Origin.assets,
Origin.auto_rotation,
)
.all()
)
logging.debug("Max age: %s", max_age_proxies) logging.debug("Max age: %s", max_age_proxies)
for proxy in max_age_proxies: for proxy in max_age_proxies:
max_age_cutoff = datetime.now(timezone.utc) - timedelta( max_age_cutoff = datetime.now(timezone.utc) - timedelta(
days=1, seconds=86400 * random.random()) # nosec: B311 days=1, seconds=86400 * random.random() # nosec: B311
)
if proxy.added < max_age_cutoff: if proxy.added < max_age_cutoff:
proxy.deprecate(reason="max_age_reached") proxy.deprecate(reason="max_age_reached")
@ -171,8 +190,7 @@ def destroy_expired_proxies() -> None:
""" """
expiry_cutoff = datetime.now(timezone.utc) - timedelta(days=4) expiry_cutoff = datetime.now(timezone.utc) - timedelta(days=4)
proxies = Proxy.query.filter( proxies = Proxy.query.filter(
Proxy.destroyed.is_(None), Proxy.destroyed.is_(None), Proxy.deprecated < expiry_cutoff
Proxy.deprecated < expiry_cutoff
).all() ).all()
for proxy in proxies: for proxy in proxies:
logging.debug("Destroying expired proxy") logging.debug("Destroying expired proxy")
@ -244,12 +262,17 @@ class ProxyMetaAutomation(BaseAutomation):
if origin.destroyed is not None: if origin.destroyed is not None:
continue continue
proxies = [ proxies = [
x for x in origin.proxies x
if x.pool_id == pool.id and x.deprecated is None and x.destroyed is None for x in origin.proxies
if x.pool_id == pool.id
and x.deprecated is None
and x.destroyed is None
] ]
logging.debug("Proxies for group %s: %s", group.group_name, proxies) logging.debug("Proxies for group %s: %s", group.group_name, proxies)
if not proxies: if not proxies:
logging.debug("Creating new proxy for %s in pool %s", origin, pool) logging.debug(
"Creating new proxy for %s in pool %s", origin, pool
)
if not promote_hot_spare_proxy(pool.id, origin): if not promote_hot_spare_proxy(pool.id, origin):
# No "hot spare" available # No "hot spare" available
self.create_proxy(pool.id, origin) self.create_proxy(pool.id, origin)
@ -270,8 +293,13 @@ class ProxyMetaAutomation(BaseAutomation):
""" """
for provider in PROXY_PROVIDERS.values(): for provider in PROXY_PROVIDERS.values():
logging.debug("Looking at provider %s", provider.provider) logging.debug("Looking at provider %s", provider.provider)
subgroup = next_subgroup(self.subgroup_count, provider.provider, origin.group_id, subgroup = next_subgroup(
provider.subgroup_members_max, provider.subgroup_count_max) self.subgroup_count,
provider.provider,
origin.group_id,
provider.subgroup_members_max,
provider.subgroup_count_max,
)
if subgroup is None: if subgroup is None:
continue # Exceeded maximum number of subgroups and last subgroup is full continue # Exceeded maximum number of subgroups and last subgroup is full
self.increment_subgroup(provider.provider, origin.group_id, subgroup) self.increment_subgroup(provider.provider, origin.group_id, subgroup)
@ -317,9 +345,7 @@ class ProxyMetaAutomation(BaseAutomation):
If an origin is not destroyed and lacks active proxies (not deprecated and not destroyed), If an origin is not destroyed and lacks active proxies (not deprecated and not destroyed),
a new 'hot spare' proxy for this origin is created in the reserve pool (with pool_id = -1). a new 'hot spare' proxy for this origin is created in the reserve pool (with pool_id = -1).
""" """
origins = Origin.query.filter( origins = Origin.query.filter(Origin.destroyed.is_(None)).all()
Origin.destroyed.is_(None)
).all()
for origin in origins: for origin in origins:
if origin.countries: if origin.countries:
risk_levels = origin.risk_level.items() risk_levels = origin.risk_level.items()
@ -328,7 +354,10 @@ class ProxyMetaAutomation(BaseAutomation):
if highest_risk_level < 4: if highest_risk_level < 4:
for proxy in origin.proxies: for proxy in origin.proxies:
if proxy.destroyed is None and proxy.pool_id == -1: if proxy.destroyed is None and proxy.pool_id == -1:
logging.debug("Destroying hot spare proxy for origin %s (low risk)", origin) logging.debug(
"Destroying hot spare proxy for origin %s (low risk)",
origin,
)
proxy.destroy() proxy.destroy()
continue continue
if origin.destroyed is not None: if origin.destroyed is not None:

View file

@ -15,21 +15,26 @@ from app.terraform.terraform import TerraformAutomation
def import_state(state: Any) -> None: def import_state(state: Any) -> None:
if not isinstance(state, dict): if not isinstance(state, dict):
raise RuntimeError("The Terraform state object returned was not a dict.") raise RuntimeError("The Terraform state object returned was not a dict.")
if "values" not in state or "child_modules" not in state['values']['root_module']: if "values" not in state or "child_modules" not in state["values"]["root_module"]:
# There are no CloudFront origins deployed to import state for # There are no CloudFront origins deployed to import state for
return return
# CloudFront distributions (origins) # CloudFront distributions (origins)
for mod in state['values']['root_module']['child_modules']: for mod in state["values"]["root_module"]["child_modules"]:
if mod['address'].startswith('module.static_'): if mod["address"].startswith("module.static_"):
static_id = mod['address'][len('module.static_'):] static_id = mod["address"][len("module.static_") :]
logging.debug("Found static module in state: %s", static_id) logging.debug("Found static module in state: %s", static_id)
for res in mod['resources']: for res in mod["resources"]:
if res['address'].endswith('aws_cloudfront_distribution.this'): if res["address"].endswith("aws_cloudfront_distribution.this"):
logging.debug("and found related cloudfront distribution") logging.debug("and found related cloudfront distribution")
static = StaticOrigin.query.filter(StaticOrigin.id == static_id).first() static = StaticOrigin.query.filter(
static.origin_domain_name = res['values']['domain_name'] StaticOrigin.id == static_id
logging.debug("and found static origin: %s to update with domain name: %s", static.id, ).first()
static.origin_domain_name) static.origin_domain_name = res["values"]["domain_name"]
logging.debug(
"and found static origin: %s to update with domain name: %s",
static.id,
static.origin_domain_name,
)
static.terraform_updated = datetime.now(tz=timezone.utc) static.terraform_updated = datetime.now(tz=timezone.utc)
break break
db.session.commit() db.session.commit()
@ -128,14 +133,18 @@ class StaticAWSAutomation(TerraformAutomation):
groups=groups, groups=groups,
storage_cloud_accounts=storage_cloud_accounts, storage_cloud_accounts=storage_cloud_accounts,
source_cloud_accounts=source_cloud_accounts, source_cloud_accounts=source_cloud_accounts,
global_namespace=current_app.config['GLOBAL_NAMESPACE'], bypass_token=current_app.config['BYPASS_TOKEN'], global_namespace=current_app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(*list(os.path.split(current_app.root_path))[:-1], 'terraform-modules'), bypass_token=current_app.config["BYPASS_TOKEN"],
terraform_modules_path=os.path.join(
*list(os.path.split(current_app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{ backend_config=f"""backend "http" {{
lock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}" lock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}" unlock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}" address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""", }}""",
**{k: current_app.config[k.upper()] for k in self.template_parameters}) **{k: current_app.config[k.upper()] for k in self.template_parameters},
)
def tf_posthook(self, *, prehook_result: Any = None) -> None: def tf_posthook(self, *, prehook_result: Any = None) -> None:
import_state(self.tf_show()) import_state(self.tf_show())

View file

@ -27,7 +27,9 @@ class StaticMetaAutomation(BaseAutomation):
if static_origin.origin_domain_name is not None: if static_origin.origin_domain_name is not None:
try: try:
# Check if an Origin with the same domain name already exists # Check if an Origin with the same domain name already exists
origin = Origin.query.filter_by(domain_name=static_origin.origin_domain_name).one() origin = Origin.query.filter_by(
domain_name=static_origin.origin_domain_name
).one()
# Keep auto rotation value in sync # Keep auto rotation value in sync
origin.auto_rotation = static_origin.auto_rotate origin.auto_rotation = static_origin.auto_rotate
except NoResultFound: except NoResultFound:
@ -35,17 +37,21 @@ class StaticMetaAutomation(BaseAutomation):
origin = Origin( origin = Origin(
group_id=static_origin.group_id, group_id=static_origin.group_id,
description=f"PORTAL !! DO NOT DELETE !! Automatically created web origin for static origin " description=f"PORTAL !! DO NOT DELETE !! Automatically created web origin for static origin "
f"#{static_origin.id}", f"#{static_origin.id}",
domain_name=static_origin.origin_domain_name, domain_name=static_origin.origin_domain_name,
auto_rotation=static_origin.auto_rotate, auto_rotation=static_origin.auto_rotate,
smart=False, smart=False,
assets=False, assets=False,
) )
db.session.add(origin) db.session.add(origin)
logging.debug(f"Created Origin with domain name {origin.domain_name}") logging.debug(
f"Created Origin with domain name {origin.domain_name}"
)
# Step 2: Remove Origins for StaticOrigins with non-null destroy value # Step 2: Remove Origins for StaticOrigins with non-null destroy value
static_origins_with_destroyed = StaticOrigin.query.filter(StaticOrigin.destroyed.isnot(None)).all() static_origins_with_destroyed = StaticOrigin.query.filter(
StaticOrigin.destroyed.isnot(None)
).all()
for static_origin in static_origins_with_destroyed: for static_origin in static_origins_with_destroyed:
try: try:
origin = Origin.query.filter_by( origin = Origin.query.filter_by(

View file

@ -51,14 +51,20 @@ class TerraformAutomation(BaseAutomation):
prehook_result = self.tf_prehook() # pylint: disable=assignment-from-no-return prehook_result = self.tf_prehook() # pylint: disable=assignment-from-no-return
self.tf_generate() self.tf_generate()
self.tf_init() self.tf_init()
returncode, logs = self.tf_apply(self.working_dir, refresh=self.always_refresh or full) returncode, logs = self.tf_apply(
self.working_dir, refresh=self.always_refresh or full
)
self.tf_posthook(prehook_result=prehook_result) self.tf_posthook(prehook_result=prehook_result)
return returncode == 0, logs return returncode == 0, logs
def tf_apply(self, working_dir: str, *, def tf_apply(
refresh: bool = True, self,
parallelism: Optional[int] = None, working_dir: str,
lock_timeout: int = 15) -> Tuple[int, str]: *,
refresh: bool = True,
parallelism: Optional[int] = None,
lock_timeout: int = 15,
) -> Tuple[int, str]:
if not parallelism: if not parallelism:
parallelism = self.parallelism parallelism = self.parallelism
if not self.working_dir: if not self.working_dir:
@ -67,17 +73,19 @@ class TerraformAutomation(BaseAutomation):
# the argument list as an array such that argument injection would be # the argument list as an array such that argument injection would be
# ineffective. # ineffective.
tfcmd = subprocess.run( # nosec tfcmd = subprocess.run( # nosec
['terraform', [
'apply', "terraform",
'-auto-approve', "apply",
'-json', "-auto-approve",
f'-refresh={str(refresh).lower()}', "-json",
f'-parallelism={str(parallelism)}', f"-refresh={str(refresh).lower()}",
f'-lock-timeout={str(lock_timeout)}m', f"-parallelism={str(parallelism)}",
], f"-lock-timeout={str(lock_timeout)}m",
],
cwd=working_dir, cwd=working_dir,
stdout=subprocess.PIPE) stdout=subprocess.PIPE,
return tfcmd.returncode, tfcmd.stdout.decode('utf-8') )
return tfcmd.returncode, tfcmd.stdout.decode("utf-8")
@abstractmethod @abstractmethod
def tf_generate(self) -> None: def tf_generate(self) -> None:
@ -91,41 +99,49 @@ class TerraformAutomation(BaseAutomation):
# the argument list as an array such that argument injection would be # the argument list as an array such that argument injection would be
# ineffective. # ineffective.
subprocess.run( # nosec subprocess.run( # nosec
['terraform', [
'init', "terraform",
f'-lock-timeout={str(lock_timeout)}m', "init",
], f"-lock-timeout={str(lock_timeout)}m",
cwd=self.working_dir) ],
cwd=self.working_dir,
)
def tf_output(self) -> Any: def tf_output(self) -> Any:
if not self.working_dir: if not self.working_dir:
raise RuntimeError("No working directory specified.") raise RuntimeError("No working directory specified.")
# The following subprocess call does not take any user input. # The following subprocess call does not take any user input.
tfcmd = subprocess.run( # nosec tfcmd = subprocess.run( # nosec
['terraform', 'output', '-json'], ["terraform", "output", "-json"],
cwd=self.working_dir, cwd=self.working_dir,
stdout=subprocess.PIPE) stdout=subprocess.PIPE,
)
return json.loads(tfcmd.stdout) return json.loads(tfcmd.stdout)
def tf_plan(self, *, def tf_plan(
refresh: bool = True, self,
parallelism: Optional[int] = None, *,
lock_timeout: int = 15) -> Tuple[int, str]: refresh: bool = True,
parallelism: Optional[int] = None,
lock_timeout: int = 15,
) -> Tuple[int, str]:
if not self.working_dir: if not self.working_dir:
raise RuntimeError("No working directory specified.") raise RuntimeError("No working directory specified.")
# The following subprocess call takes external input, but is providing # The following subprocess call takes external input, but is providing
# the argument list as an array such that argument injection would be # the argument list as an array such that argument injection would be
# ineffective. # ineffective.
tfcmd = subprocess.run( # nosec tfcmd = subprocess.run( # nosec
['terraform', [
'plan', "terraform",
'-json', "plan",
f'-refresh={str(refresh).lower()}', "-json",
f'-parallelism={str(parallelism)}', f"-refresh={str(refresh).lower()}",
f'-lock-timeout={str(lock_timeout)}m', f"-parallelism={str(parallelism)}",
], f"-lock-timeout={str(lock_timeout)}m",
cwd=self.working_dir) ],
return tfcmd.returncode, tfcmd.stdout.decode('utf-8') cwd=self.working_dir,
)
return tfcmd.returncode, tfcmd.stdout.decode("utf-8")
def tf_posthook(self, *, prehook_result: Any = None) -> None: def tf_posthook(self, *, prehook_result: Any = None) -> None:
""" """
@ -154,9 +170,8 @@ class TerraformAutomation(BaseAutomation):
raise RuntimeError("No working directory specified.") raise RuntimeError("No working directory specified.")
# This subprocess call doesn't take any user input. # This subprocess call doesn't take any user input.
terraform = subprocess.run( # nosec terraform = subprocess.run( # nosec
['terraform', 'show', '-json'], ["terraform", "show", "-json"], cwd=self.working_dir, stdout=subprocess.PIPE
cwd=self.working_dir, )
stdout=subprocess.PIPE)
return json.loads(terraform.stdout) return json.loads(terraform.stdout)
def tf_write(self, template: str, **kwargs: Any) -> None: def tf_write(self, template: str, **kwargs: Any) -> None:

View file

@ -9,7 +9,7 @@ from app.models.tfstate import TerraformState
tfstate = Blueprint("tfstate", __name__) tfstate = Blueprint("tfstate", __name__)
@tfstate.route("/<key>", methods=['GET']) @tfstate.route("/<key>", methods=["GET"])
def handle_get(key: str) -> ResponseReturnValue: def handle_get(key: str) -> ResponseReturnValue:
state = TerraformState.query.filter(TerraformState.key == key).first() state = TerraformState.query.filter(TerraformState.key == key).first()
if state is None or state.state is None: if state is None or state.state is None:
@ -17,16 +17,18 @@ def handle_get(key: str) -> ResponseReturnValue:
return Response(state.state, content_type="application/json") return Response(state.state, content_type="application/json")
@tfstate.route("/<key>", methods=['POST', 'DELETE', 'UNLOCK']) @tfstate.route("/<key>", methods=["POST", "DELETE", "UNLOCK"])
def handle_update(key: str) -> ResponseReturnValue: def handle_update(key: str) -> ResponseReturnValue:
state = TerraformState.query.filter(TerraformState.key == key).first() state = TerraformState.query.filter(TerraformState.key == key).first()
if not state: if not state:
if request.method in ["DELETE", "UNLOCK"]: if request.method in ["DELETE", "UNLOCK"]:
return "OK", 200 return "OK", 200
state = TerraformState(key=key) state = TerraformState(key=key)
if state.lock and not (request.method == "UNLOCK" and request.args.get('ID') is None): if state.lock and not (
request.method == "UNLOCK" and request.args.get("ID") is None
):
# force-unlock seems to not give an ID to verify so accept no ID being present # force-unlock seems to not give an ID to verify so accept no ID being present
if json.loads(state.lock)['ID'] != request.args.get('ID'): if json.loads(state.lock)["ID"] != request.args.get("ID"):
return Response(state.lock, status=409, content_type="application/json") return Response(state.lock, status=409, content_type="application/json")
if request.method == "POST": if request.method == "POST":
state.state = json.dumps(request.json) state.state = json.dumps(request.json)
@ -38,9 +40,11 @@ def handle_update(key: str) -> ResponseReturnValue:
return "OK", 200 return "OK", 200
@tfstate.route("/<key>", methods=['LOCK']) @tfstate.route("/<key>", methods=["LOCK"])
def handle_lock(key: str) -> ResponseReturnValue: def handle_lock(key: str) -> ResponseReturnValue:
state = TerraformState.query.filter(TerraformState.key == key).with_for_update().first() state = (
TerraformState.query.filter(TerraformState.key == key).with_for_update().first()
)
if state is None: if state is None:
state = TerraformState(key=key, state="") state = TerraformState(key=key, state="")
db.session.add(state) db.session.add(state)

View file

@ -20,8 +20,9 @@ def onion_hostname(onion_public_key: bytes) -> str:
return onion.lower() return onion.lower()
def decode_onion_keys(onion_private_key_base64: str, onion_public_key_base64: str) -> Tuple[ def decode_onion_keys(
Optional[bytes], Optional[bytes], List[Dict[str, str]]]: onion_private_key_base64: str, onion_public_key_base64: str
) -> Tuple[Optional[bytes], Optional[bytes], List[Dict[str, str]]]:
try: try:
onion_private_key = base64.b64decode(onion_private_key_base64) onion_private_key = base64.b64decode(onion_private_key_base64)
onion_public_key = base64.b64decode(onion_public_key_base64) onion_public_key = base64.b64decode(onion_public_key_base64)

View file

@ -22,14 +22,20 @@ def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]:
return certificates return certificates
def build_certificate_chain(certificates: list[x509.Certificate]) -> list[x509.Certificate]: def build_certificate_chain(
certificates: list[x509.Certificate],
) -> list[x509.Certificate]:
if len(certificates) == 1: if len(certificates) == 1:
return certificates return certificates
chain = [] chain = []
cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates} cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates}
end_entity = next( end_entity = next(
(cert for cert in certificates if cert.subject.rfc4514_string() not in cert_map), (
None cert
for cert in certificates
if cert.subject.rfc4514_string() not in cert_map
),
None,
) )
if not end_entity: if not end_entity:
raise ValueError("Cannot identify the end-entity certificate.") raise ValueError("Cannot identify the end-entity certificate.")
@ -51,7 +57,9 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
for i in range(len(chain) - 1): for i in range(len(chain) - 1):
next_public_key = chain[i + 1].public_key() next_public_key = chain[i + 1].public_key()
if not (isinstance(next_public_key, RSAPublicKey)): if not (isinstance(next_public_key, RSAPublicKey)):
raise ValueError(f"Certificate using unsupported algorithm: {type(next_public_key)}") raise ValueError(
f"Certificate using unsupported algorithm: {type(next_public_key)}"
)
hash_algorithm = chain[i].signature_hash_algorithm hash_algorithm = chain[i].signature_hash_algorithm
if hash_algorithm is None: if hash_algorithm is None:
raise ValueError("Certificate missing hash algorithm") raise ValueError("Certificate missing hash algorithm")
@ -59,23 +67,23 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
chain[i].signature, chain[i].signature,
chain[i].tbs_certificate_bytes, chain[i].tbs_certificate_bytes,
PKCS1v15(), PKCS1v15(),
hash_algorithm hash_algorithm,
) )
end_cert = chain[-1] end_cert = chain[-1]
if not any( if not any(
end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates end_cert.issuer == trusted_cert.subject for trusted_cert in trusted_certificates
): ):
raise ValueError("Certificate chain does not terminate at a trusted root CA.") raise ValueError("Certificate chain does not terminate at a trusted root CA.")
return True return True
def validate_tls_keys( def validate_tls_keys(
tls_private_key_pem: Optional[str], tls_private_key_pem: Optional[str],
tls_certificate_pem: Optional[str], tls_certificate_pem: Optional[str],
skip_chain_verification: Optional[bool], skip_chain_verification: Optional[bool],
skip_name_verification: Optional[bool], skip_name_verification: Optional[bool],
hostname: str hostname: str,
) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]: ) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]:
errors = [] errors = []
san_list = [] san_list = []
@ -90,31 +98,55 @@ def validate_tls_keys(
private_key = serialization.load_pem_private_key( private_key = serialization.load_pem_private_key(
tls_private_key_pem.encode("utf-8"), tls_private_key_pem.encode("utf-8"),
password=None, password=None,
backend=default_backend() backend=default_backend(),
) )
if not isinstance(private_key, rsa.RSAPrivateKey): if not isinstance(private_key, rsa.RSAPrivateKey):
errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be RSA."}) errors.append(
{
"Error": "tls_private_key_invalid",
"Message": "Private key must be RSA.",
}
)
if tls_certificate_pem: if tls_certificate_pem:
certificates = list(load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))) certificates = list(
load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))
)
if not certificates: if not certificates:
errors.append({"Error": "tls_certificate_invalid", "Message": "No valid certificate found."}) errors.append(
{
"Error": "tls_certificate_invalid",
"Message": "No valid certificate found.",
}
)
else: else:
chain = build_certificate_chain(certificates) chain = build_certificate_chain(certificates)
end_entity_cert = chain[0] end_entity_cert = chain[0]
if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc): if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."}) errors.append(
{
"Error": "tls_public_key_expired",
"Message": "TLS public key is expired.",
}
)
if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc): if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."}) errors.append(
{
"Error": "tls_public_key_future",
"Message": "TLS public key is not yet valid.",
}
)
if private_key: if private_key:
public_key = end_entity_cert.public_key() public_key = end_entity_cert.public_key()
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101 assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101
assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101 assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101
assert end_entity_cert.signature_hash_algorithm is not None # nosec: B101 assert (
end_entity_cert.signature_hash_algorithm is not None
) # nosec: B101
try: try:
test_message = b"test" test_message = b"test"
signature = private_key.sign( signature = private_key.sign(
@ -130,20 +162,30 @@ def validate_tls_keys(
) )
except Exception: except Exception:
errors.append( errors.append(
{"Error": "tls_key_mismatch", "Message": "Private key does not match certificate."}) {
"Error": "tls_key_mismatch",
"Message": "Private key does not match certificate.",
}
)
if not skip_chain_verification: if not skip_chain_verification:
try: try:
validate_certificate_chain(chain) validate_certificate_chain(chain)
except ValueError as e: except ValueError as e:
errors.append({"Error": "certificate_chain_invalid", "Message": str(e)}) errors.append(
{"Error": "certificate_chain_invalid", "Message": str(e)}
)
if not skip_name_verification: if not skip_name_verification:
san_list = extract_sans(end_entity_cert) san_list = extract_sans(end_entity_cert)
for expected_hostname in [hostname, f"*.{hostname}"]: for expected_hostname in [hostname, f"*.{hostname}"]:
if expected_hostname not in san_list: if expected_hostname not in san_list:
errors.append( errors.append(
{"Error": "hostname_not_in_san", "Message": f"{expected_hostname} not found in SANs."}) {
"Error": "hostname_not_in_san",
"Message": f"{expected_hostname} not found in SANs.",
}
)
except Exception as e: except Exception as e:
errors.append({"Error": "tls_validation_error", "Message": str(e)}) errors.append({"Error": "tls_validation_error", "Message": str(e)})
@ -153,7 +195,9 @@ def validate_tls_keys(
def extract_sans(cert: x509.Certificate) -> List[str]: def extract_sans(cert: x509.Certificate) -> List[str]:
try: try:
san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) san_extension = cert.extensions.get_extension_for_oid(
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
)
sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined] sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
return sans return sans
except Exception: except Exception:

View file

@ -1,2 +1,2 @@
[flake8] [flake8]
ignore = E501,W503 ignore = E203,E501,W503