122 lines
3.8 KiB
Python
122 lines
3.8 KiB
Python
import base64
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import sys
|
|
from collections import defaultdict
|
|
from typing import List, Dict, Any
|
|
|
|
from sqlalchemy import inspect
|
|
|
|
from app import app
|
|
from app.cli import _SubparserType, BaseCliHandler
|
|
from app.extensions import db
|
|
from app.models.activity import Webhook, Activity
|
|
from app.models.automation import AutomationLogs, Automation, AutomationState
|
|
from app.models.base import Group, MirrorList, PoolGroup, Pool
|
|
from app.models.bridges import Bridge, BridgeConf
|
|
from app.models.mirrors import Origin, Proxy, SmartProxy
|
|
from app.models.alarms import Alarm, AlarmState
|
|
from app.models.onions import Onion, Eotk
|
|
from app.models.tfstate import TerraformState
|
|
|
|
# order matters due to foreign key constraints
|
|
models = [
|
|
Group,
|
|
Activity,
|
|
Pool,
|
|
PoolGroup,
|
|
SmartProxy,
|
|
Origin,
|
|
Proxy,
|
|
Onion,
|
|
Alarm,
|
|
Automation,
|
|
AutomationLogs,
|
|
BridgeConf,
|
|
Bridge,
|
|
Eotk,
|
|
MirrorList,
|
|
TerraformState,
|
|
Webhook
|
|
]
|
|
|
|
|
|
class ExportEncoder(json.JSONEncoder):
|
|
"""Encoder to serialise all types used in the database."""
|
|
|
|
def default(self, obj: Any) -> Any:
|
|
if isinstance(obj, AlarmState):
|
|
return obj.name
|
|
elif isinstance(obj, AutomationState):
|
|
return obj.name
|
|
elif isinstance(obj, bytes):
|
|
return base64.encodebytes(obj).decode('utf-8')
|
|
elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
|
|
return obj.isoformat()
|
|
elif isinstance(obj, datetime.timedelta):
|
|
return (datetime.datetime.min + obj).time().isoformat()
|
|
return super().default(obj)
|
|
|
|
|
|
def model_to_dict(model: db.Model) -> Dict[str, Any]: # type: ignore[name-defined]
|
|
output = {}
|
|
for column in inspect(type(model)).columns:
|
|
item = getattr(model, column.name)
|
|
output[f"{type(item).__name__}_{column.name}"] = item
|
|
return output
|
|
|
|
|
|
def db_export() -> None:
|
|
encoder = ExportEncoder()
|
|
output = defaultdict(list)
|
|
for model in models:
|
|
for row in model.query.all(): # type: ignore[attr-defined]
|
|
output[model.__name__].append(model_to_dict(row))
|
|
print(encoder.encode(output))
|
|
|
|
|
|
decoder = {
|
|
"AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x),
|
|
"AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x),
|
|
"bytes": lambda x: base64.decodebytes(x.encode('utf-8')),
|
|
"datetime": lambda x: datetime.datetime.fromisoformat(x),
|
|
"int": lambda x: int(x),
|
|
"str": lambda x: x,
|
|
# TODO: timedelta (not currently used but could be in the future)
|
|
}
|
|
|
|
|
|
def db_import_model(model: db.Model, data: List[Dict[str, Any]]) -> None: # type: ignore[name-defined]
|
|
for row in data:
|
|
new = model()
|
|
for col in row:
|
|
type_name, col_name = col.split("_", 1)
|
|
new.__setattr__(col_name, decoder.get(type_name, lambda x: x)(row[col])) # type: ignore[no-untyped-call]
|
|
db.session.add(new)
|
|
|
|
|
|
def db_import() -> None:
|
|
data = json.load(sys.stdin)
|
|
# import order matters due to foreign key constraints
|
|
for model in models:
|
|
db_import_model(model, data.get(model.__name__, []))
|
|
db.session.commit()
|
|
|
|
|
|
class DbCliHandler(BaseCliHandler):
|
|
@classmethod
|
|
def add_subparser_to(cls, subparsers: _SubparserType) -> None:
|
|
parser = subparsers.add_parser("db", help="database operations")
|
|
parser.add_argument("--export", help="export data to JSON format", action="store_true")
|
|
parser.add_argument("--import", help="import data from JSON format", action="store_true")
|
|
parser.set_defaults(cls=cls)
|
|
|
|
def run(self) -> None:
|
|
with app.app_context():
|
|
if self.args.export:
|
|
db_export()
|
|
elif vars(self.args)["import"]:
|
|
db_import()
|
|
else:
|
|
logging.error("No action requested")
|