From 356495beaf93b84953487fe7858d9be0960c2b0b Mon Sep 17 00:00:00 2001 From: Iain Learmonth Date: Sun, 13 Nov 2022 20:09:48 +0000 Subject: [PATCH] db: import and export with generalised json format --- app/cli/db.py | 159 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 105 insertions(+), 54 deletions(-) diff --git a/app/cli/db.py b/app/cli/db.py index eb44a2d..fbe9787 100644 --- a/app/cli/db.py +++ b/app/cli/db.py @@ -1,88 +1,139 @@ -import csv +import base64 import datetime +import json import logging import sys +from collections import defaultdict +from typing import List, Dict, Any + +from sqlalchemy import inspect from app import app from app.cli import _SubparserType, BaseCliHandler from app.extensions import db -from app.models.base import Group, MirrorList +from app.models.activity import Webhook, Activity +from app.models.automation import AutomationLogs, Automation, AutomationState +from app.models.base import Group, MirrorList, PoolGroup, Pool from app.models.bridges import Bridge, BridgeConf -from app.models.mirrors import Origin, Proxy +from app.models.mirrors import Origin, Proxy, SmartProxy from app.models.alarms import Alarm, AlarmState +from app.models.onions import Onion, Eotk +from app.models.tfstate import TerraformState models = { + "activity": Activity, + "alarm": Alarm, + "automation": Automation, + "automation_logs": AutomationLogs, "bridge": Bridge, "bridgeconf": BridgeConf, - "alarm": Alarm, + "eotk": Eotk, "group": Group, "list": MirrorList, + "onion": Onion, "origin": Origin, - "proxy": Proxy + "pool": Pool, + "pool_group": PoolGroup, + "proxy": Proxy, + "smart_proxy": SmartProxy, + "terraform_state": TerraformState, + "webhook": Webhook } -def export(model: db.Model) -> None: # type: ignore[name-defined] - out = csv.writer(sys.stdout) - out.writerow(model.csv_header()) - for row in model.query.all(): - out.writerow(row.csv_row()) +class ExportEncoder(json.JSONEncoder): + """Encoder to serialise all types used in the database.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, AlarmState): + return obj.name + elif isinstance(obj, AutomationState): + return obj.name + elif isinstance(obj, bytes): + return base64.encodebytes(obj).decode('utf-8') + elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): + return obj.isoformat() + elif isinstance(obj, datetime.timedelta): + return (datetime.datetime.min + obj).time().isoformat() + return super().default(obj) -def impot(model: db.Model) -> None: # type: ignore[name-defined] - first = True - header = model.csv_header() - try: - for line in csv.reader(sys.stdin): - if first: - if line != header: - logging.error("CSV header mismatch") - sys.exit(1) - first = False - continue - new_entity = model() - for idx, field_name in header: - if field_name in ["added", "updated", "destroyed", "deprecated", "last_updated", "terraform_updated"]: - # datetime fields - if line[idx] == "": - line[idx] = None # type: ignore - else: - line[idx] = datetime.datetime.strptime(line[idx], "%Y-%m-%d %H:%M:%S.%f") # type: ignore - elif field_name in ["eotk", "auto_rotation", "smart"]: - # boolean fields - line[idx] = line[idx] == "True" - elif field_name.endswith("_id") and line[idx] == "": - # integer foreign keys - line[idx] = None # type: ignore - elif field_name in ["alarm_state"]: - # alarm states - line[idx] = getattr(AlarmState, line[idx][len("AlarmState."):]) - setattr(new_entity, field_name, line[idx]) - db.session.add(new_entity) - db.session.commit() - logging.info("Import completed successfully") - # Many things can go wrong in the above, like IO, format or database errors. - # We catch all the errors and ensure the database transaction is rolled back, and log it. - except Exception as exc: # pylint: disable=broad-except - logging.exception(exc) - db.session.rollback() +def model_to_dict(model: db.Model) -> Dict[str, Any]: # type: ignore[name-defined] + output = {} + for column in inspect(type(model)).columns: + item = getattr(model, column.name) + output[f"{type(item).__name__}_{column.name}"] = item + return output + + +def db_export() -> None: + encoder = ExportEncoder() + output = defaultdict(list) + for model in models: + for row in models[model].query.all(): # type: ignore[attr-defined] + output[model].append(model_to_dict(row)) + print(encoder.encode(output)) + + +decoder = { + "AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x), + "AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x), + "bytes": lambda x: base64.decodebytes(x.encode('utf-8')), + "datetime": lambda x: datetime.datetime.fromisoformat(x), + "int": lambda x: int(x), + "str": lambda x: x, + # TODO: timedelta (not currently used but could be in the future) +} + + +def db_import_model(model: str, data: List[Dict[str, Any]]) -> None: + for row in data: + new = models[model]() + for col in row: + type_name, col_name = col.split("_", 1) + new.__setattr__(col_name, decoder.get(type_name, lambda x: x)(row[col])) # type: ignore[no-untyped-call] + db.session.add(new) + + +def db_import() -> None: + data = json.load(sys.stdin) + # import order matters due to foreign key constraints + for model in [ + "group", + "pool", + "pool_group", + "smart_proxy", + "origin", + "proxy", + "onion", + "alarm", + "automation", + "automation_logs", + "bridgeconf", + "bridge", + "eotk", + "list", + "terraform_state", + "webhook" + ]: + db_import_model(model, data[model]) + + db.session.commit() class DbCliHandler(BaseCliHandler): @classmethod def add_subparser_to(cls, subparsers: _SubparserType) -> None: parser = subparsers.add_parser("db", help="database operations") - parser.add_argument("--export", choices=sorted(models.keys()), - help="export data to CSV format") - parser.add_argument("--import", choices=sorted(models.keys()), - help="import data from CSV format", dest="impot") + parser.add_argument("--export", help="export data to JSON format", action="store_true") + parser.add_argument("--import", help="import data from JSON format", action="store_true") parser.set_defaults(cls=cls) def run(self) -> None: with app.app_context(): if self.args.export: - export(models[self.args.export]) - elif self.args.impot: - impot(models[self.args.impot]) + db_export() + elif vars(self.args)["import"]: + db_import() else: logging.error("No action requested")