lint: reformat python code with black

This commit is contained in:
Iain Learmonth 2024-12-06 18:15:47 +00:00
parent 331beb01b4
commit a406a7974b
88 changed files with 2579 additions and 1608 deletions

View file

@ -6,8 +6,7 @@ import yaml
from flask import Flask, redirect, send_from_directory, url_for
from flask.typing import ResponseReturnValue
from prometheus_client import CollectorRegistry, Metric, make_wsgi_app
from prometheus_client.metrics_core import (CounterMetricFamily,
GaugeMetricFamily)
from prometheus_client.metrics_core import CounterMetricFamily, GaugeMetricFamily
from prometheus_client.registry import REGISTRY, Collector
from prometheus_flask_exporter import PrometheusMetrics
from sqlalchemy import text
@ -28,9 +27,9 @@ app.config.from_file("../config.yaml", load=yaml.safe_load)
registry = CollectorRegistry()
metrics = PrometheusMetrics(app, registry=registry)
app.wsgi_app = DispatcherMiddleware(app.wsgi_app, { # type: ignore[method-assign]
'/metrics': make_wsgi_app(registry)
})
app.wsgi_app = DispatcherMiddleware( # type: ignore[method-assign]
app.wsgi_app, {"/metrics": make_wsgi_app(registry)}
)
# register default collectors to our new registry
collectors = list(REGISTRY._collector_to_names.keys())
@ -54,12 +53,16 @@ def not_migrating() -> bool:
class DefinedProxiesCollector(Collector):
def collect(self) -> Iterator[Metric]:
with app.app_context():
ok = GaugeMetricFamily("database_collector",
ok = GaugeMetricFamily(
"database_collector",
"Status of a database collector (0: bad, 1: good)",
labels=["collector"])
labels=["collector"],
)
try:
with db.engine.connect() as conn:
result = conn.execute(text("""
result = conn.execute(
text(
"""
SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name,
COUNT(proxy.id) FROM proxy, origin, pool, "group"
WHERE proxy.origin_id = origin.id
@ -67,13 +70,24 @@ class DefinedProxiesCollector(Collector):
AND proxy.pool_id = pool.id
AND proxy.destroyed IS NULL
GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name;
"""))
c = GaugeMetricFamily("defined_proxies", "Number of proxies currently defined for deployment",
labels=['group_id', 'group_name', 'provider', 'pool_id',
'pool_name'])
"""
)
)
c = GaugeMetricFamily(
"defined_proxies",
"Number of proxies currently defined for deployment",
labels=[
"group_id",
"group_name",
"provider",
"pool_id",
"pool_name",
],
)
for row in result:
c.add_metric([str(row[0]), row[1], row[2], str(row[3]), row[4]],
row[5])
c.add_metric(
[str(row[0]), row[1], row[2], str(row[3]), row[4]], row[5]
)
yield c
ok.add_metric(["defined_proxies"], 1)
except SQLAlchemyError:
@ -84,12 +98,16 @@ class DefinedProxiesCollector(Collector):
class BlockedProxiesCollector(Collector):
def collect(self) -> Iterator[Metric]:
with app.app_context():
ok = GaugeMetricFamily("database_collector",
ok = GaugeMetricFamily(
"database_collector",
"Status of a database collector (0: bad, 1: good)",
labels=["collector"])
labels=["collector"],
)
try:
with db.engine.connect() as conn:
result = conn.execute(text("""
result = conn.execute(
text(
"""
SELECT origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name,
proxy.deprecation_reason, COUNT(proxy.id) FROM proxy, origin, pool, "group"
WHERE proxy.origin_id = origin.id
@ -98,14 +116,26 @@ class BlockedProxiesCollector(Collector):
AND proxy.deprecated IS NOT NULL
GROUP BY origin.group_id, "group".group_name, proxy.provider, proxy.pool_id, pool.pool_name,
proxy.deprecation_reason;
"""))
c = CounterMetricFamily("deprecated_proxies",
"""
)
)
c = CounterMetricFamily(
"deprecated_proxies",
"Number of proxies deprecated",
labels=['group_id', 'group_name', 'provider', 'pool_id', 'pool_name',
'deprecation_reason'])
labels=[
"group_id",
"group_name",
"provider",
"pool_id",
"pool_name",
"deprecation_reason",
],
)
for row in result:
c.add_metric([str(row[0]), row[1], row[2], str(row[3]), row[4], row[5]],
row[6])
c.add_metric(
[str(row[0]), row[1], row[2], str(row[3]), row[4], row[5]],
row[6],
)
yield c
ok.add_metric(["deprecated_proxies"], 0)
except SQLAlchemyError:
@ -116,24 +146,36 @@ class BlockedProxiesCollector(Collector):
class AutomationCollector(Collector):
def collect(self) -> Iterator[Metric]:
with app.app_context():
ok = GaugeMetricFamily("database_collector",
ok = GaugeMetricFamily(
"database_collector",
"Status of a database collector (0: bad, 1: good)",
labels=["collector"])
labels=["collector"],
)
try:
state = GaugeMetricFamily("automation_state", "The automation state (0: idle, 1: running, 2: error)",
labels=['automation_name'])
enabled = GaugeMetricFamily("automation_enabled",
state = GaugeMetricFamily(
"automation_state",
"The automation state (0: idle, 1: running, 2: error)",
labels=["automation_name"],
)
enabled = GaugeMetricFamily(
"automation_enabled",
"Whether an automation is enabled (0: disabled, 1: enabled)",
labels=['automation_name'])
next_run = GaugeMetricFamily("automation_next_run", "The timestamp of the next run of the automation",
labels=['automation_name'])
last_run_start = GaugeMetricFamily("automation_last_run_start",
labels=["automation_name"],
)
next_run = GaugeMetricFamily(
"automation_next_run",
"The timestamp of the next run of the automation",
labels=["automation_name"],
)
last_run_start = GaugeMetricFamily(
"automation_last_run_start",
"The timestamp of the last run of the automation ",
labels=['automation_name'])
labels=["automation_name"],
)
automations = Automation.query.all()
for automation in automations:
if automation.short_name in app.config['HIDDEN_AUTOMATIONS']:
if automation.short_name in app.config["HIDDEN_AUTOMATIONS"]:
continue
if automation.state == AutomationState.IDLE:
state.add_metric([automation.short_name], 0)
@ -141,13 +183,19 @@ class AutomationCollector(Collector):
state.add_metric([automation.short_name], 1)
else:
state.add_metric([automation.short_name], 2)
enabled.add_metric([automation.short_name], 1 if automation.enabled else 0)
enabled.add_metric(
[automation.short_name], 1 if automation.enabled else 0
)
if automation.next_run:
next_run.add_metric([automation.short_name], automation.next_run.timestamp())
next_run.add_metric(
[automation.short_name], automation.next_run.timestamp()
)
else:
next_run.add_metric([automation.short_name], 0)
if automation.last_run:
last_run_start.add_metric([automation.short_name], automation.last_run.timestamp())
last_run_start.add_metric(
[automation.short_name], automation.last_run.timestamp()
)
else:
last_run_start.add_metric([automation.short_name], 0)
yield state
@ -161,31 +209,31 @@ class AutomationCollector(Collector):
# register all custom collectors to registry
if not_migrating() and 'DISABLE_METRICS' not in os.environ:
if not_migrating() and "DISABLE_METRICS" not in os.environ:
registry.register(DefinedProxiesCollector())
registry.register(BlockedProxiesCollector())
registry.register(AutomationCollector())
@app.route('/ui')
@app.route("/ui")
def redirect_ui() -> ResponseReturnValue:
return redirect("/ui/")
@app.route('/ui/', defaults={'path': ''})
@app.route('/ui/<path:path>')
@app.route("/ui/", defaults={"path": ""})
@app.route("/ui/<path:path>")
def serve_ui(path: str) -> ResponseReturnValue:
if path != "" and os.path.exists("app/static/ui/" + path):
return send_from_directory('static/ui', path)
return send_from_directory("static/ui", path)
else:
return send_from_directory('static/ui', 'index.html')
return send_from_directory("static/ui", "index.html")
@app.route('/')
@app.route("/")
def index() -> ResponseReturnValue:
# TODO: update to point at new UI when ready
return redirect(url_for("portal.portal_home"))
if __name__ == '__main__':
if __name__ == "__main__":
app.run()

View file

@ -7,18 +7,15 @@ from app.models.alarms import Alarm
def alarms_for(target: BRN) -> List[Alarm]:
return list(Alarm.query.filter(
Alarm.target == str(target)
).all())
return list(Alarm.query.filter(Alarm.target == str(target)).all())
def _get_alarm(target: BRN,
aspect: str,
create_if_missing: bool = True) -> Optional[Alarm]:
def _get_alarm(
target: BRN, aspect: str, create_if_missing: bool = True
) -> Optional[Alarm]:
target_str = str(target)
alarm: Optional[Alarm] = Alarm.query.filter(
Alarm.aspect == aspect,
Alarm.target == target_str
Alarm.aspect == aspect, Alarm.target == target_str
).first()
if create_if_missing and alarm is None:
alarm = Alarm()

View file

@ -5,34 +5,38 @@ from werkzeug.exceptions import HTTPException
from app.api.onion import api_onion
from app.api.web import api_web
api = Blueprint('api', __name__)
api.register_blueprint(api_onion, url_prefix='/onion')
api.register_blueprint(api_web, url_prefix='/web')
api = Blueprint("api", __name__)
api.register_blueprint(api_onion, url_prefix="/onion")
api.register_blueprint(api_web, url_prefix="/web")
@api.errorhandler(400)
def bad_request(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Bad Request', 'message': error.description})
response = jsonify({"error": "Bad Request", "message": error.description})
response.status_code = 400
return response
@api.errorhandler(401)
def unauthorized(error: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Unauthorized', 'message': error.description})
response = jsonify({"error": "Unauthorized", "message": error.description})
response.status_code = 401
return response
@api.errorhandler(404)
def not_found(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Not found', 'message': 'Resource could not be found.'})
response = jsonify(
{"error": "Not found", "message": "Resource could not be found."}
)
response.status_code = 404
return response
@api.errorhandler(500)
def internal_server_error(_: HTTPException) -> ResponseReturnValue:
response = jsonify({'error': 'Internal Server Error', 'message': 'An unexpected error occurred.'})
response = jsonify(
{"error": "Internal Server Error", "message": "An unexpected error occurred."}
)
response.status_code = 500
return response

View file

@ -7,31 +7,37 @@ from flask import Blueprint, abort, jsonify, request
from flask.typing import ResponseReturnValue
from sqlalchemy import exc
from app.api.util import (DOMAIN_NAME_REGEX, MAX_ALLOWED_ITEMS,
MAX_DOMAIN_NAME_LENGTH, ListFilter,
get_single_resource, list_resources,
validate_description)
from app.api.util import (
DOMAIN_NAME_REGEX,
MAX_ALLOWED_ITEMS,
MAX_DOMAIN_NAME_LENGTH,
ListFilter,
get_single_resource,
list_resources,
validate_description,
)
from app.extensions import db
from app.models.base import Group
from app.models.onions import Onion
from app.util.onion import decode_onion_keys, onion_hostname
from app.util.x509 import validate_tls_keys
api_onion = Blueprint('api_onion', __name__)
api_onion = Blueprint("api_onion", __name__)
@api_onion.route('/onion', methods=['GET'])
@api_onion.route("/onion", methods=["GET"])
def list_onions() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
domain_name_filter = request.args.get("DomainName")
group_id_filter = request.args.get("GroupId")
filters: List[ListFilter] = [
(Onion.destroyed.is_(None))
]
filters: List[ListFilter] = [(Onion.destroyed.is_(None))]
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.")
abort(
400,
description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.",
)
if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.")
filters.append(Onion.domain_name.ilike(f"%{domain_name_filter}%"))
@ -46,9 +52,9 @@ def list_onions() -> ResponseReturnValue:
Onion,
lambda onion: onion.to_dict(),
filters=filters,
resource_name='OnionsList',
resource_name="OnionsList",
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
protective_marking="amber",
)
@ -71,13 +77,26 @@ def create_onion() -> ResponseReturnValue:
abort(400)
errors = []
for field in ["DomainName", "Description", "OnionPrivateKey", "OnionPublicKey", "GroupId", "TlsPrivateKey",
"TlsCertificate"]:
for field in [
"DomainName",
"Description",
"OnionPrivateKey",
"OnionPublicKey",
"GroupId",
"TlsPrivateKey",
"TlsCertificate",
]:
if not data.get(field):
errors.append({"Error": f"{field}_missing", "Message": f"Missing required field: {field}"})
errors.append(
{
"Error": f"{field}_missing",
"Message": f"Missing required field: {field}",
}
)
onion_private_key, onion_public_key, onion_errors = decode_onion_keys(data["OnionPrivateKey"],
data["OnionPublicKey"])
onion_private_key, onion_public_key, onion_errors = decode_onion_keys(
data["OnionPrivateKey"], data["OnionPublicKey"]
)
if onion_errors:
errors.extend(onion_errors)
@ -85,23 +104,35 @@ def create_onion() -> ResponseReturnValue:
return jsonify({"Errors": errors}), 400
if onion_private_key:
existing_onion = db.session.query(Onion).where(
existing_onion = (
db.session.query(Onion)
.where(
Onion.onion_private_key == onion_private_key,
Onion.destroyed.is_(None),
).first()
)
.first()
)
if existing_onion:
errors.append(
{"Error": "duplicate_onion_key", "Message": "An onion service with this private key already exists."})
{
"Error": "duplicate_onion_key",
"Message": "An onion service with this private key already exists.",
}
)
if "GroupId" in data:
group = Group.query.get(data["GroupId"])
if not group:
errors.append({"Error": "group_id_not_found", "Message": "Invalid group ID."})
errors.append(
{"Error": "group_id_not_found", "Message": "Invalid group ID."}
)
chain, san_list, tls_errors = validate_tls_keys(
data["TlsPrivateKey"], data["TlsCertificate"], data.get("SkipChainVerification"),
data["TlsPrivateKey"],
data["TlsCertificate"],
data.get("SkipChainVerification"),
data.get("SkipNameVerification"),
f"{onion_hostname(onion_public_key)}.onion"
f"{onion_hostname(onion_public_key)}.onion",
)
if tls_errors:
@ -123,15 +154,21 @@ def create_onion() -> ResponseReturnValue:
added=datetime.now(timezone.utc),
updated=datetime.now(timezone.utc),
cert_expiry=cert_expiry_date,
cert_sans=",".join(san_list)
cert_sans=",".join(san_list),
)
try:
db.session.add(onion)
db.session.commit()
return jsonify({"Message": "Onion service created successfully.", "Id": onion.id}), 201
return (
jsonify({"Message": "Onion service created successfully.", "Id": onion.id}),
201,
)
except exc.SQLAlchemyError as e:
return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500
return (
jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}),
500,
)
class UpdateOnionRequest(TypedDict):
@ -152,8 +189,19 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
onion = Onion.query.get(onion_id)
if not onion:
return jsonify(
{"Errors": [{"Error": "onion_not_found", "Message": f"No Onion service found with ID {onion_id}"}]}), 404
return (
jsonify(
{
"Errors": [
{
"Error": "onion_not_found",
"Message": f"No Onion service found with ID {onion_id}",
}
]
}
),
404,
)
if "Description" in data:
description = data["Description"]
@ -161,7 +209,12 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
if validate_description(description):
onion.description = description
else:
errors.append({"Error": "description_error", "Message": "Description field is invalid"})
errors.append(
{
"Error": "description_error",
"Message": "Description field is invalid",
}
)
tls_private_key_pem: Optional[str] = None
tls_certificate_pem: Optional[str] = None
@ -176,7 +229,9 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
tls_private_key_pem = onion.tls_private_key.decode("utf-8")
chain, san_list, tls_errors = validate_tls_keys(
tls_private_key_pem, tls_certificate_pem, data.get("SkipChainVerification", False),
tls_private_key_pem,
tls_certificate_pem,
data.get("SkipChainVerification", False),
data.get("SkipNameVerification", False),
f"{onion_hostname(onion.onion_public_key)}.onion",
)
@ -200,7 +255,10 @@ def update_onion(onion_id: int) -> ResponseReturnValue:
db.session.commit()
return jsonify({"Message": "Onion service updated successfully."}), 200
except exc.SQLAlchemyError as e:
return jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}), 500
return (
jsonify({"Errors": [{"Error": "database_error", "Message": str(e)}]}),
500,
)
@api_onion.route("/onion/<int:onion_id>", methods=["GET"])

View file

@ -12,7 +12,7 @@ from app.extensions import db
logger = logging.getLogger(__name__)
MAX_DOMAIN_NAME_LENGTH = 255
DOMAIN_NAME_REGEX = re.compile(r'^[a-zA-Z0-9.\-]*$')
DOMAIN_NAME_REGEX = re.compile(r"^[a-zA-Z0-9.\-]*$")
MAX_ALLOWED_ITEMS = 100
ListFilter = Union[BinaryExpression[Any], ColumnElement[Any]]
@ -24,7 +24,10 @@ def validate_max_items(max_items_str: str, max_allowed: int) -> int:
raise ValueError()
return max_items
except ValueError:
abort(400, description=f"MaxItems must be a positive integer not exceeding {max_allowed}.")
abort(
400,
description=f"MaxItems must be a positive integer not exceeding {max_allowed}.",
)
def validate_marker(marker_str: str) -> int:
@ -52,16 +55,17 @@ def list_resources( # pylint: disable=too-many-arguments,too-many-locals
*,
filters: Optional[List[ListFilter]] = None,
order_by: Optional[ColumnElement[Any]] = None,
resource_name: str = 'ResourceList',
max_items_param: str = 'MaxItems',
marker_param: str = 'Marker',
resource_name: str = "ResourceList",
max_items_param: str = "MaxItems",
marker_param: str = "Marker",
max_allowed_items: int = 100,
protective_marking: TlpMarkings = 'default',
protective_marking: TlpMarkings = "default",
) -> ResponseReturnValue:
try:
marker = request.args.get(marker_param)
max_items = validate_max_items(
request.args.get(max_items_param, default='100'), max_allowed_items)
request.args.get(max_items_param, default="100"), max_allowed_items
)
query = select(model)
if filters:
@ -101,14 +105,21 @@ def list_resources( # pylint: disable=too-many-arguments,too-many-locals
abort(500)
def get_single_resource(model: Type[Any], id_: int, resource_name: str) -> ResponseReturnValue:
def get_single_resource(
model: Type[Any], id_: int, resource_name: str
) -> ResponseReturnValue:
try:
resource = db.session.get(model, id_)
if not resource:
return jsonify({
return (
jsonify(
{
"Error": "resource_not_found",
"Message": f"No {resource_name} found with ID {id_}"
}), 404
"Message": f"No {resource_name} found with ID {id_}",
}
),
404,
)
return jsonify({resource_name: resource.to_dict()}), 200
except Exception: # pylint: disable=broad-exception-caught
logger.exception("An unexpected error occurred while retrieving the onion")

View file

@ -4,35 +4,43 @@ from typing import List
from flask import Blueprint, abort, request
from flask.typing import ResponseReturnValue
from app.api.util import (DOMAIN_NAME_REGEX, MAX_ALLOWED_ITEMS,
MAX_DOMAIN_NAME_LENGTH, ListFilter, list_resources)
from app.api.util import (
DOMAIN_NAME_REGEX,
MAX_ALLOWED_ITEMS,
MAX_DOMAIN_NAME_LENGTH,
ListFilter,
list_resources,
)
from app.models.base import Group
from app.models.mirrors import Origin, Proxy
api_web = Blueprint('web', __name__)
api_web = Blueprint("web", __name__)
@api_web.route('/group', methods=['GET'])
@api_web.route("/group", methods=["GET"])
def list_groups() -> ResponseReturnValue:
return list_resources(
Group,
lambda group: group.to_dict(),
resource_name='OriginGroupList',
resource_name="OriginGroupList",
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
protective_marking="amber",
)
@api_web.route('/origin', methods=['GET'])
@api_web.route("/origin", methods=["GET"])
def list_origins() -> ResponseReturnValue:
domain_name_filter = request.args.get('DomainName')
group_id_filter = request.args.get('GroupId')
domain_name_filter = request.args.get("DomainName")
group_id_filter = request.args.get("GroupId")
filters: List[ListFilter] = []
if domain_name_filter:
if len(domain_name_filter) > MAX_DOMAIN_NAME_LENGTH:
abort(400, description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.")
abort(
400,
description=f"DomainName cannot exceed {MAX_DOMAIN_NAME_LENGTH} characters.",
)
if not DOMAIN_NAME_REGEX.match(domain_name_filter):
abort(400, description="DomainName contains invalid characters.")
filters.append(Origin.domain_name.ilike(f"%{domain_name_filter}%"))
@ -47,18 +55,18 @@ def list_origins() -> ResponseReturnValue:
Origin,
lambda origin: origin.to_dict(),
filters=filters,
resource_name='OriginsList',
resource_name="OriginsList",
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
protective_marking="amber",
)
@api_web.route('/mirror', methods=['GET'])
@api_web.route("/mirror", methods=["GET"])
def list_mirrors() -> ResponseReturnValue:
filters = []
twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24)
status_filter = request.args.get('Status')
status_filter = request.args.get("Status")
if status_filter:
if status_filter == "pending":
filters.append(Proxy.url.is_(None))
@ -74,13 +82,15 @@ def list_mirrors() -> ResponseReturnValue:
if status_filter == "destroyed":
filters.append(Proxy.destroyed > twenty_four_hours_ago)
else:
filters.append((Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago))
filters.append(
(Proxy.destroyed.is_(None)) | (Proxy.destroyed > twenty_four_hours_ago)
)
return list_resources(
Proxy,
lambda proxy: proxy.to_dict(),
filters=filters,
resource_name='MirrorsList',
resource_name="MirrorsList",
max_allowed_items=MAX_ALLOWED_ITEMS,
protective_marking='amber',
protective_marking="amber",
)

View file

