matrix-ops-bot/ops_bot/main.py

237 lines
7.3 KiB
Python

import asyncio
import logging
import os
import sys
import time
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Dict, List, Optional, Protocol, Tuple, cast
import json_logging
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.security import (
HTTPAuthorizationCredentials,
HTTPBasic,
HTTPBasicCredentials,
HTTPBearer,
)
from prometheus_fastapi_instrumentator import Instrumentator
from ops_bot import alertmanager, aws, pagerduty
from ops_bot.config import BotSettings, RoutingKey, load_config
from ops_bot.gitlab import hook as gitlab_hook
from ops_bot.matrix import MatrixClient
from ops_bot.metrics import (
CONFIG_LOADED_TOTAL,
EVENT_TO_SEND_SECONDS,
MESSAGES_SENT_TOTAL,
MESSAGE_SEND_FAILURES_TOTAL,
WEBHOOK_EVENTS_TOTAL,
classify_payload_error,
classify_send_failure,
source_label,
)
async def get_matrix_service(request: Request) -> MatrixClient:
"""A helper to fetch the matrix client from the app state"""
return cast(MatrixClient, request.app.state.matrix_client)
async def matrix_main(matrix_client: MatrixClient) -> None:
"""Execs the matrix client asyncio task"""
workers = [asyncio.create_task(matrix_client.start())]
await asyncio.gather(*workers)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
config_fname = os.environ.get("BOT_CONFIG_FILE", "config.json")
try:
bot_settings = load_config(config_fname)
c = MatrixClient(
settings=bot_settings.matrix, join_rooms=bot_settings.get_rooms()
)
except Exception:
CONFIG_LOADED_TOTAL.labels(result="failure").inc()
raise
CONFIG_LOADED_TOTAL.labels(result="success").inc()
app.state.matrix_client = c
app.state.bot_settings = bot_settings
asyncio.create_task(matrix_main(c))
yield
await app.state.matrix_client.shutdown()
# start_http_server(9000)
app = FastAPI(lifespan=lifespan)
instrumentator = Instrumentator().instrument(app)
instrumentator.expose(app, endpoint="/metrics", include_in_schema=False)
bearer_security = HTTPBearer(auto_error=False)
basic_security = HTTPBasic(auto_error=False)
log = logging.getLogger("ops_bot")
log.addHandler(logging.StreamHandler(sys.stdout))
json_logging.init_fastapi(enable_json=True)
json_logging.init_request_instrument(app)
@app.get("/")
async def root() -> Dict[str, str]:
return {"message": "Hello World"}
async def bearer_token_authorizer(
route: RoutingKey,
request: Request,
basic_credentials: Optional[HTTPBasicCredentials],
bearer_credentials: Optional[HTTPAuthorizationCredentials],
) -> bool:
bearer_token: Optional[str] = route.secret_token
return (
bearer_credentials is not None
and bearer_credentials.credentials == bearer_token
)
async def nop_authorizer(
route: RoutingKey,
request: Request,
basic_credentials: Optional[HTTPBasicCredentials],
bearer_credentials: Optional[HTTPAuthorizationCredentials],
) -> bool:
return True
def get_route(bot_settings: BotSettings, path_key: str) -> Optional[RoutingKey]:
# find path_key in bot_settings.routing_keys
for route in bot_settings.routing_keys:
if route.path_key == path_key:
return route
return None
class Authorizer(Protocol):
async def __call__(
self,
route: RoutingKey,
request: Request,
basic_credentials: Optional[HTTPBasicCredentials],
bearer_credentials: Optional[HTTPAuthorizationCredentials],
) -> bool: ...
class ParseHandler(Protocol):
async def __call__(
self,
route: RoutingKey,
payload: Any,
request: Request,
) -> List[Tuple[str, str]]: ...
handlers: Dict[str, Tuple[Authorizer, ParseHandler]] = {
"gitlab": (gitlab_hook.authorize, gitlab_hook.parse_event),
"pagerduty": (bearer_token_authorizer, pagerduty.parse_pagerduty_event),
"aws-sns": (nop_authorizer, aws.parse_sns_event),
"alertmanager": (nop_authorizer, alertmanager.parse_alertmanager_event),
}
@app.post("/hook/{routing_key}")
async def webhook_handler(
request: Request,
routing_key: str,
basic_credentials: Optional[HTTPBasicCredentials] = Depends(basic_security),
bearer_credentials: Optional[HTTPAuthorizationCredentials] = Depends(
bearer_security
),
matrix_client: MatrixClient = Depends(get_matrix_service),
) -> Dict[str, str]:
request_start = time.perf_counter()
route = get_route(request.app.state.bot_settings, routing_key)
if not route:
logging.error(f"unknown routing key {routing_key}")
WEBHOOK_EVENTS_TOTAL.labels(source="unknown", result="unknown_route").inc()
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Unknown routing key"
)
source = source_label(route.hook_type)
handler: Optional[Tuple[Authorizer, ParseHandler]] = handlers.get(route.hook_type)
if not handler:
WEBHOOK_EVENTS_TOTAL.labels(source=source, result="handler_error").inc()
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Unknown hook type"
)
authorizer, parse_handler = handler
if not await authorizer(
route,
request=request,
bearer_credentials=bearer_credentials,
basic_credentials=basic_credentials,
):
WEBHOOK_EVENTS_TOTAL.labels(source=source, result="auth_failed").inc()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
)
try:
payload: Any = await request.json()
except Exception:
WEBHOOK_EVENTS_TOTAL.labels(source=source, result="invalid_payload").inc()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid payload"
)
try:
messages = await parse_handler(route, payload, request=request)
except Exception as exc:
WEBHOOK_EVENTS_TOTAL.labels(
source=source, result=classify_payload_error(exc)
).inc()
raise
first_send_attempt_observed = False
for msg_plain, msg_formatted in messages:
if not first_send_attempt_observed:
EVENT_TO_SEND_SECONDS.labels(source=source).observe(
time.perf_counter() - request_start
)
first_send_attempt_observed = True
try:
await matrix_client.room_send(
route.room_id,
msg_plain,
message_formatted=msg_formatted,
)
except Exception as exc:
MESSAGE_SEND_FAILURES_TOTAL.labels(
source=source, reason=classify_send_failure(exc)
).inc()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Failed to send message to Matrix",
)
MESSAGES_SENT_TOTAL.labels(source=source).inc()
WEBHOOK_EVENTS_TOTAL.labels(source=source, result="accepted").inc()
return {"status": "ok"}
def start_dev() -> None:
uvicorn.run("ops_bot.main:app", port=1112, host="127.0.0.1", reload=True)
def main() -> None:
host = os.environ.get("BOT_LISTEN_HOST", "127.0.0.1")
port = int(os.environ.get("BOT_LISTEN_PORT", "1111"))
uvicorn.run(app, port=port, host=host) # nosec B104
if __name__ == "__main__":
main()