db: import and export with generalised json format

This commit is contained in:
Iain Learmonth 2022-11-13 20:09:48 +00:00
parent 25bbdf8e2b
commit 356495beaf

View file

@ -1,88 +1,139 @@
import csv import base64
import datetime import datetime
import json
import logging import logging
import sys import sys
from collections import defaultdict
from typing import List, Dict, Any
from sqlalchemy import inspect
from app import app from app import app
from app.cli import _SubparserType, BaseCliHandler from app.cli import _SubparserType, BaseCliHandler
from app.extensions import db from app.extensions import db
from app.models.base import Group, MirrorList from app.models.activity import Webhook, Activity
from app.models.automation import AutomationLogs, Automation, AutomationState
from app.models.base import Group, MirrorList, PoolGroup, Pool
from app.models.bridges import Bridge, BridgeConf from app.models.bridges import Bridge, BridgeConf
from app.models.mirrors import Origin, Proxy from app.models.mirrors import Origin, Proxy, SmartProxy
from app.models.alarms import Alarm, AlarmState from app.models.alarms import Alarm, AlarmState
from app.models.onions import Onion, Eotk
from app.models.tfstate import TerraformState
models = { models = {
"activity": Activity,
"alarm": Alarm,
"automation": Automation,
"automation_logs": AutomationLogs,
"bridge": Bridge, "bridge": Bridge,
"bridgeconf": BridgeConf, "bridgeconf": BridgeConf,
"alarm": Alarm, "eotk": Eotk,
"group": Group, "group": Group,
"list": MirrorList, "list": MirrorList,
"onion": Onion,
"origin": Origin, "origin": Origin,
"proxy": Proxy "pool": Pool,
"pool_group": PoolGroup,
"proxy": Proxy,
"smart_proxy": SmartProxy,
"terraform_state": TerraformState,
"webhook": Webhook
} }
def export(model: db.Model) -> None: # type: ignore[name-defined] class ExportEncoder(json.JSONEncoder):
out = csv.writer(sys.stdout) """Encoder to serialise all types used in the database."""
out.writerow(model.csv_header())
for row in model.query.all(): def default(self, obj: Any) -> Any:
out.writerow(row.csv_row()) if isinstance(obj, AlarmState):
return obj.name
elif isinstance(obj, AutomationState):
return obj.name
elif isinstance(obj, bytes):
return base64.encodebytes(obj).decode('utf-8')
elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
elif isinstance(obj, datetime.timedelta):
return (datetime.datetime.min + obj).time().isoformat()
return super().default(obj)
def impot(model: db.Model) -> None: # type: ignore[name-defined] def model_to_dict(model: db.Model) -> Dict[str, Any]: # type: ignore[name-defined]
first = True output = {}
header = model.csv_header() for column in inspect(type(model)).columns:
try: item = getattr(model, column.name)
for line in csv.reader(sys.stdin): output[f"{type(item).__name__}_{column.name}"] = item
if first: return output
if line != header:
logging.error("CSV header mismatch")
sys.exit(1) def db_export() -> None:
first = False encoder = ExportEncoder()
continue output = defaultdict(list)
new_entity = model() for model in models:
for idx, field_name in header: for row in models[model].query.all(): # type: ignore[attr-defined]
if field_name in ["added", "updated", "destroyed", "deprecated", "last_updated", "terraform_updated"]: output[model].append(model_to_dict(row))
# datetime fields print(encoder.encode(output))
if line[idx] == "":
line[idx] = None # type: ignore
else: decoder = {
line[idx] = datetime.datetime.strptime(line[idx], "%Y-%m-%d %H:%M:%S.%f") # type: ignore "AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x),
elif field_name in ["eotk", "auto_rotation", "smart"]: "AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x),
# boolean fields "bytes": lambda x: base64.decodebytes(x.encode('utf-8')),
line[idx] = line[idx] == "True" "datetime": lambda x: datetime.datetime.fromisoformat(x),
elif field_name.endswith("_id") and line[idx] == "": "int": lambda x: int(x),
# integer foreign keys "str": lambda x: x,
line[idx] = None # type: ignore # TODO: timedelta (not currently used but could be in the future)
elif field_name in ["alarm_state"]: }
# alarm states
line[idx] = getattr(AlarmState, line[idx][len("AlarmState."):])
setattr(new_entity, field_name, line[idx]) def db_import_model(model: str, data: List[Dict[str, Any]]) -> None:
db.session.add(new_entity) for row in data:
new = models[model]()
for col in row:
type_name, col_name = col.split("_", 1)
new.__setattr__(col_name, decoder.get(type_name, lambda x: x)(row[col])) # type: ignore[no-untyped-call]
db.session.add(new)
def db_import() -> None:
data = json.load(sys.stdin)
# import order matters due to foreign key constraints
for model in [
"group",
"pool",
"pool_group",
"smart_proxy",
"origin",
"proxy",
"onion",
"alarm",
"automation",
"automation_logs",
"bridgeconf",
"bridge",
"eotk",
"list",
"terraform_state",
"webhook"
]:
db_import_model(model, data[model])
db.session.commit() db.session.commit()
logging.info("Import completed successfully")
# Many things can go wrong in the above, like IO, format or database errors.
# We catch all the errors and ensure the database transaction is rolled back, and log it.
except Exception as exc: # pylint: disable=broad-except
logging.exception(exc)
db.session.rollback()
class DbCliHandler(BaseCliHandler): class DbCliHandler(BaseCliHandler):
@classmethod @classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None: def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("db", help="database operations") parser = subparsers.add_parser("db", help="database operations")
parser.add_argument("--export", choices=sorted(models.keys()), parser.add_argument("--export", help="export data to JSON format", action="store_true")
help="export data to CSV format") parser.add_argument("--import", help="import data from JSON format", action="store_true")
parser.add_argument("--import", choices=sorted(models.keys()),
help="import data from CSV format", dest="impot")
parser.set_defaults(cls=cls) parser.set_defaults(cls=cls)
def run(self) -> None: def run(self) -> None:
with app.app_context(): with app.app_context():
if self.args.export: if self.args.export:
export(models[self.args.export]) db_export()
elif self.args.impot: elif vars(self.args)["import"]:
impot(models[self.args.impot]) db_import()
else: else:
logging.error("No action requested") logging.error("No action requested")