@ -29,31 +29,37 @@ class BRN:
def from_str(cls, string: str) -> BRN:
parts = string.split(":")
if len(parts) != 6 or parts[0].lower() != "brn" or not is_integer(parts[2]):
raise TypeError(f"Expected a valid BRN but got {repr(string)} (invalid parts).")
raise TypeError(
f"Expected a valid BRN but got {repr(string)} (invalid parts)."
)
resource_parts = parts[5].split("/")
if len(resource_parts) != 2:
raise TypeError(f"Expected a valid BRN but got {repr(string)} (invalid resource parts).")
raise TypeError(
f"Expected a valid BRN but got {repr(string)} (invalid resource parts)."
)
return cls(
global_namespace=parts[1],
group_id=int(parts[2]),
product=parts[3],
provider=parts[4],
resource_type=resource_parts[0],
resource_id=resource_parts[1]
resource_id=resource_parts[1],
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def __str__(self) -> str:
return ":".join([
return ":".join(
[
"brn",
self.global_namespace,
str(self.group_id),
self.product,
self.provider,
f"{self.resource_type}/{self.resource_id}"
])
f"{self.resource_type}/{self.resource_id}",
]
)
def __repr__(self) -> str:
return f"<BRN {str(self)}>"

View file

@ -47,14 +47,18 @@ def create_static_origin(
else:
raise ValueError("group_id must be an int")
if isinstance(storage_cloud_account_id, int):
cloud_account = CloudAccount.query.filter(CloudAccount.id == storage_cloud_account_id).first()
cloud_account = CloudAccount.query.filter(
CloudAccount.id == storage_cloud_account_id
).first()
if cloud_account is None:
raise ValueError("storage_cloud_account_id must match an existing provider")
static_origin.storage_cloud_account_id = storage_cloud_account_id
else:
raise ValueError("storage_cloud_account_id must be an int")
if isinstance(source_cloud_account_id, int):
cloud_account = CloudAccount.query.filter(CloudAccount.id == source_cloud_account_id).first()
cloud_account = CloudAccount.query.filter(
CloudAccount.id == source_cloud_account_id
).first()
if cloud_account is None:
raise ValueError("source_cloud_account_id must match an existing provider")
static_origin.source_cloud_account_id = source_cloud_account_id
@ -69,7 +73,7 @@ def create_static_origin(
keanu_convene_logo,
keanu_convene_color,
clean_insights_backend,
False
False,
)
if db_session_commit:
db.session.add(static_origin)

View file

@ -26,7 +26,9 @@ def is_integer(contender: Any) -> bool:
return float(contender).is_integer()
def thumbnail_uploaded_image(file: FileStorage, max_size: Tuple[int, int] = (256, 256)) -> bytes:
def thumbnail_uploaded_image(
file: FileStorage, max_size: Tuple[int, int] = (256, 256)
) -> bytes:
"""
Process an uploaded image file into a resized image of a specific size.
@ -39,7 +41,9 @@ def thumbnail_uploaded_image(file: FileStorage, max_size: Tuple[int, int] = (256
img = Image.open(file)
img.thumbnail(max_size)
byte_arr = BytesIO()
img.save(byte_arr, format='PNG' if file.filename.lower().endswith('.png') else 'JPEG')
img.save(
byte_arr, format="PNG" if file.filename.lower().endswith(".png") else "JPEG"
)
return byte_arr.getvalue()
@ -52,9 +56,11 @@ def create_data_uri(bytes_data: bytes, file_extension: str) -> str:
:return: A data URI representing the image.
"""
# base64 encode
encoded = base64.b64encode(bytes_data).decode('ascii')
encoded = base64.b64encode(bytes_data).decode("ascii")
# create data URI
data_uri = "data:image/{};base64,{}".format('jpeg' if file_extension == 'jpg' else file_extension, encoded)
data_uri = "data:image/{};base64,{}".format(
"jpeg" if file_extension == "jpg" else file_extension, encoded
)
return data_uri
@ -80,7 +86,7 @@ def normalize_color(color: str) -> str:
return webcolors.name_to_hex(color) # type: ignore[no-any-return]
except ValueError:
pass
if color.startswith('#'):
if color.startswith("#"):
color = color[1:].lower()
if len(color) in [3, 6]:
try:

View file

@ -3,7 +3,9 @@ from abc import abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
_SubparserType = argparse._SubParsersAction[argparse.ArgumentParser] # pylint: disable=protected-access
_SubparserType = argparse._SubParsersAction[
argparse.ArgumentParser
] # pylint: disable=protected-access
else:
_SubparserType = Any

View file

@ -13,7 +13,9 @@ def parse_args(argv: List[str]) -> None:
if basename(argv[0]) == "__main__.py":
argv[0] = "bypass"
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--verbose", help="increase logging verbosity", action="store_true")
parser.add_argument(
"-v", "--verbose", help="increase logging verbosity", action="store_true"
)
subparsers = parser.add_subparsers(title="command", help="command to run")
AutomateCliHandler.add_subparser_to(subparsers)
DbCliHandler.add_subparser_to(subparsers)
@ -28,7 +30,6 @@ def parse_args(argv: List[str]) -> None:
if __name__ == "__main__":
VERBOSE = "-v" in sys.argv or "--verbose" in sys.argv
logging.basicConfig(
level=logging.DEBUG if VERBOSE else logging.INFO)
logging.basicConfig(level=logging.DEBUG if VERBOSE else logging.INFO)
logging.debug("Arguments: %s", sys.argv)
parse_args(sys.argv)

View file

@ -14,18 +14,14 @@ from app.models.automation import Automation, AutomationLogs, AutomationState
from app.terraform import BaseAutomation
from app.terraform.alarms.eotk_aws import AlarmEotkAwsAutomation
from app.terraform.alarms.proxy_azure_cdn import AlarmProxyAzureCdnAutomation
from app.terraform.alarms.proxy_cloudfront import \
AlarmProxyCloudfrontAutomation
from app.terraform.alarms.proxy_http_status import \
AlarmProxyHTTPStatusAutomation
from app.terraform.alarms.proxy_cloudfront import AlarmProxyCloudfrontAutomation
from app.terraform.alarms.proxy_http_status import AlarmProxyHTTPStatusAutomation
from app.terraform.alarms.smart_aws import AlarmSmartAwsAutomation
from app.terraform.block.block_blocky import BlockBlockyAutomation
from app.terraform.block.block_scriptzteam import \
BlockBridgeScriptzteamAutomation
from app.terraform.block.block_scriptzteam import BlockBridgeScriptzteamAutomation
from app.terraform.block.bridge_github import BlockBridgeGitHubAutomation
from app.terraform.block.bridge_gitlab import BlockBridgeGitlabAutomation
from app.terraform.block.bridge_roskomsvoboda import \
BlockBridgeRoskomsvobodaAutomation
from app.terraform.block.bridge_roskomsvoboda import BlockBridgeRoskomsvobodaAutomation
from app.terraform.block_external import BlockExternalAutomation
from app.terraform.block_ooni import BlockOONIAutomation
from app.terraform.block_roskomsvoboda import BlockRoskomsvobodaAutomation
@ -58,12 +54,10 @@ jobs = {
BlockExternalAutomation,
BlockOONIAutomation,
BlockRoskomsvobodaAutomation,
# Create new resources
BridgeMetaAutomation,
StaticMetaAutomation,
ProxyMetaAutomation,
# Terraform
BridgeAWSAutomation,
BridgeGandiAutomation,
@ -74,14 +68,12 @@ jobs = {
ProxyAzureCdnAutomation,
ProxyCloudfrontAutomation,
ProxyFastlyAutomation,
# Import alarms
AlarmEotkAwsAutomation,
AlarmProxyAzureCdnAutomation,
AlarmProxyCloudfrontAutomation,
AlarmProxyHTTPStatusAutomation,
AlarmSmartAwsAutomation,
# Update lists
ListGithubAutomation,
ListGitlabAutomation,
@ -103,9 +95,12 @@ def run_all(**kwargs: bool) -> None:
run_job(job, **kwargs)
def run_job(job_cls: Type[BaseAutomation], *,
force: bool = False, ignore_schedule: bool = False) -> None:
automation = Automation.query.filter(Automation.short_name == job_cls.short_name).first()
def run_job(
job_cls: Type[BaseAutomation], *, force: bool = False, ignore_schedule: bool = False
) -> None:
automation = Automation.query.filter(
Automation.short_name == job_cls.short_name
).first()
if automation is None:
automation = Automation()
automation.short_name = job_cls.short_name
@ -121,18 +116,24 @@ def run_job(job_cls: Type[BaseAutomation], *,
logging.warning("Not running an already running automation")
return
if not ignore_schedule and not force:
if automation.next_run is not None and automation.next_run > datetime.now(tz=timezone.utc):
if automation.next_run is not None and automation.next_run > datetime.now(
tz=timezone.utc
):
logging.warning("Not time to run this job yet")
return
if not automation.enabled and not force:
logging.warning("job %s is disabled and --force not specified", job_cls.short_name)
logging.warning(
"job %s is disabled and --force not specified", job_cls.short_name
)
return
automation.state = AutomationState.RUNNING
db.session.commit()
try:
if 'TERRAFORM_DIRECTORY' in app.config:
working_dir = os.path.join(app.config['TERRAFORM_DIRECTORY'],
job_cls.short_name or job_cls.__class__.__name__.lower())
if "TERRAFORM_DIRECTORY" in app.config:
working_dir = os.path.join(
app.config["TERRAFORM_DIRECTORY"],
job_cls.short_name or job_cls.__class__.__name__.lower(),
)
else:
working_dir = tempfile.mkdtemp()
job: BaseAutomation = job_cls(working_dir)
@ -150,8 +151,9 @@ def run_job(job_cls: Type[BaseAutomation], *,
if job is not None and success:
automation.state = AutomationState.IDLE
automation.next_run = datetime.now(tz=timezone.utc) + timedelta(
minutes=getattr(job, "frequency", 7))
if 'TERRAFORM_DIRECTORY' not in app.config and working_dir is not None:
minutes=getattr(job, "frequency", 7)
)
if "TERRAFORM_DIRECTORY" not in app.config and working_dir is not None:
# We used a temporary working directory
shutil.rmtree(working_dir)
else:
@ -165,7 +167,7 @@ def run_job(job_cls: Type[BaseAutomation], *,
"list_gitlab",
"block_blocky",
"block_external",
"block_ooni"
"block_ooni",
]
if job.short_name not in safe_jobs:
automation.enabled = False
@ -179,10 +181,12 @@ def run_job(job_cls: Type[BaseAutomation], *,
db.session.commit()
activity = Activity(
activity_type="automation",
text=(f"[{automation.short_name}] 🚨 Automation failure: It was not possible to handle this failure safely "
text=(
f"[{automation.short_name}] 🚨 Automation failure: It was not possible to handle this failure safely "
"and so the automation task has been automatically disabled. It may be possible to simply re-enable "
"the task, but repeated failures will usually require deeper investigation. See logs for full "
"details.")
"details."
),
)
db.session.add(activity)
activity.notify() # Notify before commit because the failure occurred even if we can't commit.
@ -194,20 +198,43 @@ class AutomateCliHandler(BaseCliHandler):
@classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("automate", help="automation operations")
parser.add_argument("-a", "--all", dest="all", help="run all automation jobs", action="store_true")
parser.add_argument("-j", "--job", dest="job", choices=sorted(jobs.keys()),
help="run a specific automation job")
parser.add_argument("--force", help="run job even if disabled and it's not time yet", action="store_true")
parser.add_argument("--ignore-schedule", help="run job even if it's not time yet", action="store_true")
parser.add_argument(
"-a",
"--all",
dest="all",
help="run all automation jobs",
action="store_true",
)
parser.add_argument(
"-j",
"--job",
dest="job",
choices=sorted(jobs.keys()),
help="run a specific automation job",
)
parser.add_argument(
"--force",
help="run job even if disabled and it's not time yet",
action="store_true",
)
parser.add_argument(
"--ignore-schedule",
help="run job even if it's not time yet",
action="store_true",
)
parser.set_defaults(cls=cls)
def run(self) -> None:
with app.app_context():
if self.args.job:
run_job(jobs[self.args.job],
run_job(
jobs[self.args.job],
force=self.args.force,
ignore_schedule=self.args.ignore_schedule)
ignore_schedule=self.args.ignore_schedule,
)
elif self.args.all:
run_all(force=self.args.force, ignore_schedule=self.args.ignore_schedule)
run_all(
force=self.args.force, ignore_schedule=self.args.ignore_schedule
)
else:
logging.error("No action requested")

View file

@ -40,7 +40,7 @@ models: List[Model] = [
Eotk,
MirrorList,
TerraformState,
Webhook
Webhook,
]
@ -53,7 +53,7 @@ class ExportEncoder(json.JSONEncoder):
if isinstance(o, AutomationState):
return o.name
if isinstance(o, bytes):
return base64.encodebytes(o).decode('utf-8')
return base64.encodebytes(o).decode("utf-8")
if isinstance(o, (datetime.datetime, datetime.date, datetime.time)):
return o.isoformat()
return super().default(o)
@ -82,7 +82,7 @@ def db_export() -> None:
decoder: Dict[str, Callable[[Any], Any]] = {
"AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x),
"AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x),
"bytes": lambda x: base64.decodebytes(x.encode('utf-8')),
"bytes": lambda x: base64.decodebytes(x.encode("utf-8")),
"datetime": datetime.datetime.fromisoformat,
"int": int,
"str": lambda x: x,
@ -110,8 +110,12 @@ class DbCliHandler(BaseCliHandler):
@classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("db", help="database operations")
parser.add_argument("--export", help="export data to JSON format", action="store_true")
parser.add_argument("--import", help="import data from JSON format", action="store_true")
parser.add_argument(
"--export", help="export data to JSON format", action="store_true"
)
parser.add_argument(
"--import", help="import data from JSON format", action="store_true"
)
parser.set_defaults(cls=cls)
def run(self) -> None:

View file

@ -17,8 +17,9 @@ class ListCliHandler(BaseCliHandler):
@classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("list", help="list operations")
parser.add_argument("--dump", choices=sorted(lists.keys()),
help="dump a list in JSON format")
parser.add_argument(
"--dump", choices=sorted(lists.keys()), help="dump a list in JSON format"
)
parser.set_defaults(cls=cls)
def run(self) -> None:

View file

@ -4,11 +4,11 @@ from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
convention = {
"ix": 'ix_%(column_0_label)s',
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"
"pk": "pk_%(table_name)s",
}
metadata = MetaData(naming_convention=convention)

View file

@ -26,13 +26,15 @@ def onion_alternative(origin: Origin) -> List[BC2Alternative]:
url: Optional[str] = origin.onion()
if url is None:
return []
return [{
return [
{
"proto": "tor",
"type": "eotk",
"created_at": str(origin.added),
"updated_at": str(origin.updated),
"url": url
}]
"url": url,
}
]
def proxy_alternative(proxy: Proxy) -> Optional[BC2Alternative]:
@ -43,43 +45,51 @@ def proxy_alternative(proxy: Proxy) -> Optional[BC2Alternative]:
"type": "mirror",
"created_at": proxy.added.isoformat(),
"updated_at": proxy.updated.isoformat(),
"url": proxy.url
"url": proxy.url,
}
def main_domain(origin: Origin) -> str:
description: str = origin.description
if description.startswith("proxy:"):
return description[len("proxy:"):].replace("www.", "")
return description[len("proxy:") :].replace("www.", "")
domain_name: str = origin.domain_name
return domain_name.replace("www.", "")
def active_proxies(origin: Origin, pool: Pool) -> List[Proxy]:
return [
proxy for proxy in origin.proxies
if proxy.url is not None and not proxy.deprecated and not proxy.destroyed and proxy.pool_id == pool.id
proxy
for proxy in origin.proxies
if proxy.url is not None
and not proxy.deprecated
and not proxy.destroyed
and proxy.pool_id == pool.id
]
def mirror_sites(pool: Pool) -> BypassCensorship2:
origins = Origin.query.filter(Origin.destroyed.is_(None)).order_by(Origin.domain_name).all()
origins = (
Origin.query.filter(Origin.destroyed.is_(None))
.order_by(Origin.domain_name)
.all()
)
sites: List[BC2Site] = []
for origin in origins:
# Gather alternatives, filtering out None values from proxy_alternative
alternatives = onion_alternative(origin) + [
alt for proxy in active_proxies(origin, pool)
alt
for proxy in active_proxies(origin, pool)
if (alt := proxy_alternative(proxy)) is not None
]
# Add the site dictionary to the list
sites.append({
sites.append(
{
"main_domain": main_domain(origin),
"available_alternatives": list(alternatives)
})
return {
"version": "2.0",
"sites": sites
"available_alternatives": list(alternatives),
}
)
return {"version": "2.0", "sites": sites}

View file

@ -11,12 +11,14 @@ class BridgelinesDict(TypedDict):
bridgelines: List[str]
def bridgelines(pool: Pool, *, distribution_method: Optional[str] = None) -> BridgelinesDict:
def bridgelines(
pool: Pool, *, distribution_method: Optional[str] = None
) -> BridgelinesDict:
# Fetch bridges with selectinload for related data
query = Bridge.query.options(selectinload(Bridge.conf)).filter(
Bridge.destroyed.is_(None),
Bridge.deprecated.is_(None),
Bridge.bridgeline.is_not(None)
Bridge.bridgeline.is_not(None),
)
if distribution_method is not None:
@ -26,7 +28,4 @@ def bridgelines(pool: Pool, *, distribution_method: Optional[str] = None) -> Bri
bridgelines = [b.bridgeline for b in query.all() if b.conf.pool_id == pool.id]
# Return dictionary directly, inlining the previous `to_dict` functionality
return {
"version": "1.0",
"bridgelines": bridgelines
}
return {"version": "1.0", "bridgelines": bridgelines}

View file

@ -48,7 +48,9 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping:
countries = proxy.origin.risk_level
if countries:
highest_risk_country_code, highest_risk_level = max(countries.items(), key=lambda x: x[1])
highest_risk_country_code, highest_risk_level = max(
countries.items(), key=lambda x: x[1]
)
else:
highest_risk_country_code = "ZZ"
highest_risk_level = 0
@ -61,7 +63,7 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping:
"valid_to": proxy.destroyed.isoformat() if proxy.destroyed else None,
"countries": countries,
"country": highest_risk_country_code,
"risk": highest_risk_level
"risk": highest_risk_level,
}
groups = db.session.query(Group).options(selectinload(Group.pools))
@ -70,8 +72,4 @@ def mirror_mapping(_: Optional[Pool]) -> MirrorMapping:
for g in groups.filter(Group.destroyed.is_(None)).all()
]
return {
"version": "1.2",
"mappings": result,
"s3_buckets": s3_buckets
}
return {"version": "1.2", "mappings": result, "s3_buckets": s3_buckets}

View file

@ -26,15 +26,17 @@ def redirector_pool_origins(pool: Pool) -> Dict[str, str]:
Proxy.deprecated.is_(None),
Proxy.destroyed.is_(None),
Proxy.url.is_not(None),
Proxy.pool_id == pool.id
Proxy.pool_id == pool.id,
)
}
def redirector_data(_: Optional[Pool]) -> RedirectorData:
active_pools = Pool.query.options(
selectinload(Pool.proxies)
).filter(Pool.destroyed.is_(None)).all()
active_pools = (
Pool.query.options(selectinload(Pool.proxies))
.filter(Pool.destroyed.is_(None))
.all()
)
pools: List[RedirectorPool] = [
{
@ -42,12 +44,9 @@ def redirector_data(_: Optional[Pool]) -> RedirectorData:
"description": pool.description,
"api_key": pool.api_key,
"redirector_domain": pool.redirector_domain,
"origins": redirector_pool_origins(pool)
"origins": redirector_pool_origins(pool),
}
for pool in active_pools
]
return {
"version": "1.0",
"pools": pools
}
return {"version": "1.0", "pools": pools}

View file

@ -17,7 +17,9 @@ class AbstractConfiguration(db.Model): # type: ignore
description: Mapped[str]
added: Mapped[datetime] = mapped_column(AwareDateTime())
updated: Mapped[datetime] = mapped_column(AwareDateTime())
destroyed: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True)
destroyed: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
@property
@abstractmethod
@ -30,14 +32,10 @@ class AbstractConfiguration(db.Model): # type: ignore
@classmethod
def csv_header(cls) -> List[str]:
return [
"id", "description", "added", "updated", "destroyed"
]
return ["id", "description", "added", "updated", "destroyed"]
def csv_row(self) -> List[Any]:
return [
getattr(self, x) for x in self.csv_header()
]
return [getattr(self, x) for x in self.csv_header()]
class Deprecation(db.Model): # type: ignore[name-defined,misc]
@ -51,7 +49,8 @@ class Deprecation(db.Model): # type: ignore[name-defined,misc]
@property
def resource(self) -> "AbstractResource":
from app.models.mirrors import Proxy # pylint: disable=R0401
model = {'Proxy': Proxy}[self.resource_type]
model = {"Proxy": Proxy}[self.resource_type]
return model.query.get(self.resource_id) # type: ignore[no-any-return]
@ -61,29 +60,38 @@ class AbstractResource(db.Model): # type: ignore
id: Mapped[int] = mapped_column(db.Integer, primary_key=True)
added: Mapped[datetime] = mapped_column(AwareDateTime())
updated: Mapped[datetime] = mapped_column(AwareDateTime())
deprecated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True)
deprecated: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
deprecation_reason: Mapped[Optional[str]]
destroyed: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True)
destroyed: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
def __init__(self, *,
def __init__(
self,
*,
id: Optional[int] = None,
added: Optional[datetime] = None,
updated: Optional[datetime] = None,
deprecated: Optional[datetime] = None,
deprecation_reason: Optional[str] = None,
destroyed: Optional[datetime] = None,
**kwargs: Any) -> None:
**kwargs: Any
) -> None:
if added is None:
added = datetime.now(tz=timezone.utc)
if updated is None:
updated = datetime.now(tz=timezone.utc)
super().__init__(id=id,
super().__init__(
id=id,
added=added,
updated=updated,
deprecated=deprecated,
deprecation_reason=deprecation_reason,
destroyed=destroyed,
**kwargs)
**kwargs
)
@property
@abstractmethod
@ -110,19 +118,21 @@ class AbstractResource(db.Model): # type: ignore
resource_type=type(self).__name__,
resource_id=self.id,
reason=reason,
meta=meta
meta=meta,
)
db.session.add(new_deprecation)
return True
logging.info("Not deprecating %s (reason=%s) because it's already deprecated with that reason.",
self.brn, reason)
logging.info(
"Not deprecating %s (reason=%s) because it's already deprecated with that reason.",
self.brn,
reason,
)
return False
@property
def deprecations(self) -> List[Deprecation]:
return Deprecation.query.filter_by( # type: ignore[no-any-return]
resource_type='Proxy',
resource_id=self.id
resource_type="Proxy", resource_id=self.id
).all()
def destroy(self) -> None:
@ -139,10 +149,13 @@ class AbstractResource(db.Model): # type: ignore
@classmethod
def csv_header(cls) -> List[str]:
return [
"id", "added", "updated", "deprecated", "deprecation_reason", "destroyed"
"id",
"added",
"updated",
"deprecated",
"deprecation_reason",
"destroyed",
]
def csv_row(self) -> List[Union[datetime, bool, int, str]]:
return [
getattr(self, x) for x in self.csv_header()
]
return [getattr(self, x) for x in self.csv_header()]

View file

@ -17,31 +17,40 @@ class Activity(db.Model): # type: ignore
text: Mapped[str]
added: Mapped[datetime] = mapped_column(AwareDateTime())
def __init__(self, *,
def __init__(
self,
*,
id: Optional[int] = None,
group_id: Optional[int] = None,
activity_type: str,
text: str,
added: Optional[datetime] = None,
**kwargs: Any) -> None:
if not isinstance(activity_type, str) or len(activity_type) > 20 or activity_type == "":
raise TypeError("expected string for activity type between 1 and 20 characters")
**kwargs: Any
) -> None:
if (
not isinstance(activity_type, str)
or len(activity_type) > 20
or activity_type == ""
):
raise TypeError(
"expected string for activity type between 1 and 20 characters"
)
if not isinstance(text, str):
raise TypeError("expected string for text")
if added is None:
added = datetime.now(tz=timezone.utc)
super().__init__(id=id,
super().__init__(
id=id,
group_id=group_id,
activity_type=activity_type,
text=text,
added=added,
**kwargs)
**kwargs
)
def notify(self) -> int:
count = 0
hooks = Webhook.query.filter(
Webhook.destroyed.is_(None)
)
hooks = Webhook.query.filter(Webhook.destroyed.is_(None))
for hook in hooks:
hook.send(self.text)
count += 1
@ -59,7 +68,7 @@ class Webhook(AbstractConfiguration):
product="notify",
provider=self.format,
resource_type="conf",
resource_id=str(self.id)
resource_id=str(self.id),
)
def send(self, text: str) -> None:

View file

@ -37,7 +37,15 @@ class Alarm(db.Model): # type: ignore
@classmethod
def csv_header(cls) -> List[str]:
return ["id", "target", "alarm_type", "alarm_state", "state_changed", "last_updated", "text"]
return [
"id",
"target",
"alarm_type",
"alarm_state",
"state_changed",
"last_updated",
"text",
]
def csv_row(self) -> List[Any]:
return [getattr(self, x) for x in self.csv_header()]
@ -45,11 +53,15 @@ class Alarm(db.Model): # type: ignore
def update_state(self, state: AlarmState, text: str) -> None:
if self.alarm_state != state or self.state_changed is None:
self.state_changed = datetime.now(tz=timezone.utc)
activity = Activity(activity_type="alarm_state",
activity = Activity(
activity_type="alarm_state",
text=f"[{self.aspect}] {state.emoji} Alarm state changed from "
f"{self.alarm_state.name} to {state.name} on {self.target}: {text}.")
if (self.alarm_state.name in ["WARNING", "CRITICAL"]
or state.name in ["WARNING", "CRITICAL"]):
f"{self.alarm_state.name} to {state.name} on {self.target}: {text}.",
)
if self.alarm_state.name in ["WARNING", "CRITICAL"] or state.name in [
"WARNING",
"CRITICAL",
]:
# Notifications are only sent on recovery from warning/critical state or on entry
# to warning/critical states. This should reduce alert fatigue.
activity.notify()

View file

@ -33,7 +33,7 @@ class Automation(AbstractConfiguration):
product="core",
provider="",
resource_type="automation",
resource_id=self.short_name
resource_id=self.short_name,
)
def kick(self) -> None:
@ -55,5 +55,5 @@ class AutomationLogs(AbstractResource):
product="core",
provider="",
resource_type="automationlog",
resource_id=str(self.id)
resource_id=str(self.id),
)

View file

@ -26,17 +26,21 @@ class Group(AbstractConfiguration):
eotk: Mapped[bool]
origins: Mapped[List["Origin"]] = relationship("Origin", back_populates="group")
statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="group")
statics: Mapped[List["StaticOrigin"]] = relationship(
"StaticOrigin", back_populates="group"
)
eotks: Mapped[List["Eotk"]] = relationship("Eotk", back_populates="group")
onions: Mapped[List["Onion"]] = relationship("Onion", back_populates="group")
smart_proxies: Mapped[List["SmartProxy"]] = relationship("SmartProxy", back_populates="group")
pools: Mapped[List["Pool"]] = relationship("Pool", secondary="pool_group", back_populates="groups")
smart_proxies: Mapped[List["SmartProxy"]] = relationship(
"SmartProxy", back_populates="group"
)
pools: Mapped[List["Pool"]] = relationship(
"Pool", secondary="pool_group", back_populates="groups"
)
@classmethod
def csv_header(cls) -> List[str]:
return super().csv_header() + [
"group_name", "eotk"
]
return super().csv_header() + ["group_name", "eotk"]
@property
def brn(self) -> BRN:
@ -45,16 +49,15 @@ class Group(AbstractConfiguration):
product="group",
provider="",
resource_type="group",
resource_id=str(self.id)
resource_id=str(self.id),
)
def to_dict(self) -> GroupDict:
if not TYPE_CHECKING:
from app.models.mirrors import Origin # to prevent circular import
active_origins_query = (
db.session.query(aliased(Origin))
.filter(and_(Origin.group_id == self.id, Origin.destroyed.is_(None)))
active_origins_query = db.session.query(aliased(Origin)).filter(
and_(Origin.group_id == self.id, Origin.destroyed.is_(None))
)
active_origins_count = active_origins_query.count()
return {
@ -70,16 +73,20 @@ class Pool(AbstractConfiguration):
api_key: Mapped[str]
redirector_domain: Mapped[Optional[str]]
bridgeconfs: Mapped[List["BridgeConf"]] = relationship("BridgeConf", back_populates="pool")
bridgeconfs: Mapped[List["BridgeConf"]] = relationship(
"BridgeConf", back_populates="pool"
)
proxies: Mapped[List["Proxy"]] = relationship("Proxy", back_populates="pool")
lists: Mapped[List["MirrorList"]] = relationship("MirrorList", back_populates="pool")
groups: Mapped[List[Group]] = relationship("Group", secondary="pool_group", back_populates="pools")
lists: Mapped[List["MirrorList"]] = relationship(
"MirrorList", back_populates="pool"
)
groups: Mapped[List[Group]] = relationship(
"Group", secondary="pool_group", back_populates="pools"
)
@classmethod
def csv_header(cls) -> List[str]:
return super().csv_header() + [
"pool_name"
]
return super().csv_header() + ["pool_name"]
@property
def brn(self) -> BRN:
@ -88,7 +95,7 @@ class Pool(AbstractConfiguration):
product="pool",
provider="",
resource_type="pool",
resource_id=str(self.pool_name)
resource_id=str(self.pool_name),
)
@ -121,14 +128,14 @@ class MirrorList(AbstractConfiguration):
"bc3": "Bypass Censorship v3",
"bca": "Bypass Censorship Analytics",
"bridgelines": "Tor Bridge Lines",
"rdr": "Redirector Data"
"rdr": "Redirector Data",
}
encodings_supported = {
"json": "JSON (Plain)",
"jsno": "JSON (Obfuscated)",
"js": "JavaScript (Plain)",
"jso": "JavaScript (Obfuscated)"
"jso": "JavaScript (Obfuscated)",
}
def destroy(self) -> None:
@ -149,7 +156,11 @@ class MirrorList(AbstractConfiguration):
@classmethod
def csv_header(cls) -> List[str]:
return super().csv_header() + [
"provider", "format", "container", "branch", "filename"
"provider",
"format",
"container",
"branch",
"filename",
]
@property
@ -159,5 +170,5 @@ class MirrorList(AbstractConfiguration):
product="list",
provider=self.provider,
resource_type="list",
resource_id=str(self.id)
resource_id=str(self.id),
)

View file

@ -34,7 +34,7 @@ class BridgeConf(AbstractConfiguration):
product="bridge",
provider="",
resource_type="bridgeconf",
resource_id=str(self.id)
resource_id=str(self.id),
)
def destroy(self) -> None:
@ -48,14 +48,22 @@ class BridgeConf(AbstractConfiguration):
@classmethod
def csv_header(cls) -> List[str]:
return super().csv_header() + [
"pool_id", "provider", "method", "description", "target_number", "max_number", "expiry_hours"
"pool_id",
"provider",
"method",
"description",
"target_number",
"max_number",
"expiry_hours",
]
class Bridge(AbstractResource):
conf_id: Mapped[int] = mapped_column(db.ForeignKey("bridge_conf.id"))
cloud_account_id: Mapped[int] = mapped_column(db.ForeignKey("cloud_account.id"))
terraform_updated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True)
terraform_updated: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
nickname: Mapped[Optional[str]]
fingerprint: Mapped[Optional[str]]
hashed_fingerprint: Mapped[Optional[str]]
@ -71,11 +79,16 @@ class Bridge(AbstractResource):
product="bridge",
provider=self.cloud_account.provider.key,
resource_type="bridge",
resource_id=str(self.id)
resource_id=str(self.id),
)
@classmethod
def csv_header(cls) -> List[str]:
return super().csv_header() + [
"conf_id", "terraform_updated", "nickname", "fingerprint", "hashed_fingerprint", "bridgeline"
"conf_id",
"terraform_updated",
"nickname",
"fingerprint",
"hashed_fingerprint",
"bridgeline",
]

View file

@ -42,9 +42,14 @@ class CloudAccount(AbstractConfiguration):
# Compute Quotas
max_instances: Mapped[int]
bridges: Mapped[List["Bridge"]] = relationship("Bridge", back_populates="cloud_account")
statics: Mapped[List["StaticOrigin"]] = relationship("StaticOrigin", back_populates="storage_cloud_account", foreign_keys=[
StaticOrigin.storage_cloud_account_id])
bridges: Mapped[List["Bridge"]] = relationship(
"Bridge", back_populates="cloud_account"
)
statics: Mapped[List["StaticOrigin"]] = relationship(
"StaticOrigin",
back_populates="storage_cloud_account",
foreign_keys=[StaticOrigin.storage_cloud_account_id],
)
@property
def brn(self) -> BRN:

View file

