Format, lint, type

This commit is contained in:
Abel Luck 2022-12-01 14:20:37 +00:00
parent a1ae717c8f
commit c925079e8b
8 changed files with 159 additions and 91 deletions

View file

@ -1,18 +1,14 @@
import re
import attr
import logging
from typing import Any, Tuple
from jinja2 import TemplateNotFound
import re
from typing import Any, List, Tuple
from mautrix.types import (EventType, RoomID, StateEvent, Membership, MessageType, JSON,
TextMessageEventContent, Format, ReactionEventContent, RelationType)
import attr
from jinja2 import TemplateNotFound
from mautrix.types import Format, MessageType, TextMessageEventContent
from mautrix.util.formatter import parse_html
from ..util.template import TemplateManager, TemplateUtil
from .types import EventParse, OTHER_ENUMS, Action
from ..common import COLOR_ALARM, COLOR_OK, COLOR_UNKNOWN
from .types import OTHER_ENUMS, Action, EventParse # type: ignore
spaces = re.compile(" +")
space = " "
@ -21,22 +17,23 @@ space = " "
messages = TemplateManager("gitlab", "messages")
templates = TemplateManager("gitlab", "mixins")
async def parse_event(x_gitlab_event: str, payload: Any) -> Tuple[str, str]:
async def parse_event(x_gitlab_event: str, payload: Any) -> List[Tuple[str, str]]:
evt = EventParse[x_gitlab_event].deserialize(payload)
print("processing", evt)
try:
tpl = messages[evt.template_name]
except TemplateNotFound as e:
except TemplateNotFound:
msg = f"Received unhandled gitlab event type {x_gitlab_event}"
logging.info(msg)
logging.info(payload)
return [(msg, msg)]
logging.error(msg)
logging.debug(payload)
return []
aborted = False
def abort() -> None:
nonlocal aborted
aborted = True
base_args = {
**{field.key: field for field in Action if field.key.isupper()},
**OTHER_ENUMS,
@ -45,14 +42,13 @@ async def parse_event(x_gitlab_event: str, payload: Any) -> Tuple[str, str]:
msgs = []
for subevt in evt.preprocess():
print("preprocessing", subevt)
args = {
**attr.asdict(subevt, recurse=False),
**{key: getattr(subevt, key) for key in subevt.event_properties},
"abort": abort,
**base_args,
**base_args, # type: ignore
}
args["templates"] = templates.proxy(args)
args["templates"] = templates.proxy(args) # type: ignore
html = tpl.render(**args)
if not html or aborted:
@ -60,8 +56,12 @@ async def parse_event(x_gitlab_event: str, payload: Any) -> Tuple[str, str]:
continue
html = spaces.sub(space, html.strip())
content = TextMessageEventContent(msgtype=MessageType.TEXT, format=Format.HTML,
formatted_body=html, body=await parse_html(html))
content = TextMessageEventContent(
msgtype=MessageType.TEXT,
format=Format.HTML,
formatted_body=html,
body=await parse_html(html),
)
content["xyz.maubot.gitlab.webhook"] = {
"event_type": x_gitlab_event,
**subevt.meta,

View file

@ -1,3 +1,4 @@
# type: ignore
# gitlab - A GitLab client and webhook receiver for maubot
# Copyright (C) 2019 Lorenz Steinert
# Copyright (C) 2021 Tulir Asokan
@ -14,22 +15,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Union, Dict, Optional, Type, NewType, ClassVar, Tuple, Iterable
from datetime import datetime
from typing import ClassVar, Dict, Iterable, List, NewType, Optional, Tuple, Type, Union
from jinja2 import TemplateNotFound
from attr import dataclass
from yarl import URL
import attr
from mautrix.types import JSON, ExtensibleEnum, SerializableAttrs, serializer, deserializer
from attr import dataclass
from jinja2 import TemplateNotFound
from mautrix.types import (
JSON,
ExtensibleEnum,
SerializableAttrs,
deserializer,
serializer,
)
from yarl import URL
from ..util.contrast import contrast, hex_to_rgb
@serializer(datetime)
def datetime_serializer(dt: datetime) -> JSON:
return dt.strftime('%Y-%m-%dT%H:%M:%S%z')
return dt.strftime("%Y-%m-%dT%H:%M:%S%z")
@deserializer(datetime)
@ -93,9 +99,12 @@ class GitlabLabel(SerializableAttrs):
@property
def foreground_color(self) -> str:
return (self.white_hex
if contrast(hex_to_rgb(self.color), self.white_rgb) >= self.contrast_threshold
else self.black_hex)
return (
self.white_hex
if contrast(hex_to_rgb(self.color), self.white_rgb)
>= self.contrast_threshold
else self.black_hex
)
@dataclass
@ -133,7 +142,7 @@ class GitlabUser(SerializableAttrs):
def __hash__(self) -> int:
return self.id
def __eq__(self, other: 'GitlabUser') -> bool:
def __eq__(self, other: "GitlabUser") -> bool:
if not isinstance(other, GitlabUser):
return False
return self.id == other.id
@ -221,7 +230,7 @@ class GitlabSource(SerializableAttrs):
http_url: Optional[str] = None
GitlabTarget = NewType('GitlabTarget', GitlabSource)
GitlabTarget = NewType("GitlabTarget", GitlabSource)
class GitlabChangeWrapper:
@ -600,7 +609,7 @@ class GitlabBuild(SerializableAttrs):
@dataclass
class GitlabEvent:
def preprocess(self) -> List['GitlabEvent']:
def preprocess(self) -> List["GitlabEvent"]:
return [self]
@property
@ -641,9 +650,14 @@ class GitlabPushEvent(SerializableAttrs, GitlabEvent):
@property
def user(self) -> GitlabUser:
return GitlabUser(id=self.user_id, name=self.user_name, email=self.user_email,
username=self.user_username, avatar_url=self.user_avatar,
web_url=f"{self.project.gitlab_base_url}/{self.user_username}")
return GitlabUser(
id=self.user_id,
name=self.user_name,
email=self.user_email,
username=self.user_username,
avatar_url=self.user_avatar,
web_url=f"{self.project.gitlab_base_url}/{self.user_username}",
)
@property
def template_name(self) -> str:
@ -651,8 +665,15 @@ class GitlabPushEvent(SerializableAttrs, GitlabEvent):
@property
def event_properties(self) -> Iterable[str]:
return ("user", "is_new_ref", "is_deleted_ref", "ref_name", "ref_type", "ref_url",
"diff_url")
return (
"user",
"is_new_ref",
"is_deleted_ref",
"ref_name",
"ref_type",
"ref_url",
"diff_url",
)
@property
def diff_url(self) -> str:
@ -695,7 +716,9 @@ class GitlabPushEvent(SerializableAttrs, GitlabEvent):
return f"push-{self.project_id}-{self.checkout_sha}-{self.ref_name}"
def split_updates(evt: Union['GitlabIssueEvent', 'GitlabMergeRequestEvent']) -> List[GitlabEvent]:
def split_updates(
evt: Union["GitlabIssueEvent", "GitlabMergeRequestEvent"]
) -> List[GitlabEvent]:
if not evt.changes:
return [evt]
output = []
@ -704,7 +727,9 @@ def split_updates(evt: Union['GitlabIssueEvent', 'GitlabMergeRequestEvent']) ->
for field in attr.fields(GitlabChanges):
value = getattr(evt.changes, field.name)
if value:
output.append(attr.evolve(evt, changes=GitlabChanges(**{field.name: value})))
output.append(
attr.evolve(evt, changes=GitlabChanges(**{field.name: value}))
)
return output
@ -719,7 +744,7 @@ class GitlabIssueEvent(SerializableAttrs, GitlabEvent):
labels: Optional[List[GitlabLabel]] = None
changes: Optional[GitlabChanges] = None
def preprocess(self) -> List['GitlabIssueEvent']:
def preprocess(self) -> List["GitlabIssueEvent"]:
users_to_mutate = [self.user]
if self.changes and self.changes.assignees:
users_to_mutate += self.changes.assignees.previous
@ -737,7 +762,7 @@ class GitlabIssueEvent(SerializableAttrs, GitlabEvent):
@property
def event_properties(self) -> Iterable[str]:
return "action",
return ("action",)
@property
def action(self) -> Action:
@ -757,7 +782,7 @@ class GitlabCommentEvent(SerializableAttrs, GitlabEvent):
issue: Optional[GitlabIssue] = None
snippet: Optional[GitlabSnippet] = None
def preprocess(self) -> List['GitlabCommentEvent']:
def preprocess(self) -> List["GitlabCommentEvent"]:
self.user.web_url = f"{self.project.gitlab_base_url}/{self.user.username}"
return [self]
@ -776,7 +801,7 @@ class GitlabMergeRequestEvent(SerializableAttrs, GitlabEvent):
labels: List[GitlabLabel]
changes: GitlabChanges
def preprocess(self) -> List['GitlabMergeRequestEvent']:
def preprocess(self) -> List["GitlabMergeRequestEvent"]:
users_to_mutate = [self.user]
if self.changes and self.changes.assignees:
users_to_mutate += self.changes.assignees.previous
@ -792,7 +817,7 @@ class GitlabMergeRequestEvent(SerializableAttrs, GitlabEvent):
@property
def event_properties(self) -> Iterable[str]:
return "action",
return ("action",)
@property
def action(self) -> Action:
@ -807,7 +832,7 @@ class GitlabWikiPageEvent(SerializableAttrs, GitlabEvent):
wiki: GitlabWiki
object_attributes: GitlabWikiPageAttributes
def preprocess(self) -> List['GitlabWikiPageEvent']:
def preprocess(self) -> List["GitlabWikiPageEvent"]:
self.user.web_url = f"{self.project.gitlab_base_url}/{self.user.username}"
return [self]
@ -862,7 +887,7 @@ class GitlabJobEvent(SerializableAttrs, GitlabEvent):
repository: GitlabRepository
runner: Optional[GitlabRunner]
def preprocess(self) -> List['GitlabJobEvent']:
def preprocess(self) -> List["GitlabJobEvent"]:
base_url = str(URL(self.repository.homepage).with_path(""))
self.user.web_url = f"{base_url}/{self.user.username}"
return [self]
@ -894,20 +919,22 @@ class GitlabJobEvent(SerializableAttrs, GitlabEvent):
@property
def event_properties(self) -> Iterable[str]:
return "build_url",
return ("build_url",)
@property
def build_url(self) -> str:
return f"{self.repository.homepage}/-/jobs/{self.build_id}"
GitlabEventType = Union[Type[GitlabPushEvent],
GitlabEventType = Union[
Type[GitlabPushEvent],
Type[GitlabIssueEvent],
Type[GitlabCommentEvent],
Type[GitlabMergeRequestEvent],
Type[GitlabWikiPageEvent],
Type[GitlabPipelineEvent],
Type[GitlabJobEvent]]
Type[GitlabJobEvent],
]
EventParse: Dict[str, GitlabEventType] = {
"Push Hook": GitlabPushEvent,
@ -919,7 +946,7 @@ EventParse: Dict[str, GitlabEventType] = {
"Merge Request Hook": GitlabMergeRequestEvent,
"Wiki Page Hook": GitlabWikiPageEvent,
"Pipeline Hook": GitlabPipelineEvent,
"Job Hook": GitlabJobEvent
"Job Hook": GitlabJobEvent,
}
OTHER_ENUMS = {

View file

@ -3,19 +3,20 @@ import json
import logging
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, cast
from dotenv import load_dotenv
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, status, Header
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Header, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import pydantic
from pydantic import BaseSettings
from ops_bot import aws, pagerduty
from ops_bot.matrix import MatrixClient, MatrixClientSettings
from ops_bot.gitlab import hook as gitlab_hook
from ops_bot.matrix import MatrixClient, MatrixClientSettings
load_dotenv()
class BotSettings(BaseSettings):
bearer_token: str
routing_keys: Dict[str, str]
@ -128,18 +129,18 @@ async def aws_sns_hook(
)
return {"message": msg_plain, "message_formatted": msg_formatted}
@app.post("/hook/gitlab/{routing_key}")
async def gitlab_webhook(
request: Request,
x_gitlab_token: str = Header(default=""),
x_gitlab_event: str = Header(default=""),
matrix_client: MatrixClient = Depends(get_matrix_service)
matrix_client: MatrixClient = Depends(get_matrix_service),
) -> Dict[str, str]:
bearer_token = request.app.state.bot_settings.bearer_token
if x_gitlab_token != bearer_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect X-Gitlab-Token"
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect X-Gitlab-Token"
)
room_id, payload = await receive_helper(request)
messages = await gitlab_hook.parse_event(x_gitlab_event, payload)

View file

@ -14,8 +14,8 @@ def hex_to_rgb(color: str) -> RGB:
step = 1 if len(color) == 3 else 2
try:
r = int(color[0:step], 16)
g = int(color[step:2 * step], 16)
b = int(color[2 * step:3 * step], 16)
g = int(color[step : 2 * step], 16)
b = int(color[2 * step : 3 * step], 16)
except ValueError as e:
raise ValueError("Invalid hex value") from e
return r / 255, g / 255, b / 255
@ -59,4 +59,4 @@ def _linearize(v: float) -> float:
if v <= 0.03928:
return v / 12.92
else:
return ((v + 0.055) / 1.055) ** 2.4
return float(((v + 0.055) / 1.055) ** 2.4)

View file

@ -1,8 +1,10 @@
# Copyright (c) 2022 Tulir Asokan
# # Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any
import commonmark
@ -11,12 +13,12 @@ class HtmlEscapingRenderer(commonmark.HtmlRenderer):
super().__init__()
self.allow_html = allow_html
def lit(self, s):
def lit(self, s: str) -> None:
if self.allow_html:
return super().lit(s)
return super().lit(s.replace("<", "&lt;").replace(">", "&gt;"))
def image(self, node, entering):
def image(self, node: Any, entering: Any) -> None:
prev = self.allow_html
self.allow_html = True
super().image(node, entering)
@ -29,8 +31,8 @@ no_html_renderer = HtmlEscapingRenderer()
def render(message: str, allow_html: bool = False) -> str:
parsed = md_parser.parse(message)
parsed = md_parser.parse(message) # type: ignore
if allow_html:
return yes_html_renderer.render(parsed)
return yes_html_renderer.render(parsed) # type: ignore
else:
return no_html_renderer.render(parsed)
return no_html_renderer.render(parsed) # type: ignore

View file

@ -13,13 +13,16 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any, Tuple, Callable, Iterable, List, Union
import os.path
from typing import Any, Callable, Dict, List, Tuple, Union
from jinja2 import Environment as JinjaEnvironment, Template, BaseLoader, TemplateNotFound, FileSystemLoader
from jinja2 import BaseLoader
from jinja2 import Environment as JinjaEnvironment
from jinja2 import Template, TemplateNotFound
from ops_bot.util import markdown
def sync_read_file(path: str) -> str:
with open(path) as file:
return file.read()
@ -28,6 +31,7 @@ def sync_read_file(path: str) -> str:
def sync_list_files(directory: str) -> list[str]:
return os.listdir(directory)
class TemplateUtil:
@staticmethod
def bold_scope(label: str) -> str:
@ -61,20 +65,29 @@ class TemplateUtil:
if minutes > 0:
parts.append(cls.pluralize(minutes, "minute"))
if seconds > 0 or len(parts) == 0:
parts.append(cls.pluralize(seconds + frac_seconds, "second"))
parts.append(cls.pluralize(int(seconds + frac_seconds), "second"))
if len(parts) == 1:
return parts[0]
return ", ".join(parts[:-1]) + f" and {parts[-1]}"
@staticmethod
def join_human_list(data: List[str], *, joiner: str = ", ", final_joiner: str = " and ",
mutate: Callable[[str], str] = lambda val: val) -> str:
def join_human_list(
data: List[str],
*,
joiner: str = ", ",
final_joiner: str = " and ",
mutate: Callable[[str], str] = lambda val: val,
) -> str:
if not data:
return ""
elif len(data) == 1:
return mutate(data[0])
return joiner.join(mutate(val) for val in data[:-1]) + final_joiner + mutate(data[-1])
return (
joiner.join(mutate(val) for val in data[:-1])
+ final_joiner
+ mutate(data[-1])
)
class TemplateProxy:
@ -101,7 +114,9 @@ class PluginTemplateLoader(BaseLoader):
self.directory = os.path.join("templates", base, directory)
self.macros = sync_read_file(os.path.join("templates", base, "macros.html"))
def get_source(self, environment: Any, name: str) -> Tuple[str, str, Callable[[], bool]]:
def get_source(
self, environment: Any, name: str
) -> Tuple[str, str, Callable[[], bool]]:
path = f"{os.path.join(self.directory, name)}.html"
try:
tpl = sync_read_file(path)
@ -109,21 +124,31 @@ class PluginTemplateLoader(BaseLoader):
raise TemplateNotFound(name)
return self.macros + tpl, name, lambda: True
def list_templates(self) -> Iterable[str]:
return [os.path.splitext(os.path.basename(path))[0]
def list_templates(self) -> List[str]:
return [
os.path.splitext(os.path.basename(path))[0]
for path in sync_list_files(self.directory)
if path.endswith(".html")]
if path.endswith(".html")
]
class TemplateManager:
_env: JinjaEnvironment
_loader: PluginTemplateLoader
def __init__(self, base: str, directory: str) -> None:
#self._loader = FileSystemLoader(os.path.join("templates/", base))
# self._loader = FileSystemLoader(os.path.join("templates/", base))
self._loader = PluginTemplateLoader(base, directory)
self._env = JinjaEnvironment(loader=self._loader, lstrip_blocks=True, trim_blocks=True,
extensions=["jinja2.ext.do"])
self._env.filters["markdown"] = lambda message: markdown.render(message, allow_html=True)
self._env = JinjaEnvironment( # nosec B701
loader=self._loader,
lstrip_blocks=True,
trim_blocks=True,
autoescape=False,
extensions=["jinja2.ext.do"],
)
self._env.filters["markdown"] = lambda message: markdown.render(
message, allow_html=True
)
def __getitem__(self, item: str) -> Template:
return self._env.get_template(item)

14
poetry.lock generated
View file

@ -811,6 +811,14 @@ category = "dev"
optional = false
python-versions = ">=3.7"
[[package]]
name = "types-commonmark"
version = "0.9.2"
description = "Typing stubs for commonmark"
category = "dev"
optional = false
python-versions = "*"
[[package]]
name = "types-markdown"
version = "3.4.2.1"
@ -885,7 +893,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "50bb2a7ce02730b129e8bcee3ffad0e1cc7c028ebaff2f9e3d07643907db4f16"
content-hash = "3e5e0fa3501dbbd6f79e37380b75e0f7bf0f8f3668f0ef9e463891bcb62216e2"
[metadata.files]
aiofiles = [
@ -1643,6 +1651,10 @@ tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
types-commonmark = [
{file = "types-commonmark-0.9.2.tar.gz", hash = "sha256:b894b67750c52fd5abc9a40a9ceb9da4652a391d75c1b480bba9cef90f19fc86"},
{file = "types_commonmark-0.9.2-py3-none-any.whl", hash = "sha256:56f20199a1f9a2924443211a0ef97f8b15a8a956a7f4e9186be6950bf38d6d02"},
]
types-markdown = [
{file = "types-Markdown-3.4.2.1.tar.gz", hash = "sha256:03c0904cf5886a7d8193e2f50bcf842afc89e0ab80f060f389f6c2635c65628f"},
{file = "types_Markdown-3.4.2.1-py3-none-any.whl", hash = "sha256:b2333f6f4b8f69af83de359e10a097e4a3f14bbd6d2484e1829d9b0ec56fa0cb"},

View file

@ -27,6 +27,7 @@ flake8-black = "^0.3.5"
types-Markdown = "^3.4.0"
types-termcolor = "^1.1.5"
pytest-asyncio = "^0.20.2"
types-commonmark = "^0.9.2"
[build-system]
requires = ["poetry-core>=1.0.0"]