majuna/app/cli/db.py

140 lines
4.2 KiB
Python
Raw Normal View History

import base64
2022-04-22 12:52:41 +01:00
import datetime
import json
2022-04-22 12:52:41 +01:00
import logging
import sys
from collections import defaultdict
from typing import List, Dict, Any
from sqlalchemy import inspect
2022-04-22 12:52:41 +01:00
from app import app
2022-06-17 13:21:35 +01:00
from app.cli import _SubparserType, BaseCliHandler
2022-04-22 12:52:41 +01:00
from app.extensions import db
from app.models.activity import Webhook, Activity
from app.models.automation import AutomationLogs, Automation, AutomationState
from app.models.base import Group, MirrorList, PoolGroup, Pool
2022-05-06 12:28:11 +01:00
from app.models.bridges import Bridge, BridgeConf
from app.models.mirrors import Origin, Proxy, SmartProxy
2022-04-22 14:45:47 +01:00
from app.models.alarms import Alarm, AlarmState
from app.models.onions import Onion, Eotk
from app.models.tfstate import TerraformState
2022-04-22 12:52:41 +01:00
models = {
"activity": Activity,
"alarm": Alarm,
"automation": Automation,
"automation_logs": AutomationLogs,
2022-05-06 12:28:11 +01:00
"bridge": Bridge,
"bridgeconf": BridgeConf,
"eotk": Eotk,
2022-04-22 12:52:41 +01:00
"group": Group,
2022-04-22 14:29:24 +01:00
"list": MirrorList,
"onion": Onion,
2022-04-22 12:52:41 +01:00
"origin": Origin,
"pool": Pool,
"pool_group": PoolGroup,
"proxy": Proxy,
"smart_proxy": SmartProxy,
"terraform_state": TerraformState,
"webhook": Webhook
}
class ExportEncoder(json.JSONEncoder):
"""Encoder to serialise all types used in the database."""
def default(self, obj: Any) -> Any:
if isinstance(obj, AlarmState):
return obj.name
elif isinstance(obj, AutomationState):
return obj.name
elif isinstance(obj, bytes):
return base64.encodebytes(obj).decode('utf-8')
elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
elif isinstance(obj, datetime.timedelta):
return (datetime.datetime.min + obj).time().isoformat()
return super().default(obj)
def model_to_dict(model: db.Model) -> Dict[str, Any]: # type: ignore[name-defined]
output = {}
for column in inspect(type(model)).columns:
item = getattr(model, column.name)
output[f"{type(item).__name__}_{column.name}"] = item
return output
def db_export() -> None:
encoder = ExportEncoder()
output = defaultdict(list)
for model in models:
for row in models[model].query.all(): # type: ignore[attr-defined]
output[model].append(model_to_dict(row))
print(encoder.encode(output))
decoder = {
"AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x),
"AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x),
"bytes": lambda x: base64.decodebytes(x.encode('utf-8')),
"datetime": lambda x: datetime.datetime.fromisoformat(x),
"int": lambda x: int(x),
"str": lambda x: x,
# TODO: timedelta (not currently used but could be in the future)
2022-04-22 12:52:41 +01:00
}
def db_import_model(model: str, data: List[Dict[str, Any]]) -> None:
for row in data:
new = models[model]()
for col in row:
type_name, col_name = col.split("_", 1)
new.__setattr__(col_name, decoder.get(type_name, lambda x: x)(row[col])) # type: ignore[no-untyped-call]
db.session.add(new)
def db_import() -> None:
data = json.load(sys.stdin)
# import order matters due to foreign key constraints
for model in [
"group",
"pool",
"pool_group",
"smart_proxy",
"origin",
"proxy",
"onion",
"alarm",
"automation",
"automation_logs",
"bridgeconf",
"bridge",
"eotk",
"list",
"terraform_state",
"webhook"
]:
db_import_model(model, data[model])
db.session.commit()
2022-04-22 12:52:41 +01:00
2022-06-17 13:21:35 +01:00
class DbCliHandler(BaseCliHandler):
2022-04-22 12:52:41 +01:00
@classmethod
2022-05-16 11:44:03 +01:00
def add_subparser_to(cls, subparsers: _SubparserType) -> None:
2022-04-22 12:52:41 +01:00
parser = subparsers.add_parser("db", help="database operations")
parser.add_argument("--export", help="export data to JSON format", action="store_true")
parser.add_argument("--import", help="import data from JSON format", action="store_true")
2022-04-22 12:52:41 +01:00
parser.set_defaults(cls=cls)
2022-05-16 11:44:03 +01:00
def run(self) -> None:
2022-04-22 12:52:41 +01:00
with app.app_context():
if self.args.export:
db_export()
elif vars(self.args)["import"]:
db_import()
2022-04-22 12:52:41 +01:00
else:
logging.error("No action requested")