@ -10,8 +10,7 @@ from tldextract import extract
from werkzeug.datastructures import FileStorage
from app.brm.brn import BRN
from app.brm.utils import (create_data_uri, normalize_color,
thumbnail_uploaded_image)
from app.brm.utils import create_data_uri, normalize_color, thumbnail_uploaded_image
from app.extensions import db
from app.models import AbstractConfiguration, AbstractResource, Deprecation
from app.models.base import Group, Pool
@ -19,10 +18,10 @@ from app.models.onions import Onion
from app.models.types import AwareDateTime
country_origin = db.Table(
'country_origin',
"country_origin",
db.metadata,
db.Column('country_id', db.ForeignKey('country.id'), primary_key=True),
db.Column('origin_id', db.ForeignKey('origin.id'), primary_key=True),
db.Column("country_id", db.ForeignKey("country.id"), primary_key=True),
db.Column("origin_id", db.ForeignKey("origin.id"), primary_key=True),
extend_existing=True,
)
@ -45,7 +44,9 @@ class Origin(AbstractConfiguration):
group: Mapped[Group] = relationship("Group", back_populates="origins")
proxies: Mapped[List[Proxy]] = relationship("Proxy", back_populates="origin")
countries: Mapped[List[Country]] = relationship("Country", secondary=country_origin, back_populates='origins')
countries: Mapped[List[Country]] = relationship(
"Country", secondary=country_origin, back_populates="origins"
)
@property
def brn(self) -> BRN:
@ -54,13 +55,18 @@ class Origin(AbstractConfiguration):
product="mirror",
provider="conf",
resource_type="origin",
resource_id=self.domain_name
resource_id=self.domain_name,
)
@classmethod
def csv_header(cls) -> List[str]:
return super().csv_header() + [
"group_id", "domain_name", "auto_rotation", "smart", "assets", "country"
"group_id",
"domain_name",
"auto_rotation",
"smart",
"assets",
"country",
]
def destroy(self) -> None:
@ -84,30 +90,41 @@ class Origin(AbstractConfiguration):
@property
def risk_level(self) -> Dict[str, int]:
if self.risk_level_override:
return {country.country_code: self.risk_level_override for country in self.countries}
return {
country.country_code: self.risk_level_override
for country in self.countries
}
frequency_factor = 0.0
recency_factor = 0.0
recent_deprecations = (
db.session.query(Deprecation)
.join(Proxy,
Deprecation.resource_id == Proxy.id)
.join(Proxy, Deprecation.resource_id == Proxy.id)
.join(Origin, Origin.id == Proxy.origin_id)
.filter(
Origin.id == self.id,
Deprecation.resource_type == 'Proxy',
Deprecation.deprecated_at >= datetime.now(tz=timezone.utc) - timedelta(hours=168),
Deprecation.reason != "destroyed"
Deprecation.resource_type == "Proxy",
Deprecation.deprecated_at
>= datetime.now(tz=timezone.utc) - timedelta(hours=168),
Deprecation.reason != "destroyed",
)
.distinct(Proxy.id)
.all()
)
for deprecation in recent_deprecations:
recency_factor += 1 / max((datetime.now(tz=timezone.utc) - deprecation.deprecated_at).total_seconds() // 3600, 1)
recency_factor += 1 / max(
(
datetime.now(tz=timezone.utc) - deprecation.deprecated_at
).total_seconds()
// 3600,
1,
)
frequency_factor += 1
risk_levels: Dict[str, int] = {}
for country in self.countries:
risk_levels[country.country_code.upper()] = int(
max(1, min(10, frequency_factor * recency_factor))) + country.risk_level
risk_levels[country.country_code.upper()] = (
int(max(1, min(10, frequency_factor * recency_factor)))
+ country.risk_level
)
return risk_levels
def to_dict(self) -> OriginDict:
@ -128,13 +145,15 @@ class Country(AbstractConfiguration):
product="country",
provider="iso3166-1",
resource_type="alpha2",
resource_id=self.country_code
resource_id=self.country_code,
)
country_code: Mapped[str]
risk_level_override: Mapped[Optional[int]]
origins = db.relationship("Origin", secondary=country_origin, back_populates='countries')
origins = db.relationship(
"Origin", secondary=country_origin, back_populates="countries"
)
@property
def risk_level(self) -> int:
@ -144,29 +163,39 @@ class Country(AbstractConfiguration):
recency_factor = 0.0
recent_deprecations = (
db.session.query(Deprecation)
.join(Proxy,
Deprecation.resource_id == Proxy.id)
.join(Proxy, Deprecation.resource_id == Proxy.id)
.join(Origin, Origin.id == Proxy.origin_id)
.join(Origin.countries)
.filter(
Country.id == self.id,
Deprecation.resource_type == 'Proxy',
Deprecation.deprecated_at >= datetime.now(tz=timezone.utc) - timedelta(hours=168),
Deprecation.reason != "destroyed"
Deprecation.resource_type == "Proxy",
Deprecation.deprecated_at
>= datetime.now(tz=timezone.utc) - timedelta(hours=168),
Deprecation.reason != "destroyed",
)
.distinct(Proxy.id)
.all()
)
for deprecation in recent_deprecations:
recency_factor += 1 / max((datetime.now(tz=timezone.utc) - deprecation.deprecated_at).total_seconds() // 3600, 1)
recency_factor += 1 / max(
(
datetime.now(tz=timezone.utc) - deprecation.deprecated_at
).total_seconds()
// 3600,
1,
)
frequency_factor += 1
return int(max(1, min(10, frequency_factor * recency_factor)))
class StaticOrigin(AbstractConfiguration):
group_id = mapped_column(db.Integer, db.ForeignKey("group.id"), nullable=False)
storage_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False)
source_cloud_account_id = mapped_column(db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False)
storage_cloud_account_id = mapped_column(
db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False
)
source_cloud_account_id = mapped_column(
db.Integer(), db.ForeignKey("cloud_account.id"), nullable=False
)
source_project = mapped_column(db.String(255), nullable=False)
auto_rotate = mapped_column(db.Boolean, nullable=False)
matrix_homeserver = mapped_column(db.String(255), nullable=True)
@ -182,14 +211,18 @@ class StaticOrigin(AbstractConfiguration):
product="mirror",
provider="aws",
resource_type="static",
resource_id=self.domain_name
resource_id=self.domain_name,
)
group = db.relationship("Group", back_populates="statics")
storage_cloud_account = db.relationship("CloudAccount", back_populates="statics",
foreign_keys=[storage_cloud_account_id])
source_cloud_account = db.relationship("CloudAccount", back_populates="statics",
foreign_keys=[source_cloud_account_id])
storage_cloud_account = db.relationship(
"CloudAccount",
back_populates="statics",
foreign_keys=[storage_cloud_account_id],
)
source_cloud_account = db.relationship(
"CloudAccount", back_populates="statics", foreign_keys=[source_cloud_account_id]
)
def destroy(self) -> None:
# TODO: The StaticMetaAutomation will clean up for now, but it should probably happen here for consistency
@ -235,19 +268,29 @@ class StaticOrigin(AbstractConfiguration):
elif isinstance(keanu_convene_logo, FileStorage):
if keanu_convene_logo.filename: # if False, no file was uploaded
keanu_convene_config["logo"] = create_data_uri(
thumbnail_uploaded_image(keanu_convene_logo), keanu_convene_logo.filename)
thumbnail_uploaded_image(keanu_convene_logo),
keanu_convene_logo.filename,
)
else:
raise ValueError("keanu_convene_logo must be a FileStorage")
try:
if isinstance(keanu_convene_color, str):
keanu_convene_config["color"] = normalize_color(keanu_convene_color) # can raise ValueError
keanu_convene_config["color"] = normalize_color(
keanu_convene_color
) # can raise ValueError
else:
raise ValueError() # re-raised below with message
except ValueError:
raise ValueError("keanu_convene_path must be a str containing an HTML color (CSS name or hex)")
self.keanu_convene_config = json.dumps(keanu_convene_config, separators=(',', ':'))
raise ValueError(
"keanu_convene_path must be a str containing an HTML color (CSS name or hex)"
)
self.keanu_convene_config = json.dumps(
keanu_convene_config, separators=(",", ":")
)
del keanu_convene_config # done with this temporary variable
if clean_insights_backend is None or (isinstance(clean_insights_backend, bool) and not clean_insights_backend):
if clean_insights_backend is None or (
isinstance(clean_insights_backend, bool) and not clean_insights_backend
):
self.clean_insights_backend = None
elif isinstance(clean_insights_backend, bool) and clean_insights_backend:
self.clean_insights_backend = "metrics.cleaninsights.org"
@ -260,7 +303,9 @@ class StaticOrigin(AbstractConfiguration):
self.updated = datetime.now(tz=timezone.utc)
ResourceStatus = Union[Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"]]
ResourceStatus = Union[
Literal["active"], Literal["pending"], Literal["expiring"], Literal["destroyed"]
]
class ProxyDict(TypedDict):
@ -271,12 +316,16 @@ class ProxyDict(TypedDict):
class Proxy(AbstractResource):
origin_id: Mapped[int] = mapped_column(db.Integer, db.ForeignKey("origin.id"), nullable=False)
origin_id: Mapped[int] = mapped_column(
db.Integer, db.ForeignKey("origin.id"), nullable=False
)
pool_id: Mapped[Optional[int]] = mapped_column(db.Integer, db.ForeignKey("pool.id"))
provider: Mapped[str] = mapped_column(db.String(20), nullable=False)
psg: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
slug: Mapped[Optional[str]] = mapped_column(db.String(20), nullable=True)
terraform_updated: Mapped[Optional[datetime]] = mapped_column(AwareDateTime(), nullable=True)
terraform_updated: Mapped[Optional[datetime]] = mapped_column(
AwareDateTime(), nullable=True
)
url: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
origin: Mapped[Origin] = relationship("Origin", back_populates="proxies")
@ -289,13 +338,18 @@ class Proxy(AbstractResource):
product="mirror",
provider=self.provider,
resource_type="proxy",
resource_id=str(self.id)
resource_id=str(self.id),
)
@classmethod
def csv_header(cls) -> List[str]:
return super().csv_header() + [
"origin_id", "provider", "psg", "slug", "terraform_updated", "url"
"origin_id",
"provider",
"psg",
"slug",
"terraform_updated",
"url",
]
def to_dict(self) -> ProxyDict:
@ -329,5 +383,5 @@ class SmartProxy(AbstractResource):
product="mirror",
provider=self.provider,
resource_type="smart_proxy",
resource_id=str(1)
resource_id=str(1),
)

View file

@ -32,7 +32,7 @@ class Onion(AbstractConfiguration):
product="eotk",
provider="*",
resource_type="onion",
resource_id=self.onion_name
resource_id=self.onion_name,
)
group_id: Mapped[int] = mapped_column(db.ForeignKey("group.id"))
@ -80,5 +80,5 @@ class Eotk(AbstractResource):
provider=self.provider,
product="eotk",
resource_type="instance",
resource_id=self.region
resource_id=self.region,
)

View file

@ -32,7 +32,9 @@ from app.portal.static import bp as static
from app.portal.storage import bp as storage
from app.portal.webhook import bp as webhook
portal = Blueprint("portal", __name__, template_folder="templates", static_folder="static")
portal = Blueprint(
"portal", __name__, template_folder="templates", static_folder="static"
)
portal.register_blueprint(automation, url_prefix="/automation")
portal.register_blueprint(bridgeconf, url_prefix="/bridgeconf")
portal.register_blueprint(bridge, url_prefix="/bridge")
@ -54,7 +56,10 @@ portal.register_blueprint(webhook, url_prefix="/webhook")
@portal.app_template_filter("bridge_expiry")
def calculate_bridge_expiry(b: Bridge) -> str:
if b.deprecated is None:
logging.warning("Bridge expiry requested by template for a bridge %s that was not expiring.", b.id)
logging.warning(
"Bridge expiry requested by template for a bridge %s that was not expiring.",
b.id,
)
return "Not expiring"
expiry = b.deprecated + timedelta(hours=b.conf.expiry_hours)
countdown = expiry - datetime.now(tz=timezone.utc)
@ -85,27 +90,27 @@ def describe_brn(s: str) -> ResponseReturnValue:
if parts[3] == "mirror":
if parts[5].startswith("origin/"):
origin = Origin.query.filter(
Origin.domain_name == parts[5][len("origin/"):]
Origin.domain_name == parts[5][len("origin/") :]
).first()
if not origin:
return s
return f"Origin: {origin.domain_name} ({origin.group.group_name})"
if parts[5].startswith("proxy/"):
proxy = Proxy.query.filter(
Proxy.id == int(parts[5][len("proxy/"):])
Proxy.id == int(parts[5][len("proxy/") :])
).first()
if not proxy:
return s
return Markup(
f"Proxy: {proxy.url}<br>({proxy.origin.group.group_name}: {proxy.origin.domain_name})")
f"Proxy: {proxy.url}<br>({proxy.origin.group.group_name}: {proxy.origin.domain_name})"
)
if parts[5].startswith("quota/"):
if parts[4] == "cloudfront":
return f"Quota: CloudFront {parts[5][len('quota/'):]}"
if parts[3] == "eotk":
if parts[5].startswith("instance/"):
eotk = Eotk.query.filter(
Eotk.group_id == parts[2],
Eotk.region == parts[5][len("instance/"):]
Eotk.group_id == parts[2], Eotk.region == parts[5][len("instance/") :]
).first()
if not eotk:
return s
@ -138,9 +143,16 @@ def portal_home() -> ResponseReturnValue:
proxies = Proxy.query.filter(Proxy.destroyed.is_(None)).all()
last24 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=1))).all())
last72 = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=3))).all())
lastweek = len(Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=7))).all())
lastweek = len(
Proxy.query.filter(Proxy.deprecated > (now - timedelta(days=7))).all()
)
alarms = {
s: len(Alarm.query.filter(Alarm.alarm_state == s.upper(), Alarm.last_updated > (now - timedelta(days=1))).all())
s: len(
Alarm.query.filter(
Alarm.alarm_state == s.upper(),
Alarm.last_updated > (now - timedelta(days=1)),
).all()
)
for s in ["critical", "warning", "ok", "unknown"]
}
bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all()
@ -148,13 +160,36 @@ def portal_home() -> ResponseReturnValue:
d: len(Bridge.query.filter(Bridge.deprecated > (now - timedelta(days=d))).all())
for d in [1, 3, 7]
}
activity = Activity.query.filter(Activity.added > (now - timedelta(days=2))).order_by(desc(Activity.added)).all()
onionified = len([o for o in Origin.query.filter(Origin.destroyed.is_(None)).all() if o.onion() is not None])
activity = (
Activity.query.filter(Activity.added > (now - timedelta(days=2)))
.order_by(desc(Activity.added))
.all()
)
onionified = len(
[
o
for o in Origin.query.filter(Origin.destroyed.is_(None)).all()
if o.onion() is not None
]
)
ooni_blocked = total_origins_blocked()
total_origins = len(Origin.query.filter(Origin.destroyed.is_(None)).all())
return render_template("home.html.j2", section="home", groups=groups, last24=last24, last72=last72,
lastweek=lastweek, proxies=proxies, **alarms, activity=activity, total_origins=total_origins,
onionified=onionified, br_last=br_last, ooni_blocked=ooni_blocked, bridges=bridges)
return render_template(
"home.html.j2",
section="home",
groups=groups,
last24=last24,
last72=last72,
lastweek=lastweek,
proxies=proxies,
**alarms,
activity=activity,
total_origins=total_origins,
onionified=onionified,
br_last=br_last,
ooni_blocked=ooni_blocked,
bridges=bridges,
)
@portal.route("/search")
@ -163,19 +198,27 @@ def search() -> ResponseReturnValue:
if query is None:
return redirect(url_for("portal.portal_home"))
proxies = Proxy.query.filter(
or_(func.lower(Proxy.url).contains(query.lower())), Proxy.destroyed.is_(None)).all()
or_(func.lower(Proxy.url).contains(query.lower())), Proxy.destroyed.is_(None)
).all()
origins = Origin.query.filter(
or_(func.lower(Origin.description).contains(query.lower()),
func.lower(Origin.domain_name).contains(query.lower()))).all()
return render_template("search.html.j2", section="home", proxies=proxies, origins=origins)
or_(
func.lower(Origin.description).contains(query.lower()),
func.lower(Origin.domain_name).contains(query.lower()),
)
).all()
return render_template(
"search.html.j2", section="home", proxies=proxies, origins=origins
)
@portal.route('/alarms')
@portal.route("/alarms")
def view_alarms() -> ResponseReturnValue:
one_day_ago = datetime.now(timezone.utc) - timedelta(days=1)
alarms = Alarm.query.filter(Alarm.last_updated >= one_day_ago).order_by(
desc(Alarm.alarm_state), desc(Alarm.state_changed)).all()
return render_template("list.html.j2",
section="alarm",
title="Alarms",
items=alarms)
alarms = (
Alarm.query.filter(Alarm.last_updated >= one_day_ago)
.order_by(desc(Alarm.alarm_state), desc(Alarm.state_changed))
.all()
)
return render_template(
"list.html.j2", section="alarm", title="Alarms", items=alarms
)

View file

@ -17,40 +17,52 @@ bp = Blueprint("automation", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "automation",
"help_url": "https://bypass.censorship.guide/user/automation.html"
"help_url": "https://bypass.censorship.guide/user/automation.html",
}
class EditAutomationForm(FlaskForm): # type: ignore
enabled = BooleanField('Enabled')
submit = SubmitField('Save Changes')
enabled = BooleanField("Enabled")
submit = SubmitField("Save Changes")
@bp.route("/list")
def automation_list() -> ResponseReturnValue:
automations = list(filter(
lambda a: a.short_name not in current_app.config.get('HIDDEN_AUTOMATIONS', []),
Automation.query.filter(
Automation.destroyed.is_(None)).order_by(Automation.description).all()
))
automations = list(
filter(
lambda a: a.short_name
not in current_app.config.get("HIDDEN_AUTOMATIONS", []),
Automation.query.filter(Automation.destroyed.is_(None))
.order_by(Automation.description)
.all(),
)
)
states = {tfs.key: tfs for tfs in TerraformState.query.all()}
return render_template("list.html.j2",
return render_template(
"list.html.j2",
title="Automation Jobs",
item="automation",
items=automations,
states=states,
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS
)
@bp.route('/edit/<automation_id>', methods=['GET', 'POST'])
@bp.route("/edit/<automation_id>", methods=["GET", "POST"])
def automation_edit(automation_id: int) -> ResponseReturnValue:
automation: Optional[Automation] = Automation.query.filter(Automation.id == automation_id).first()
automation: Optional[Automation] = Automation.query.filter(
Automation.id == automation_id
).first()
if automation is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Automation Job Not Found",
message="The requested automation job could not be found.",
**_SECTION_TEMPLATE_VARS),
status=404)
**_SECTION_TEMPLATE_VARS
),
status=404,
)
form = EditAutomationForm(enabled=automation.enabled)
if form.validate_on_submit():
automation.enabled = form.enabled.data
@ -59,21 +71,30 @@ def automation_edit(automation_id: int) -> ResponseReturnValue:
db.session.commit()
flash("Saved changes to bridge configuration.", "success")
except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the bridge configuration.", "danger")
logs = AutomationLogs.query.filter(AutomationLogs.automation_id == automation.id).order_by(
desc(AutomationLogs.added)).limit(5).all()
return render_template("automation.html.j2",
flash(
"An error occurred saving the changes to the bridge configuration.",
"danger",
)
logs = (
AutomationLogs.query.filter(AutomationLogs.automation_id == automation.id)
.order_by(desc(AutomationLogs.added))
.limit(5)
.all()
)
return render_template(
"automation.html.j2",
automation=automation,
logs=logs,
form=form,
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS
)
@bp.route("/kick/<automation_id>", methods=['GET', 'POST'])
@bp.route("/kick/<automation_id>", methods=["GET", "POST"])
def automation_kick(automation_id: int) -> ResponseReturnValue:
automation = Automation.query.filter(
Automation.id == automation_id,
Automation.destroyed.is_(None)).first()
Automation.id == automation_id, Automation.destroyed.is_(None)
).first()
if automation is None:
return response_404("The requested bridge configuration could not be found.")
return view_lifecycle(
@ -83,5 +104,5 @@ def automation_kick(automation_id: int) -> ResponseReturnValue:
success_view="portal.automation.automation_list",
success_message="This automation job will next run within 1 minute.",
resource=automation,
action="kick"
action="kick",
)

View file

