diff --git a/README.md b/README.md index 0c9a6ea..78d2f86 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,8 @@ Configuration values can be set using environment variables, or optionally loade - **interval** (`TAILSCALESD_INTERVAL`): The interval on which the Tailscale API is polled in seconds. Default is `60`. - **bearer_token** (`TAILSCALESD_BEARER_TOKEN`): The authentication token passed in the Authorization header (required). - **tailnet** (`TAILSCALESD_TAILNET`): The Tailscale tailnet identifier (required). -- **api_key** (`TAILSCALESD_API_KEY`): The Tailscale API key (required). +- **client_id** (`TAILSCALESD_CLIENT_ID`): The Tailscale oauth client id (required). +- **client_secret** (`TAILSCALESD_CLIENT_SECRET`): The Tailscale oauth client secret (required). #### Environment File @@ -83,7 +84,8 @@ You can also specify an environment file to load configuration values. The path ```env TAILSCALESD_TAILNET=my-tailnet -TAILSCALESD_API_KEY=my-api-key +TAILSCALESD_CLIENT_ID=xxxx +TAILSCALESD_CLIENT_SECRET=yyyyy TAILSCALESD_HOST=127.0.0.1 TAILSCALESD_BEARER_TOKEN=supersecret ``` @@ -101,6 +103,9 @@ This service provides the following Prometheus metrics: - **Description**: The number times a matrix sd host was unreachable. This counter increments each time a connection attempt to a matrix sd host fails. - **Labels**: - `device_hostname`: The hostname of the device that was unreachable. +- `tailscalesd_polling_up` + - **Type**: Gauge + - **Description**: Indicates if tailscalesd can access the tailscale devices API up (1) or down (0) It also provides HTTP server metrics from [trallnag/prometheus-fastapi-instrumentator](https://github.com/trallnag/prometheus-fastapi-instrumentator) diff --git a/tailscalesd/main.py b/tailscalesd/main.py index 62f4490..118adb1 100644 --- a/tailscalesd/main.py +++ b/tailscalesd/main.py @@ -3,6 +3,7 @@ import logging import os import sys from contextlib import asynccontextmanager +from datetime import datetime, timedelta from functools import lru_cache from ipaddress import ip_address from typing import Annotated, Dict, List @@ -12,7 +13,7 @@ import json_logging # type: ignore import uvicorn from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from prometheus_client import Counter +from prometheus_client import Counter, Gauge from prometheus_fastapi_instrumentator import Instrumentator from pydantic import Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict @@ -35,6 +36,10 @@ counter_matrix_sd_down = Counter( "The number times a matrix sd host was unreachable", ["device_hostname"], ) +gauge_tailscale_polling_up = Gauge( + "tailscalesd_polling_up", + "Indicates if tailscalesd can access the tailscale devices API up (1) or down (0)", +) def ipv4_only(addresses) -> List[str]: @@ -51,20 +56,66 @@ class Settings(BaseSettings): interval: int = 60 tailnet: str = Field() bearer_token: str = Field() - api_key: SecretStr = Field() test_mode: bool = False + client_id: SecretStr = Field() + client_secret: SecretStr = Field() CACHE_SD = [] -async def tailscale_devices(settings: Settings) -> List[Dict]: +class AccessToken: + access_token: str + + def __init__(self, api_resp): + self._token = api_resp["access_token"] + self.expires_in = api_resp["expires_in"] + self.scope = api_resp["scope"] + self.expiration_time = self._calculate_expiration_time(self.expires_in) + + def _calculate_expiration_time(self, expires_in): + return datetime.now() + timedelta(seconds=expires_in) + + def is_expiring_soon(self, buffer_seconds=300): + return datetime.now() >= self.expiration_time - timedelta( + seconds=buffer_seconds + ) + + @property + def token(self): + if self.is_expiring_soon(): + raise Exception("Access Token is expiring soon") + return self._token + + +async def get_access_token( + tailscale_client_id: str, tailscale_client_secret: str +) -> AccessToken: + ts_auth_url = "https://api.tailscale.com/api/v2/oauth/token" + data = { + "client_id": tailscale_client_id, + "client_secret": tailscale_client_secret, + "grant_type": "client_credentials", + } + + async with httpx.AsyncClient() as client: + try: + r = await client.post(ts_auth_url, data=data) + r.raise_for_status() + return AccessToken(r.json()) + except Exception as e: + raise Exception("Error requesting API key using OAUTH client.") from e + + +async def tailscale_devices( + settings: Settings, access_token: AccessToken +) -> List[Dict]: async with httpx.AsyncClient() as client: try: # https://github.com/tailscale/tailscale/blob/main/api.md#tailnet-devices-get r = await client.get( f"https://api.tailscale.com/api/v2/tailnet/{settings.tailnet}/devices", - auth=(settings.api_key.get_secret_value(), ""), + auth=(access_token.token, ""), ) return r.json()["devices"] except Exception as e: @@ -194,20 +245,30 @@ async def poll_sd(settings: Settings): global CACHE_SD if settings.test_mode: return + access_token = None while True: try: - devices = await tailscale_devices(settings) + if not access_token or access_token.is_expiring_soon(): + access_token = await get_access_token( + settings.client_id.get_secret_value(), + settings.client_secret.get_secret_value(), + ) + + devices = await tailscale_devices(settings, access_token) devices = filter_devices(devices) device_targets = plain_devices_sd(settings.tailnet, devices) matrix_targets = await matrix_sd(settings.tailnet, devices) CACHE_SD = matrix_targets + device_targets + gauge_tailscale_polling_up.set(1) await asyncio.sleep(settings.interval) except Exception as e: counter_unhandled_background_task_crashes.inc() + gauge_tailscale_polling_up.set(0) log.error( "Service Discovery poller failed", exc_info=e, ) + await asyncio.sleep(settings.interval) @lru_cache @@ -259,5 +320,13 @@ def main(): uvicorn.run(app, host=settings.host, port=settings.port) +# async def test(): +# tailscale_client_id = os.environ.get("TAILSCALESD_CLIENT_ID") +# tailscale_client_secret = os.environ.get("TAILSCALESD_CLIENT_SECRET") +# r = await get_access_token(tailscale_client_id, tailscale_client_secret) +# print(r) + + if __name__ == "__main__": + # asyncio.run(test()) main() diff --git a/tests/test_auth.py b/tests/test_auth.py index f318ef8..cc13d82 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -5,7 +5,13 @@ client = TestClient(app) def get_settings_override() -> Settings: - return Settings(test_mode=True, tailnet="test", api_key="test", bearer_token="test") + return Settings( + test_mode=True, + tailnet="test", + client_id="test", + client_secret="test", + bearer_token="test", + ) app.dependency_overrides[get_settings] = get_settings_override