Switch to using tailscale oauth api

This removes the need to update the API key every 90 days.
This commit is contained in:
Abel Luck 2024-07-17 09:48:30 +02:00
parent 2bf137847d
commit b195bd1e8f
3 changed files with 88 additions and 8 deletions

View file

@ -73,7 +73,8 @@ Configuration values can be set using environment variables, or optionally loade
- **interval** (`TAILSCALESD_INTERVAL`): The interval on which the Tailscale API is polled in seconds. Default is `60`. - **interval** (`TAILSCALESD_INTERVAL`): The interval on which the Tailscale API is polled in seconds. Default is `60`.
- **bearer_token** (`TAILSCALESD_BEARER_TOKEN`): The authentication token passed in the Authorization header (required). - **bearer_token** (`TAILSCALESD_BEARER_TOKEN`): The authentication token passed in the Authorization header (required).
- **tailnet** (`TAILSCALESD_TAILNET`): The Tailscale tailnet identifier (required). - **tailnet** (`TAILSCALESD_TAILNET`): The Tailscale tailnet identifier (required).
- **api_key** (`TAILSCALESD_API_KEY`): The Tailscale API key (required). - **client_id** (`TAILSCALESD_CLIENT_ID`): The Tailscale oauth client id (required).
- **client_secret** (`TAILSCALESD_CLIENT_SECRET`): The Tailscale oauth client secret (required).
#### Environment File #### Environment File
@ -83,7 +84,8 @@ You can also specify an environment file to load configuration values. The path
```env ```env
TAILSCALESD_TAILNET=my-tailnet TAILSCALESD_TAILNET=my-tailnet
TAILSCALESD_API_KEY=my-api-key TAILSCALESD_CLIENT_ID=xxxx
TAILSCALESD_CLIENT_SECRET=yyyyy
TAILSCALESD_HOST=127.0.0.1 TAILSCALESD_HOST=127.0.0.1
TAILSCALESD_BEARER_TOKEN=supersecret TAILSCALESD_BEARER_TOKEN=supersecret
``` ```
@ -101,6 +103,9 @@ This service provides the following Prometheus metrics:
- **Description**: The number times a matrix sd host was unreachable. This counter increments each time a connection attempt to a matrix sd host fails. - **Description**: The number times a matrix sd host was unreachable. This counter increments each time a connection attempt to a matrix sd host fails.
- **Labels**: - **Labels**:
- `device_hostname`: The hostname of the device that was unreachable. - `device_hostname`: The hostname of the device that was unreachable.
- `tailscalesd_polling_up`
- **Type**: Gauge
- **Description**: Indicates if tailscalesd can access the tailscale devices API up (1) or down (0)
It also provides HTTP server metrics from [trallnag/prometheus-fastapi-instrumentator](https://github.com/trallnag/prometheus-fastapi-instrumentator) It also provides HTTP server metrics from [trallnag/prometheus-fastapi-instrumentator](https://github.com/trallnag/prometheus-fastapi-instrumentator)

View file

@ -3,6 +3,7 @@ import logging
import os import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from functools import lru_cache from functools import lru_cache
from ipaddress import ip_address from ipaddress import ip_address
from typing import Annotated, Dict, List from typing import Annotated, Dict, List
@ -12,7 +13,7 @@ import json_logging # type: ignore
import uvicorn import uvicorn
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from prometheus_client import Counter from prometheus_client import Counter, Gauge
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
from pydantic import Field, SecretStr from pydantic import Field, SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -35,6 +36,10 @@ counter_matrix_sd_down = Counter(
"The number times a matrix sd host was unreachable", "The number times a matrix sd host was unreachable",
["device_hostname"], ["device_hostname"],
) )
gauge_tailscale_polling_up = Gauge(
"tailscalesd_polling_up",
"Indicates if tailscalesd can access the tailscale devices API up (1) or down (0)",
)
def ipv4_only(addresses) -> List[str]: def ipv4_only(addresses) -> List[str]:
@ -51,20 +56,66 @@ class Settings(BaseSettings):
interval: int = 60 interval: int = 60
tailnet: str = Field() tailnet: str = Field()
bearer_token: str = Field() bearer_token: str = Field()
api_key: SecretStr = Field()
test_mode: bool = False test_mode: bool = False
client_id: SecretStr = Field()
client_secret: SecretStr = Field()
CACHE_SD = [] CACHE_SD = []
async def tailscale_devices(settings: Settings) -> List[Dict]: class AccessToken:
access_token: str
def __init__(self, api_resp):
self._token = api_resp["access_token"]
self.expires_in = api_resp["expires_in"]
self.scope = api_resp["scope"]
self.expiration_time = self._calculate_expiration_time(self.expires_in)
def _calculate_expiration_time(self, expires_in):
return datetime.now() + timedelta(seconds=expires_in)
def is_expiring_soon(self, buffer_seconds=300):
return datetime.now() >= self.expiration_time - timedelta(
seconds=buffer_seconds
)
@property
def token(self):
if self.is_expiring_soon():
raise Exception("Access Token is expiring soon")
return self._token
async def get_access_token(
tailscale_client_id: str, tailscale_client_secret: str
) -> AccessToken:
ts_auth_url = "https://api.tailscale.com/api/v2/oauth/token"
data = {
"client_id": tailscale_client_id,
"client_secret": tailscale_client_secret,
"grant_type": "client_credentials",
}
async with httpx.AsyncClient() as client:
try:
r = await client.post(ts_auth_url, data=data)
r.raise_for_status()
return AccessToken(r.json())
except Exception as e:
raise Exception("Error requesting API key using OAUTH client.") from e
async def tailscale_devices(
settings: Settings, access_token: AccessToken
) -> List[Dict]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
# https://github.com/tailscale/tailscale/blob/main/api.md#tailnet-devices-get # https://github.com/tailscale/tailscale/blob/main/api.md#tailnet-devices-get
r = await client.get( r = await client.get(
f"https://api.tailscale.com/api/v2/tailnet/{settings.tailnet}/devices", f"https://api.tailscale.com/api/v2/tailnet/{settings.tailnet}/devices",
auth=(settings.api_key.get_secret_value(), ""), auth=(access_token.token, ""),
) )
return r.json()["devices"] return r.json()["devices"]
except Exception as e: except Exception as e:
@ -194,20 +245,30 @@ async def poll_sd(settings: Settings):
global CACHE_SD global CACHE_SD
if settings.test_mode: if settings.test_mode:
return return
access_token = None
while True: while True:
try: try:
devices = await tailscale_devices(settings) if not access_token or access_token.is_expiring_soon():
access_token = await get_access_token(
settings.client_id.get_secret_value(),
settings.client_secret.get_secret_value(),
)
devices = await tailscale_devices(settings, access_token)
devices = filter_devices(devices) devices = filter_devices(devices)
device_targets = plain_devices_sd(settings.tailnet, devices) device_targets = plain_devices_sd(settings.tailnet, devices)
matrix_targets = await matrix_sd(settings.tailnet, devices) matrix_targets = await matrix_sd(settings.tailnet, devices)
CACHE_SD = matrix_targets + device_targets CACHE_SD = matrix_targets + device_targets
gauge_tailscale_polling_up.set(1)
await asyncio.sleep(settings.interval) await asyncio.sleep(settings.interval)
except Exception as e: except Exception as e:
counter_unhandled_background_task_crashes.inc() counter_unhandled_background_task_crashes.inc()
gauge_tailscale_polling_up.set(0)
log.error( log.error(
"Service Discovery poller failed", "Service Discovery poller failed",
exc_info=e, exc_info=e,
) )
await asyncio.sleep(settings.interval)
@lru_cache @lru_cache
@ -259,5 +320,13 @@ def main():
uvicorn.run(app, host=settings.host, port=settings.port) uvicorn.run(app, host=settings.host, port=settings.port)
# async def test():
# tailscale_client_id = os.environ.get("TAILSCALESD_CLIENT_ID")
# tailscale_client_secret = os.environ.get("TAILSCALESD_CLIENT_SECRET")
# r = await get_access_token(tailscale_client_id, tailscale_client_secret)
# print(r)
if __name__ == "__main__": if __name__ == "__main__":
# asyncio.run(test())
main() main()

View file

@ -5,7 +5,13 @@ client = TestClient(app)
def get_settings_override() -> Settings: def get_settings_override() -> Settings:
return Settings(test_mode=True, tailnet="test", api_key="test", bearer_token="test") return Settings(
test_mode=True,
tailnet="test",
client_id="test",
client_secret="test",
bearer_token="test",
)
app.dependency_overrides[get_settings] = get_settings_override app.dependency_overrides[get_settings] = get_settings_override