@ -1,7 +1,6 @@
from typing import Optional
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from app.extensions import db
@ -12,57 +11,79 @@ bp = Blueprint("bridge", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "bridge",
"help_url": "https://bypass.censorship.guide/user/bridges.html"
"help_url": "https://bypass.censorship.guide/user/bridges.html",
}
@bp.route("/list")
def bridge_list() -> ResponseReturnValue:
bridges = Bridge.query.filter(Bridge.destroyed.is_(None)).all()
return render_template("list.html.j2",
return render_template(
"list.html.j2",
title="Tor Bridges",
item="bridge",
items=bridges,
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/block/<bridge_id>", methods=['GET', 'POST'])
@bp.route("/block/<bridge_id>", methods=["GET", "POST"])
def bridge_blocked(bridge_id: int) -> ResponseReturnValue:
bridge: Optional[Bridge] = Bridge.query.filter(Bridge.id == bridge_id, Bridge.destroyed.is_(None)).first()
bridge: Optional[Bridge] = Bridge.query.filter(
Bridge.id == bridge_id, Bridge.destroyed.is_(None)
).first()
if bridge is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Proxy Not Found",
message="The requested bridge could not be found.",
**_SECTION_TEMPLATE_VARS))
**_SECTION_TEMPLATE_VARS,
)
)
form = LifecycleForm()
if form.validate_on_submit():
bridge.deprecate(reason="manual")
db.session.commit()
flash("Bridge will be shortly replaced.", "success")
return redirect(url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id))
return render_template("lifecycle.html.j2",
return redirect(
url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)
)
return render_template(
"lifecycle.html.j2",
header=f"Mark bridge {bridge.hashed_fingerprint} as blocked?",
message=bridge.hashed_fingerprint,
form=form,
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/expire/<bridge_id>", methods=['GET', 'POST'])
@bp.route("/expire/<bridge_id>", methods=["GET", "POST"])
def bridge_expire(bridge_id: int) -> ResponseReturnValue:
bridge: Optional[Bridge] = Bridge.query.filter(Bridge.id == bridge_id, Bridge.destroyed.is_(None)).first()
bridge: Optional[Bridge] = Bridge.query.filter(
Bridge.id == bridge_id, Bridge.destroyed.is_(None)
).first()
if bridge is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Proxy Not Found",
message="The requested bridge could not be found.",
**_SECTION_TEMPLATE_VARS))
**_SECTION_TEMPLATE_VARS,
)
)
form = LifecycleForm()
if form.validate_on_submit():
bridge.destroy()
db.session.commit()
flash("Bridge will be shortly destroyed.", "success")
return redirect(url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id))
return render_template("lifecycle.html.j2",
return redirect(
url_for("portal.bridgeconf.bridgeconf_edit", bridgeconf_id=bridge.conf_id)
)
return render_template(
"lifecycle.html.j2",
header=f"Destroy bridge {bridge.hashed_fingerprint}?",
message=bridge.hashed_fingerprint,
form=form,
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS,
)

View file

@ -1,8 +1,7 @@
from datetime import datetime, timezone
from typing import List, Optional
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from sqlalchemy import exc
@ -19,77 +18,109 @@ bp = Blueprint("bridgeconf", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "bridgeconf",
"help_url": "https://bypass.censorship.guide/user/bridges.html"
"help_url": "https://bypass.censorship.guide/user/bridges.html",
}
class NewBridgeConfForm(FlaskForm): # type: ignore
method = SelectField('Distribution Method', validators=[DataRequired()])
description = StringField('Description')
pool = SelectField('Pool', validators=[DataRequired()])
target_number = IntegerField('Target Number',
method = SelectField("Distribution Method", validators=[DataRequired()])
description = StringField("Description")
pool = SelectField("Pool", validators=[DataRequired()])
target_number = IntegerField(
"Target Number",
description="The number of active bridges to deploy (excluding deprecated bridges).",
validators=[NumberRange(1, message="One or more bridges must be created.")])
max_number = IntegerField('Maximum Number',
validators=[NumberRange(1, message="One or more bridges must be created.")],
)
max_number = IntegerField(
"Maximum Number",
description="The maximum number of bridges to deploy (including deprecated bridges).",
validators=[
NumberRange(1, message="Must be at least 1, ideally greater than target number.")])
expiry_hours = IntegerField('Expiry Timer (hours)',
description=("The number of hours to wait after a bridge is deprecated before its "
"destruction."))
provider_allocation = SelectField('Provider Allocation Method',
NumberRange(
1, message="Must be at least 1, ideally greater than target number."
)
],
)
expiry_hours = IntegerField(
"Expiry Timer (hours)",
description=(
"The number of hours to wait after a bridge is deprecated before its "
"destruction."
),
)
provider_allocation = SelectField(
"Provider Allocation Method",
description="How to allocate new bridges to providers.",
choices=[
("COST", "Use cheapest provider first"),
("RANDOM", "Use providers randomly"),
])
submit = SubmitField('Save Changes')
],
)
submit = SubmitField("Save Changes")
class EditBridgeConfForm(FlaskForm): # type: ignore
description = StringField('Description')
target_number = IntegerField('Target Number',
description = StringField("Description")
target_number = IntegerField(
"Target Number",
description="The number of active bridges to deploy (excluding deprecated bridges).",
validators=[NumberRange(1, message="One or more bridges must be created.")])
max_number = IntegerField('Maximum Number',
validators=[NumberRange(1, message="One or more bridges must be created.")],
)
max_number = IntegerField(
"Maximum Number",
description="The maximum number of bridges to deploy (including deprecated bridges).",
validators=[
NumberRange(1, message="Must be at least 1, ideally greater than target number.")])
expiry_hours = IntegerField('Expiry Timer (hours)',
description=("The number of hours to wait after a bridge is deprecated before its "
"destruction."))
provider_allocation = SelectField('Provider Allocation Method',
NumberRange(
1, message="Must be at least 1, ideally greater than target number."
)
],
)
expiry_hours = IntegerField(
"Expiry Timer (hours)",
description=(
"The number of hours to wait after a bridge is deprecated before its "
"destruction."
),
)
provider_allocation = SelectField(
"Provider Allocation Method",
description="How to allocate new bridges to providers.",
choices=[
("COST", "Use cheapest provider first"),
("RANDOM", "Use providers randomly"),
])
submit = SubmitField('Save Changes')
],
)
submit = SubmitField("Save Changes")
@bp.route("/list")
def bridgeconf_list() -> ResponseReturnValue:
bridgeconfs: List[BridgeConf] = BridgeConf.query.filter(BridgeConf.destroyed.is_(None)).all()
return render_template("list.html.j2",
bridgeconfs: List[BridgeConf] = BridgeConf.query.filter(
BridgeConf.destroyed.is_(None)
).all()
return render_template(
"list.html.j2",
title="Tor Bridge Configurations",
item="bridge configuration",
items=bridgeconfs,
new_link=url_for("portal.bridgeconf.bridgeconf_new"),
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new/<group_id>", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=["GET", "POST"])
def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = NewBridgeConfForm()
form.pool.choices = [(x.id, x.pool_name) for x in Pool.query.filter(Pool.destroyed.is_(None)).all()]
form.pool.choices = [
(x.id, x.pool_name) for x in Pool.query.filter(Pool.destroyed.is_(None)).all()
]
form.method.choices = [
("any", "Any (BridgeDB)"),
("email", "E-Mail (BridgeDB)"),
("moat", "Moat (BridgeDB)"),
("settings", "Settings (BridgeDB)"),
("https", "HTTPS (BridgeDB)"),
("none", "None (Private)")
("none", "None (Private)"),
]
if form.validate_on_submit():
bridgeconf = BridgeConf()
@ -99,7 +130,9 @@ def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue:
bridgeconf.target_number = form.target_number.data
bridgeconf.max_number = form.max_number.data
bridgeconf.expiry_hours = form.expiry_hours.data
bridgeconf.provider_allocation = ProviderAllocation[form.provider_allocation.data]
bridgeconf.provider_allocation = ProviderAllocation[
form.provider_allocation.data
]
bridgeconf.added = datetime.now(tz=timezone.utc)
bridgeconf.updated = datetime.now(tz=timezone.utc)
try:
@ -112,21 +145,24 @@ def bridgeconf_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return redirect(url_for("portal.bridgeconf.bridgeconf_list"))
if group_id:
form.group.data = group_id
return render_template("new.html.j2",
form=form,
**_SECTION_TEMPLATE_VARS)
return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
@bp.route('/edit/<bridgeconf_id>', methods=['GET', 'POST'])
@bp.route("/edit/<bridgeconf_id>", methods=["GET", "POST"])
def bridgeconf_edit(bridgeconf_id: int) -> ResponseReturnValue:
bridgeconf = BridgeConf.query.filter(BridgeConf.id == bridgeconf_id).first()
if bridgeconf is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Bridge Configuration Not Found",
message="The requested bridge configuration could not be found.",
**_SECTION_TEMPLATE_VARS),
status=404)
form = EditBridgeConfForm(description=bridgeconf.description,
**_SECTION_TEMPLATE_VARS,
),
status=404,
)
form = EditBridgeConfForm(
description=bridgeconf.description,
target_number=bridgeconf.target_number,
max_number=bridgeconf.max_number,
expiry_hours=bridgeconf.expiry_hours,
@ -137,22 +173,28 @@ def bridgeconf_edit(bridgeconf_id: int) -> ResponseReturnValue:
bridgeconf.target_number = form.target_number.data
bridgeconf.max_number = form.max_number.data
bridgeconf.expiry_hours = form.expiry_hours.data
bridgeconf.provider_allocation = ProviderAllocation[form.provider_allocation.data]
bridgeconf.provider_allocation = ProviderAllocation[
form.provider_allocation.data
]
bridgeconf.updated = datetime.now(tz=timezone.utc)
try:
db.session.commit()
flash("Saved changes to bridge configuration.", "success")
except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the bridge configuration.", "danger")
return render_template("bridgeconf.html.j2",
bridgeconf=bridgeconf,
form=form,
**_SECTION_TEMPLATE_VARS)
flash(
"An error occurred saving the changes to the bridge configuration.",
"danger",
)
return render_template(
"bridgeconf.html.j2", bridgeconf=bridgeconf, form=form, **_SECTION_TEMPLATE_VARS
)
@bp.route("/destroy/<bridgeconf_id>", methods=['GET', 'POST'])
@bp.route("/destroy/<bridgeconf_id>", methods=["GET", "POST"])
def bridgeconf_destroy(bridgeconf_id: int) -> ResponseReturnValue:
bridgeconf = BridgeConf.query.filter(BridgeConf.id == bridgeconf_id, BridgeConf.destroyed.is_(None)).first()
bridgeconf = BridgeConf.query.filter(
BridgeConf.id == bridgeconf_id, BridgeConf.destroyed.is_(None)
).first()
if bridgeconf is None:
return response_404("The requested bridge configuration could not be found.")
return view_lifecycle(
@ -162,5 +204,5 @@ def bridgeconf_destroy(bridgeconf_id: int) -> ResponseReturnValue:
success_message="All bridges from the destroyed configuration will shortly be destroyed at their providers.",
section="bridgeconf",
resource=bridgeconf,
action="destroy"
action="destroy",
)

View file

@ -3,8 +3,15 @@ from typing import Dict, List, Optional, Type, Union
from flask import Blueprint, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from wtforms import (BooleanField, Form, FormField, IntegerField, SelectField,
StringField, SubmitField)
from wtforms import (
BooleanField,
Form,
FormField,
IntegerField,
SelectField,
StringField,
SubmitField,
)
from wtforms.validators import InputRequired
from app.extensions import db
@ -14,54 +21,72 @@ bp = Blueprint("cloud", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "cloud",
"help_url": "https://bypass.censorship.guide/user/cloud.html"
"help_url": "https://bypass.censorship.guide/user/cloud.html",
}
class NewCloudAccountForm(FlaskForm): # type: ignore
provider = SelectField('Cloud Provider', validators=[InputRequired()])
submit = SubmitField('Next')
provider = SelectField("Cloud Provider", validators=[InputRequired()])
submit = SubmitField("Next")
class AWSAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()])
aws_access_key = StringField('AWS Access Key', validators=[InputRequired()])
aws_secret_key = StringField('AWS Secret Key', validators=[InputRequired()])
aws_region = StringField('AWS Region', default='us-east-2', validators=[InputRequired()])
max_distributions = IntegerField('Cloudfront Distributions Quota', default=200,
provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField("Description", validators=[InputRequired()])
aws_access_key = StringField("AWS Access Key", validators=[InputRequired()])
aws_secret_key = StringField("AWS Secret Key", validators=[InputRequired()])
aws_region = StringField(
"AWS Region", default="us-east-2", validators=[InputRequired()]
)
max_distributions = IntegerField(
"Cloudfront Distributions Quota",
default=200,
description="This is the quota for number of distributions per account.",
validators=[InputRequired()])
max_instances = IntegerField('EC2 Instance Quota', default=2,
validators=[InputRequired()],
)
max_instances = IntegerField(
"EC2 Instance Quota",
default=2,
description="This can be impacted by a number of quotas including instance limits "
"and IP address limits.",
validators=[InputRequired()])
enabled = BooleanField('Enable this account', default=True,
validators=[InputRequired()],
)
enabled = BooleanField(
"Enable this account",
default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.")
submit = SubmitField('Save Changes')
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class HcloudAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()])
hcloud_token = StringField('Hetzner Cloud Token', validators=[InputRequired()])
max_instances = IntegerField('Server Limit', default=10,
validators=[InputRequired()])
enabled = BooleanField('Enable this account', default=True,
provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField("Description", validators=[InputRequired()])
hcloud_token = StringField("Hetzner Cloud Token", validators=[InputRequired()])
max_instances = IntegerField(
"Server Limit", default=10, validators=[InputRequired()]
)
enabled = BooleanField(
"Enable this account",
default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.")
submit = SubmitField('Save Changes')
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class GitlabAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()])
gitlab_token = StringField('GitLab Access Token', validators=[InputRequired()])
enabled = BooleanField('Enable this account', default=True,
provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField("Description", validators=[InputRequired()])
gitlab_token = StringField("GitLab Access Token", validators=[InputRequired()])
enabled = BooleanField(
"Enable this account",
default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.")
submit = SubmitField('Save Changes')
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class OvhHorizonForm(Form): # type: ignore[misc]
@ -77,16 +102,20 @@ class OvhApiForm(Form): # type: ignore[misc]
class OvhAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()])
horizon = FormField(OvhHorizonForm, 'OpenStack Horizon API')
ovh_api = FormField(OvhApiForm, 'OVH API')
max_instances = IntegerField('Server Limit', default=10,
validators=[InputRequired()])
enabled = BooleanField('Enable this account', default=True,
provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField("Description", validators=[InputRequired()])
horizon = FormField(OvhHorizonForm, "OpenStack Horizon API")
ovh_api = FormField(OvhApiForm, "OVH API")
max_instances = IntegerField(
"Server Limit", default=10, validators=[InputRequired()]
)
enabled = BooleanField(
"Enable this account",
default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.")
submit = SubmitField('Save Changes')
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
class GandiHorizonForm(Form): # type: ignore[misc]
@ -96,18 +125,24 @@ class GandiHorizonForm(Form): # type: ignore[misc]
class GandiAccountForm(FlaskForm): # type: ignore
provider = StringField('Platform', render_kw={"disabled": ""})
description = StringField('Description', validators=[InputRequired()])
horizon = FormField(GandiHorizonForm, 'OpenStack Horizon API')
max_instances = IntegerField('Server Limit', default=10,
validators=[InputRequired()])
enabled = BooleanField('Enable this account', default=True,
provider = StringField("Platform", render_kw={"disabled": ""})
description = StringField("Description", validators=[InputRequired()])
horizon = FormField(GandiHorizonForm, "OpenStack Horizon API")
max_instances = IntegerField(
"Server Limit", default=10, validators=[InputRequired()]
)
enabled = BooleanField(
"Enable this account",
default=True,
description="New resources will not be deployed to disabled accounts, however existing "
"resources will persist until destroyed at the end of their lifecycle.")
submit = SubmitField('Save Changes')
"resources will persist until destroyed at the end of their lifecycle.",
)
submit = SubmitField("Save Changes")
CloudAccountForm = Union[AWSAccountForm, HcloudAccountForm, GandiAccountForm, OvhAccountForm]
CloudAccountForm = Union[
AWSAccountForm, HcloudAccountForm, GandiAccountForm, OvhAccountForm
]
provider_forms: Dict[str, Type[CloudAccountForm]] = {
CloudProvider.AWS.name: AWSAccountForm,
@ -118,7 +153,9 @@ provider_forms: Dict[str, Type[CloudAccountForm]] = {
}
def cloud_account_save(account: Optional[CloudAccount], provider: CloudProvider, form: CloudAccountForm) -> None:
def cloud_account_save(
account: Optional[CloudAccount], provider: CloudProvider, form: CloudAccountForm
) -> None:
if not account:
account = CloudAccount()
account.provider = provider
@ -162,7 +199,9 @@ def cloud_account_save(account: Optional[CloudAccount], provider: CloudProvider,
"ovh_openstack_password": form.horizon.data["ovh_openstack_password"],
"ovh_openstack_tenant_id": form.horizon.data["ovh_openstack_tenant_id"],
"ovh_cloud_application_key": form.ovh_api.data["ovh_cloud_application_key"],
"ovh_cloud_application_secret": form.ovh_api.data["ovh_cloud_application_secret"],
"ovh_cloud_application_secret": form.ovh_api.data[
"ovh_cloud_application_secret"
],
"ovh_cloud_consumer_key": form.ovh_api.data["ovh_cloud_consumer_key"],
}
account.max_distributions = 0
@ -182,53 +221,82 @@ def cloud_account_populate(form: CloudAccountForm, account: CloudAccount) -> Non
form.aws_region.data = account.credentials["aws_region"]
form.max_distributions.data = account.max_distributions
form.max_instances.data = account.max_instances
elif account.provider == CloudProvider.HCLOUD and isinstance(form, HcloudAccountForm):
elif account.provider == CloudProvider.HCLOUD and isinstance(
form, HcloudAccountForm
):
form.hcloud_token.data = account.credentials["hcloud_token"]
form.max_instances.data = account.max_instances
elif account.provider == CloudProvider.GANDI and isinstance(form, GandiAccountForm):
form.horizon.form.gandi_openstack_user.data = account.credentials["gandi_openstack_user"]
form.horizon.form.gandi_openstack_password.data = account.credentials["gandi_openstack_password"]
form.horizon.form.gandi_openstack_tenant_id.data = account.credentials["gandi_openstack_tenant_id"]
form.horizon.form.gandi_openstack_user.data = account.credentials[
"gandi_openstack_user"
]
form.horizon.form.gandi_openstack_password.data = account.credentials[
"gandi_openstack_password"
]
form.horizon.form.gandi_openstack_tenant_id.data = account.credentials[
"gandi_openstack_tenant_id"
]
form.max_instances.data = account.max_instances
elif account.provider == CloudProvider.GITLAB and isinstance(form, GitlabAccountForm):
elif account.provider == CloudProvider.GITLAB and isinstance(
form, GitlabAccountForm
):
form.gitlab_token.data = account.credentials["gitlab_token"]
elif account.provider == CloudProvider.OVH and isinstance(form, OvhAccountForm):
form.horizon.form.ovh_openstack_user.data = account.credentials["ovh_openstack_user"]
form.horizon.form.ovh_openstack_password.data = account.credentials["ovh_openstack_password"]
form.horizon.form.ovh_openstack_tenant_id.data = account.credentials["ovh_openstack_tenant_id"]
form.ovh_api.form.ovh_cloud_application_key.data = account.credentials["ovh_cloud_application_key"]
form.ovh_api.form.ovh_cloud_application_secret.data = account.credentials["ovh_cloud_application_secret"]
form.ovh_api.form.ovh_cloud_consumer_key.data = account.credentials["ovh_cloud_consumer_key"]
form.horizon.form.ovh_openstack_user.data = account.credentials[
"ovh_openstack_user"
]
form.horizon.form.ovh_openstack_password.data = account.credentials[
"ovh_openstack_password"
]
form.horizon.form.ovh_openstack_tenant_id.data = account.credentials[
"ovh_openstack_tenant_id"
]
form.ovh_api.form.ovh_cloud_application_key.data = account.credentials[
"ovh_cloud_application_key"
]
form.ovh_api.form.ovh_cloud_application_secret.data = account.credentials[
"ovh_cloud_application_secret"
]
form.ovh_api.form.ovh_cloud_consumer_key.data = account.credentials[
"ovh_cloud_consumer_key"
]
form.max_instances.data = account.max_instances
else:
raise RuntimeError(f"Unknown provider {account.provider} or form data {type(form)} did not match provider.")
raise RuntimeError(
f"Unknown provider {account.provider} or form data {type(form)} did not match provider."
)
@bp.route("/list")
def cloud_account_list() -> ResponseReturnValue:
accounts: List[CloudAccount] = CloudAccount.query.filter(CloudAccount.destroyed.is_(None)).all()
return render_template("list.html.j2",
accounts: List[CloudAccount] = CloudAccount.query.filter(
CloudAccount.destroyed.is_(None)
).all()
return render_template(
"list.html.j2",
title="Cloud Accounts",
item="cloud account",
items=accounts,
new_link=url_for("portal.cloud.cloud_account_new"),
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS,
)
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
def cloud_account_new() -> ResponseReturnValue:
form = NewCloudAccountForm()
form.provider.choices = sorted([
(provider.name, provider.description) for provider in CloudProvider
], key=lambda p: p[1].lower())
form.provider.choices = sorted(
[(provider.name, provider.description) for provider in CloudProvider],
key=lambda p: p[1].lower(),
)
if form.validate_on_submit():
return redirect(url_for("portal.cloud.cloud_account_new_for", provider=form.provider.data))
return render_template("new.html.j2",
form=form,
**_SECTION_TEMPLATE_VARS)
return redirect(
url_for("portal.cloud.cloud_account_new_for", provider=form.provider.data)
)
return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
@bp.route("/new/<provider>", methods=['GET', 'POST'])
@bp.route("/new/<provider>", methods=["GET", "POST"])
def cloud_account_new_for(provider: str) -> ResponseReturnValue:
form = provider_forms[provider]()
form.provider.data = CloudProvider[provider].description
@ -236,12 +304,10 @@ def cloud_account_new_for(provider: str) -> ResponseReturnValue:
cloud_account_save(None, CloudProvider[provider], form)
db.session.commit()
return redirect(url_for("portal.cloud.cloud_account_list"))
return render_template("new.html.j2",
form=form,
**_SECTION_TEMPLATE_VARS)
return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
@bp.route("/edit/<account_id>", methods=['GET', 'POST'])
@bp.route("/edit/<account_id>", methods=["GET", "POST"])
def cloud_account_edit(account_id: int) -> ResponseReturnValue:
account = CloudAccount.query.filter(
CloudAccount.id == account_id,
@ -256,6 +322,4 @@ def cloud_account_edit(account_id: int) -> ResponseReturnValue:
db.session.commit()
return redirect(url_for("portal.cloud.cloud_account_list"))
cloud_account_populate(form, account)
return render_template("new.html.j2",
form=form,
**_SECTION_TEMPLATE_VARS)
return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)

View file

@ -13,7 +13,7 @@ bp = Blueprint("country", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "country",
"help_url": "https://bypass.censorship.guide/user/countries.html"
"help_url": "https://bypass.censorship.guide/user/countries.html",
}
@ -22,17 +22,18 @@ def filter_country_flag(country_code: str) -> str:
country_code = country_code.upper()
# Calculate the regional indicator symbol for each letter in the country code
base = ord('\U0001F1E6') - ord('A')
flag = ''.join([chr(ord(char) + base) for char in country_code])
base = ord("\U0001F1E6") - ord("A")
flag = "".join([chr(ord(char) + base) for char in country_code])
return flag
@bp.route('/list')
@bp.route("/list")
def country_list() -> ResponseReturnValue:
countries = Country.query.filter(Country.destroyed.is_(None)).all()
print(len(countries))
return render_template("list.html.j2",
return render_template(
"list.html.j2",
title="Countries",
item="country",
new_link=None,
@ -43,21 +44,29 @@ def country_list() -> ResponseReturnValue:
class EditCountryForm(FlaskForm): # type: ignore[misc]
risk_level_override = BooleanField("Force Risk Level Override?")
risk_level_override_number = IntegerField("Forced Risk Level", description="Number from 0 to 20", default=0)
submit = SubmitField('Save Changes')
risk_level_override_number = IntegerField(
"Forced Risk Level", description="Number from 0 to 20", default=0
)
submit = SubmitField("Save Changes")
@bp.route('/edit/<country_id>', methods=['GET', 'POST'])
@bp.route("/edit/<country_id>", methods=["GET", "POST"])
def country_edit(country_id: int) -> ResponseReturnValue:
country = Country.query.filter(Country.id == country_id).first()
if country is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="country",
header="404 Country Not Found",
message="The requested country could not be found."),
status=404)
form = EditCountryForm(risk_level_override=country.risk_level_override is not None,
risk_level_override_number=country.risk_level_override)
message="The requested country could not be found.",
),
status=404,
)
form = EditCountryForm(
risk_level_override=country.risk_level_override is not None,
risk_level_override_number=country.risk_level_override,
)
if form.validate_on_submit():
if form.risk_level_override.data:
country.risk_level_override = form.risk_level_override_number.data
@ -69,6 +78,6 @@ def country_edit(country_id: int) -> ResponseReturnValue:
flash("Saved changes to country.", "success")
except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the country.", "danger")
return render_template("country.html.j2",
section="country",
country=country, form=form)
return render_template(
"country.html.j2", section="country", country=country, form=form
)

View file

@ -10,23 +10,32 @@ bp = Blueprint("eotk", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "eotk",
"help_url": "https://bypass.censorship.guide/user/eotk.html"
"help_url": "https://bypass.censorship.guide/user/eotk.html",
}
@bp.route("/list")
def eotk_list() -> ResponseReturnValue:
instances = Eotk.query.filter(Eotk.destroyed.is_(None)).order_by(desc(Eotk.added)).all()
return render_template("list.html.j2",
instances = (
Eotk.query.filter(Eotk.destroyed.is_(None)).order_by(desc(Eotk.added)).all()
)
return render_template(
"list.html.j2",
title="EOTK Instances",
item="eotk",
items=instances,
**_SECTION_TEMPLATE_VARS)
**_SECTION_TEMPLATE_VARS
)
@bp.route("/conf/<group_id>")
def eotk_conf(group_id: int) -> ResponseReturnValue:
group = Group.query.filter(Group.id == group_id).first()
return Response(render_template("sites.conf.j2",
return Response(
render_template(
"sites.conf.j2",
bypass_token=current_app.config["BYPASS_TOKEN"],
group=group), content_type="text/plain")
group=group,
),
content_type="text/plain",
)

View file

@ -3,6 +3,6 @@ from wtforms import SelectField, StringField, SubmitField
class EditMirrorForm(FlaskForm): # type: ignore
origin = SelectField('Origin')
url = StringField('URL')
submit = SubmitField('Save Changes')
origin = SelectField("Origin")
url = StringField("URL")
submit = SubmitField("Save Changes")

View file

@ -1,8 +1,7 @@
from datetime import datetime, timezone
import sqlalchemy
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from wtforms import BooleanField, StringField, SubmitField
@ -18,27 +17,29 @@ class NewGroupForm(FlaskForm): # type: ignore
group_name = StringField("Short Name", validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()])
eotk = BooleanField("Deploy EOTK instances?")
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"})
submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
class EditGroupForm(FlaskForm): # type: ignore
description = StringField('Description', validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()])
eotk = BooleanField("Deploy EOTK instances?")
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"})
submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
@bp.route("/list")
def group_list() -> ResponseReturnValue:
groups = Group.query.order_by(Group.group_name).all()
return render_template("list.html.j2",
return render_template(
"list.html.j2",
section="group",
title="Groups",
item="group",
items=groups,
new_link=url_for("portal.group.group_new"))
new_link=url_for("portal.group.group_new"),
)
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
def group_new() -> ResponseReturnValue:
form = NewGroupForm()
if form.validate_on_submit():
@ -59,17 +60,20 @@ def group_new() -> ResponseReturnValue:
return render_template("new.html.j2", section="group", form=form)
@bp.route('/edit/<group_id>', methods=['GET', 'POST'])
@bp.route("/edit/<group_id>", methods=["GET", "POST"])
def group_edit(group_id: int) -> ResponseReturnValue:
group = Group.query.filter(Group.id == group_id).first()
if group is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="group",
header="404 Group Not Found",
message="The requested group could not be found."),
status=404)
form = EditGroupForm(description=group.description,
eotk=group.eotk)
message="The requested group could not be found.",
),
status=404,
)
form = EditGroupForm(description=group.description, eotk=group.eotk)
if form.validate_on_submit():
group.description = form.description.data
group.eotk = form.eotk.data
@ -79,6 +83,4 @@ def group_edit(group_id: int) -> ResponseReturnValue:
flash("Saved changes to group.", "success")
except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the group.", "danger")
return render_template("group.html.j2",
section="group",
group=group, form=form)
return render_template("group.html.j2", section="group", group=group, form=form)

View file

@ -2,8 +2,7 @@ import json
from datetime import datetime, timezone
from typing import Any, Optional
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from sqlalchemy import exc
@ -23,7 +22,7 @@ bp = Blueprint("list", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "list",
"help_url": "https://bypass.censorship.guide/user/lists.html"
"help_url": "https://bypass.censorship.guide/user/lists.html",
}
@ -42,10 +41,11 @@ def list_encoding_name(key: str) -> str:
return MirrorList.encodings_supported.get(key, "Unknown")
@bp.route('/list')
@bp.route("/list")
def list_list() -> ResponseReturnValue:
lists = MirrorList.query.filter(MirrorList.destroyed.is_(None)).all()
return render_template("list.html.j2",
return render_template(
"list.html.j2",
title="Distribution Lists",
item="distribution list",
new_link=url_for("portal.list.list_new"),
@ -54,25 +54,31 @@ def list_list() -> ResponseReturnValue:
)
@bp.route('/preview/<format_>/<pool_id>')
@bp.route("/preview/<format_>/<pool_id>")
def list_preview(format_: str, pool_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first()
if not pool:
return response_404(message="Pool not found")
if format_ == "bca":
return Response(json.dumps(mirror_mapping(pool)), content_type="application/json")
return Response(
json.dumps(mirror_mapping(pool)), content_type="application/json"
)
if format_ == "bc2":
return Response(json.dumps(mirror_sites(pool)), content_type="application/json")
if format_ == "bridgelines":
return Response(json.dumps(bridgelines(pool)), content_type="application/json")
if format_ == "rdr":
return Response(json.dumps(redirector_data(pool)), content_type="application/json")
return Response(
json.dumps(redirector_data(pool)), content_type="application/json"
)
return response_404(message="Format not found")
@bp.route("/destroy/<list_id>", methods=['GET', 'POST'])
@bp.route("/destroy/<list_id>", methods=["GET", "POST"])
def list_destroy(list_id: int) -> ResponseReturnValue:
list_ = MirrorList.query.filter(MirrorList.id == list_id, MirrorList.destroyed.is_(None)).first()
list_ = MirrorList.query.filter(
MirrorList.id == list_id, MirrorList.destroyed.is_(None)
).first()
if list_ is None:
return response_404("The requested bridge configuration could not be found.")
return view_lifecycle(
@ -82,12 +88,12 @@ def list_destroy(list_id: int) -> ResponseReturnValue:
success_message="This list will no longer be updated and may be deleted depending on the provider.",
section="list",
resource=list_,
action="destroy"
action="destroy",
)
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new/<group_id>", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=["GET", "POST"])
def list_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = NewMirrorListForm()
form.provider.choices = list(MirrorList.providers_supported.items())
@ -116,43 +122,53 @@ def list_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return redirect(url_for("portal.list.list_list"))
if group_id:
form.group.data = group_id
return render_template("new.html.j2",
form=form,
**_SECTION_TEMPLATE_VARS)
return render_template("new.html.j2", form=form, **_SECTION_TEMPLATE_VARS)
class NewMirrorListForm(FlaskForm): # type: ignore
pool = SelectField('Resource Pool', validators=[DataRequired()])
provider = SelectField('Provider', validators=[DataRequired()])
format = SelectField('Distribution Method', validators=[DataRequired()])
encoding = SelectField('Encoding', validators=[DataRequired()])
description = StringField('Description', validators=[DataRequired()])
container = StringField('Container', validators=[DataRequired()],
description="GitHub Project, GitLab Project or AWS S3 bucket name.")
branch = StringField('Git Branch/AWS Region', validators=[DataRequired()],
pool = SelectField("Resource Pool", validators=[DataRequired()])
provider = SelectField("Provider", validators=[DataRequired()])
format = SelectField("Distribution Method", validators=[DataRequired()])
encoding = SelectField("Encoding", validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()])
container = StringField(
"Container",
validators=[DataRequired()],
description="GitHub Project, GitLab Project or AWS S3 bucket name.",
)
branch = StringField(
"Git Branch/AWS Region",
validators=[DataRequired()],
description="For GitHub/GitLab, set this to the desired branch name, e.g. main. For AWS S3, "
"set this field to the desired region, e.g. us-east-1.")
role = StringField('Role ARN',
description="(Optional) ARN for IAM role to assume for interaction with the S3 bucket.")
filename = StringField('Filename', validators=[DataRequired()])
submit = SubmitField('Save Changes')
"set this field to the desired region, e.g. us-east-1.",
)
role = StringField(
"Role ARN",
description="(Optional) ARN for IAM role to assume for interaction with the S3 bucket.",
)
filename = StringField("Filename", validators=[DataRequired()])
submit = SubmitField("Save Changes")
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.pool.choices = [
(pool.id, pool.pool_name) for pool in Pool.query.all()
]
self.pool.choices = [(pool.id, pool.pool_name) for pool in Pool.query.all()]
@bp.route('/edit/<list_id>', methods=['GET', 'POST'])
@bp.route("/edit/<list_id>", methods=["GET", "POST"])
def list_edit(list_id: int) -> ResponseReturnValue:
list_: Optional[MirrorList] = MirrorList.query.filter(MirrorList.id == list_id).first()
list_: Optional[MirrorList] = MirrorList.query.filter(
MirrorList.id == list_id
).first()
if list_ is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Distribution List Not Found",
message="The requested distribution list could not be found.",
**_SECTION_TEMPLATE_VARS),
status=404)
**_SECTION_TEMPLATE_VARS
),
status=404,
)
form = NewMirrorListForm(
pool=list_.pool_id,
provider=list_.provider,
@ -162,7 +178,7 @@ def list_edit(list_id: int) -> ResponseReturnValue:
container=list_.container,
branch=list_.branch,
role=list_.role,
filename=list_.filename
filename=list_.filename,
)
form.provider.choices = list(MirrorList.providers_supported.items())
form.format.choices = list(MirrorList.formats_supported.items())
@ -182,7 +198,10 @@ def list_edit(list_id: int) -> ResponseReturnValue:
db.session.commit()
flash("Saved changes to group.", "success")
except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the distribution list.", "danger")
return render_template("distlist.html.j2",
list=list_, form=form,
**_SECTION_TEMPLATE_VARS)
flash(
"An error occurred saving the changes to the distribution list.",
"danger",
)
return render_template(
"distlist.html.j2", list=list_, form=form, **_SECTION_TEMPLATE_VARS
)

View file

@ -9,13 +9,13 @@ from app.portal.util import response_404, view_lifecycle
bp = Blueprint("onion", __name__)
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new/<group_id>", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=["GET", "POST"])
def onion_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return redirect("/ui/web/onions/new")
@bp.route('/edit/<onion_id>', methods=['GET', 'POST'])
@bp.route("/edit/<onion_id>", methods=["GET", "POST"])
def onion_edit(onion_id: int) -> ResponseReturnValue:
return redirect("/ui/web/onions/edit/{}".format(onion_id))
@ -25,9 +25,11 @@ def onion_list() -> ResponseReturnValue:
return redirect("/ui/web/onions")
@bp.route("/destroy/<onion_id>", methods=['GET', 'POST'])
@bp.route("/destroy/<onion_id>", methods=["GET", "POST"])
def onion_destroy(onion_id: str) -> ResponseReturnValue:
onion = Onion.query.filter(Onion.id == int(onion_id), Onion.destroyed.is_(None)).first()
onion = Onion.query.filter(
Onion.id == int(onion_id), Onion.destroyed.is_(None)
).first()
if onion is None:
return response_404("The requested onion service could not be found.")
return view_lifecycle(
@ -37,5 +39,5 @@ def onion_destroy(onion_id: str) -> ResponseReturnValue:
success_view="portal.onion.onion_list",
section="onion",
resource=onion,
action="destroy"
action="destroy",
)

View file

