majuna/app/cli/db.py

124 lines
3.7 KiB
Python

import base64
import datetime
import json
import logging
import sys
from collections import defaultdict
from typing import Any, Callable, Dict, List, Type
from sqlalchemy import inspect
from app import app
from app.cli import BaseCliHandler, _SubparserType
from app.extensions import db
from app.models.activity import Activity, Webhook
from app.models.alarms import Alarm, AlarmState
from app.models.automation import Automation, AutomationLogs, AutomationState
from app.models.base import Group, MirrorList, Pool, PoolGroup
from app.models.bridges import Bridge, BridgeConf
from app.models.mirrors import Origin, Proxy, SmartProxy
from app.models.onions import Eotk, Onion
from app.models.tfstate import TerraformState
Model = Type[db.Model] # type: ignore[name-defined]
# order matters due to foreign key constraints
models: List[Model] = [
Group,
Activity,
Pool,
PoolGroup,
SmartProxy,
Origin,
Proxy,
Onion,
Alarm,
Automation,
AutomationLogs,
BridgeConf,
Bridge,
Eotk,
MirrorList,
TerraformState,
Webhook
]
class ExportEncoder(json.JSONEncoder):
"""Encoder to serialise all types used in the database."""
def default(self, o: Any) -> Any:
if isinstance(o, AlarmState):
return o.name
if isinstance(o, AutomationState):
return o.name
if isinstance(o, bytes):
return base64.encodebytes(o).decode('utf-8')
if isinstance(o, (datetime.datetime, datetime.date, datetime.time)):
return o.isoformat()
return super().default(o)
def model_to_dict(model: Model) -> Dict[str, Any]:
output = {}
inspection = inspect(type(model))
if not inspection:
raise RuntimeError(f"Could not inspect model {model}")
for column in inspection.columns:
item = getattr(model, column.name)
output[f"{type(item).__name__}_{column.name}"] = item
return output
def db_export() -> None:
encoder = ExportEncoder()
output = defaultdict(list)
for model in models:
for row in model.query.all():
output[model.__name__].append(model_to_dict(row))
print(encoder.encode(output))
decoder: Dict[str, Callable[[Any], Any]] = {
"AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x),
"AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x),
"bytes": lambda x: base64.decodebytes(x.encode('utf-8')),
"datetime": datetime.datetime.fromisoformat,
"int": int,
"str": lambda x: x,
}
def db_import_model(model: Model, data: List[Dict[str, Any]]) -> None:
for row in data:
new = model()
for col in row:
type_name, col_name = col.split("_", 1)
setattr(new, col_name, decoder.get(type_name, lambda x: x)(row[col]))
db.session.add(new)
def db_import() -> None:
data = json.load(sys.stdin)
# import order matters due to foreign key constraints
for model in models:
db_import_model(model, data.get(model.__name__, []))
db.session.commit()
class DbCliHandler(BaseCliHandler):
@classmethod
def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("db", help="database operations")
parser.add_argument("--export", help="export data to JSON format", action="store_true")
parser.add_argument("--import", help="import data from JSON format", action="store_true")
parser.set_defaults(cls=cls)
def run(self) -> None:
with app.app_context():
if self.args.export:
db_export()
elif vars(self.args)["import"]:
db_import()
else:
logging.error("No action requested")