cli/automate: fix up errors found with mypy

This commit is contained in:
Iain Learmonth 2022-05-16 09:24:37 +01:00
parent 55a0b19c8c
commit ccf0ce6a06
3 changed files with 55 additions and 22 deletions

View file

@ -3,6 +3,7 @@ import datetime
import json import json
import logging import logging
from traceback import TracebackException from traceback import TracebackException
from typing import Type, TYPE_CHECKING, Any
from app import app from app import app
from app.extensions import db from app.extensions import db
@ -28,6 +29,11 @@ from app.terraform.proxy.azure_cdn import ProxyAzureCdnAutomation
from app.terraform.proxy.cloudfront import ProxyCloudfrontAutomation from app.terraform.proxy.cloudfront import ProxyCloudfrontAutomation
if TYPE_CHECKING:
_SubparserType = argparse._SubParsersAction[argparse.ArgumentParser]
else:
_SubparserType = Any
jobs = { jobs = {
x.short_name: x x.short_name: x
for x in [ for x in [
@ -52,17 +58,18 @@ jobs = {
} }
def run_all(**kwargs): def run_all(**kwargs: bool) -> None:
for job in jobs.values(): for job in jobs.values():
run_job(job, **kwargs) run_job(job, **kwargs)
def run_job(job: BaseAutomation, *, force: bool = False, ignore_schedule: bool = False): def run_job(job_cls: Type[BaseAutomation], *,
automation = Automation.query.filter(Automation.short_name == job.short_name).first() force: bool = False, ignore_schedule: bool = False) -> None:
automation = Automation.query.filter(Automation.short_name == job_cls.short_name).first()
if automation is None: if automation is None:
automation = Automation() automation = Automation()
automation.short_name = job.short_name automation.short_name = job_cls.short_name
automation.description = job.description automation.description = job_cls.description
automation.enabled = False automation.enabled = False
automation.next_is_full = False automation.next_is_full = False
automation.added = datetime.datetime.utcnow() automation.added = datetime.datetime.utcnow()
@ -78,11 +85,11 @@ def run_job(job: BaseAutomation, *, force: bool = False, ignore_schedule: bool =
logging.warning("Not time to run this job yet") logging.warning("Not time to run this job yet")
return return
if not automation.enabled and not force: if not automation.enabled and not force:
logging.warning(f"job {job.short_name} is disabled and --force not specified") logging.warning(f"job {job_cls.short_name} is disabled and --force not specified")
return return
automation.state = AutomationState.RUNNING automation.state = AutomationState.RUNNING
db.session.commit() db.session.commit()
job = job() job: BaseAutomation = job_cls()
try: try:
success, logs = job.automate() success, logs = job.automate()
except Exception as e: except Exception as e:
@ -114,7 +121,7 @@ def run_job(job: BaseAutomation, *, force: bool = False, ignore_schedule: bool =
class AutomateCliHandler: class AutomateCliHandler:
@classmethod @classmethod
def add_subparser_to(cls, subparsers: argparse._SubParsersAction) -> None: def add_subparser_to(cls, subparsers: _SubparserType) -> None:
parser = subparsers.add_parser("automate", help="automation operations") parser = subparsers.add_parser("automate", help="automation operations")
parser.add_argument("-a", "--all", dest="all", help="run all automation jobs", action="store_true") parser.add_argument("-a", "--all", dest="all", help="run all automation jobs", action="store_true")
parser.add_argument("-j", "--job", dest="job", choices=sorted(jobs.keys()), parser.add_argument("-j", "--job", dest="job", choices=sorted(jobs.keys()),
@ -123,10 +130,10 @@ class AutomateCliHandler:
parser.add_argument("--ignore-schedule", help="run job even if it's not time yet", action="store_true") parser.add_argument("--ignore-schedule", help="run job even if it's not time yet", action="store_true")
parser.set_defaults(cls=cls) parser.set_defaults(cls=cls)
def __init__(self, args): def __init__(self, args: argparse.Namespace) -> None:
self.args = args self.args = args
def run(self): def run(self) -> None:
with app.app_context(): with app.app_context():
if self.args.job: if self.args.job:
run_job(jobs[self.args.job], force=self.args.force, ignore_schedule=self.args.ignore_schedule) run_job(jobs[self.args.job], force=self.args.force, ignore_schedule=self.args.ignore_schedule)

View file

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Union, List, Optional, Any
from app.extensions import db from app.extensions import db
@ -38,31 +39,44 @@ class AbstractResource(db.Model):
deprecation_reason = db.Column(db.String(), nullable=True) deprecation_reason = db.Column(db.String(), nullable=True)
destroyed = db.Column(db.DateTime(), nullable=True) destroyed = db.Column(db.DateTime(), nullable=True)
def __init__(self, **kwargs): def __init__(self, *,
super().__init__(**kwargs) id: Optional[int] = None,
added: Optional[datetime] = None,
updated: Optional[datetime] = None,
deprecated: Optional[datetime] = None,
deprecation_reason: Optional[str] = None,
destroyed: Optional[datetime] = None,
**kwargs: Any) -> None:
super().__init__(id=id,
added=added,
updated=updated,
deprecated=deprecated,
deprecation_reason=deprecation_reason,
destroyed=destroyed,
**kwargs)
if self.added is None: if self.added is None:
self.added = datetime.utcnow() self.added = datetime.utcnow()
if self.updated is None: if self.updated is None:
self.updated = datetime.utcnow() self.updated = datetime.utcnow()
def deprecate(self, *, reason: str): def deprecate(self, *, reason: str) -> None:
self.deprecated = datetime.utcnow() self.deprecated = datetime.utcnow()
self.deprecation_reason = reason self.deprecation_reason = reason
self.updated = datetime.utcnow() self.updated = datetime.utcnow()
def destroy(self): def destroy(self) -> None:
if self.deprecated is None: if self.deprecated is None:
self.deprecated = datetime.utcnow() self.deprecated = datetime.utcnow()
self.destroyed = datetime.utcnow() self.destroyed = datetime.utcnow()
self.updated = datetime.utcnow() self.updated = datetime.utcnow()
@classmethod @classmethod
def csv_header(cls): def csv_header(cls) -> List[str]:
return [ return [
"id", "added", "updated", "deprecated", "deprecation_reason", "destroyed" "id", "added", "updated", "deprecated", "deprecation_reason", "destroyed"
] ]
def csv_row(self): def csv_row(self) -> List[Union[datetime, bool, int, str]]:
return [ return [
getattr(self, x) for x in self.csv_header() getattr(self, x) for x in self.csv_header()
] ]

View file

@ -1,4 +1,5 @@
import datetime import datetime
from typing import Any, Optional
import requests import requests
@ -13,14 +14,25 @@ class Activity(db.Model):
text = db.Column(db.Text(), nullable=False) text = db.Column(db.Text(), nullable=False)
added = db.Column(db.DateTime(), nullable=False) added = db.Column(db.DateTime(), nullable=False)
def __init__(self, **kwargs): def __init__(self, *,
if type(kwargs["activity_type"]) != str or len(kwargs["activity_type"]) > 20 or kwargs["activity_type"] == "": id: Optional[int] = None,
group_id: Optional[int] = None,
activity_type: str,
text: str,
added: Optional[datetime.datetime] = None,
**kwargs: Any) -> None:
if type(activity_type) != str or len(activity_type) > 20 or activity_type == "":
raise TypeError("expected string for activity type between 1 and 20 characters") raise TypeError("expected string for activity type between 1 and 20 characters")
if type(kwargs["text"]) != str: if type(text) != str:
raise TypeError("expected string for text") raise TypeError("expected string for text")
if "added" not in kwargs: super().__init__(id=id,
kwargs["added"] = datetime.datetime.utcnow() group_id=group_id,
super().__init__(**kwargs) activity_type=activity_type,
text=text,
added=added,
**kwargs)
if self.added is None:
self.added = datetime.datetime.utcnow()
def notify(self) -> int: def notify(self) -> int:
count = 0 count = 0