@ -4,13 +4,11 @@ from typing import List, Optional
import requests
import sqlalchemy
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from sqlalchemy import exc
from wtforms import (BooleanField, IntegerField, SelectField, StringField,
SubmitField)
from wtforms import BooleanField, IntegerField, SelectField, StringField, SubmitField
from wtforms.validators import DataRequired
from app.extensions import db
@ -22,29 +20,31 @@ bp = Blueprint("origin", __name__)
class NewOriginForm(FlaskForm): # type: ignore
domain_name = StringField('Domain Name', validators=[DataRequired()])
description = StringField('Description', validators=[DataRequired()])
group = SelectField('Group', validators=[DataRequired()])
domain_name = StringField("Domain Name", validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()])
group = SelectField("Group", validators=[DataRequired()])
auto_rotate = BooleanField("Enable auto-rotation?", default=True)
smart_proxy = BooleanField("Requires smart proxy?", default=False)
asset_domain = BooleanField("Used to host assets for other domains?", default=False)
submit = SubmitField('Save Changes')
submit = SubmitField("Save Changes")
class EditOriginForm(FlaskForm): # type: ignore[misc]
description = StringField('Description', validators=[DataRequired()])
group = SelectField('Group', validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()])
group = SelectField("Group", validators=[DataRequired()])
auto_rotate = BooleanField("Enable auto-rotation?")
smart_proxy = BooleanField("Requires smart proxy?")
asset_domain = BooleanField("Used to host assets for other domains?", default=False)
risk_level_override = BooleanField("Force Risk Level Override?")
risk_level_override_number = IntegerField("Forced Risk Level", description="Number from 0 to 20", default=0)
submit = SubmitField('Save Changes')
risk_level_override_number = IntegerField(
"Forced Risk Level", description="Number from 0 to 20", default=0
)
submit = SubmitField("Save Changes")
class CountrySelectForm(FlaskForm): # type: ignore[misc]
country = SelectField("Country", validators=[DataRequired()])
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"})
submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
def final_domain_name(domain_name: str) -> str:
@ -53,8 +53,8 @@ def final_domain_name(domain_name: str) -> str:
return urllib.parse.urlparse(r.url).netloc
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new/<group_id>", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=["GET", "POST"])
def origin_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = NewOriginForm()
form.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
@ -81,22 +81,28 @@ def origin_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return render_template("new.html.j2", section="origin", form=form)
@bp.route('/edit/<origin_id>', methods=['GET', 'POST'])
@bp.route("/edit/<origin_id>", methods=["GET", "POST"])
def origin_edit(origin_id: int) -> ResponseReturnValue:
origin: Optional[Origin] = Origin.query.filter(Origin.id == origin_id).first()
if origin is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="origin",
header="404 Origin Not Found",
message="The requested origin could not be found."),
status=404)
form = EditOriginForm(group=origin.group_id,
message="The requested origin could not be found.",
),
status=404,
)
form = EditOriginForm(
group=origin.group_id,
description=origin.description,
auto_rotate=origin.auto_rotation,
smart_proxy=origin.smart,
asset_domain=origin.assets,
risk_level_override=origin.risk_level_override is not None,
risk_level_override_number=origin.risk_level_override)
risk_level_override_number=origin.risk_level_override,
)
form.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
if form.validate_on_submit():
origin.group_id = form.group.data
@ -114,41 +120,47 @@ def origin_edit(origin_id: int) -> ResponseReturnValue:
flash(f"Saved changes for origin {origin.domain_name}.", "success")
except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the origin.", "danger")
return render_template("origin.html.j2",
section="origin",
origin=origin, form=form)
return render_template("origin.html.j2", section="origin", origin=origin, form=form)
@bp.route("/list")
def origin_list() -> ResponseReturnValue:
origins: List[Origin] = Origin.query.order_by(Origin.domain_name).all()
return render_template("list.html.j2",
return render_template(
"list.html.j2",
section="origin",
title="Web Origins",
item="origin",
new_link=url_for("portal.origin.origin_new"),
items=origins,
extra_buttons=[{
extra_buttons=[
{
"link": url_for("portal.origin.origin_onion"),
"text": "Onion services",
"style": "onion"
}])
"style": "onion",
}
],
)
@bp.route("/onion")
def origin_onion() -> ResponseReturnValue:
origins = Origin.query.order_by(Origin.domain_name).all()
return render_template("list.html.j2",
return render_template(
"list.html.j2",
section="origin",
title="Onion Sites",
item="onion service",
new_link=url_for("portal.onion.onion_new"),
items=origins)
items=origins,
)
@bp.route("/destroy/<origin_id>", methods=['GET', 'POST'])
@bp.route("/destroy/<origin_id>", methods=["GET", "POST"])
def origin_destroy(origin_id: int) -> ResponseReturnValue:
origin = Origin.query.filter(Origin.id == origin_id, Origin.destroyed.is_(None)).first()
origin = Origin.query.filter(
Origin.id == origin_id, Origin.destroyed.is_(None)
).first()
if origin is None:
return response_404("The requested origin could not be found.")
return view_lifecycle(
@ -158,32 +170,44 @@ def origin_destroy(origin_id: int) -> ResponseReturnValue:
success_view="portal.origin.origin_list",
section="origin",
resource=origin,
action="destroy"
action="destroy",
)
@bp.route('/country_remove/<origin_id>/<country_id>', methods=['GET', 'POST'])
@bp.route("/country_remove/<origin_id>/<country_id>", methods=["GET", "POST"])
def origin_country_remove(origin_id: int, country_id: int) -> ResponseReturnValue:
origin = Origin.query.filter(Origin.id == origin_id).first()
if origin is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="origin",
header="404 Pool Not Found",
message="The requested origin could not be found."),
status=404)
message="The requested origin could not be found.",
),
status=404,
)
country = Country.query.filter(Country.id == country_id).first()
if country is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="origin",
header="404 Country Not Found",
message="The requested country could not be found."),
status=404)
message="The requested country could not be found.",
),
status=404,
)
if country not in origin.countries:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="origin",
header="404 Country Not In Pool",
message="The requested country could not be found in the specified origin."),
status=404)
message="The requested country could not be found in the specified origin.",
),
status=404,
)
form = LifecycleForm()
if form.validate_on_submit():
origin.countries.remove(country)
@ -193,32 +217,45 @@ def origin_country_remove(origin_id: int, country_id: int) -> ResponseReturnValu
return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id))
except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the origin.", "danger")
return render_template("lifecycle.html.j2",
return render_template(
"lifecycle.html.j2",
header=f"Remove {country.description} from the {origin.domain_name} origin?",
message="Stop monitoring in this country.",
section="origin",
origin=origin, form=form)
origin=origin,
form=form,
)
@bp.route('/country_add/<origin_id>', methods=['GET', 'POST'])
@bp.route("/country_add/<origin_id>", methods=["GET", "POST"])
def origin_country_add(origin_id: int) -> ResponseReturnValue:
origin = Origin.query.filter(Origin.id == origin_id).first()
if origin is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="origin",
header="404 Origin Not Found",
message="The requested origin could not be found."),
status=404)
message="The requested origin could not be found.",
),
status=404,
)
form = CountrySelectForm()
form.country.choices = [(x.id, f"{x.country_code} - {x.description}") for x in Country.query.all()]
form.country.choices = [
(x.id, f"{x.country_code} - {x.description}") for x in Country.query.all()
]
if form.validate_on_submit():
country = Country.query.filter(Country.id == form.country.data).first()
if country is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="origin",
header="404 Country Not Found",
message="The requested country could not be found."),
status=404)
message="The requested country could not be found.",
),
status=404,
)
origin.countries.append(country)
try:
db.session.commit()
@ -226,8 +263,11 @@ def origin_country_add(origin_id: int) -> ResponseReturnValue:
return redirect(url_for("portal.origin.origin_edit", origin_id=origin.id))
except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the origin.", "danger")
return render_template("lifecycle.html.j2",
return render_template(
"lifecycle.html.j2",
header=f"Add a country to {origin.domain_name}",
message="Enable monitoring from this country:",
section="origin",
origin=origin, form=form)
origin=origin,
form=form,
)

View file

@ -3,8 +3,7 @@ import secrets
from datetime import datetime, timezone
import sqlalchemy
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from wtforms import SelectField, StringField, SubmitField
@ -21,41 +20,50 @@ class NewPoolForm(FlaskForm): # type: ignore[misc]
group_name = StringField("Short Name", validators=[DataRequired()])
description = StringField("Description", validators=[DataRequired()])
redirector_domain = StringField("Redirector Domain")
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"})
submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
class EditPoolForm(FlaskForm): # type: ignore[misc]
description = StringField("Description", validators=[DataRequired()])
redirector_domain = StringField("Redirector Domain")
api_key = StringField("API Key", description=("Any change to this field (e.g. clearing it) will result in the "
"API key being regenerated."))
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"})
api_key = StringField(
"API Key",
description=(
"Any change to this field (e.g. clearing it) will result in the "
"API key being regenerated."
),
)
submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
class GroupSelectForm(FlaskForm): # type: ignore[misc]
group = SelectField("Group", validators=[DataRequired()])
submit = SubmitField('Save Changes', render_kw={"class": "btn btn-success"})
submit = SubmitField("Save Changes", render_kw={"class": "btn btn-success"})
@bp.route("/list")
def pool_list() -> ResponseReturnValue:
pools = Pool.query.order_by(Pool.pool_name).all()
return render_template("list.html.j2",
return render_template(
"list.html.j2",
section="pool",
title="Resource Pools",
item="pool",
items=pools,
new_link=url_for("portal.pool.pool_new"))
new_link=url_for("portal.pool.pool_new"),
)
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
def pool_new() -> ResponseReturnValue:
form = NewPoolForm()
if form.validate_on_submit():
pool = Pool()
pool.pool_name = form.group_name.data
pool.description = form.description.data
pool.redirector_domain = form.redirector_domain.data if form.redirector_domain.data != "" else None
pool.redirector_domain = (
form.redirector_domain.data if form.redirector_domain.data != "" else None
)
pool.api_key = secrets.token_urlsafe(nbytes=32)
pool.added = datetime.now(timezone.utc)
pool.updated = datetime.now(timezone.utc)
@ -71,21 +79,29 @@ def pool_new() -> ResponseReturnValue:
return render_template("new.html.j2", section="pool", form=form)
@bp.route('/edit/<pool_id>', methods=['GET', 'POST'])
@bp.route("/edit/<pool_id>", methods=["GET", "POST"])
def pool_edit(pool_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first()
if pool is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="pool",
header="404 Pool Not Found",
message="The requested pool could not be found."),
status=404)
form = EditPoolForm(description=pool.description,
message="The requested pool could not be found.",
),
status=404,
)
form = EditPoolForm(
description=pool.description,
api_key=pool.api_key,
redirector_domain=pool.redirector_domain)
redirector_domain=pool.redirector_domain,
)
if form.validate_on_submit():
pool.description = form.description.data
pool.redirector_domain = form.redirector_domain.data if form.redirector_domain.data != "" else None
pool.redirector_domain = (
form.redirector_domain.data if form.redirector_domain.data != "" else None
)
if form.api_key.data != pool.api_key:
pool.api_key = secrets.token_urlsafe(nbytes=32)
form.api_key.data = pool.api_key
@ -95,33 +111,43 @@ def pool_edit(pool_id: int) -> ResponseReturnValue:
flash("Saved changes to pool.", "success")
except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the pool.", "danger")
return render_template("pool.html.j2",
section="pool",
pool=pool, form=form)
return render_template("pool.html.j2", section="pool", pool=pool, form=form)
@bp.route('/group_remove/<pool_id>/<group_id>', methods=['GET', 'POST'])
@bp.route("/group_remove/<pool_id>/<group_id>", methods=["GET", "POST"])
def pool_group_remove(pool_id: int, group_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first()
if pool is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="pool",
header="404 Pool Not Found",
message="The requested pool could not be found."),
status=404)
message="The requested pool could not be found.",
),
status=404,
)
group = Group.query.filter(Group.id == group_id).first()
if group is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="pool",
header="404 Group Not Found",
message="The requested group could not be found."),
status=404)
message="The requested group could not be found.",
),
status=404,
)
if group not in pool.groups:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="pool",
header="404 Group Not In Pool",
message="The requested group could not be found in the specified pool."),
status=404)
message="The requested group could not be found in the specified pool.",
),
status=404,
)
form = LifecycleForm()
if form.validate_on_submit():
pool.groups.remove(group)
@ -131,32 +157,43 @@ def pool_group_remove(pool_id: int, group_id: int) -> ResponseReturnValue:
return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id))
except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the pool.", "danger")
return render_template("lifecycle.html.j2",
return render_template(
"lifecycle.html.j2",
header=f"Remove {group.group_name} from the {pool.pool_name} pool?",
message="Resources deployed and available in the pool will be destroyed soon.",
section="pool",
pool=pool, form=form)
pool=pool,
form=form,
)
@bp.route('/group_add/<pool_id>', methods=['GET', 'POST'])
@bp.route("/group_add/<pool_id>", methods=["GET", "POST"])
def pool_group_add(pool_id: int) -> ResponseReturnValue:
pool = Pool.query.filter(Pool.id == pool_id).first()
if pool is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="pool",
header="404 Pool Not Found",
message="The requested pool could not be found."),
status=404)
message="The requested pool could not be found.",
),
status=404,
)
form = GroupSelectForm()
form.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
if form.validate_on_submit():
group = Group.query.filter(Group.id == form.group.data).first()
if group is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="pool",
header="404 Group Not Found",
message="The requested group could not be found."),
status=404)
message="The requested group could not be found.",
),
status=404,
)
pool.groups.append(group)
try:
db.session.commit()
@ -164,8 +201,11 @@ def pool_group_add(pool_id: int) -> ResponseReturnValue:
return redirect(url_for("portal.pool.pool_edit", pool_id=pool.id))
except sqlalchemy.exc.SQLAlchemyError:
flash("An error occurred saving the changes to the pool.", "danger")
return render_template("lifecycle.html.j2",
return render_template(
"lifecycle.html.j2",
header=f"Add a group to {pool.pool_name}",
message="Resources will shortly be deployed and available for all origins in this group.",
section="pool",
pool=pool, form=form)
pool=pool,
form=form,
)

View file

@ -1,5 +1,4 @@
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from sqlalchemy import desc
@ -12,51 +11,63 @@ bp = Blueprint("proxy", __name__)
@bp.route("/list")
def proxy_list() -> ResponseReturnValue:
proxies = Proxy.query.filter(Proxy.destroyed.is_(None)).order_by(desc(Proxy.added)).all()
return render_template("list.html.j2",
section="proxy",
title="Proxies",
item="proxy",
items=proxies)
proxies = (
Proxy.query.filter(Proxy.destroyed.is_(None)).order_by(desc(Proxy.added)).all()
)
return render_template(
"list.html.j2", section="proxy", title="Proxies", item="proxy", items=proxies
)
@bp.route("/expire/<proxy_id>", methods=['GET', 'POST'])
@bp.route("/expire/<proxy_id>", methods=["GET", "POST"])
def proxy_expire(proxy_id: int) -> ResponseReturnValue:
proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first()
if proxy is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Proxy Not Found",
message="The requested proxy could not be found. It may have already been "
"destroyed."))
"destroyed.",
)
)
form = LifecycleForm()
if form.validate_on_submit():
proxy.destroy()
db.session.commit()
flash("Proxy will be shortly retired.", "success")
return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id))
return render_template("lifecycle.html.j2",
return render_template(
"lifecycle.html.j2",
header=f"Expire proxy for {proxy.origin.domain_name} immediately?",
message=proxy.url,
section="proxy",
form=form)
form=form,
)
@bp.route("/block/<proxy_id>", methods=['GET', 'POST'])
@bp.route("/block/<proxy_id>", methods=["GET", "POST"])
def proxy_block(proxy_id: int) -> ResponseReturnValue:
proxy = Proxy.query.filter(Proxy.id == proxy_id, Proxy.destroyed.is_(None)).first()
if proxy is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Proxy Not Found",
message="The requested proxy could not be found. It may have already been "
"destroyed."))
"destroyed.",
)
)
form = LifecycleForm()
if form.validate_on_submit():
proxy.deprecate(reason="manual")
db.session.commit()
flash("Proxy will be shortly replaced.", "success")
return redirect(url_for("portal.origin.origin_edit", origin_id=proxy.origin.id))
return render_template("lifecycle.html.j2",
return render_template(
"lifecycle.html.j2",
header=f"Mark proxy for {proxy.origin.domain_name} as blocked?",
message=proxy.url,
section="proxy",
form=form)
form=form,
)

View file

@ -20,12 +20,12 @@ def generate_subqueries():
deprecations_24hr_subquery = (
db.session.query(
DeprecationAlias.resource_id,
func.count(DeprecationAlias.resource_id).label('deprecations_24hr')
func.count(DeprecationAlias.resource_id).label("deprecations_24hr"),
)
.filter(
DeprecationAlias.reason.like('block_%'),
DeprecationAlias.reason.like("block_%"),
DeprecationAlias.deprecated_at >= now - timedelta(hours=24),
DeprecationAlias.resource_type == 'Proxy'
DeprecationAlias.resource_type == "Proxy",
)
.group_by(DeprecationAlias.resource_id)
.subquery()
@ -33,12 +33,12 @@ def generate_subqueries():
deprecations_72hr_subquery = (
db.session.query(
DeprecationAlias.resource_id,
func.count(DeprecationAlias.resource_id).label('deprecations_72hr')
func.count(DeprecationAlias.resource_id).label("deprecations_72hr"),
)
.filter(
DeprecationAlias.reason.like('block_%'),
DeprecationAlias.reason.like("block_%"),
DeprecationAlias.deprecated_at >= now - timedelta(hours=72),
DeprecationAlias.resource_type == 'Proxy'
DeprecationAlias.resource_type == "Proxy",
)
.group_by(DeprecationAlias.resource_id)
.subquery()
@ -52,13 +52,23 @@ def countries_report():
return (
db.session.query(
Country,
func.coalesce(func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0).label('total_deprecations_24hr'),
func.coalesce(func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0).label('total_deprecations_72hr')
func.coalesce(
func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0
).label("total_deprecations_24hr"),
func.coalesce(
func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0
).label("total_deprecations_72hr"),
)
.join(Origin, Country.origins)
.join(Proxy, Origin.proxies)
.outerjoin(deprecations_24hr_subquery, Proxy.id == deprecations_24hr_subquery.c.resource_id)
.outerjoin(deprecations_72hr_subquery, Proxy.id == deprecations_72hr_subquery.c.resource_id)
.outerjoin(
deprecations_24hr_subquery,
Proxy.id == deprecations_24hr_subquery.c.resource_id,
)
.outerjoin(
deprecations_72hr_subquery,
Proxy.id == deprecations_72hr_subquery.c.resource_id,
)
.group_by(Country.id)
.all()
)
@ -70,12 +80,22 @@ def origins_report():
return (
db.session.query(
Origin,
func.coalesce(func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0).label('total_deprecations_24hr'),
func.coalesce(func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0).label('total_deprecations_72hr')
func.coalesce(
func.sum(deprecations_24hr_subquery.c.deprecations_24hr), 0
).label("total_deprecations_24hr"),
func.coalesce(
func.sum(deprecations_72hr_subquery.c.deprecations_72hr), 0
).label("total_deprecations_72hr"),
)
.outerjoin(Proxy, Origin.proxies)
.outerjoin(deprecations_24hr_subquery, Proxy.id == deprecations_24hr_subquery.c.resource_id)
.outerjoin(deprecations_72hr_subquery, Proxy.id == deprecations_72hr_subquery.c.resource_id)
.outerjoin(
deprecations_24hr_subquery,
Proxy.id == deprecations_24hr_subquery.c.resource_id,
)
.outerjoin(
deprecations_72hr_subquery,
Proxy.id == deprecations_72hr_subquery.c.resource_id,
)
.filter(Origin.destroyed.is_(None))
.group_by(Origin.id)
.order_by(desc("total_deprecations_24hr"))
@ -83,26 +103,37 @@ def origins_report():
)
@report.app_template_filter('country_name')
@report.app_template_filter("country_name")
def country_description_filter(country_code):
country = Country.query.filter_by(country_code=country_code).first()
return country.description if country else None
@report.route("/blocks", methods=['GET'])
@report.route("/blocks", methods=["GET"])
def report_blocks() -> ResponseReturnValue:
blocked_today = db.session.query( # type: ignore[no-untyped-call]
blocked_today = (
db.session.query( # type: ignore[no-untyped-call]
Origin.domain_name,
Origin.description,
Proxy.added,
Proxy.deprecated,
Proxy.deprecation_reason
).join(Origin, Origin.id == Proxy.origin_id
).filter(and_(Proxy.deprecated > datetime.now(tz=timezone.utc) - timedelta(days=1),
Proxy.deprecation_reason.like('block_%'))).all()
Proxy.deprecation_reason,
)
.join(Origin, Origin.id == Proxy.origin_id)
.filter(
and_(
Proxy.deprecated > datetime.now(tz=timezone.utc) - timedelta(days=1),
Proxy.deprecation_reason.like("block_%"),
)
)
.all()
)
return render_template("report_blocks.html.j2",
return render_template(
"report_blocks.html.j2",
blocked_today=blocked_today,
origins=sorted(origins_report(), key=lambda o: o[1], reverse=True),
countries=sorted(countries_report(), key=lambda c: c[0].risk_level, reverse=True),
countries=sorted(
countries_report(), key=lambda c: c[0].risk_level, reverse=True
),
)

View file

@ -9,9 +9,15 @@ bp = Blueprint("smart_proxy", __name__)
@bp.route("/list")
def smart_proxy_list() -> ResponseReturnValue:
instances = SmartProxy.query.filter(SmartProxy.destroyed.is_(None)).order_by(desc(SmartProxy.added)).all()
return render_template("list.html.j2",
instances = (
SmartProxy.query.filter(SmartProxy.destroyed.is_(None))
.order_by(desc(SmartProxy.added))
.all()
)
return render_template(
"list.html.j2",
section="smart_proxy",
title="Smart Proxy Instances",
item="smart proxy",
items=instances)
items=instances,
)

View file

@ -2,13 +2,19 @@ import logging
from typing import Any, List, Optional
import sqlalchemy.exc
from flask import (Blueprint, Response, current_app, flash, redirect,
render_template, url_for)
from flask import (
Blueprint,
Response,
current_app,
flash,
redirect,
render_template,
url_for,
)
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from sqlalchemy import exc
from wtforms import (BooleanField, FileField, SelectField, StringField,
SubmitField)
from wtforms import BooleanField, FileField, SelectField, StringField, SubmitField
from wtforms.validators import DataRequired
from app.brm.static import create_static_origin
@ -22,87 +28,99 @@ bp = Blueprint("static", __name__)
class StaticOriginForm(FlaskForm): # type: ignore
description = StringField(
'Description',
"Description",
validators=[DataRequired()],
description='Enter a brief description of the static website that you are creating in this field. This is '
'also a required field.'
description="Enter a brief description of the static website that you are creating in this field. This is "
"also a required field.",
)
group = SelectField(
'Group',
"Group",
validators=[DataRequired()],
description='Select the group that you want the origin to belong to from the drop-down menu in this field. '
'This is a required field.'
description="Select the group that you want the origin to belong to from the drop-down menu in this field. "
"This is a required field.",
)
storage_cloud_account = SelectField(
'Storage Cloud Account',
"Storage Cloud Account",
validators=[DataRequired()],
description='Select the cloud account that you want the origin to be deployed to from the drop-down menu in '
'this field. This is a required field.'
description="Select the cloud account that you want the origin to be deployed to from the drop-down menu in "
"this field. This is a required field.",
)
source_cloud_account = SelectField(
'Source Cloud Account',
"Source Cloud Account",
validators=[DataRequired()],
description='Select the cloud account that will be used to modify the source repository for the web content '
'for this static origin. This is a required field.'
description="Select the cloud account that will be used to modify the source repository for the web content "
"for this static origin. This is a required field.",
)
source_project = StringField(
'Source Project',
"Source Project",
validators=[DataRequired()],
description='GitLab project path.'
description="GitLab project path.",
)
auto_rotate = BooleanField(
'Auto-Rotate',
"Auto-Rotate",
default=True,
description='Select this field if you want to enable auto-rotation for the mirror. This means that the mirror '
'will automatically redeploy with a new domain name if it is detected to be blocked. This field '
'is optional and is enabled by default.'
description="Select this field if you want to enable auto-rotation for the mirror. This means that the mirror "
"will automatically redeploy with a new domain name if it is detected to be blocked. This field "
"is optional and is enabled by default.",
)
matrix_homeserver = SelectField(
'Matrix Homeserver',
description='Select the Matrix homeserver from the drop-down box to enable Keanu Convene on mirrors of this '
'static origin.'
"Matrix Homeserver",
description="Select the Matrix homeserver from the drop-down box to enable Keanu Convene on mirrors of this "
"static origin.",
)
keanu_convene_path = StringField(
'Keanu Convene Path',
default='talk',
description='Enter the subdirectory to present the Keanu Convene application at on the mirror. This defaults '
'to "talk".'
"Keanu Convene Path",
default="talk",
description="Enter the subdirectory to present the Keanu Convene application at on the mirror. This defaults "
'to "talk".',
)
keanu_convene_logo = FileField(
'Keanu Convene Logo',
description='Logo to use for Keanu Convene'
"Keanu Convene Logo", description="Logo to use for Keanu Convene"
)
keanu_convene_color = StringField(
'Keanu Convene Accent Color',
default='#0047ab',
description='Accent color to use for Keanu Convene (HTML hex code)'
"Keanu Convene Accent Color",
default="#0047ab",
description="Accent color to use for Keanu Convene (HTML hex code)",
)
enable_clean_insights = BooleanField(
'Enable Clean Insights',
description='When enabled, a Clean Insights Measurement Proxy endpoint is deployed on the mirror to allow for '
'submission of results from any of the supported Clean Insights SDKs.'
"Enable Clean Insights",
description="When enabled, a Clean Insights Measurement Proxy endpoint is deployed on the mirror to allow for "
"submission of results from any of the supported Clean Insights SDKs.",
)
submit = SubmitField('Save Changes')
submit = SubmitField("Save Changes")
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.group.choices = [(x.id, x.group_name) for x in Group.query.all()]
self.storage_cloud_account.choices = [(x.id, f"{x.provider.description} - {x.description}") for x in
CloudAccount.query.filter(
CloudAccount.provider == CloudProvider.AWS).all()]
self.source_cloud_account.choices = [(x.id, f"{x.provider.description} - {x.description}") for x in
CloudAccount.query.filter(
CloudAccount.provider == CloudProvider.GITLAB).all()]
self.matrix_homeserver.choices = [(x, x) for x in current_app.config['MATRIX_HOMESERVERS']]
self.storage_cloud_account.choices = [
(x.id, f"{x.provider.description} - {x.description}")
for x in CloudAccount.query.filter(
CloudAccount.provider == CloudProvider.AWS
).all()
]
self.source_cloud_account.choices = [
(x.id, f"{x.provider.description} - {x.description}")
for x in CloudAccount.query.filter(
CloudAccount.provider == CloudProvider.GITLAB
).all()
]
self.matrix_homeserver.choices = [
(x, x) for x in current_app.config["MATRIX_HOMESERVERS"]
]
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new/<group_id>", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
@bp.route("/new/<group_id>", methods=["GET", "POST"])
def static_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form = StaticOriginForm()
if len(form.source_cloud_account.choices) == 0 or len(form.storage_cloud_account.choices) == 0:
flash("You must add at least one AWS account and at least one GitLab account before creating static origins.",
"warning")
if (
len(form.source_cloud_account.choices) == 0
or len(form.storage_cloud_account.choices) == 0
):
flash(
"You must add at least one AWS account and at least one GitLab account before creating static origins.",
"warning",
)
return redirect(url_for("portal.cloud.cloud_account_list"))
if form.validate_on_submit():
try:
@ -118,16 +136,22 @@ def static_new(group_id: Optional[int] = None) -> ResponseReturnValue:
form.keanu_convene_logo.data,
form.keanu_convene_color.data,
form.enable_clean_insights.data,
True
True,
)
flash(f"Created new static origin #{static.id}.", "success")
return redirect(url_for("portal.static.static_edit", static_id=static.id))
except ValueError as e: # may be returned by create_static_origin and from the int conversion
except (
ValueError
) as e: # may be returned by create_static_origin and from the int conversion
logging.warning(e)
flash("Failed to create new static origin due to an invalid input.", "danger")
flash(
"Failed to create new static origin due to an invalid input.", "danger"
)
return redirect(url_for("portal.static.static_list"))
except exc.SQLAlchemyError as e:
flash("Failed to create new static origin due to a database error.", "danger")
flash(
"Failed to create new static origin due to a database error.", "danger"
)
logging.warning(e)
return redirect(url_for("portal.static.static_list"))
if group_id:
@ -135,16 +159,23 @@ def static_new(group_id: Optional[int] = None) -> ResponseReturnValue:
return render_template("new.html.j2", section="static", form=form)
@bp.route('/edit/<static_id>', methods=['GET', 'POST'])
@bp.route("/edit/<static_id>", methods=["GET", "POST"])
def static_edit(static_id: int) -> ResponseReturnValue:
static_origin: Optional[StaticOrigin] = StaticOrigin.query.filter(StaticOrigin.id == static_id).first()
static_origin: Optional[StaticOrigin] = StaticOrigin.query.filter(
StaticOrigin.id == static_id
).first()
if static_origin is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="static",
header="404 Origin Not Found",
message="The requested static origin could not be found."),
status=404)
form = StaticOriginForm(description=static_origin.description,
message="The requested static origin could not be found.",
),
status=404,
)
form = StaticOriginForm(
description=static_origin.description,
group=static_origin.group_id,
storage_cloud_account=static_origin.storage_cloud_account_id,
source_cloud_account=static_origin.source_cloud_account_id,
@ -152,7 +183,8 @@ def static_edit(static_id: int) -> ResponseReturnValue:
matrix_homeserver=static_origin.matrix_homeserver,
keanu_convene_path=static_origin.keanu_convene_path,
auto_rotate=static_origin.auto_rotate,
enable_clean_insights=bool(static_origin.clean_insights_backend))
enable_clean_insights=bool(static_origin.clean_insights_backend),
)
form.group.render_kw = {"disabled": ""}
form.storage_cloud_account.render_kw = {"disabled": ""}
form.source_cloud_account.render_kw = {"disabled": ""}
@ -167,41 +199,59 @@ def static_edit(static_id: int) -> ResponseReturnValue:
form.keanu_convene_logo.data,
form.keanu_convene_color.data,
form.enable_clean_insights.data,
True
True,
)
flash("Saved changes to group.", "success")
except ValueError as e: # may be returned by create_static_origin and from the int conversion
except (
ValueError
) as e: # may be returned by create_static_origin and from the int conversion
logging.warning(e)
flash("An error occurred saving the changes to the static origin due to an invalid input.", "danger")
flash(
"An error occurred saving the changes to the static origin due to an invalid input.",
"danger",
)
except exc.SQLAlchemyError as e:
logging.warning(e)
flash("An error occurred saving the changes to the static origin due to a database error.", "danger")
flash(
"An error occurred saving the changes to the static origin due to a database error.",
"danger",
)
try:
origin = Origin.query.filter_by(domain_name=static_origin.origin_domain_name).one()
origin = Origin.query.filter_by(
domain_name=static_origin.origin_domain_name
).one()
proxies = origin.proxies
except sqlalchemy.exc.NoResultFound:
proxies = []
return render_template("static.html.j2",
return render_template(
"static.html.j2",
section="static",
static=static_origin, form=form,
proxies=proxies)
static=static_origin,
form=form,
proxies=proxies,
)
@bp.route("/list")
def static_list() -> ResponseReturnValue:
statics: List[StaticOrigin] = StaticOrigin.query.order_by(StaticOrigin.description).all()
return render_template("list.html.j2",
statics: List[StaticOrigin] = StaticOrigin.query.order_by(
StaticOrigin.description
).all()
return render_template(
"list.html.j2",
section="static",
title="Static Origins",
item="static",
new_link=url_for("portal.static.static_new"),
items=statics
items=statics,
)
@bp.route("/destroy/<static_id>", methods=['GET', 'POST'])
@bp.route("/destroy/<static_id>", methods=["GET", "POST"])
def static_destroy(static_id: int) -> ResponseReturnValue:
static = StaticOrigin.query.filter(StaticOrigin.id == static_id, StaticOrigin.destroyed.is_(None)).first()
static = StaticOrigin.query.filter(
StaticOrigin.id == static_id, StaticOrigin.destroyed.is_(None)
).first()
if static is None:
return response_404("The requested static origin could not be found.")
return view_lifecycle(
@ -212,5 +262,5 @@ def static_destroy(static_id: int) -> ResponseReturnValue:
success_view="portal.static.static_list",
section="static",
resource=static,
action="destroy"
action="destroy",
)

View file

@ -17,24 +17,30 @@ bp = Blueprint("storage", __name__)
_SECTION_TEMPLATE_VARS = {
"section": "automation",
"help_url": "https://bypass.censorship.guide/user/automation.html"
"help_url": "https://bypass.censorship.guide/user/automation.html",
}
class EditStorageForm(FlaskForm): # type: ignore
force_unlock = BooleanField('Force Unlock')
submit = SubmitField('Save Changes')
force_unlock = BooleanField("Force Unlock")
submit = SubmitField("Save Changes")
@bp.route('/edit/<storage_key>', methods=['GET', 'POST'])
@bp.route("/edit/<storage_key>", methods=["GET", "POST"])
def storage_edit(storage_key: str) -> ResponseReturnValue:
storage: Optional[TerraformState] = TerraformState.query.filter(TerraformState.key == storage_key).first()
storage: Optional[TerraformState] = TerraformState.query.filter(
TerraformState.key == storage_key
).first()
if storage is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
header="404 Storage Key Not Found",
message="The requested storage could not be found.",
**_SECTION_TEMPLATE_VARS),
status=404)
**_SECTION_TEMPLATE_VARS
),
status=404,
)
form = EditStorageForm()
if form.validate_on_submit():
if form.force_unlock.data:
@ -45,17 +51,16 @@ def storage_edit(storage_key: str) -> ResponseReturnValue:
flash("Storage has been force unlocked.", "success")
except exc.SQLAlchemyError:
flash("An error occurred unlocking the storage.", "danger")
return render_template("storage.html.j2",
storage=storage,
form=form,
**_SECTION_TEMPLATE_VARS)
return render_template(
"storage.html.j2", storage=storage, form=form, **_SECTION_TEMPLATE_VARS
)
@bp.route("/kick/<automation_id>", methods=['GET', 'POST'])
@bp.route("/kick/<automation_id>", methods=["GET", "POST"])
def automation_kick(automation_id: int) -> ResponseReturnValue:
automation = Automation.query.filter(
Automation.id == automation_id,
Automation.destroyed.is_(None)).first()
Automation.id == automation_id, Automation.destroyed.is_(None)
).first()
if automation is None:
return response_404("The requested bridge configuration could not be found.")
return view_lifecycle(
@ -65,5 +70,5 @@ def automation_kick(automation_id: int) -> ResponseReturnValue:
success_view="portal.automation.automation_list",
success_message="This automation job will next run within 1 minute.",
resource=automation,
action="kick"
action="kick",
)

