402 lines
13 KiB
Python
402 lines
13 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timedelta
|
|
from functools import lru_cache
|
|
from ipaddress import ip_address
|
|
from pathlib import Path
|
|
from typing import Annotated, Dict, List, Optional
|
|
|
|
import httpx
|
|
import json_logging # type: ignore
|
|
import uvicorn
|
|
from fastapi import Depends, FastAPI, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from prometheus_client import Counter, Gauge
|
|
from prometheus_fastapi_instrumentator import Instrumentator
|
|
from pydantic import Field, SecretStr, model_validator
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
MATRIX_TAG = "tag:matrix"
|
|
|
|
env_path = os.getenv("TAILSCALESD_ENV_FILE")
|
|
debug = os.getenv("TAILSCALESD_DEBUG", "false").lower() in (
|
|
"1",
|
|
"true",
|
|
"yes",
|
|
"on",
|
|
)
|
|
log = logging.getLogger("tailscalesd")
|
|
log.setLevel(logging.DEBUG if debug else logging.INFO)
|
|
log.addHandler(logging.StreamHandler(sys.stdout))
|
|
|
|
counter_unhandled_background_task_crashes = Counter(
|
|
"tailscalesd_unhandled_background_task_crashes",
|
|
"The number of unhandled background task crashes",
|
|
)
|
|
|
|
counter_matrix_sd_down = Counter(
|
|
"tailscalesd_matrix_sd_down",
|
|
"The number times a matrix sd host was unreachable",
|
|
["device_hostname"],
|
|
)
|
|
gauge_tailscale_polling_up = Gauge(
|
|
"tailscalesd_polling_up",
|
|
"Indicates if tailscalesd can access the tailscale devices API up (1) or down (0)",
|
|
)
|
|
|
|
|
|
def ipv4_only(addresses) -> List[str]:
|
|
"""Given a list of ip addresses, returns only the ipv4 ones"""
|
|
return list(filter(lambda a: ip_address(a).version == 4, addresses))
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
model_config = SettingsConfigDict(
|
|
env_prefix="TAILSCALESD_", env_file=env_path, env_file_encoding="utf-8"
|
|
)
|
|
host: str = "0.0.0.0" # nosec B104
|
|
port: int = 9242
|
|
interval: int = 60
|
|
tailnet: str = Field()
|
|
bearer_token: Optional[SecretStr] = Field(default=None)
|
|
bearer_token_file: Optional[str] = Field(default=None)
|
|
test_mode: bool = False
|
|
client_id: Optional[SecretStr] = Field(default=None)
|
|
client_id_file: Optional[str] = Field(default=None)
|
|
client_secret: Optional[SecretStr] = Field(default=None)
|
|
client_secret_file: Optional[str] = Field(default=None)
|
|
|
|
@staticmethod
|
|
def _load_secret(
|
|
value: Optional[SecretStr], file_path: Optional[str]
|
|
) -> Optional[SecretStr]:
|
|
if value is not None:
|
|
return value
|
|
if file_path is None:
|
|
return None
|
|
secret = Path(file_path).read_text(encoding="utf-8").strip()
|
|
return SecretStr(secret)
|
|
|
|
@model_validator(mode="after")
|
|
def resolve_secret_sources(self):
|
|
self.bearer_token = self._load_secret(self.bearer_token, self.bearer_token_file)
|
|
self.client_id = self._load_secret(self.client_id, self.client_id_file)
|
|
self.client_secret = self._load_secret(
|
|
self.client_secret, self.client_secret_file
|
|
)
|
|
|
|
missing = []
|
|
if self.bearer_token is None:
|
|
missing.append("TAILSCALESD_BEARER_TOKEN or TAILSCALESD_BEARER_TOKEN_FILE")
|
|
if self.client_id is None:
|
|
missing.append("TAILSCALESD_CLIENT_ID or TAILSCALESD_CLIENT_ID_FILE")
|
|
if self.client_secret is None:
|
|
missing.append(
|
|
"TAILSCALESD_CLIENT_SECRET or TAILSCALESD_CLIENT_SECRET_FILE"
|
|
)
|
|
|
|
if missing:
|
|
raise ValueError(f"Missing required settings: {', '.join(missing)}")
|
|
|
|
return self
|
|
|
|
|
|
CACHE_SD = []
|
|
|
|
|
|
class AccessToken:
|
|
access_token: str
|
|
|
|
def __init__(self, api_resp):
|
|
self._token = api_resp["access_token"]
|
|
self.expires_in = api_resp["expires_in"]
|
|
self.scope = api_resp["scope"]
|
|
self.expiration_time = self._calculate_expiration_time(self.expires_in)
|
|
|
|
def _calculate_expiration_time(self, expires_in):
|
|
return datetime.now() + timedelta(seconds=expires_in)
|
|
|
|
def is_expiring_soon(self, buffer_seconds=300):
|
|
return datetime.now() >= self.expiration_time - timedelta(
|
|
seconds=buffer_seconds
|
|
)
|
|
|
|
@property
|
|
def token(self):
|
|
if self.is_expiring_soon():
|
|
raise Exception("Access Token is expiring soon")
|
|
return self._token
|
|
|
|
|
|
async def get_access_token(
|
|
tailscale_client_id: str, tailscale_client_secret: str
|
|
) -> AccessToken:
|
|
ts_auth_url = "https://api.tailscale.com/api/v2/oauth/token"
|
|
data = {
|
|
"client_id": tailscale_client_id,
|
|
"client_secret": tailscale_client_secret,
|
|
"grant_type": "client_credentials",
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
r = await client.post(ts_auth_url, data=data)
|
|
r.raise_for_status()
|
|
return AccessToken(r.json())
|
|
except Exception as e:
|
|
raise Exception("Error requesting API key using OAUTH client.") from e
|
|
|
|
|
|
async def tailscale_devices(
|
|
settings: Settings, access_token: AccessToken
|
|
) -> List[Dict]:
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
# https://github.com/tailscale/tailscale/blob/main/api.md#tailnet-devices-get
|
|
r = await client.get(
|
|
f"https://api.tailscale.com/api/v2/tailnet/{settings.tailnet}/devices",
|
|
auth=(access_token.token, ""),
|
|
)
|
|
return r.json()["devices"]
|
|
except Exception as e:
|
|
log.error(
|
|
"Polling tailscale devices failed!",
|
|
exc_info=e,
|
|
)
|
|
counter_unhandled_background_task_crashes.inc()
|
|
return []
|
|
|
|
|
|
def group_by_type(input_list):
|
|
result = {}
|
|
for item in input_list:
|
|
key = item.get("type")
|
|
if key:
|
|
result.setdefault(key, []).append(item)
|
|
return result
|
|
|
|
|
|
def parse_tag(prefix, device) -> Optional[str]:
|
|
for tag in device["tags"]:
|
|
if tag.startswith(prefix):
|
|
return tag[len(prefix) :]
|
|
return None
|
|
|
|
|
|
def namespace(device) -> Optional[str]:
|
|
return parse_tag("tag:ns-", device)
|
|
|
|
|
|
def tailscale_labels(tailnet, device, tag) -> Dict[str, str]:
|
|
labels = {
|
|
"__meta_tailscale_device_client_version": device["clientVersion"],
|
|
"__meta_tailscale_device_hostname": device["hostname"],
|
|
"__meta_tailscale_device_authorized": str(device["authorized"]).lower(),
|
|
"__meta_tailscale_device_id": device["id"],
|
|
"__meta_tailscale_device_name": device["name"],
|
|
"__meta_tailscale_device_os": device["os"],
|
|
"__meta_tailscale_tailnet": tailnet,
|
|
"tailscale_hostname": device["hostname"],
|
|
"tailscale_name": device["name"],
|
|
"hostname": device["hostname"],
|
|
"instance": device["hostname"],
|
|
}
|
|
if tag:
|
|
labels["__meta_tailscale_device_tag"] = tag
|
|
return labels
|
|
|
|
|
|
async def matrix_node_sd(device) -> Dict:
|
|
log.info("Polling matrix node", extra={"props": {"hostname": device["hostname"]}})
|
|
ipv4 = ipv4_only(device["addresses"])[0]
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.get(f"http://{ipv4}:8081/")
|
|
data = r.json()
|
|
log.debug(f"Found {len(data)} workers", extra={"hostname": device["hostname"]})
|
|
return group_by_type(data)
|
|
|
|
|
|
def matrix_workers_to_sd(tailnet, device, workers) -> List[Dict]:
|
|
if len(workers) == 0:
|
|
return []
|
|
ipv4 = ipv4_only(device["addresses"])[0]
|
|
target_groups = []
|
|
for worker_type, workers in workers.items():
|
|
for worker in workers:
|
|
port = worker["metrics_port"]
|
|
worker_name = worker.get("name", "WORKER_NO_NAME")
|
|
if not port:
|
|
log.error(
|
|
f"Error parsing worker {worker_name} on host={device['hostname']}. Port is invalid port={port}"
|
|
)
|
|
continue
|
|
target_groups.append(
|
|
{
|
|
"targets": [f"{ipv4}:{port}"],
|
|
"labels": tailscale_labels(tailnet, device, None)
|
|
| {
|
|
"__meta_matrix_worker_type": worker_type,
|
|
"__meta_matrix_worker_name": worker_name,
|
|
},
|
|
}
|
|
)
|
|
return target_groups
|
|
|
|
|
|
def filter_devices(devices) -> List[Dict]:
|
|
devices_filtered = []
|
|
for device in devices:
|
|
hostname = device.get("hostname", None)
|
|
if not hostname:
|
|
log.warning("device doesn't have a hostname!", extra={"device": device})
|
|
continue
|
|
if "tags" not in device:
|
|
log.warning(
|
|
f"device does not have any tags! hostname={hostname}",
|
|
extra={"device": device},
|
|
)
|
|
continue
|
|
devices_filtered.append(device)
|
|
return devices_filtered
|
|
|
|
|
|
async def matrix_sd(tailnet, devices) -> List[Dict]:
|
|
sd: List[Dict] = []
|
|
for device in devices:
|
|
if MATRIX_TAG not in device["tags"]:
|
|
continue
|
|
try:
|
|
workers = await matrix_node_sd(device)
|
|
except Exception as e:
|
|
counter_matrix_sd_down.labels(device_hostname=device["hostname"]).inc()
|
|
log.error(
|
|
f"Failed parsing matrix node sd for device={device['hostname']}",
|
|
exc_info=e,
|
|
)
|
|
workers = {}
|
|
targets = matrix_workers_to_sd(tailnet, device, workers)
|
|
if targets:
|
|
sd = sd + targets
|
|
log.info(f"Found {len(sd)} matrix servcies")
|
|
return sd
|
|
|
|
|
|
def plain_devices_sd(tailnet, devices) -> List[Dict]:
|
|
sd = []
|
|
for device in devices:
|
|
targets = ipv4_only(device["addresses"])
|
|
tags = device["tags"]
|
|
if tags:
|
|
for tag in tags:
|
|
labels = tailscale_labels(tailnet, device, tag)
|
|
sd.append({"labels": labels, "targets": targets})
|
|
else:
|
|
labels = tailscale_labels(tailnet, device, None)
|
|
sd.append({"labels": labels, "targets": targets})
|
|
return sd
|
|
|
|
|
|
async def poll_sd(settings: Settings):
|
|
global CACHE_SD
|
|
if settings.test_mode:
|
|
return
|
|
access_token = None
|
|
while True:
|
|
try:
|
|
if not access_token or access_token.is_expiring_soon():
|
|
client_id = settings.client_id
|
|
client_secret = settings.client_secret
|
|
if client_id is None or client_secret is None:
|
|
raise RuntimeError(
|
|
"settings validation failed for oauth credentials"
|
|
)
|
|
access_token = await get_access_token(
|
|
client_id.get_secret_value(),
|
|
client_secret.get_secret_value(),
|
|
)
|
|
|
|
devices = await tailscale_devices(settings, access_token)
|
|
devices = filter_devices(devices)
|
|
device_targets = plain_devices_sd(settings.tailnet, devices)
|
|
matrix_targets = await matrix_sd(settings.tailnet, devices)
|
|
CACHE_SD = matrix_targets + device_targets
|
|
gauge_tailscale_polling_up.set(1)
|
|
await asyncio.sleep(settings.interval)
|
|
except Exception as e:
|
|
counter_unhandled_background_task_crashes.inc()
|
|
gauge_tailscale_polling_up.set(0)
|
|
log.error(
|
|
"Service Discovery poller failed",
|
|
exc_info=e,
|
|
)
|
|
await asyncio.sleep(settings.interval)
|
|
|
|
|
|
@lru_cache
|
|
def get_settings():
|
|
return Settings() # type: ignore[call-arg]
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
instrumentator.expose(app)
|
|
settings = get_settings()
|
|
asyncio.create_task(poll_sd(settings))
|
|
yield
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
security = HTTPBearer()
|
|
instrumentator = Instrumentator().instrument(app)
|
|
|
|
json_logging.init_fastapi(enable_json=True)
|
|
json_logging.init_request_instrument(app)
|
|
|
|
|
|
async def is_authorized(
|
|
settings: Annotated[Settings, Depends(get_settings)],
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
bearer_token = settings.bearer_token
|
|
if bearer_token is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Service is missing bearer token configuration",
|
|
)
|
|
if (
|
|
credentials.scheme != "Bearer"
|
|
or credentials.credentials != bearer_token.get_secret_value()
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authentication credentials",
|
|
)
|
|
return True
|
|
|
|
|
|
@app.get("/")
|
|
async def sd(
|
|
is_authed: bool = Depends(is_authorized),
|
|
):
|
|
if is_authed:
|
|
return CACHE_SD
|
|
|
|
|
|
def main():
|
|
settings = get_settings()
|
|
uvicorn.run(app, host=settings.host, port=settings.port)
|
|
|
|
|
|
# async def test():
|
|
# tailscale_client_id = os.environ.get("TAILSCALESD_CLIENT_ID")
|
|
# tailscale_client_secret = os.environ.get("TAILSCALESD_CLIENT_SECRET")
|
|
# r = await get_access_token(tailscale_client_id, tailscale_client_secret)
|
|
# print(r)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# asyncio.run(test())
|
|
main()
|