diff --git a/app/cli/db.py b/app/cli/db.py index c0f2661..3de3e8e 100644 --- a/app/cli/db.py +++ b/app/cli/db.py @@ -4,7 +4,7 @@ import json import logging import sys from collections import defaultdict -from typing import List, Dict, Any +from typing import Any, Callable, Dict, List, Type from sqlalchemy import inspect @@ -20,8 +20,10 @@ from app.models.alarms import Alarm, AlarmState from app.models.onions import Onion, Eotk from app.models.tfstate import TerraformState +Model = Type[db.Model] # type: ignore[name-defined] + # order matters due to foreign key constraints -models = [ +models: List[Model] = [ Group, Activity, Pool, @@ -57,9 +59,12 @@ class ExportEncoder(json.JSONEncoder): return super().default(o) -def model_to_dict(model: db.Model) -> Dict[str, Any]: # type: ignore[name-defined] +def model_to_dict(model: Model) -> Dict[str, Any]: output = {} - for column in inspect(type(model)).columns: + inspection = inspect(type(model)) + if not inspection: + raise RuntimeError(f"Could not inspect model {model}") + for column in inspection.columns: item = getattr(model, column.name) output[f"{type(item).__name__}_{column.name}"] = item return output @@ -69,12 +74,12 @@ def db_export() -> None: encoder = ExportEncoder() output = defaultdict(list) for model in models: - for row in model.query.all(): # type: ignore[attr-defined] + for row in model.query.all(): output[model.__name__].append(model_to_dict(row)) print(encoder.encode(output)) -decoder = { +decoder: Dict[str, Callable[[Any], Any]] = { "AlarmState": lambda x: AlarmState.__getattribute__(AlarmState, x), "AutomationState": lambda x: AutomationState.__getattribute__(AutomationState, x), "bytes": lambda x: base64.decodebytes(x.encode('utf-8')), @@ -84,12 +89,12 @@ decoder = { } -def db_import_model(model: db.Model, data: List[Dict[str, Any]]) -> None: # type: ignore[name-defined] +def db_import_model(model: Model, data: List[Dict[str, Any]]) -> None: for row in data: new = model() for col in row: type_name, col_name = col.split("_", 1) - setattr(new, col_name, decoder.get(type_name, lambda x: x)(row[col])) # type: ignore[no-untyped-call] + setattr(new, col_name, decoder.get(type_name, lambda x: x)(row[col])) db.session.add(new)