View file

@ -9,19 +9,21 @@ from app.models.activity import Activity
def response_404(message: str) -> ResponseReturnValue:
return Response(render_template("error.html.j2",
header="404 Not Found",
message=message))
return Response(
render_template("error.html.j2", header="404 Not Found", message=message)
)
def view_lifecycle(*,
def view_lifecycle(
*,
header: str,
message: str,
success_message: str,
success_view: str,
section: str,
resource: AbstractResource,
action: str) -> ResponseReturnValue:
action: str,
) -> ResponseReturnValue:
form = LifecycleForm()
if action == "destroy":
form.submit.render_kw = {"class": "btn btn-danger"}
@ -41,19 +43,17 @@ def view_lifecycle(*,
return redirect(url_for("portal.portal_home"))
activity = Activity(
activity_type="lifecycle",
text=f"Portal action: {message}. {success_message}"
text=f"Portal action: {message}. {success_message}",
)
db.session.add(activity)
db.session.commit()
activity.notify()
flash(success_message, "success")
return redirect(url_for(success_view))
return render_template("lifecycle.html.j2",
header=header,
message=message,
section=section,
form=form)
return render_template(
"lifecycle.html.j2", header=header, message=message, section=section, form=form
)
class LifecycleForm(FlaskForm): # type: ignore
submit = SubmitField('Confirm')
submit = SubmitField("Confirm")

View file

@ -1,8 +1,7 @@
from datetime import datetime, timezone
from typing import Optional
from flask import (Blueprint, Response, flash, redirect, render_template,
url_for)
from flask import Blueprint, Response, flash, redirect, render_template, url_for
from flask.typing import ResponseReturnValue
from flask_wtf import FlaskForm
from sqlalchemy import exc
@ -26,47 +25,54 @@ def webhook_format_name(key: str) -> str:
class NewWebhookForm(FlaskForm): # type: ignore
description = StringField('Description', validators=[DataRequired()])
format = SelectField('Format', choices=[
("telegram", "Telegram"),
("matrix", "Matrix")
], validators=[DataRequired()])
url = StringField('URL', validators=[DataRequired()])
submit = SubmitField('Save Changes')
description = StringField("Description", validators=[DataRequired()])
format = SelectField(
"Format",
choices=[("telegram", "Telegram"), ("matrix", "Matrix")],
validators=[DataRequired()],
)
url = StringField("URL", validators=[DataRequired()])
submit = SubmitField("Save Changes")
@bp.route("/new", methods=['GET', 'POST'])
@bp.route("/new", methods=["GET", "POST"])
def webhook_new() -> ResponseReturnValue:
form = NewWebhookForm()
if form.validate_on_submit():
webhook = Webhook(
description=form.description.data,
format=form.format.data,
url=form.url.data
url=form.url.data,
)
try:
db.session.add(webhook)
db.session.commit()
flash(f"Created new webhook {webhook.url}.", "success")
return redirect(url_for("portal.webhook.webhook_edit", webhook_id=webhook.id))
return redirect(
url_for("portal.webhook.webhook_edit", webhook_id=webhook.id)
)
except exc.SQLAlchemyError:
flash("Failed to create new webhook.", "danger")
return redirect(url_for("portal.webhook.webhook_list"))
return render_template("new.html.j2", section="webhook", form=form)
@bp.route('/edit/<webhook_id>', methods=['GET', 'POST'])
@bp.route("/edit/<webhook_id>", methods=["GET", "POST"])
def webhook_edit(webhook_id: int) -> ResponseReturnValue:
webhook = Webhook.query.filter(Webhook.id == webhook_id).first()
if webhook is None:
return Response(render_template("error.html.j2",
return Response(
render_template(
"error.html.j2",
section="webhook",
header="404 Webhook Not Found",
message="The requested webhook could not be found."),
status=404)
form = NewWebhookForm(description=webhook.description,
format=webhook.format,
url=webhook.url)
message="The requested webhook could not be found.",
),
status=404,
)
form = NewWebhookForm(
description=webhook.description, format=webhook.format, url=webhook.url
)
if form.validate_on_submit():
webhook.description = form.description.data
webhook.format = form.description.data
@ -77,26 +83,29 @@ def webhook_edit(webhook_id: int) -> ResponseReturnValue:
flash("Saved changes to webhook.", "success")
except exc.SQLAlchemyError:
flash("An error occurred saving the changes to the webhook.", "danger")
return render_template("edit.html.j2",
section="webhook",
title="Edit Webhook",
item=webhook, form=form)
return render_template(
"edit.html.j2", section="webhook", title="Edit Webhook", item=webhook, form=form
)
@bp.route("/list")
def webhook_list() -> ResponseReturnValue:
webhooks = Webhook.query.all()
return render_template("list.html.j2",
return render_template(
"list.html.j2",
section="webhook",
title="Webhooks",
item="webhook",
new_link=url_for("portal.webhook.webhook_new"),
items=webhooks)
items=webhooks,
)
@bp.route("/destroy/<webhook_id>", methods=['GET', 'POST'])
@bp.route("/destroy/<webhook_id>", methods=["GET", "POST"])
def webhook_destroy(webhook_id: int) -> ResponseReturnValue:
webhook: Optional[Webhook] = Webhook.query.filter(Webhook.id == webhook_id, Webhook.destroyed.is_(None)).first()
webhook: Optional[Webhook] = Webhook.query.filter(
Webhook.id == webhook_id, Webhook.destroyed.is_(None)
).first()
if webhook is None:
return response_404("The requested webhook could not be found.")
return view_lifecycle(
@ -106,5 +115,5 @@ def webhook_destroy(webhook_id: int) -> ResponseReturnValue:
success_view="portal.webhook.webhook_list",
section="webhook",
resource=webhook,
action="destroy"
action="destroy",
)

View file

@ -12,6 +12,7 @@ class DeterministicZip:
Heavily inspired by https://github.com/bboe/deterministic_zip.
"""
zipfile: ZipFile
def __init__(self, filename: str):
@ -67,15 +68,22 @@ class BaseAutomation:
if not self.working_dir:
raise RuntimeError("No working directory specified.")
tmpl = jinja2.Template(template)
with open(os.path.join(self.working_dir, filename), 'w', encoding="utf-8") as tfconf:
with open(
os.path.join(self.working_dir, filename), "w", encoding="utf-8"
) as tfconf:
tfconf.write(tmpl.render(**kwargs))
def bin_write(self, filename: str, data: bytes, group_id: Optional[int] = None) -> None:
def bin_write(
self, filename: str, data: bytes, group_id: Optional[int] = None
) -> None:
if not self.working_dir:
raise RuntimeError("No working directory specified.")
try:
os.mkdir(os.path.join(self.working_dir, str(group_id)))
except FileExistsError:
pass
with open(os.path.join(self.working_dir, str(group_id) if group_id else "", filename), 'wb') as binfile:
with open(
os.path.join(self.working_dir, str(group_id) if group_id else "", filename),
"wb",
) as binfile:
binfile.write(data)

View file

@ -13,33 +13,38 @@ from app.terraform import BaseAutomation
def alarms_in_region(region: str, prefix: str, aspect: str) -> None:
cloudwatch = boto3.client('cloudwatch',
aws_access_key_id=app.config['AWS_ACCESS_KEY'],
aws_secret_access_key=app.config['AWS_SECRET_KEY'],
region_name=region)
dist_paginator = cloudwatch.get_paginator('describe_alarms')
cloudwatch = boto3.client(
"cloudwatch",
aws_access_key_id=app.config["AWS_ACCESS_KEY"],
aws_secret_access_key=app.config["AWS_SECRET_KEY"],
region_name=region,
)
dist_paginator = cloudwatch.get_paginator("describe_alarms")
page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix)
for page in page_iterator:
for cw_alarm in page['MetricAlarms']:
eotk_id = cw_alarm["AlarmName"][len(prefix):].split("-")
group: Optional[Group] = Group.query.filter(func.lower(Group.group_name) == eotk_id[1]).first()
for cw_alarm in page["MetricAlarms"]:
eotk_id = cw_alarm["AlarmName"][len(prefix) :].split("-")
group: Optional[Group] = Group.query.filter(
func.lower(Group.group_name) == eotk_id[1]
).first()
if group is None:
print("Unable to find group for " + cw_alarm['AlarmName'])
print("Unable to find group for " + cw_alarm["AlarmName"])
continue
eotk = Eotk.query.filter(
Eotk.group_id == group.id,
Eotk.region == region
Eotk.group_id == group.id, Eotk.region == region
).first()
if eotk is None:
print("Skipping unknown instance " + cw_alarm['AlarmName'])
print("Skipping unknown instance " + cw_alarm["AlarmName"])
continue
alarm = get_or_create_alarm(eotk.brn, aspect)
if cw_alarm['StateValue'] == "OK":
if cw_alarm["StateValue"] == "OK":
alarm.update_state(AlarmState.OK, "CloudWatch alarm OK")
elif cw_alarm['StateValue'] == "ALARM":
elif cw_alarm["StateValue"] == "ALARM":
alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM")
else:
alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}")
alarm.update_state(
AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}"
)
class AlarmEotkAwsAutomation(BaseAutomation):

View file

@ -16,20 +16,19 @@ class AlarmProxyAzureCdnAutomation(BaseAutomation):
def automate(self, full: bool = False) -> Tuple[bool, str]:
credential = ClientSecretCredential(
tenant_id=app.config['AZURE_TENANT_ID'],
client_id=app.config['AZURE_CLIENT_ID'],
client_secret=app.config['AZURE_CLIENT_SECRET'])
client = AlertsManagementClient(
credential,
app.config['AZURE_SUBSCRIPTION_ID']
tenant_id=app.config["AZURE_TENANT_ID"],
client_id=app.config["AZURE_CLIENT_ID"],
client_secret=app.config["AZURE_CLIENT_SECRET"],
)
firing = [x.name[len("bandwidth-out-high-bc-"):]
client = AlertsManagementClient(credential, app.config["AZURE_SUBSCRIPTION_ID"])
firing = [
x.name[len("bandwidth-out-high-bc-") :]
for x in client.alerts.get_all()
if x.name.startswith("bandwidth-out-high-bc-")
and x.properties.essentials.monitor_condition == "Fired"]
and x.properties.essentials.monitor_condition == "Fired"
]
for proxy in Proxy.query.filter(
Proxy.provider == "azure_cdn",
Proxy.destroyed.is_(None)
Proxy.provider == "azure_cdn", Proxy.destroyed.is_(None)
):
alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high")
if proxy.origin.group.group_name.lower() not in firing:

View file

@ -16,9 +16,8 @@ def _cloudfront_quota() -> None:
# It would be nice to learn this from the Service Quotas API, however
# at the time of writing this comment, the current value for this quota
# is not available from the API. It just doesn't return anything.
max_count = int(current_app.config.get('AWS_CLOUDFRONT_MAX_DISTRIBUTIONS', 200))
deployed_count = len(Proxy.query.filter(
Proxy.destroyed.is_(None)).all())
max_count = int(current_app.config.get("AWS_CLOUDFRONT_MAX_DISTRIBUTIONS", 200))
deployed_count = len(Proxy.query.filter(Proxy.destroyed.is_(None)).all())
message = f"{deployed_count} distributions deployed of {max_count} quota"
alarm = get_or_create_alarm(
BRN(
@ -26,9 +25,9 @@ def _cloudfront_quota() -> None:
product="mirror",
provider="cloudfront",
resource_type="quota",
resource_id="distributions"
resource_id="distributions",
),
"quota-usage"
"quota-usage",
)
if deployed_count > max_count * 0.9:
alarm.update_state(AlarmState.CRITICAL, message)
@ -39,26 +38,30 @@ def _cloudfront_quota() -> None:
def _proxy_alarms() -> None:
cloudwatch = boto3.client('cloudwatch',
aws_access_key_id=app.config['AWS_ACCESS_KEY'],
aws_secret_access_key=app.config['AWS_SECRET_KEY'],
region_name='us-east-2')
dist_paginator = cloudwatch.get_paginator('describe_alarms')
cloudwatch = boto3.client(
"cloudwatch",
aws_access_key_id=app.config["AWS_ACCESS_KEY"],
aws_secret_access_key=app.config["AWS_SECRET_KEY"],
region_name="us-east-2",
)
dist_paginator = cloudwatch.get_paginator("describe_alarms")
page_iterator = dist_paginator.paginate(AlarmNamePrefix="bandwidth-out-high-")
for page in page_iterator:
for cw_alarm in page['MetricAlarms']:
dist_id = cw_alarm["AlarmName"][len("bandwidth-out-high-"):]
for cw_alarm in page["MetricAlarms"]:
dist_id = cw_alarm["AlarmName"][len("bandwidth-out-high-") :]
proxy = Proxy.query.filter(Proxy.slug == dist_id).first()
if proxy is None:
print("Skipping unknown proxy " + dist_id)
continue
alarm = get_or_create_alarm(proxy.brn, "bandwidth-out-high")
if cw_alarm['StateValue'] == "OK":
if cw_alarm["StateValue"] == "OK":
alarm.update_state(AlarmState.OK, "CloudWatch alarm OK")
elif cw_alarm['StateValue'] == "ALARM":
elif cw_alarm["StateValue"] == "ALARM":
alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM")
else:
alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}")
alarm.update_state(
AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}"
)
class AlarmProxyCloudfrontAutomation(BaseAutomation):

View file

@ -16,39 +16,25 @@ class AlarmProxyHTTPStatusAutomation(BaseAutomation):
frequency = 45
def automate(self, full: bool = False) -> Tuple[bool, str]:
proxies = Proxy.query.filter(
Proxy.destroyed.is_(None)
)
proxies = Proxy.query.filter(Proxy.destroyed.is_(None))
for proxy in proxies:
try:
if proxy.url is None:
continue
r = requests.get(proxy.url,
allow_redirects=False,
timeout=5)
r = requests.get(proxy.url, allow_redirects=False, timeout=5)
r.raise_for_status()
alarm = get_or_create_alarm(proxy.brn, "http-status")
if r.is_redirect:
alarm.update_state(
AlarmState.CRITICAL,
f"{r.status_code} {r.reason}"
AlarmState.CRITICAL, f"{r.status_code} {r.reason}"
)
else:
alarm.update_state(
AlarmState.OK,
f"{r.status_code} {r.reason}"
)
alarm.update_state(AlarmState.OK, f"{r.status_code} {r.reason}")
except requests.HTTPError:
alarm = get_or_create_alarm(proxy.brn, "http-status")
alarm.update_state(
AlarmState.CRITICAL,
f"{r.status_code} {r.reason}"
)
alarm.update_state(AlarmState.CRITICAL, f"{r.status_code} {r.reason}")
except RequestException as e:
alarm = get_or_create_alarm(proxy.brn, "http-status")
alarm.update_state(
AlarmState.CRITICAL,
repr(e)
)
alarm.update_state(AlarmState.CRITICAL, repr(e))
db.session.commit()
return True, ""

View file

@ -13,33 +13,38 @@ from app.terraform import BaseAutomation
def alarms_in_region(region: str, prefix: str, aspect: str) -> None:
cloudwatch = boto3.client('cloudwatch',
aws_access_key_id=app.config['AWS_ACCESS_KEY'],
aws_secret_access_key=app.config['AWS_SECRET_KEY'],
region_name=region)
dist_paginator = cloudwatch.get_paginator('describe_alarms')
cloudwatch = boto3.client(
"cloudwatch",
aws_access_key_id=app.config["AWS_ACCESS_KEY"],
aws_secret_access_key=app.config["AWS_SECRET_KEY"],
region_name=region,
)
dist_paginator = cloudwatch.get_paginator("describe_alarms")
page_iterator = dist_paginator.paginate(AlarmNamePrefix=prefix)
for page in page_iterator:
for cw_alarm in page['MetricAlarms']:
smart_id = cw_alarm["AlarmName"][len(prefix):].split("-")
group: Optional[Group] = Group.query.filter(func.lower(Group.group_name) == smart_id[1]).first()
for cw_alarm in page["MetricAlarms"]:
smart_id = cw_alarm["AlarmName"][len(prefix) :].split("-")
group: Optional[Group] = Group.query.filter(
func.lower(Group.group_name) == smart_id[1]
).first()
if group is None:
print("Unable to find group for " + cw_alarm['AlarmName'])
print("Unable to find group for " + cw_alarm["AlarmName"])
continue
smart_proxy = SmartProxy.query.filter(
SmartProxy.group_id == group.id,
SmartProxy.region == region
SmartProxy.group_id == group.id, SmartProxy.region == region
).first()
if smart_proxy is None:
print("Skipping unknown instance " + cw_alarm['AlarmName'])
print("Skipping unknown instance " + cw_alarm["AlarmName"])
continue
alarm = get_or_create_alarm(smart_proxy.brn, aspect)
if cw_alarm['StateValue'] == "OK":
if cw_alarm["StateValue"] == "OK":
alarm.update_state(AlarmState.OK, "CloudWatch alarm OK")
elif cw_alarm['StateValue'] == "ALARM":
elif cw_alarm["StateValue"] == "ALARM":
alarm.update_state(AlarmState.CRITICAL, "CloudWatch alarm ALARM")
else:
alarm.update_state(AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}")
alarm.update_state(
AlarmState.UNKNOWN, f"CloudWatch alarm {cw_alarm['StateValue']}"
)
class AlarmSmartAwsAutomation(BaseAutomation):

View file

@ -16,7 +16,7 @@ def clean_json_response(raw_response: str) -> Dict[str, Any]:
"""
end_index = raw_response.rfind("}")
if end_index != -1:
raw_response = raw_response[:end_index + 1]
raw_response = raw_response[: end_index + 1]
response: Dict[str, Any] = json.loads(raw_response)
return response
@ -27,20 +27,21 @@ def request_test_now(test_url: str) -> str:
"User-Agent": "bypasscensorship.org",
"Content-Type": "application/json;charset=utf-8",
"Pragma": "no-cache",
"Cache-Control": "no-cache"
"Cache-Control": "no-cache",
}
request_count = 0
while request_count < 180:
params = {
"url": test_url,
"timestamp": str(int(time.time())) # unix timestamp
}
response = requests.post(api_url, params=params, headers=headers, json={}, timeout=30)
params = {"url": test_url, "timestamp": str(int(time.time()))} # unix timestamp
response = requests.post(
api_url, params=params, headers=headers, json={}, timeout=30
)
response_data = clean_json_response(response.text)
print(f"Response: {response_data}")
if "url_test_id" in response_data.get("d", {}):
url_test_id: str = response_data["d"]["url_test_id"]
logging.debug("Test result for %s has test result ID %s", test_url, url_test_id)
logging.debug(
"Test result for %s has test result ID %s", test_url, url_test_id
)
return url_test_id
request_count += 1
time.sleep(2)
@ -52,13 +53,19 @@ def request_test_result(url_test_id: str) -> int:
headers = {
"User-Agent": "bypasscensorship.org",
"Pragma": "no-cache",
"Cache-Control": "no-cache"
"Cache-Control": "no-cache",
}
response = requests.get(url, headers=headers, timeout=30)
response_data = response.json()
tests = response_data.get("d", [])
non_zero_curl_exit_count: int = sum(1 for test in tests if test.get("curl_exit_value") != "0")
logging.debug("Test result for %s has %s non-zero exit values", url_test_id, non_zero_curl_exit_count)
non_zero_curl_exit_count: int = sum(
1 for test in tests if test.get("curl_exit_value") != "0"
)
logging.debug(
"Test result for %s has %s non-zero exit values",
url_test_id,
non_zero_curl_exit_count,
)
return non_zero_curl_exit_count
@ -81,7 +88,7 @@ class BlockBlockyAutomation(BlockMirrorAutomation):
Proxy.url.is_not(None),
Proxy.deprecated.is_(None),
Proxy.destroyed.is_(None),
Proxy.pool_id != -1
Proxy.pool_id != -1,
)
.all()
)

View file

@ -15,7 +15,8 @@ class BlockBridgeScriptzteamAutomation(BlockBridgelinesAutomation):
def fetch(self) -> None:
r = requests.get(
"https://raw.githubusercontent.com/scriptzteam/Tor-Bridges-Collector/main/bridges-obfs4",
timeout=60)
timeout=60,
)
r.encoding = "utf-8"
contents = r.text
self._lines = contents.splitlines()

View file

@ -24,7 +24,8 @@ class BlockBridgeAutomation(BaseAutomation):
self.hashed_fingerprints = []
super().__init__(*args, **kwargs)
def perform_deprecations(self, ids: List[str], bridge_select_func: Callable[[str], Optional[Bridge]]
def perform_deprecations(
self, ids: List[str], bridge_select_func: Callable[[str], Optional[Bridge]]
) -> List[Tuple[Optional[str], Any, Any]]:
rotated = []
for id_ in ids:
@ -37,7 +38,13 @@ class BlockBridgeAutomation(BaseAutomation):
continue
if bridge.deprecate(reason=self.short_name):
logging.info("Rotated %s", bridge.hashed_fingerprint)
rotated.append((bridge.fingerprint, bridge.cloud_account.provider, bridge.cloud_account.description))
rotated.append(
(
bridge.fingerprint,
bridge.cloud_account.provider,
bridge.cloud_account.description,
)
)
else:
logging.debug("Not rotating a bridge that is already deprecated")
return rotated
@ -50,15 +57,28 @@ class BlockBridgeAutomation(BaseAutomation):
rotated = []
rotated.extend(self.perform_deprecations(self.ips, get_bridge_by_ip))
logging.debug("Blocked by IP")
rotated.extend(self.perform_deprecations(self.fingerprints, get_bridge_by_fingerprint))
rotated.extend(
self.perform_deprecations(self.fingerprints, get_bridge_by_fingerprint)
)
logging.debug("Blocked by fingerprint")
rotated.extend(self.perform_deprecations(self.hashed_fingerprints, get_bridge_by_hashed_fingerprint))
rotated.extend(
self.perform_deprecations(
self.hashed_fingerprints, get_bridge_by_hashed_fingerprint
)
)
logging.debug("Blocked by hashed fingerprint")
if rotated:
activity = Activity(
activity_type="block",
text=(f"[{self.short_name}] ♻ Rotated {len(rotated)} bridges: \n"
+ "\n".join([f"* {fingerprint} ({provider}: {provider_description})" for fingerprint, provider, provider_description in rotated]))
text=(
f"[{self.short_name}] ♻ Rotated {len(rotated)} bridges: \n"
+ "\n".join(
[
f"* {fingerprint} ({provider}: {provider_description})"
for fingerprint, provider, provider_description in rotated
]
)
),
)
db.session.add(activity)
activity.notify()
@ -87,7 +107,7 @@ def get_bridge_by_ip(ip: str) -> Optional[Bridge]:
return Bridge.query.filter( # type: ignore[no-any-return]
Bridge.deprecated.is_(None),
Bridge.destroyed.is_(None),
Bridge.bridgeline.contains(f" {ip} ")
Bridge.bridgeline.contains(f" {ip} "),
).first()
@ -95,7 +115,7 @@ def get_bridge_by_fingerprint(fingerprint: str) -> Optional[Bridge]:
return Bridge.query.filter( # type: ignore[no-any-return]
Bridge.deprecated.is_(None),
Bridge.destroyed.is_(None),
Bridge.fingerprint == fingerprint
Bridge.fingerprint == fingerprint,
).first()
@ -103,5 +123,5 @@ def get_bridge_by_hashed_fingerprint(hashed_fingerprint: str) -> Optional[Bridge
return Bridge.query.filter( # type: ignore[no-any-return]
Bridge.deprecated.is_(None),
Bridge.destroyed.is_(None),
Bridge.hashed_fingerprint == hashed_fingerprint
Bridge.hashed_fingerprint == hashed_fingerprint,
).first()

View file

@ -17,6 +17,8 @@ class BlockBridgelinesAutomation(BlockBridgeAutomation, ABC):
fingerprint = parts[2]
self.ips.append(ip_address)
self.fingerprints.append(fingerprint)
logging.debug(f"Added blocked bridge with IP {ip_address} and fingerprint {fingerprint}")
logging.debug(
f"Added blocked bridge with IP {ip_address} and fingerprint {fingerprint}"
)
except IndexError:
logging.warning("A parsing error occured.")

View file

@ -1,8 +1,7 @@
from flask import current_app
from github import Github
from app.terraform.block.bridge_reachability import \
BlockBridgeReachabilityAutomation
from app.terraform.block.bridge_reachability import BlockBridgeReachabilityAutomation
class BlockBridgeGitHubAutomation(BlockBridgeReachabilityAutomation):
@ -15,12 +14,13 @@ class BlockBridgeGitHubAutomation(BlockBridgeReachabilityAutomation):
frequency = 30
def fetch(self) -> None:
github = Github(current_app.config['GITHUB_API_KEY'])
repo = github.get_repo(current_app.config['GITHUB_BRIDGE_REPO'])
for vantage_point in current_app.config['GITHUB_BRIDGE_VANTAGE_POINTS']:
github = Github(current_app.config["GITHUB_API_KEY"])
repo = github.get_repo(current_app.config["GITHUB_BRIDGE_REPO"])
for vantage_point in current_app.config["GITHUB_BRIDGE_VANTAGE_POINTS"]:
contents = repo.get_contents(f"recentResult_{vantage_point}")
if isinstance(contents, list):
raise RuntimeError(
f"Expected a file at recentResult_{vantage_point}"
" but got a directory.")
self._lines = contents.decoded_content.decode('utf-8').splitlines()
" but got a directory."
)
self._lines = contents.decoded_content.decode("utf-8").splitlines()

View file

@ -1,8 +1,7 @@
from flask import current_app
from gitlab import Gitlab
from app.terraform.block.bridge_reachability import \
BlockBridgeReachabilityAutomation
from app.terraform.block.bridge_reachability import BlockBridgeReachabilityAutomation
class BlockBridgeGitlabAutomation(BlockBridgeReachabilityAutomation):
@ -16,15 +15,15 @@ class BlockBridgeGitlabAutomation(BlockBridgeReachabilityAutomation):
def fetch(self) -> None:
self._lines = list()
credentials = {"private_token": current_app.config['GITLAB_TOKEN']}
credentials = {"private_token": current_app.config["GITLAB_TOKEN"]}
if "GITLAB_URL" in current_app.config:
credentials['url'] = current_app.config['GITLAB_URL']
credentials["url"] = current_app.config["GITLAB_URL"]
gitlab = Gitlab(**credentials)
project = gitlab.projects.get(current_app.config['GITLAB_BRIDGE_PROJECT'])
for vantage_point in current_app.config['GITHUB_BRIDGE_VANTAGE_POINTS']:
project = gitlab.projects.get(current_app.config["GITLAB_BRIDGE_PROJECT"])
for vantage_point in current_app.config["GITHUB_BRIDGE_VANTAGE_POINTS"]:
contents = project.files.get(
file_path=f"recentResult_{vantage_point}",
ref=current_app.config["GITLAB_BRIDGE_BRANCH"]
ref=current_app.config["GITLAB_BRIDGE_BRANCH"],
)
# Decode the base64 first, then decode the UTF-8 string
self._lines.extend(contents.decode().decode('utf-8').splitlines())
self._lines.extend(contents.decode().decode("utf-8").splitlines())

View file

@ -14,8 +14,10 @@ class BlockBridgeReachabilityAutomation(BlockBridgeAutomation, ABC):
def parse(self) -> None:
for line in self._lines:
parts = line.split("\t")
if isoparse(parts[2]) < (datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(days=3)):
if isoparse(parts[2]) < (
datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(days=3)
):
# Skip results older than 3 days
continue
if int(parts[1]) < 40:

View file

@ -13,7 +13,9 @@ class BlockBridgeRoskomsvobodaAutomation(BlockBridgeAutomation):
_data: Any
def fetch(self) -> None:
self._data = requests.get("https://reestr.rublacklist.net/api/v3/ips/", timeout=180).json()
self._data = requests.get(
"https://reestr.rublacklist.net/api/v3/ips/", timeout=180
).json()
def parse(self) -> None:
self.ips.extend(self._data)

View file

@ -9,7 +9,7 @@ from app.terraform.block_mirror import BlockMirrorAutomation
def _trim_prefix(s: str, prefix: str) -> str:
if s.startswith(prefix):
return s[len(prefix):]
return s[len(prefix) :]
return s
@ -20,30 +20,31 @@ def trim_http_https(s: str) -> str:
:param s: String to modify.
:return: Modified string.
"""
return _trim_prefix(
_trim_prefix(s, "https://"),
"http://")
return _trim_prefix(_trim_prefix(s, "https://"), "http://")
class BlockExternalAutomation(BlockMirrorAutomation):
"""
Automation task to import proxy reachability results from external source.
"""
short_name = "block_external"
description = "Import proxy reachability results from external source"
_content: bytes
def fetch(self) -> None:
user_agent = {'User-agent': 'BypassCensorship/1.0'}
check_urls_config = app.config.get('EXTERNAL_CHECK_URL', [])
user_agent = {"User-agent": "BypassCensorship/1.0"}
check_urls_config = app.config.get("EXTERNAL_CHECK_URL", [])
if isinstance(check_urls_config, dict):
# Config is already a dictionary, use as is.
check_urls = check_urls_config
elif isinstance(check_urls_config, list):
# Convert list of strings to a dictionary with "external_N" keys.
check_urls = {f"external_{i}": url for i, url in enumerate(check_urls_config)}
check_urls = {
f"external_{i}": url for i, url in enumerate(check_urls_config)
}
elif isinstance(check_urls_config, str):
# Single string, convert to a dictionary with key "external".
check_urls = {"external": check_urls_config}
@ -53,9 +54,13 @@ class BlockExternalAutomation(BlockMirrorAutomation):
for source, check_url in check_urls.items():
if self._data is None:
self._data = defaultdict(list)
self._data[source].extend(requests.get(check_url, headers=user_agent, timeout=30).json())
self._data[source].extend(
requests.get(check_url, headers=user_agent, timeout=30).json()
)
def parse(self) -> None:
for source, patterns in self._data.items():
self.patterns[source].extend(["https://" + trim_http_https(pattern) for pattern in patterns])
self.patterns[source].extend(
["https://" + trim_http_https(pattern) for pattern in patterns]
)
logging.debug("Found URLs: %s", self.patterns)

View file

@ -52,8 +52,15 @@ class BlockMirrorAutomation(BaseAutomation):
if rotated:
activity = Activity(
activity_type="block",
text=(f"[{self.short_name}] ♻ Rotated {len(rotated)} proxies: \n"
+ "\n".join([f"* {proxy_domain} ({origin_domain})" for proxy_domain, origin_domain in rotated]))
text=(
f"[{self.short_name}] ♻ Rotated {len(rotated)} proxies: \n"
+ "\n".join(
[
f"* {proxy_domain} ({origin_domain})"
for proxy_domain, origin_domain in rotated
]
)
),
)
db.session.add(activity)
activity.notify()
@ -79,15 +86,15 @@ class BlockMirrorAutomation(BaseAutomation):
def active_proxy_urls() -> List[str]:
return [proxy.url for proxy in Proxy.query.filter(
Proxy.deprecated.is_(None),
Proxy.destroyed.is_(None)
).all()]
return [
proxy.url
for proxy in Proxy.query.filter(
Proxy.deprecated.is_(None), Proxy.destroyed.is_(None)
).all()
]
def proxy_by_url(url: str) -> Optional[Proxy]:
return Proxy.query.filter( # type: ignore[no-any-return]
Proxy.deprecated.is_(None),
Proxy.destroyed.is_(None),
Proxy.url == url
Proxy.deprecated.is_(None), Proxy.destroyed.is_(None), Proxy.url == url
).first()

View file

@ -12,19 +12,23 @@ from app.terraform import BaseAutomation
def check_origin(domain_name: str) -> Dict[str, Any]:
start_date = (datetime.now(tz=timezone.utc) - timedelta(days=1)).strftime("%Y-%m-%dT%H%%3A%M")
start_date = (datetime.now(tz=timezone.utc) - timedelta(days=1)).strftime(
"%Y-%m-%dT%H%%3A%M"
)
end_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H%%3A%M")
api_url = f"https://api.ooni.io/api/v1/measurements?domain={domain_name}&since={start_date}&until={end_date}"
result: Dict[str, Dict[str, int]] = defaultdict(lambda: {"anomaly": 0, "confirmed": 0, "failure": 0, "ok": 0})
result: Dict[str, Dict[str, int]] = defaultdict(
lambda: {"anomaly": 0, "confirmed": 0, "failure": 0, "ok": 0}
)
return _check_origin(api_url, result)
def _check_origin(api_url: str, result: Dict[str, Any]) -> Dict[str, Any]:
print(f"Processing {api_url}")
req = requests.get(api_url, timeout=30).json()
if 'results' not in req or not req['results']:
if "results" not in req or not req["results"]:
return result
for r in req['results']:
for r in req["results"]:
not_ok = False
for status in ["anomaly", "confirmed", "failure"]:
if status in r and r[status]:
@ -33,27 +37,28 @@ def _check_origin(api_url: str, result: Dict[str, Any]) -> Dict[str, Any]:
break
if not not_ok:
result[r["probe_cc"]]["ok"] += 1
if req['metadata']['next_url']:
return _check_origin(req['metadata']['next_url'], result)
if req["metadata"]["next_url"]:
return _check_origin(req["metadata"]["next_url"], result)
return result
def threshold_origin(domain_name: str) -> Dict[str, Any]:
ooni = check_origin(domain_name)
for country in ooni:
total = sum([
total = sum(
[
ooni[country]["anomaly"],
ooni[country]["confirmed"],
ooni[country]["failure"],
ooni[country]["ok"]
])
total_blocks = sum([
ooni[country]["anomaly"],
ooni[country]["confirmed"]
])
ooni[country]["ok"],
]
)
total_blocks = sum([ooni[country]["anomaly"], ooni[country]["confirmed"]])
block_perc = round((total_blocks / total * 100), 1)
ooni[country]["block_perc"] = block_perc
ooni[country]["state"] = AlarmState.WARNING if block_perc > 20 else AlarmState.OK
ooni[country]["state"] = (
AlarmState.WARNING if block_perc > 20 else AlarmState.OK
)
ooni[country]["message"] = f"Blocked in {block_perc}% of measurements"
return ooni
@ -72,8 +77,9 @@ class BlockOONIAutomation(BaseAutomation):
for origin in origins:
ooni = threshold_origin(origin.domain_name)
for country in ooni:
alarm = get_or_create_alarm(origin.brn,
f"origin-block-ooni-{country.lower()}")
alarm = get_or_create_alarm(
origin.brn, f"origin-block-ooni-{country.lower()}"
)
alarm.update_state(ooni[country]["state"], ooni[country]["message"])
db.session.commit()
return True, ""

View file

@ -32,6 +32,7 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
Where proxies are found to be blocked they will be rotated.
"""
short_name = "block_roskomsvoboda"
description = "Import Russian blocklist from RosKomSvoboda"
frequency = 300
@ -43,7 +44,11 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
try:
# This endpoint routinely has an expired certificate, and it's more useful that we are consuming the
# data than that we are verifying the certificate.
r = requests.get(f"https://dumps.rublacklist.net/fetch/{latest_rev}", timeout=180, verify=False) # nosec: B501
r = requests.get(
f"https://dumps.rublacklist.net/fetch/{latest_rev}",
timeout=180,
verify=False,
) # nosec: B501
r.raise_for_status()
zip_file = ZipFile(BytesIO(r.content))
self._data = zip_file.read("dump.xml")
@ -51,26 +56,33 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
except requests.HTTPError:
activity = Activity(
activity_type="automation",
text=(f"[{self.short_name}] 🚨 Unable to download dump {latest_rev} due to HTTP error {r.status_code}. "
text=(
f"[{self.short_name}] 🚨 Unable to download dump {latest_rev} due to HTTP error {r.status_code}. "
"The automation task has not been disabled and will attempt to download the next dump when the "
"latest dump revision is incremented at the server."))
"latest dump revision is incremented at the server."
),
)
activity.notify()
db.session.add(activity)
db.session.commit()
except BadZipFile:
activity = Activity(
activity_type="automation",
text=(f"[{self.short_name}] 🚨 Unable to extract zip file from dump {latest_rev}. There was an error "
text=(
f"[{self.short_name}] 🚨 Unable to extract zip file from dump {latest_rev}. There was an error "
"related to the format of the zip file. "
"The automation task has not been disabled and will attempt to download the next dump when the "
"latest dump revision is incremented at the server."))
"latest dump revision is incremented at the server."
),
)
activity.notify()
db.session.add(activity)
db.session.commit()
def fetch(self) -> None:
state: Optional[TerraformState] = TerraformState.query.filter(
TerraformState.key == "block_roskomsvoboda").first()
TerraformState.key == "block_roskomsvoboda"
).first()
if state is None:
state = TerraformState()
state.key = "block_roskomsvoboda"
@ -80,8 +92,14 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
latest_metadata = json.loads(state.state)
# This endpoint routinely has an expired certificate, and it's more useful that we are consuming the
# data than that we are verifying the certificate.
latest_rev = requests.get("https://dumps.rublacklist.net/fetch/latest", timeout=30, verify=False).text.strip() # nosec: B501
logging.debug("Latest revision is %s, already got %s", latest_rev, latest_metadata["dump_rev"])
latest_rev = requests.get(
"https://dumps.rublacklist.net/fetch/latest", timeout=30, verify=False
).text.strip() # nosec: B501
logging.debug(
"Latest revision is %s, already got %s",
latest_rev,
latest_metadata["dump_rev"],
)
if latest_rev != latest_metadata["dump_rev"]:
state.state = json.dumps({"dump_rev": latest_rev})
db.session.commit()
@ -94,18 +112,24 @@ class BlockRoskomsvobodaAutomation(BlockMirrorAutomation):
logging.debug("No new data to parse")
return
try:
for _event, element in lxml.etree.iterparse(BytesIO(self._data),
resolve_entities=False):
for _event, element in lxml.etree.iterparse(
BytesIO(self._data), resolve_entities=False
):
if element.tag == "domain":
self.patterns["roskomsvoboda"].append("https://" + element.text.strip())
self.patterns["roskomsvoboda"].append(
"https://" + element.text.strip()
)
except XMLSyntaxError:
activity = Activity(
activity_type="automation",
text=(f"[{self.short_name}] 🚨 Unable to parse XML file from dump. There was an error "
text=(
f"[{self.short_name}] 🚨 Unable to parse XML file from dump. There was an error "
"related to the format of the XML file within the zip file. Interestingly we were able to "
"extract the file from the zip file fine. "
"The automation task has not been disabled and will attempt to download the next dump when the "
"latest dump revision is incremented at the server."))
"latest dump revision is incremented at the server."
),
)
activity.notify()
db.session.add(activity)
db.session.commit()

View file

