add required bearer auth token
This commit is contained in:
parent
5da9d04d7e
commit
02151e49b8
3 changed files with 72 additions and 10 deletions
|
|
@ -10,7 +10,7 @@ See [the prometheus docs][0] for more information on the HTTP service discovery
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
`curl http://tailscalesd:9242/`
|
`curl -H "Authorization: Bearer supersecret" http://tailscalesd:9242`
|
||||||
|
|
||||||
``` json
|
``` json
|
||||||
[
|
[
|
||||||
|
|
@ -59,6 +59,7 @@ Configuration values can be set using environment variables, or optionally loade
|
||||||
- **host** (`TAILSCALESD_HOST`): The host address on which the application will bind (designed to be used in a container, BE CAREFUL!). Default is `0.0.0.0`.
|
- **host** (`TAILSCALESD_HOST`): The host address on which the application will bind (designed to be used in a container, BE CAREFUL!). Default is `0.0.0.0`.
|
||||||
- **port** (`TAILSCALESD_PORT`): The port number on which the application will be accessible. Default is `9242`.
|
- **port** (`TAILSCALESD_PORT`): The port number on which the application will be accessible. Default is `9242`.
|
||||||
- **interval** (`TAILSCALESD_INTERVAL`): The interval on which the Tailscale API is polled in seconds. Default is `60`.
|
- **interval** (`TAILSCALESD_INTERVAL`): The interval on which the Tailscale API is polled in seconds. Default is `60`.
|
||||||
|
- **bearer_token** (`TAILSCALESD_BEARER_TOKEN`): The authentication token passed in the Authorization header (required).
|
||||||
- **tailnet** (`TAILSCALESD_TAILNET`): The Tailscale tailnet identifier (required).
|
- **tailnet** (`TAILSCALESD_TAILNET`): The Tailscale tailnet identifier (required).
|
||||||
- **api_key** (`TAILSCALESD_API_KEY`): The Tailscale API key (required).
|
- **api_key** (`TAILSCALESD_API_KEY`): The Tailscale API key (required).
|
||||||
|
|
||||||
|
|
@ -72,6 +73,7 @@ You can also specify an environment file to load configuration values. The path
|
||||||
TAILSCALESD_TAILNET=my-tailnet
|
TAILSCALESD_TAILNET=my-tailnet
|
||||||
TAILSCALESD_API_KEY=my-api-key
|
TAILSCALESD_API_KEY=my-api-key
|
||||||
TAILSCALESD_HOST=127.0.0.1
|
TAILSCALESD_HOST=127.0.0.1
|
||||||
|
TAILSCALESD_BEARER_TOKEN=supersecret
|
||||||
```
|
```
|
||||||
|
|
||||||
### Monitoring
|
### Monitoring
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,15 @@ import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import lru_cache
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
from typing import Dict, List
|
from typing import Annotated, Dict, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import json_logging # type: ignore
|
import json_logging # type: ignore
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field, SecretStr
|
||||||
|
|
@ -48,14 +50,15 @@ class Settings(BaseSettings):
|
||||||
port: int = 9242
|
port: int = 9242
|
||||||
interval: int = 60
|
interval: int = 60
|
||||||
tailnet: str = Field()
|
tailnet: str = Field()
|
||||||
|
bearer_token: str = Field()
|
||||||
api_key: SecretStr = Field()
|
api_key: SecretStr = Field()
|
||||||
|
test_mode: bool = False
|
||||||
|
|
||||||
|
|
||||||
settings = Settings() # type: ignore[call-arg]
|
|
||||||
CACHE_SD = []
|
CACHE_SD = []
|
||||||
|
|
||||||
|
|
||||||
async def tailscale_devices() -> List[Dict]:
|
async def tailscale_devices(settings: Settings) -> List[Dict]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
try:
|
try:
|
||||||
# https://github.com/tailscale/tailscale/blob/main/api.md#tailnet-devices-get
|
# https://github.com/tailscale/tailscale/blob/main/api.md#tailnet-devices-get
|
||||||
|
|
@ -170,11 +173,13 @@ def plain_devices_sd(tailnet, devices) -> List[Dict]:
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
async def poll_sd():
|
async def poll_sd(settings: Settings):
|
||||||
global CACHE_SD
|
global CACHE_SD
|
||||||
|
if settings.test_mode:
|
||||||
|
return
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
devices = await tailscale_devices()
|
devices = await tailscale_devices(settings)
|
||||||
device_targets = plain_devices_sd(settings.tailnet, devices)
|
device_targets = plain_devices_sd(settings.tailnet, devices)
|
||||||
matrix_targets = await matrix_sd(settings.tailnet, devices)
|
matrix_targets = await matrix_sd(settings.tailnet, devices)
|
||||||
CACHE_SD = matrix_targets + device_targets
|
CACHE_SD = matrix_targets + device_targets
|
||||||
|
|
@ -187,26 +192,52 @@ async def poll_sd():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings():
|
||||||
|
return Settings() # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
instrumentator.expose(app)
|
instrumentator.expose(app)
|
||||||
asyncio.create_task(poll_sd())
|
settings = get_settings()
|
||||||
|
asyncio.create_task(poll_sd(settings))
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
security = HTTPBearer()
|
||||||
instrumentator = Instrumentator().instrument(app)
|
instrumentator = Instrumentator().instrument(app)
|
||||||
|
|
||||||
json_logging.init_fastapi(enable_json=True)
|
json_logging.init_fastapi(enable_json=True)
|
||||||
json_logging.init_request_instrument(app)
|
json_logging.init_request_instrument(app)
|
||||||
|
|
||||||
|
|
||||||
|
async def is_authorized(
|
||||||
|
settings: Annotated[Settings, Depends(get_settings)],
|
||||||
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
credentials.scheme != "Bearer"
|
||||||
|
or credentials.credentials != settings.bearer_token
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid authentication credentials",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def sd():
|
async def sd(
|
||||||
|
is_authed: bool = Depends(is_authorized),
|
||||||
|
):
|
||||||
|
if is_authed:
|
||||||
return CACHE_SD
|
return CACHE_SD
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
settings = get_settings()
|
||||||
uvicorn.run(app, host=settings.host, port=settings.port)
|
uvicorn.run(app, host=settings.host, port=settings.port)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
29
tests/test_auth.py
Normal file
29
tests/test_auth.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from tailscalesd.main import Settings, app, get_settings
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings_override() -> Settings:
|
||||||
|
return Settings(test_mode=True, tailnet="test", api_key="test", bearer_token="test")
|
||||||
|
|
||||||
|
|
||||||
|
app.dependency_overrides[get_settings] = get_settings_override
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_works():
|
||||||
|
response = client.get("/", headers={"Authorization": "Bearer test"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_unauthorized_wrong_token():
|
||||||
|
response = client.get("/", headers={"Authorization": "Bearer incorrect_token"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {"detail": "Invalid authentication credentials"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_unauthorized_no_token():
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {"detail": "Not authenticated"}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue