tailscalesd/tailscalesd/main.py

402 lines
13 KiB
Python

import asyncio
import logging
import os
import sys
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from functools import lru_cache
from ipaddress import ip_address
from pathlib import Path
from typing import Annotated, Dict, List, Optional
import httpx
import json_logging # type: ignore
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from prometheus_client import Counter, Gauge
from prometheus_fastapi_instrumentator import Instrumentator
from pydantic import Field, SecretStr, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
MATRIX_TAG = "tag:matrix"
env_path = os.getenv("TAILSCALESD_ENV_FILE")
debug = os.getenv("TAILSCALESD_DEBUG", "false").lower() in (
"1",
"true",
"yes",
"on",
)
log = logging.getLogger("tailscalesd")
log.setLevel(logging.DEBUG if debug else logging.INFO)
log.addHandler(logging.StreamHandler(sys.stdout))
counter_unhandled_background_task_crashes = Counter(
"tailscalesd_unhandled_background_task_crashes",
"The number of unhandled background task crashes",
)
counter_matrix_sd_down = Counter(
"tailscalesd_matrix_sd_down",
"The number times a matrix sd host was unreachable",
["device_hostname"],
)
gauge_tailscale_polling_up = Gauge(
"tailscalesd_polling_up",
"Indicates if tailscalesd can access the tailscale devices API up (1) or down (0)",
)
def ipv4_only(addresses) -> List[str]:
"""Given a list of ip addresses, returns only the ipv4 ones"""
return list(filter(lambda a: ip_address(a).version == 4, addresses))
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="TAILSCALESD_", env_file=env_path, env_file_encoding="utf-8"
)
host: str = "0.0.0.0" # nosec B104
port: int = 9242
interval: int = 60
tailnet: str = Field()
bearer_token: Optional[SecretStr] = Field(default=None)
bearer_token_file: Optional[str] = Field(default=None)
test_mode: bool = False
client_id: Optional[SecretStr] = Field(default=None)
client_id_file: Optional[str] = Field(default=None)
client_secret: Optional[SecretStr] = Field(default=None)
client_secret_file: Optional[str] = Field(default=None)
@staticmethod
def _load_secret(
value: Optional[SecretStr], file_path: Optional[str]
) -> Optional[SecretStr]:
if value is not None:
return value
if file_path is None:
return None
secret = Path(file_path).read_text(encoding="utf-8").strip()
return SecretStr(secret)
@model_validator(mode="after")
def resolve_secret_sources(self):
self.bearer_token = self._load_secret(self.bearer_token, self.bearer_token_file)
self.client_id = self._load_secret(self.client_id, self.client_id_file)
self.client_secret = self._load_secret(
self.client_secret, self.client_secret_file
)
missing = []
if self.bearer_token is None:
missing.append("TAILSCALESD_BEARER_TOKEN or TAILSCALESD_BEARER_TOKEN_FILE")
if self.client_id is None:
missing.append("TAILSCALESD_CLIENT_ID or TAILSCALESD_CLIENT_ID_FILE")
if self.client_secret is None:
missing.append(
"TAILSCALESD_CLIENT_SECRET or TAILSCALESD_CLIENT_SECRET_FILE"
)
if missing:
raise ValueError(f"Missing required settings: {', '.join(missing)}")
return self
CACHE_SD = []
class AccessToken:
access_token: str
def __init__(self, api_resp):
self._token = api_resp["access_token"]
self.expires_in = api_resp["expires_in"]
self.scope = api_resp["scope"]
self.expiration_time = self._calculate_expiration_time(self.expires_in)
def _calculate_expiration_time(self, expires_in):
return datetime.now() + timedelta(seconds=expires_in)
def is_expiring_soon(self, buffer_seconds=300):
return datetime.now() >= self.expiration_time - timedelta(
seconds=buffer_seconds
)
@property
def token(self):
if self.is_expiring_soon():
raise Exception("Access Token is expiring soon")
return self._token
async def get_access_token(
tailscale_client_id: str, tailscale_client_secret: str
) -> AccessToken:
ts_auth_url = "https://api.tailscale.com/api/v2/oauth/token"
data = {
"client_id": tailscale_client_id,
"client_secret": tailscale_client_secret,
"grant_type": "client_credentials",
}
async with httpx.AsyncClient() as client:
try:
r = await client.post(ts_auth_url, data=data)
r.raise_for_status()
return AccessToken(r.json())
except Exception as e:
raise Exception("Error requesting API key using OAUTH client.") from e
async def tailscale_devices(
settings: Settings, access_token: AccessToken
) -> List[Dict]:
async with httpx.AsyncClient() as client:
try:
# https://github.com/tailscale/tailscale/blob/main/api.md#tailnet-devices-get
r = await client.get(
f"https://api.tailscale.com/api/v2/tailnet/{settings.tailnet}/devices",
auth=(access_token.token, ""),
)
return r.json()["devices"]
except Exception as e:
log.error(
"Polling tailscale devices failed!",
exc_info=e,
)
counter_unhandled_background_task_crashes.inc()
return []
def group_by_type(input_list):
result = {}
for item in input_list:
key = item.get("type")
if key:
result.setdefault(key, []).append(item)
return result
def parse_tag(prefix, device) -> Optional[str]:
for tag in device["tags"]:
if tag.startswith(prefix):
return tag[len(prefix) :]
return None
def namespace(device) -> Optional[str]:
return parse_tag("tag:ns-", device)
def tailscale_labels(tailnet, device, tag) -> Dict[str, str]:
labels = {
"__meta_tailscale_device_client_version": device["clientVersion"],
"__meta_tailscale_device_hostname": device["hostname"],
"__meta_tailscale_device_authorized": str(device["authorized"]).lower(),
"__meta_tailscale_device_id": device["id"],
"__meta_tailscale_device_name": device["name"],
"__meta_tailscale_device_os": device["os"],
"__meta_tailscale_tailnet": tailnet,
"tailscale_hostname": device["hostname"],
"tailscale_name": device["name"],
"hostname": device["hostname"],
"instance": device["hostname"],
}
if tag:
labels["__meta_tailscale_device_tag"] = tag
return labels
async def matrix_node_sd(device) -> Dict:
log.info("Polling matrix node", extra={"props": {"hostname": device["hostname"]}})
ipv4 = ipv4_only(device["addresses"])[0]
async with httpx.AsyncClient() as client:
r = await client.get(f"http://{ipv4}:8081/")
data = r.json()
log.debug(f"Found {len(data)} workers", extra={"hostname": device["hostname"]})
return group_by_type(data)
def matrix_workers_to_sd(tailnet, device, workers) -> List[Dict]:
if len(workers) == 0:
return []
ipv4 = ipv4_only(device["addresses"])[0]
target_groups = []
for worker_type, workers in workers.items():
for worker in workers:
port = worker["metrics_port"]
worker_name = worker.get("name", "WORKER_NO_NAME")
if not port:
log.error(
f"Error parsing worker {worker_name} on host={device['hostname']}. Port is invalid port={port}"
)
continue
target_groups.append(
{
"targets": [f"{ipv4}:{port}"],
"labels": tailscale_labels(tailnet, device, None)
| {
"__meta_matrix_worker_type": worker_type,
"__meta_matrix_worker_name": worker_name,
},
}
)
return target_groups
def filter_devices(devices) -> List[Dict]:
devices_filtered = []
for device in devices:
hostname = device.get("hostname", None)
if not hostname:
log.warning("device doesn't have a hostname!", extra={"device": device})
continue
if "tags" not in device:
log.warning(
f"device does not have any tags! hostname={hostname}",
extra={"device": device},
)
continue
devices_filtered.append(device)
return devices_filtered
async def matrix_sd(tailnet, devices) -> List[Dict]:
sd: List[Dict] = []
for device in devices:
if MATRIX_TAG not in device["tags"]:
continue
try:
workers = await matrix_node_sd(device)
except Exception as e:
counter_matrix_sd_down.labels(device_hostname=device["hostname"]).inc()
log.error(
f"Failed parsing matrix node sd for device={device['hostname']}",
exc_info=e,
)
workers = {}
targets = matrix_workers_to_sd(tailnet, device, workers)
if targets:
sd = sd + targets
log.info(f"Found {len(sd)} matrix servcies")
return sd
def plain_devices_sd(tailnet, devices) -> List[Dict]:
sd = []
for device in devices:
targets = ipv4_only(device["addresses"])
tags = device["tags"]
if tags:
for tag in tags:
labels = tailscale_labels(tailnet, device, tag)
sd.append({"labels": labels, "targets": targets})
else:
labels = tailscale_labels(tailnet, device, None)
sd.append({"labels": labels, "targets": targets})
return sd
async def poll_sd(settings: Settings):
global CACHE_SD
if settings.test_mode:
return
access_token = None
while True:
try:
if not access_token or access_token.is_expiring_soon():
client_id = settings.client_id
client_secret = settings.client_secret
if client_id is None or client_secret is None:
raise RuntimeError(
"settings validation failed for oauth credentials"
)
access_token = await get_access_token(
client_id.get_secret_value(),
client_secret.get_secret_value(),
)
devices = await tailscale_devices(settings, access_token)
devices = filter_devices(devices)
device_targets = plain_devices_sd(settings.tailnet, devices)
matrix_targets = await matrix_sd(settings.tailnet, devices)
CACHE_SD = matrix_targets + device_targets
gauge_tailscale_polling_up.set(1)
await asyncio.sleep(settings.interval)
except Exception as e:
counter_unhandled_background_task_crashes.inc()
gauge_tailscale_polling_up.set(0)
log.error(
"Service Discovery poller failed",
exc_info=e,
)
await asyncio.sleep(settings.interval)
@lru_cache
def get_settings():
return Settings() # type: ignore[call-arg]
@asynccontextmanager
async def lifespan(app: FastAPI):
instrumentator.expose(app)
settings = get_settings()
asyncio.create_task(poll_sd(settings))
yield
app = FastAPI(lifespan=lifespan)
security = HTTPBearer()
instrumentator = Instrumentator().instrument(app)
json_logging.init_fastapi(enable_json=True)
json_logging.init_request_instrument(app)
async def is_authorized(
settings: Annotated[Settings, Depends(get_settings)],
credentials: HTTPAuthorizationCredentials = Depends(security),
):
bearer_token = settings.bearer_token
if bearer_token is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Service is missing bearer token configuration",
)
if (
credentials.scheme != "Bearer"
or credentials.credentials != bearer_token.get_secret_value()
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
return True
@app.get("/")
async def sd(
is_authed: bool = Depends(is_authorized),
):
if is_authed:
return CACHE_SD
def main():
settings = get_settings()
uvicorn.run(app, host=settings.host, port=settings.port)
# async def test():
# tailscale_client_id = os.environ.get("TAILSCALESD_CLIENT_ID")
# tailscale_client_secret = os.environ.get("TAILSCALESD_CLIENT_SECRET")
# r = await get_access_token(tailscale_client_id, tailscale_client_secret)
# print(r)
if __name__ == "__main__":
# asyncio.run(test())
main()