@ -16,21 +16,33 @@ BridgeResourceRow = Row[Tuple[AbstractResource, BridgeConf, CloudAccount]]
def active_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]:
stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where(
stmt = (
select(Bridge, BridgeConf, CloudAccount)
.join_from(Bridge, BridgeConf)
.join_from(Bridge, CloudAccount)
.where(
CloudAccount.provider == provider,
Bridge.destroyed.is_(None),
)
)
bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all()
return bridges
def recently_destroyed_bridges_by_provider(provider: CloudProvider) -> Sequence[BridgeResourceRow]:
def recently_destroyed_bridges_by_provider(
provider: CloudProvider,
) -> Sequence[BridgeResourceRow]:
cutoff = datetime.now(tz=timezone.utc) - timedelta(hours=72)
stmt = select(Bridge, BridgeConf, CloudAccount).join_from(Bridge, BridgeConf).join_from(Bridge, CloudAccount).where(
stmt = (
select(Bridge, BridgeConf, CloudAccount)
.join_from(Bridge, BridgeConf)
.join_from(Bridge, CloudAccount)
.where(
CloudAccount.provider == provider,
Bridge.destroyed.is_not(None),
Bridge.destroyed >= cutoff,
)
)
bridges: Sequence[BridgeResourceRow] = db.session.execute(stmt).all()
return bridges
@ -60,35 +72,38 @@ class BridgeAutomation(TerraformAutomation):
self.template,
active_resources=active_bridges_by_provider(self.provider),
destroyed_resources=recently_destroyed_bridges_by_provider(self.provider),
global_namespace=app.config['GLOBAL_NAMESPACE'],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'),
global_namespace=app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""",
**{
k: app.config[k.upper()]
for k in self.template_parameters
}
**{k: app.config[k.upper()] for k in self.template_parameters},
)
def tf_posthook(self, *, prehook_result: Any = None) -> None:
outputs = self.tf_output()
for output in outputs:
if output.startswith('bridge_hashed_fingerprint_'):
parts = outputs[output]['value'].split(" ")
if output.startswith("bridge_hashed_fingerprint_"):
parts = outputs[output]["value"].split(" ")
if len(parts) < 2:
continue
bridge = Bridge.query.filter(Bridge.id == output[len('bridge_hashed_fingerprint_'):]).first()
bridge = Bridge.query.filter(
Bridge.id == output[len("bridge_hashed_fingerprint_") :]
).first()
bridge.nickname = parts[0]
bridge.hashed_fingerprint = parts[1]
bridge.terraform_updated = datetime.now(tz=timezone.utc)
if output.startswith('bridge_bridgeline_'):
parts = outputs[output]['value'].split(" ")
if output.startswith("bridge_bridgeline_"):
parts = outputs[output]["value"].split(" ")
if len(parts) < 4:
continue
bridge = Bridge.query.filter(Bridge.id == output[len('bridge_bridgeline_'):]).first()
bridge = Bridge.query.filter(
Bridge.id == output[len("bridge_bridgeline_") :]
).first()
del parts[3]
bridge.bridgeline = " ".join(parts)
bridge.terraform_updated = datetime.now(tz=timezone.utc)

View file

@ -7,10 +7,7 @@ class BridgeGandiAutomation(BridgeAutomation):
description = "Deploy Tor bridges on GandiCloud VPS"
provider = CloudProvider.GANDI
template_parameters = [
"ssh_public_key_path",
"ssh_private_key_path"
]
template_parameters = ["ssh_public_key_path", "ssh_private_key_path"]
template = """
terraform {

View file

@ -7,10 +7,7 @@ class BridgeHcloudAutomation(BridgeAutomation):
description = "Deploy Tor bridges on Hetzner Cloud"
provider = CloudProvider.HCLOUD
template_parameters = [
"ssh_private_key_path",
"ssh_public_key_path"
]
template_parameters = ["ssh_private_key_path", "ssh_public_key_path"]
template = """
terraform {

View file

@ -25,10 +25,17 @@ def active_bridges_in_account(account: CloudAccount) -> List[Bridge]:
return bridges
def create_bridges_in_account(bridgeconf: BridgeConf, account: CloudAccount, count: int) -> int:
def create_bridges_in_account(
bridgeconf: BridgeConf, account: CloudAccount, count: int
) -> int:
created = 0
while created < count and len(active_bridges_in_account(account)) < account.max_instances:
logging.debug("Creating bridge for configuration %s in account %s", bridgeconf.id, account)
while (
created < count
and len(active_bridges_in_account(account)) < account.max_instances
):
logging.debug(
"Creating bridge for configuration %s in account %s", bridgeconf.id, account
)
bridge = Bridge()
bridge.pool_id = bridgeconf.pool.id
bridge.conf_id = bridgeconf.id
@ -45,7 +52,9 @@ def create_bridges_by_cost(bridgeconf: BridgeConf, count: int) -> int:
"""
Creates bridge resources for the given bridge configuration using the cheapest available provider.
"""
logging.debug("Creating %s bridges by cost for configuration %s", count, bridgeconf.id)
logging.debug(
"Creating %s bridges by cost for configuration %s", count, bridgeconf.id
)
created = 0
for provider in BRIDGE_PROVIDERS:
if created >= count:
@ -78,7 +87,9 @@ def create_bridges_by_random(bridgeconf: BridgeConf, count: int) -> int:
"""
Creates bridge resources for the given bridge configuration using random providers.
"""
logging.debug("Creating %s bridges by random for configuration %s", count, bridgeconf.id)
logging.debug(
"Creating %s bridges by random for configuration %s", count, bridgeconf.id
)
created = 0
while candidate_accounts := _accounts_with_room():
# Not security-critical random number generation
@ -97,16 +108,24 @@ def create_bridges(bridgeconf: BridgeConf, count: int) -> int:
return create_bridges_by_random(bridgeconf, count)
def deprecate_bridges(bridgeconf: BridgeConf, count: int, reason: str = "redundant") -> int:
logging.debug("Deprecating %s bridges (%s) for configuration %s", count, reason, bridgeconf.id)
def deprecate_bridges(
bridgeconf: BridgeConf, count: int, reason: str = "redundant"
) -> int:
logging.debug(
"Deprecating %s bridges (%s) for configuration %s", count, reason, bridgeconf.id
)
deprecated = 0
active_conf_bridges = iter(Bridge.query.filter(
active_conf_bridges = iter(
Bridge.query.filter(
Bridge.conf_id == bridgeconf.id,
Bridge.deprecated.is_(None),
Bridge.destroyed.is_(None),
).all())
).all()
)
while deprecated < count:
logging.debug("Deprecating bridge %s for configuration %s", deprecated + 1, bridgeconf.id)
logging.debug(
"Deprecating bridge %s for configuration %s", deprecated + 1, bridgeconf.id
)
bridge = next(active_conf_bridges)
logging.debug("Bridge %r", bridge)
bridge.deprecate(reason=reason)
@ -129,7 +148,9 @@ class BridgeMetaAutomation(BaseAutomation):
for bridge in deprecated_bridges:
if bridge.deprecated is None:
continue # Possible due to SQLAlchemy lazy loading
cutoff = datetime.now(tz=timezone.utc) - timedelta(hours=bridge.conf.expiry_hours)
cutoff = datetime.now(tz=timezone.utc) - timedelta(
hours=bridge.conf.expiry_hours
)
if bridge.deprecated < cutoff:
logging.debug("Destroying expired bridge")
bridge.destroy()
@ -146,7 +167,9 @@ class BridgeMetaAutomation(BaseAutomation):
activate_bridgeconfs = BridgeConf.query.filter(
BridgeConf.destroyed.is_(None),
).all()
logging.debug("Found %s active bridge configurations", len(activate_bridgeconfs))
logging.debug(
"Found %s active bridge configurations", len(activate_bridgeconfs)
)
for bridgeconf in activate_bridgeconfs:
active_conf_bridges = Bridge.query.filter(
Bridge.conf_id == bridgeconf.id,
@ -157,16 +180,18 @@ class BridgeMetaAutomation(BaseAutomation):
Bridge.conf_id == bridgeconf.id,
Bridge.destroyed.is_(None),
).all()
logging.debug("Generating new bridges for %s (active: %s, total: %s, target: %s, max: %s)",
logging.debug(
"Generating new bridges for %s (active: %s, total: %s, target: %s, max: %s)",
bridgeconf.id,
len(active_conf_bridges),
len(total_conf_bridges),
bridgeconf.target_number,
bridgeconf.max_number
bridgeconf.max_number,
)
missing = min(
bridgeconf.target_number - len(active_conf_bridges),
bridgeconf.max_number - len(total_conf_bridges))
bridgeconf.max_number - len(total_conf_bridges),
)
if missing > 0:
create_bridges(bridgeconf, missing)
elif missing < 0:

View file

@ -7,10 +7,7 @@ class BridgeOvhAutomation(BridgeAutomation):
description = "Deploy Tor bridges on OVH Public Cloud"
provider = CloudProvider.OVH
template_parameters = [
"ssh_public_key_path",
"ssh_private_key_path"
]
template_parameters = ["ssh_public_key_path", "ssh_private_key_path"]
template = """
terraform {

View file

@ -11,14 +11,12 @@ from app.terraform.eotk import eotk_configuration
from app.terraform.terraform import TerraformAutomation
def update_eotk_instance(group_id: int,
region: str,
instance_id: str) -> None:
def update_eotk_instance(group_id: int, region: str, instance_id: str) -> None:
instance = Eotk.query.filter(
Eotk.group_id == group_id,
Eotk.region == region,
Eotk.provider == "aws",
Eotk.destroyed.is_(None)
Eotk.destroyed.is_(None),
).first()
if instance is None:
instance = Eotk()
@ -35,10 +33,7 @@ class EotkAWSAutomation(TerraformAutomation):
short_name = "eotk_aws"
description = "Deploy EOTK instances to AWS"
template_parameters = [
"aws_access_key",
"aws_secret_key"
]
template_parameters = ["aws_access_key", "aws_secret_key"]
template = """
terraform {
@ -81,32 +76,41 @@ class EotkAWSAutomation(TerraformAutomation):
self.tf_write(
self.template,
groups=Group.query.filter(
Group.eotk.is_(True),
Group.destroyed.is_(None)
Group.eotk.is_(True), Group.destroyed.is_(None)
).all(),
global_namespace=app.config['GLOBAL_NAMESPACE'],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'),
global_namespace=app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""",
**{
k: app.config[k.upper()]
for k in self.template_parameters
}
**{k: app.config[k.upper()] for k in self.template_parameters},
)
for group in Group.query.filter(
Group.eotk.is_(True),
Group.destroyed.is_(None)
).order_by(Group.id).all():
with DeterministicZip(os.path.join(self.working_dir, f"{group.id}.zip")) as dzip:
dzip.add_file("sites.conf", eotk_configuration(group).encode('utf-8'))
for group in (
Group.query.filter(Group.eotk.is_(True), Group.destroyed.is_(None))
.order_by(Group.id)
.all()
):
with DeterministicZip(
os.path.join(self.working_dir, f"{group.id}.zip")
) as dzip:
dzip.add_file("sites.conf", eotk_configuration(group).encode("utf-8"))
for onion in sorted(group.onions, key=lambda o: o.onion_name):
dzip.add_file(f"{onion.onion_name}.v3pub.key", onion.onion_public_key)
dzip.add_file(f"{onion.onion_name}.v3sec.key", onion.onion_private_key)
dzip.add_file(f"{onion.onion_name[:20]}-v3.cert", onion.tls_public_key)
dzip.add_file(f"{onion.onion_name[:20]}-v3.pem", onion.tls_private_key)
dzip.add_file(
f"{onion.onion_name}.v3pub.key", onion.onion_public_key
)
dzip.add_file(
f"{onion.onion_name}.v3sec.key", onion.onion_private_key
)
dzip.add_file(
f"{onion.onion_name[:20]}-v3.cert", onion.tls_public_key
)
dzip.add_file(
f"{onion.onion_name[:20]}-v3.pem", onion.tls_private_key
)
def tf_posthook(self, *, prehook_result: Any = None) -> None:
for e in Eotk.query.all():
@ -115,9 +119,9 @@ class EotkAWSAutomation(TerraformAutomation):
for output in outputs:
if output.startswith("eotk_instances_"):
try:
group_id = int(output[len("eotk_instance_") + 1:])
for az in outputs[output]['value']:
update_eotk_instance(group_id, az, outputs[output]['value'][az])
group_id = int(output[len("eotk_instance_") + 1 :])
for az in outputs[output]["value"]:
update_eotk_instance(group_id, az, outputs[output]["value"][az])
except ValueError:
pass
db.session.commit()

View file

@ -55,26 +55,36 @@ class ListAutomation(TerraformAutomation):
MirrorList.destroyed.is_(None),
MirrorList.provider == self.provider,
).all(),
global_namespace=app.config['GLOBAL_NAMESPACE'],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'),
global_namespace=app.config["GLOBAL_NAMESPACE"],
terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""",
**{
k: app.config[k.upper()]
for k in self.template_parameters
}
**{k: app.config[k.upper()] for k in self.template_parameters},
)
for pool in Pool.query.filter(Pool.destroyed.is_(None)).all():
for key, formatter in lists.items():
formatted_pool = formatter(pool)
for obfuscate in [True, False]:
with open(os.path.join(
self.working_dir, f"{key}.{pool.pool_name}{'.jsno' if obfuscate else '.json'}"),
'w', encoding="utf-8") as out:
with open(
os.path.join(
self.working_dir,
f"{key}.{pool.pool_name}{'.jsno' if obfuscate else '.json'}",
),
"w",
encoding="utf-8",
) as out:
out.write(json_encode(formatted_pool, obfuscate))
with open(os.path.join(self.working_dir, f"{key}.{pool.pool_name}{'.jso' if obfuscate else '.js'}"),
'w', encoding="utf-8") as out:
with open(
os.path.join(
self.working_dir,
f"{key}.{pool.pool_name}{'.jso' if obfuscate else '.js'}",
),
"w",
encoding="utf-8",
) as out:
out.write(javascript_encode(formatted_pool, obfuscate))

View file

@ -11,9 +11,7 @@ class ListGithubAutomation(ListAutomation):
# TODO: file an issue in the github about this, GitLab had a similar issue but fixed it
parallelism = 1
template_parameters = [
"github_api_key"
]
template_parameters = ["github_api_key"]
template = """
terraform {

View file

@ -15,7 +15,7 @@ class ListGitlabAutomation(ListAutomation):
"gitlab_token",
"gitlab_author_email",
"gitlab_author_name",
"gitlab_commit_message"
"gitlab_commit_message",
]
template = """
@ -56,5 +56,5 @@ class ListGitlabAutomation(ListAutomation):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if 'GITLAB_URL' in current_app.config:
if "GITLAB_URL" in current_app.config:
self.template_parameters.append("gitlab_url")

View file

@ -6,10 +6,7 @@ class ListS3Automation(ListAutomation):
description = "Update mirror lists in AWS S3 buckets"
provider = "s3"
template_parameters = [
"aws_access_key",
"aws_secret_key"
]
template_parameters = ["aws_access_key", "aws_secret_key"]
template = """
terraform {

View file

@ -15,15 +15,14 @@ from app.models.mirrors import Origin, Proxy, SmartProxy
from app.terraform.terraform import TerraformAutomation
def update_smart_proxy_instance(group_id: int,
provider: str,
region: str,
instance_id: str) -> None:
def update_smart_proxy_instance(
group_id: int, provider: str, region: str, instance_id: str
) -> None:
instance = SmartProxy.query.filter(
SmartProxy.group_id == group_id,
SmartProxy.region == region,
SmartProxy.provider == provider,
SmartProxy.destroyed.is_(None)
SmartProxy.destroyed.is_(None),
).first()
if instance is None:
instance = SmartProxy()
@ -93,16 +92,21 @@ class ProxyAutomation(TerraformAutomation):
self.template,
groups=groups,
proxies=Proxy.query.filter(
Proxy.provider == self.provider, Proxy.destroyed.is_(None)).all(),
Proxy.provider == self.provider, Proxy.destroyed.is_(None)
).all(),
subgroups=self.get_subgroups(),
global_namespace=app.config['GLOBAL_NAMESPACE'], bypass_token=app.config['BYPASS_TOKEN'],
terraform_modules_path=os.path.join(*list(os.path.split(app.root_path))[:-1], 'terraform-modules'),
global_namespace=app.config["GLOBAL_NAMESPACE"],
bypass_token=app.config["BYPASS_TOKEN"],
terraform_modules_path=os.path.join(
*list(os.path.split(app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{
lock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""",
**{k: app.config[k.upper()] for k in self.template_parameters})
**{k: app.config[k.upper()] for k in self.template_parameters},
)
if self.smart_proxies:
for group in groups:
self.sp_config(group)
@ -111,9 +115,11 @@ class ProxyAutomation(TerraformAutomation):
group_origins: List[Origin] = Origin.query.filter(
Origin.group_id == group.id,
Origin.destroyed.is_(None),
Origin.smart.is_(True)
Origin.smart.is_(True),
).all()
self.tmpl_write(f"smart_proxy.{group.id}.conf", """
self.tmpl_write(
f"smart_proxy.{group.id}.conf",
"""
{% for origin in origins %}
server {
listen 443 ssl;
@ -175,21 +181,26 @@ class ProxyAutomation(TerraformAutomation):
""",
provider=self.provider,
origins=group_origins,
smart_zone=app.config['SMART_ZONE'])
smart_zone=app.config["SMART_ZONE"],
)
@classmethod
def get_subgroups(cls) -> Dict[int, Dict[int, int]]:
conn = db.engine.connect()
stmt = text("""
stmt = text(
"""
SELECT origin.group_id, proxy.psg, COUNT(proxy.id) FROM proxy, origin
WHERE proxy.origin_id = origin.id
AND proxy.destroyed IS NULL
AND proxy.provider = :provider
GROUP BY origin.group_id, proxy.psg;
""")
"""
)
stmt = stmt.bindparams(provider=cls.provider)
result = conn.execute(stmt).all()
subgroups: Dict[int, Dict[int, int]] = defaultdict(lambda: defaultdict(lambda: 0))
subgroups: Dict[int, Dict[int, int]] = defaultdict(
lambda: defaultdict(lambda: 0)
)
for row in result:
subgroups[row[0]][row[1]] = row[2]
return subgroups

View file

@ -21,7 +21,7 @@ class ProxyAzureCdnAutomation(ProxyAutomation):
"azure_client_secret",
"azure_subscription_id",
"azure_tenant_id",
"smart_zone"
"smart_zone",
]
template = """
@ -162,8 +162,7 @@ class ProxyAzureCdnAutomation(ProxyAutomation):
def import_state(self, state: Optional[Any]) -> None:
proxies = Proxy.query.filter(
Proxy.provider == self.provider,
Proxy.destroyed.is_(None)
Proxy.provider == self.provider, Proxy.destroyed.is_(None)
).all()
for proxy in proxies:
proxy.url = f"https://{proxy.slug}.azureedge.net"

View file

@ -17,7 +17,7 @@ class ProxyCloudfrontAutomation(ProxyAutomation):
"admin_email",
"aws_access_key",
"aws_secret_key",
"smart_zone"
"smart_zone",
]
template = """
@ -111,26 +111,35 @@ class ProxyCloudfrontAutomation(ProxyAutomation):
def import_state(self, state: Any) -> None:
if not isinstance(state, dict):
raise RuntimeError("The Terraform state object returned was not a dict.")
if "child_modules" not in state['values']['root_module']:
if "child_modules" not in state["values"]["root_module"]:
# There are no CloudFront proxies deployed to import state for
return
# CloudFront distributions (proxies)
for mod in state['values']['root_module']['child_modules']:
if mod['address'].startswith('module.cloudfront_'):
for res in mod['resources']:
if res['address'].endswith('aws_cloudfront_distribution.this'):
proxy = Proxy.query.filter(Proxy.id == mod['address'][len('module.cloudfront_'):]).first()
proxy.url = "https://" + res['values']['domain_name']
proxy.slug = res['values']['id']
for mod in state["values"]["root_module"]["child_modules"]:
if mod["address"].startswith("module.cloudfront_"):
for res in mod["resources"]:
if res["address"].endswith("aws_cloudfront_distribution.this"):
proxy = Proxy.query.filter(
Proxy.id == mod["address"][len("module.cloudfront_") :]
).first()
proxy.url = "https://" + res["values"]["domain_name"]
proxy.slug = res["values"]["id"]
proxy.terraform_updated = datetime.now(tz=timezone.utc)
break
# EC2 instances (smart proxies)
for g in state["values"]["root_module"]["child_modules"]:
if g["address"].startswith("module.smart_proxy_"):
group_id = int(g["address"][len("module.smart_proxy_"):])
group_id = int(g["address"][len("module.smart_proxy_") :])
for s in g["child_modules"]:
if s["address"].endswith(".module.instance"):
for x in s["resources"]:
if x["address"].endswith(".module.instance.aws_instance.default[0]"):
update_smart_proxy_instance(group_id, self.provider, "us-east-2a", x['values']['id'])
if x["address"].endswith(
".module.instance.aws_instance.default[0]"
):
update_smart_proxy_instance(
group_id,
self.provider,
"us-east-2a",
x["values"]["id"],
)
db.session.commit()

View file

@ -14,11 +14,7 @@ class ProxyFastlyAutomation(ProxyAutomation):
subgroup_members_max = 20
cloud_name = "fastly"
template_parameters = [
"aws_access_key",
"aws_secret_key",
"fastly_api_key"
]
template_parameters = ["aws_access_key", "aws_secret_key", "fastly_api_key"]
template = """
terraform {
@ -125,13 +121,14 @@ class ProxyFastlyAutomation(ProxyAutomation):
Constructor method.
"""
# Requires Flask application context to read configuration
self.subgroup_members_max = min(current_app.config.get("FASTLY_MAX_BACKENDS", 5), 20)
self.subgroup_members_max = min(
current_app.config.get("FASTLY_MAX_BACKENDS", 5), 20
)
super().__init__(*args, **kwargs)
def import_state(self, state: Optional[Any]) -> None:
proxies = Proxy.query.filter(
Proxy.provider == self.provider,
Proxy.destroyed.is_(None)
Proxy.provider == self.provider, Proxy.destroyed.is_(None)
).all()
for proxy in proxies:
proxy.url = f"https://{proxy.slug}.global.ssl.fastly.net"

View file

@ -18,12 +18,16 @@ from app.terraform.proxy.azure_cdn import ProxyAzureCdnAutomation
from app.terraform.proxy.cloudfront import ProxyCloudfrontAutomation
from app.terraform.proxy.fastly import ProxyFastlyAutomation
PROXY_PROVIDERS: Dict[str, Type[ProxyAutomation]] = {p.provider: p for p in [ # type: ignore[attr-defined]
PROXY_PROVIDERS: Dict[str, Type[ProxyAutomation]] = {
p.provider: p # type: ignore[attr-defined]
for p in [
# In order of preference
ProxyCloudfrontAutomation,
ProxyFastlyAutomation,
ProxyAzureCdnAutomation
] if p.enabled} # type: ignore[attr-defined]
ProxyAzureCdnAutomation,
]
if p.enabled # type: ignore[attr-defined]
}
SubgroupCount = OrderedDictT[str, OrderedDictT[int, OrderedDictT[int, int]]]
@ -61,8 +65,9 @@ def random_slug(origin_domain_name: str) -> str:
"exampasdfghjkl"
"""
# The random slug doesn't need to be cryptographically secure, hence the use of `# nosec`
return tldextract.extract(origin_domain_name).domain[:5] + ''.join(
random.choices(string.ascii_lowercase, k=12)) # nosec
return tldextract.extract(origin_domain_name).domain[:5] + "".join(
random.choices(string.ascii_lowercase, k=12) # nosec: B311
)
def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupCount:
@ -95,8 +100,13 @@ def calculate_subgroup_count(proxies: Optional[List[Proxy]] = None) -> SubgroupC
return subgroup_count
def next_subgroup(subgroup_count: SubgroupCount, provider: str, group_id: int, max_subgroup_count: int,
max_subgroup_members: int) -> Optional[int]:
def next_subgroup(
subgroup_count: SubgroupCount,
provider: str,
group_id: int,
max_subgroup_count: int,
max_subgroup_members: int,
) -> Optional[int]:
"""
Find the first available subgroup with less than the specified maximum count in the specified provider and group.
If the last subgroup in the group is full, return the next subgroup number as long as it doesn't exceed
@ -137,27 +147,36 @@ def auto_deprecate_proxies() -> None:
- The "max_age_reached" reason means the proxy has been in use for longer than the maximum allowed period.
The maximum age cutoff is randomly set to a time between 24 and 48 hours.
"""
origin_destroyed_proxies = (db.session.query(Proxy)
origin_destroyed_proxies = (
db.session.query(Proxy)
.join(Origin, Proxy.origin_id == Origin.id)
.filter(Proxy.destroyed.is_(None),
.filter(
Proxy.destroyed.is_(None),
Proxy.deprecated.is_(None),
Origin.destroyed.is_not(None))
.all())
Origin.destroyed.is_not(None),
)
.all()
)
logging.debug("Origin destroyed: %s", origin_destroyed_proxies)
for proxy in origin_destroyed_proxies:
proxy.deprecate(reason="origin_destroyed")
max_age_proxies = (db.session.query(Proxy)
max_age_proxies = (
db.session.query(Proxy)
.join(Origin, Proxy.origin_id == Origin.id)
.filter(Proxy.destroyed.is_(None),
.filter(
Proxy.destroyed.is_(None),
Proxy.deprecated.is_(None),
Proxy.pool_id != -1, # do not rotate hotspare proxies
Origin.assets,
Origin.auto_rotation)
.all())
Origin.auto_rotation,
)
.all()
)
logging.debug("Max age: %s", max_age_proxies)
for proxy in max_age_proxies:
max_age_cutoff = datetime.now(timezone.utc) - timedelta(
days=1, seconds=86400 * random.random()) # nosec: B311
days=1, seconds=86400 * random.random() # nosec: B311
)
if proxy.added < max_age_cutoff:
proxy.deprecate(reason="max_age_reached")
@ -171,8 +190,7 @@ def destroy_expired_proxies() -> None:
"""
expiry_cutoff = datetime.now(timezone.utc) - timedelta(days=4)
proxies = Proxy.query.filter(
Proxy.destroyed.is_(None),
Proxy.deprecated < expiry_cutoff
Proxy.destroyed.is_(None), Proxy.deprecated < expiry_cutoff
).all()
for proxy in proxies:
logging.debug("Destroying expired proxy")
@ -244,12 +262,17 @@ class ProxyMetaAutomation(BaseAutomation):
if origin.destroyed is not None:
continue
proxies = [
x for x in origin.proxies
if x.pool_id == pool.id and x.deprecated is None and x.destroyed is None
x
for x in origin.proxies
if x.pool_id == pool.id
and x.deprecated is None
and x.destroyed is None
]
logging.debug("Proxies for group %s: %s", group.group_name, proxies)
if not proxies:
logging.debug("Creating new proxy for %s in pool %s", origin, pool)
logging.debug(
"Creating new proxy for %s in pool %s", origin, pool
)
if not promote_hot_spare_proxy(pool.id, origin):
# No "hot spare" available
self.create_proxy(pool.id, origin)
@ -270,8 +293,13 @@ class ProxyMetaAutomation(BaseAutomation):
"""
for provider in PROXY_PROVIDERS.values():
logging.debug("Looking at provider %s", provider.provider)
subgroup = next_subgroup(self.subgroup_count, provider.provider, origin.group_id,
provider.subgroup_members_max, provider.subgroup_count_max)
subgroup = next_subgroup(
self.subgroup_count,
provider.provider,
origin.group_id,
provider.subgroup_members_max,
provider.subgroup_count_max,
)
if subgroup is None:
continue # Exceeded maximum number of subgroups and last subgroup is full
self.increment_subgroup(provider.provider, origin.group_id, subgroup)
@ -317,9 +345,7 @@ class ProxyMetaAutomation(BaseAutomation):
If an origin is not destroyed and lacks active proxies (not deprecated and not destroyed),
a new 'hot spare' proxy for this origin is created in the reserve pool (with pool_id = -1).
"""
origins = Origin.query.filter(
Origin.destroyed.is_(None)
).all()
origins = Origin.query.filter(Origin.destroyed.is_(None)).all()
for origin in origins:
if origin.countries:
risk_levels = origin.risk_level.items()
@ -328,7 +354,10 @@ class ProxyMetaAutomation(BaseAutomation):
if highest_risk_level < 4:
for proxy in origin.proxies:
if proxy.destroyed is None and proxy.pool_id == -1:
logging.debug("Destroying hot spare proxy for origin %s (low risk)", origin)
logging.debug(
"Destroying hot spare proxy for origin %s (low risk)",
origin,
)
proxy.destroy()
continue
if origin.destroyed is not None:

View file

@ -15,21 +15,26 @@ from app.terraform.terraform import TerraformAutomation
def import_state(state: Any) -> None:
if not isinstance(state, dict):
raise RuntimeError("The Terraform state object returned was not a dict.")
if "values" not in state or "child_modules" not in state['values']['root_module']:
if "values" not in state or "child_modules" not in state["values"]["root_module"]:
# There are no CloudFront origins deployed to import state for
return
# CloudFront distributions (origins)
for mod in state['values']['root_module']['child_modules']:
if mod['address'].startswith('module.static_'):
static_id = mod['address'][len('module.static_'):]
for mod in state["values"]["root_module"]["child_modules"]:
if mod["address"].startswith("module.static_"):
static_id = mod["address"][len("module.static_") :]
logging.debug("Found static module in state: %s", static_id)
for res in mod['resources']:
if res['address'].endswith('aws_cloudfront_distribution.this'):
for res in mod["resources"]:
if res["address"].endswith("aws_cloudfront_distribution.this"):
logging.debug("and found related cloudfront distribution")
static = StaticOrigin.query.filter(StaticOrigin.id == static_id).first()
static.origin_domain_name = res['values']['domain_name']
logging.debug("and found static origin: %s to update with domain name: %s", static.id,
static.origin_domain_name)
static = StaticOrigin.query.filter(
StaticOrigin.id == static_id
).first()
static.origin_domain_name = res["values"]["domain_name"]
logging.debug(
"and found static origin: %s to update with domain name: %s",
static.id,
static.origin_domain_name,
)
static.terraform_updated = datetime.now(tz=timezone.utc)
break
db.session.commit()
@ -128,14 +133,18 @@ class StaticAWSAutomation(TerraformAutomation):
groups=groups,
storage_cloud_accounts=storage_cloud_accounts,
source_cloud_accounts=source_cloud_accounts,
global_namespace=current_app.config['GLOBAL_NAMESPACE'], bypass_token=current_app.config['BYPASS_TOKEN'],
terraform_modules_path=os.path.join(*list(os.path.split(current_app.root_path))[:-1], 'terraform-modules'),
global_namespace=current_app.config["GLOBAL_NAMESPACE"],
bypass_token=current_app.config["BYPASS_TOKEN"],
terraform_modules_path=os.path.join(
*list(os.path.split(current_app.root_path))[:-1], "terraform-modules"
),
backend_config=f"""backend "http" {{
lock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}"
unlock_address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}"
address = "{current_app.config['TFSTATE_BACKEND']}/{self.short_name}"
}}""",
**{k: current_app.config[k.upper()] for k in self.template_parameters})
**{k: current_app.config[k.upper()] for k in self.template_parameters},
)
def tf_posthook(self, *, prehook_result: Any = None) -> None:
import_state(self.tf_show())

View file

@ -27,7 +27,9 @@ class StaticMetaAutomation(BaseAutomation):
if static_origin.origin_domain_name is not None:
try:
# Check if an Origin with the same domain name already exists
origin = Origin.query.filter_by(domain_name=static_origin.origin_domain_name).one()
origin = Origin.query.filter_by(
domain_name=static_origin.origin_domain_name
).one()
# Keep auto rotation value in sync
origin.auto_rotation = static_origin.auto_rotate
except NoResultFound:
@ -42,10 +44,14 @@ class StaticMetaAutomation(BaseAutomation):
assets=False,
)
db.session.add(origin)
logging.debug(f"Created Origin with domain name {origin.domain_name}")
logging.debug(
f"Created Origin with domain name {origin.domain_name}"
)
# Step 2: Remove Origins for StaticOrigins with non-null destroy value
static_origins_with_destroyed = StaticOrigin.query.filter(StaticOrigin.destroyed.isnot(None)).all()
static_origins_with_destroyed = StaticOrigin.query.filter(
StaticOrigin.destroyed.isnot(None)
).all()
for static_origin in static_origins_with_destroyed:
try:
origin = Origin.query.filter_by(

View file

@ -51,14 +51,20 @@ class TerraformAutomation(BaseAutomation):
prehook_result = self.tf_prehook() # pylint: disable=assignment-from-no-return
self.tf_generate()
self.tf_init()
returncode, logs = self.tf_apply(self.working_dir, refresh=self.always_refresh or full)
returncode, logs = self.tf_apply(
self.working_dir, refresh=self.always_refresh or full
)
self.tf_posthook(prehook_result=prehook_result)
return returncode == 0, logs
def tf_apply(self, working_dir: str, *,
def tf_apply(
self,
working_dir: str,
*,
refresh: bool = True,
parallelism: Optional[int] = None,
lock_timeout: int = 15) -> Tuple[int, str]:
lock_timeout: int = 15,
) -> Tuple[int, str]:
if not parallelism:
parallelism = self.parallelism
if not self.working_dir:
@ -67,17 +73,19 @@ class TerraformAutomation(BaseAutomation):
# the argument list as an array such that argument injection would be
# ineffective.
tfcmd = subprocess.run( # nosec
['terraform',
'apply',
'-auto-approve',
'-json',
f'-refresh={str(refresh).lower()}',
f'-parallelism={str(parallelism)}',
f'-lock-timeout={str(lock_timeout)}m',
[
"terraform",
"apply",
"-auto-approve",
"-json",
f"-refresh={str(refresh).lower()}",
f"-parallelism={str(parallelism)}",
f"-lock-timeout={str(lock_timeout)}m",
],
cwd=working_dir,
stdout=subprocess.PIPE)
return tfcmd.returncode, tfcmd.stdout.decode('utf-8')
stdout=subprocess.PIPE,
)
return tfcmd.returncode, tfcmd.stdout.decode("utf-8")
@abstractmethod
def tf_generate(self) -> None:
@ -91,41 +99,49 @@ class TerraformAutomation(BaseAutomation):
# the argument list as an array such that argument injection would be
# ineffective.
subprocess.run( # nosec
['terraform',
'init',
f'-lock-timeout={str(lock_timeout)}m',
[
"terraform",
"init",
f"-lock-timeout={str(lock_timeout)}m",
],
cwd=self.working_dir)
cwd=self.working_dir,
)
def tf_output(self) -> Any:
if not self.working_dir:
raise RuntimeError("No working directory specified.")
# The following subprocess call does not take any user input.
tfcmd = subprocess.run( # nosec
['terraform', 'output', '-json'],
["terraform", "output", "-json"],
cwd=self.working_dir,
stdout=subprocess.PIPE)
stdout=subprocess.PIPE,
)
return json.loads(tfcmd.stdout)
def tf_plan(self, *,
def tf_plan(
self,
*,
refresh: bool = True,
parallelism: Optional[int] = None,
lock_timeout: int = 15) -> Tuple[int, str]:
lock_timeout: int = 15,
) -> Tuple[int, str]:
if not self.working_dir:
raise RuntimeError("No working directory specified.")
# The following subprocess call takes external input, but is providing
# the argument list as an array such that argument injection would be
# ineffective.
tfcmd = subprocess.run( # nosec
['terraform',
'plan',
'-json',
f'-refresh={str(refresh).lower()}',
f'-parallelism={str(parallelism)}',
f'-lock-timeout={str(lock_timeout)}m',
[
"terraform",
"plan",
"-json",
f"-refresh={str(refresh).lower()}",
f"-parallelism={str(parallelism)}",
f"-lock-timeout={str(lock_timeout)}m",
],
cwd=self.working_dir)
return tfcmd.returncode, tfcmd.stdout.decode('utf-8')
cwd=self.working_dir,
)
return tfcmd.returncode, tfcmd.stdout.decode("utf-8")
def tf_posthook(self, *, prehook_result: Any = None) -> None:
"""
@ -154,9 +170,8 @@ class TerraformAutomation(BaseAutomation):
raise RuntimeError("No working directory specified.")
# This subprocess call doesn't take any user input.
terraform = subprocess.run( # nosec
['terraform', 'show', '-json'],
cwd=self.working_dir,
stdout=subprocess.PIPE)
["terraform", "show", "-json"], cwd=self.working_dir, stdout=subprocess.PIPE
)
return json.loads(terraform.stdout)
def tf_write(self, template: str, **kwargs: Any) -> None:

View file

@ -9,7 +9,7 @@ from app.models.tfstate import TerraformState
tfstate = Blueprint("tfstate", __name__)
@tfstate.route("/<key>", methods=['GET'])
@tfstate.route("/<key>", methods=["GET"])
def handle_get(key: str) -> ResponseReturnValue:
state = TerraformState.query.filter(TerraformState.key == key).first()
if state is None or state.state is None:
@ -17,16 +17,18 @@ def handle_get(key: str) -> ResponseReturnValue:
return Response(state.state, content_type="application/json")
@tfstate.route("/<key>", methods=['POST', 'DELETE', 'UNLOCK'])
@tfstate.route("/<key>", methods=["POST", "DELETE", "UNLOCK"])
def handle_update(key: str) -> ResponseReturnValue:
state = TerraformState.query.filter(TerraformState.key == key).first()
if not state:
if request.method in ["DELETE", "UNLOCK"]:
return "OK", 200
state = TerraformState(key=key)
if state.lock and not (request.method == "UNLOCK" and request.args.get('ID') is None):
if state.lock and not (
request.method == "UNLOCK" and request.args.get("ID") is None
):
# force-unlock seems to not give an ID to verify so accept no ID being present
if json.loads(state.lock)['ID'] != request.args.get('ID'):
if json.loads(state.lock)["ID"] != request.args.get("ID"):
return Response(state.lock, status=409, content_type="application/json")
if request.method == "POST":
state.state = json.dumps(request.json)
@ -38,9 +40,11 @@ def handle_update(key: str) -> ResponseReturnValue:
return "OK", 200
@tfstate.route("/<key>", methods=['LOCK'])
@tfstate.route("/<key>", methods=["LOCK"])
def handle_lock(key: str) -> ResponseReturnValue:
state = TerraformState.query.filter(TerraformState.key == key).with_for_update().first()
state = (
TerraformState.query.filter(TerraformState.key == key).with_for_update().first()
)
if state is None:
state = TerraformState(key=key, state="")
db.session.add(state)

View file

@ -20,8 +20,9 @@ def onion_hostname(onion_public_key: bytes) -> str:
return onion.lower()
def decode_onion_keys(onion_private_key_base64: str, onion_public_key_base64: str) -> Tuple[
Optional[bytes], Optional[bytes], List[Dict[str, str]]]:
def decode_onion_keys(
onion_private_key_base64: str, onion_public_key_base64: str
) -> Tuple[Optional[bytes], Optional[bytes], List[Dict[str, str]]]:
try:
onion_private_key = base64.b64decode(onion_private_key_base64)
onion_public_key = base64.b64decode(onion_public_key_base64)

View file

@ -22,14 +22,20 @@ def load_certificates_from_pem(pem_data: bytes) -> list[x509.Certificate]:
return certificates
def build_certificate_chain(certificates: list[x509.Certificate]) -> list[x509.Certificate]:
def build_certificate_chain(
certificates: list[x509.Certificate],
) -> list[x509.Certificate]:
if len(certificates) == 1:
return certificates
chain = []
cert_map = {cert.subject.rfc4514_string(): cert for cert in certificates}
end_entity = next(
(cert for cert in certificates if cert.subject.rfc4514_string() not in cert_map),
None
(
cert
for cert in certificates
if cert.subject.rfc4514_string() not in cert_map
),
None,
)
if not end_entity:
raise ValueError("Cannot identify the end-entity certificate.")
@ -51,7 +57,9 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
for i in range(len(chain) - 1):
next_public_key = chain[i + 1].public_key()
if not (isinstance(next_public_key, RSAPublicKey)):
raise ValueError(f"Certificate using unsupported algorithm: {type(next_public_key)}")
raise ValueError(
f"Certificate using unsupported algorithm: {type(next_public_key)}"
)
hash_algorithm = chain[i].signature_hash_algorithm
if hash_algorithm is None:
raise ValueError("Certificate missing hash algorithm")
@ -59,7 +67,7 @@ def validate_certificate_chain(chain: list[x509.Certificate]) -> bool:
chain[i].signature,
chain[i].tbs_certificate_bytes,
PKCS1v15(),
hash_algorithm
hash_algorithm,
)
end_cert = chain[-1]
@ -75,7 +83,7 @@ def validate_tls_keys(
tls_certificate_pem: Optional[str],
skip_chain_verification: Optional[bool],
skip_name_verification: Optional[bool],
hostname: str
hostname: str,
) -> Tuple[Optional[List[x509.Certificate]], List[str], List[Dict[str, str]]]:
errors = []
san_list = []
@ -90,31 +98,55 @@ def validate_tls_keys(
private_key = serialization.load_pem_private_key(
tls_private_key_pem.encode("utf-8"),
password=None,
backend=default_backend()
backend=default_backend(),
)
if not isinstance(private_key, rsa.RSAPrivateKey):
errors.append({"Error": "tls_private_key_invalid", "Message": "Private key must be RSA."})
errors.append(
{
"Error": "tls_private_key_invalid",
"Message": "Private key must be RSA.",
}
)
if tls_certificate_pem:
certificates = list(load_certificates_from_pem(tls_certificate_pem.encode("utf-8")))
certificates = list(
load_certificates_from_pem(tls_certificate_pem.encode("utf-8"))
)
if not certificates:
errors.append({"Error": "tls_certificate_invalid", "Message": "No valid certificate found."})
errors.append(
{
"Error": "tls_certificate_invalid",
"Message": "No valid certificate found.",
}
)
else:
chain = build_certificate_chain(certificates)
end_entity_cert = chain[0]
if end_entity_cert.not_valid_after_utc < datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_expired", "Message": "TLS public key is expired."})
errors.append(
{
"Error": "tls_public_key_expired",
"Message": "TLS public key is expired.",
}
)
if end_entity_cert.not_valid_before_utc > datetime.now(timezone.utc):
errors.append({"Error": "tls_public_key_future", "Message": "TLS public key is not yet valid."})
errors.append(
{
"Error": "tls_public_key_future",
"Message": "TLS public key is not yet valid.",
}
)
if private_key:
public_key = end_entity_cert.public_key()
if TYPE_CHECKING:
assert isinstance(public_key, rsa.RSAPublicKey) # nosec: B101
assert isinstance(private_key, rsa.RSAPrivateKey) # nosec: B101
assert end_entity_cert.signature_hash_algorithm is not None # nosec: B101
assert (
end_entity_cert.signature_hash_algorithm is not None
) # nosec: B101
try:
test_message = b"test"
signature = private_key.sign(
@ -130,20 +162,30 @@ def validate_tls_keys(
)
except Exception:
errors.append(
{"Error": "tls_key_mismatch", "Message": "Private key does not match certificate."})
{
"Error": "tls_key_mismatch",
"Message": "Private key does not match certificate.",
}
)
if not skip_chain_verification:
try:
validate_certificate_chain(chain)
except ValueError as e:
errors.append({"Error": "certificate_chain_invalid", "Message": str(e)})
errors.append(
{"Error": "certificate_chain_invalid", "Message": str(e)}
)
if not skip_name_verification:
san_list = extract_sans(end_entity_cert)
for expected_hostname in [hostname, f"*.{hostname}"]:
if expected_hostname not in san_list:
errors.append(
{"Error": "hostname_not_in_san", "Message": f"{expected_hostname} not found in SANs."})
{
"Error": "hostname_not_in_san",
"Message": f"{expected_hostname} not found in SANs.",
}
)
except Exception as e:
errors.append({"Error": "tls_validation_error", "Message": str(e)})
@ -153,7 +195,9 @@ def validate_tls_keys(
def extract_sans(cert: x509.Certificate) -> List[str]:
try:
san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
san_extension = cert.extensions.get_extension_for_oid(
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
)
sans: List[str] = san_extension.value.get_values_for_type(x509.DNSName) # type: ignore[attr-defined]
return sans
except Exception:

View file

@ -1,2 +1,2 @@
[flake8]
ignore = E501,W503
ignore = E203,E501,W503