initial working version

This commit is contained in:
Abel Luck 2026-05-05 13:43:02 +02:00
parent 4d8b83cbb6
commit 8318f9fe70
15 changed files with 917 additions and 4 deletions

2
.gitignore vendored
View file

@ -1,3 +1,3 @@
.direnv
result
/dnstt_exporter

View file

@ -1,3 +1,39 @@
# dnstt_exporter
Prometheus exporter for DNSTT client/session metrics.
`dnstt_exporter` observes DNSTT DNS traffic on a local Linux host and exports
aggregate Prometheus metrics. It does not proxy, terminate, or configure DNSTT;
it passively decodes DNSTT session IDs from DNS query names.
## Usage
```sh
sudo dnstt_exporter \
-dnstt.domain tunnel.example.com \
-dnstt.port 53 \
-web.listen-address :9713
```
The exporter needs permission to open an `AF_PACKET` raw socket. Run it as root
or grant the binary `CAP_NET_RAW`.
Metrics are served at `http://127.0.0.1:9713/metrics` by default.
## Metrics
All DNSTT metrics use a `domain` label:
- `dnstt_active_clients`
- `dnstt_peak_clients`
- `dnstt_queries_total`
- `dnstt_bytes_in_total`
- `dnstt_bytes_out_total`
- `dnstt_sessions_total`
## Development
```sh
go test ./...
go build ./cmd/dnstt_exporter
```

104
cmd/dnstt_exporter/main.go Normal file
View file

@ -0,0 +1,104 @@
package main
import (
"errors"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"guardianproject.dev/bypass-censorship/dnstt_exporter/internal/dnstt"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
type domainList []string
func (d *domainList) String() string {
return strings.Join(*d, ",")
}
func (d *domainList) Set(value string) error {
value = strings.TrimSpace(value)
if value == "" {
return errors.New("domain cannot be empty")
}
*d = append(*d, value)
return nil
}
func main() {
var domains domainList
var (
listenAddress = flag.String("web.listen-address", ":9713", "Address on which to expose metrics and web interface.")
metricsPath = flag.String("web.telemetry-path", "/metrics", "Path under which to expose metrics.")
dnsPort = flag.Int("dnstt.port", 53, "UDP port where DNSTT DNS traffic is observed.")
)
flag.Var(&domains, "dnstt.domain", "DNSTT tunnel domain to observe. Repeat for multiple domains.")
flag.Parse()
if len(domains) == 0 {
log.Fatal("at least one -dnstt.domain is required")
}
if *dnsPort < 1 || *dnsPort > 65535 {
log.Fatalf("-dnstt.port must be between 1 and 65535, got %d", *dnsPort)
}
collector := dnstt.NewCollector(domains)
fd, err := dnstt.OpenRawSocket()
if err != nil {
log.Fatalf("failed to open raw socket: %v", err)
}
defer dnstt.CloseRawSocket(fd)
stop := make(chan struct{})
go dnstt.CaptureLoop(fd, *dnsPort, collector, stop)
registry := prometheus.NewRegistry()
registry.MustRegister(dnstt.NewExporter(collector))
registry.MustRegister(collectors.NewGoCollector())
registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
mux := http.NewServeMux()
mux.Handle(*metricsPath, promhttp.HandlerFor(registry, promhttp.HandlerOpts{}))
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
fmt.Fprintf(w, `<html>
<head><title>DNSTT Exporter</title></head>
<body>
<h1>DNSTT Exporter</h1>
<p><a href="%s">Metrics</a></p>
</body>
</html>
`, *metricsPath)
})
errCh := make(chan error, 1)
server := &http.Server{Addr: *listenAddress, Handler: mux}
go func() {
log.Printf("listening on %s", *listenAddress)
errCh <- server.ListenAndServe()
}()
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-signalCh:
log.Printf("received %s, shutting down", sig)
close(stop)
case err := <-errCh:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("server error: %v", err)
}
}
}

View file

@ -11,9 +11,27 @@
forAllSystems = fn: nixpkgs.lib.genAttrs systems (system: fn nixpkgs.legacyPackages.${system});
in
{
#packages = forAllSystems (pkgs: {
# default = pkgs.callPackage ./package.nix { };
#});
packages = forAllSystems (pkgs: {
default = pkgs.buildGoModule {
pname = "dnstt_exporter";
version = "0.1.0";
src = ./.;
vendorHash = "sha256-+olXxVW9VcvapOuJGas70RIfJgMzUv6mndzJi6Apw+s=";
subPackages = [ "cmd/dnstt_exporter" ];
};
});
checks = forAllSystems (pkgs: {
tests = self.packages.${pkgs.stdenv.hostPlatform.system}.default.overrideAttrs (_: {
doCheck = true;
checkPhase = ''
runHook preCheck
go test ./...
runHook postCheck
'';
});
});
#checks = forAllSystems (
# pkgs:

19
go.mod Normal file
View file

@ -0,0 +1,19 @@
module guardianproject.dev/bypass-censorship/dnstt_exporter
go 1.25
require github.com/prometheus/client_golang v1.23.2
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sys v0.35.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
)

46
go.sum Normal file
View file

@ -0,0 +1,46 @@
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -0,0 +1,50 @@
//go:build linux
package dnstt
import (
"fmt"
"syscall"
)
// OpenRawSocket creates an AF_PACKET raw socket for sniffing IPv4 packets.
func OpenRawSocket() (int, error) {
fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(syscall.ETH_P_IP)))
if err != nil {
return -1, fmt.Errorf("open raw socket: %w", err)
}
tv := syscall.Timeval{Sec: 0, Usec: 500000}
_ = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv)
return fd, nil
}
// CaptureLoop records matching packets until stop is closed.
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
buf := make([]byte, 65535)
for {
select {
case <-stop:
return
default:
}
n, _, err := syscall.Recvfrom(fd, buf, 0)
if err != nil {
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK || err == syscall.EINTR {
continue
}
return
}
ProcessIPv4Packet(buf[:n], port, collector)
}
}
// CloseRawSocket closes a socket opened by OpenRawSocket.
func CloseRawSocket(fd int) error {
return syscall.Close(fd)
}
func htons(v uint16) uint16 {
return (v << 8) | (v >> 8)
}

View file

@ -0,0 +1,20 @@
//go:build !linux
package dnstt
import "fmt"
// OpenRawSocket is only implemented on Linux.
func OpenRawSocket() (int, error) {
return -1, fmt.Errorf("raw packet capture is only supported on Linux")
}
// CaptureLoop is only implemented on Linux.
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
<-stop
}
// CloseRawSocket is only implemented on Linux.
func CloseRawSocket(fd int) error {
return nil
}

199
internal/dnstt/collector.go Normal file
View file

@ -0,0 +1,199 @@
package dnstt
import (
"strings"
"sync"
"time"
)
// ClientTimeout is how long since the last query before a session is considered inactive.
const ClientTimeout = 30 * time.Second
// Collector accumulates observed DNSTT DNS traffic.
type Collector struct {
mu sync.Mutex
domains []string
tunnels map[string]*tunnelState
now func() time.Time
}
type tunnelState struct {
queries uint64
bytesIn uint64
bytesOut uint64
peakClients int
clients map[string]*clientState
}
type clientState struct {
firstSeen time.Time
lastSeen time.Time
queries uint64
bytesIn uint64
}
// Snapshot is a point-in-time copy of observed traffic.
type Snapshot struct {
Tunnels map[string]TunnelSnapshot
}
// TunnelSnapshot is a point-in-time copy of one DNSTT domain.
type TunnelSnapshot struct {
Domain string
ActiveClients int
PeakClients int
TotalSessions int
TotalQueries uint64
BytesIn uint64
BytesOut uint64
}
// CollectorOption configures a Collector.
type CollectorOption func(*Collector)
// WithNow overrides the clock. It is intended for deterministic tests.
func WithNow(now func() time.Time) CollectorOption {
return func(c *Collector) {
c.now = now
}
}
// NewCollector creates a collector for the provided DNSTT domains.
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
c := &Collector{
tunnels: make(map[string]*tunnelState),
now: time.Now,
}
for _, domain := range domains {
domain = normalizeDomain(domain)
if domain == "" {
continue
}
if _, exists := c.tunnels[domain]; exists {
continue
}
c.domains = append(c.domains, domain)
c.tunnels[domain] = &tunnelState{
clients: make(map[string]*clientState),
}
}
for _, opt := range opts {
opt(c)
}
return c
}
// Domains returns the normalized DNSTT domains this collector recognizes.
func (c *Collector) Domains() []string {
c.mu.Lock()
defer c.mu.Unlock()
domains := make([]string, len(c.domains))
copy(domains, c.domains)
return domains
}
// RecordQuery records an observed DNSTT DNS query.
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
c.mu.Lock()
defer c.mu.Unlock()
tunnel, ok := c.findTunnelLocked(domain)
if !ok {
return
}
tunnel.queries++
if size > 0 {
tunnel.bytesIn += uint64(size)
}
if clientID == "" {
return
}
now := c.now()
client, exists := tunnel.clients[clientID]
if !exists {
client = &clientState{firstSeen: now}
tunnel.clients[clientID] = client
}
client.lastSeen = now
client.queries++
if size > 0 {
client.bytesIn += uint64(size)
}
active := activeClients(tunnel, now)
if active > tunnel.peakClients {
tunnel.peakClients = active
}
}
// RecordResponse records an observed DNSTT DNS response.
func (c *Collector) RecordResponse(domain string, size int) {
c.mu.Lock()
defer c.mu.Unlock()
tunnel, ok := c.findTunnelLocked(domain)
if !ok {
return
}
if size > 0 {
tunnel.bytesOut += uint64(size)
}
}
// Snapshot returns a stable copy of all current metrics.
func (c *Collector) Snapshot() Snapshot {
c.mu.Lock()
defer c.mu.Unlock()
now := c.now()
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
for _, domain := range c.domains {
tunnel := c.tunnels[domain]
active := activeClients(tunnel, now)
if active > tunnel.peakClients {
tunnel.peakClients = active
}
snapshot.Tunnels[domain] = TunnelSnapshot{
Domain: domain,
ActiveClients: active,
PeakClients: tunnel.peakClients,
TotalSessions: len(tunnel.clients),
TotalQueries: tunnel.queries,
BytesIn: tunnel.bytesIn,
BytesOut: tunnel.bytesOut,
}
}
return snapshot
}
func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
queryDomain = normalizeDomain(queryDomain)
if tunnel, ok := c.tunnels[queryDomain]; ok {
return tunnel, true
}
for _, domain := range c.domains {
if strings.HasSuffix(queryDomain, "."+domain) {
return c.tunnels[domain], true
}
}
return nil, false
}
func activeClients(tunnel *tunnelState, now time.Time) int {
active := 0
for _, client := range tunnel.clients {
if now.Sub(client.lastSeen) < ClientTimeout {
active++
}
}
return active
}

View file

@ -0,0 +1,61 @@
package dnstt
import (
"testing"
"time"
)
func TestCollectorTracksActiveAndPeakClients(t *testing.T) {
now := time.Unix(1000, 0)
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
c.RecordQuery("tunnel.example.com", "client-a", 120)
c.RecordQuery("tunnel.example.com", "client-b", 80)
c.RecordResponse("tunnel.example.com", 200)
snapshot := c.Snapshot()
tunnel := snapshot.Tunnels["tunnel.example.com"]
if tunnel.ActiveClients != 2 {
t.Fatalf("active clients = %d, want 2", tunnel.ActiveClients)
}
if tunnel.PeakClients != 2 {
t.Fatalf("peak clients = %d, want 2", tunnel.PeakClients)
}
if tunnel.TotalSessions != 2 {
t.Fatalf("total sessions = %d, want 2", tunnel.TotalSessions)
}
if tunnel.TotalQueries != 2 {
t.Fatalf("queries = %d, want 2", tunnel.TotalQueries)
}
if tunnel.BytesIn != 200 {
t.Fatalf("bytes in = %d, want 200", tunnel.BytesIn)
}
if tunnel.BytesOut != 200 {
t.Fatalf("bytes out = %d, want 200", tunnel.BytesOut)
}
now = now.Add(ClientTimeout + time.Second)
snapshot = c.Snapshot()
tunnel = snapshot.Tunnels["tunnel.example.com"]
if tunnel.ActiveClients != 0 {
t.Fatalf("active clients after timeout = %d, want 0", tunnel.ActiveClients)
}
if tunnel.PeakClients != 2 {
t.Fatalf("peak clients after timeout = %d, want 2", tunnel.PeakClients)
}
}
func TestCollectorMatchesSubdomainsToRegisteredTunnel(t *testing.T) {
now := time.Unix(1000, 0)
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
c.RecordQuery("abcd.tunnel.example.com", "client-a", 120)
tunnel := c.Snapshot().Tunnels["tunnel.example.com"]
if tunnel.TotalQueries != 1 {
t.Fatalf("queries = %d, want 1", tunnel.TotalQueries)
}
if tunnel.ActiveClients != 1 {
t.Fatalf("active clients = %d, want 1", tunnel.ActiveClients)
}
}

View file

@ -0,0 +1,83 @@
package dnstt
import "github.com/prometheus/client_golang/prometheus"
const namespace = "dnstt"
// Exporter exposes aggregate DNSTT traffic metrics from a Collector.
type Exporter struct {
collector *Collector
activeClients *prometheus.Desc
peakClients *prometheus.Desc
queries *prometheus.Desc
bytesIn *prometheus.Desc
bytesOut *prometheus.Desc
sessions *prometheus.Desc
}
// NewExporter creates a Prometheus collector for DNSTT metrics.
func NewExporter(collector *Collector) *Exporter {
labels := []string{"domain"}
return &Exporter{
collector: collector,
activeClients: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "active_clients"),
"Number of DNSTT client sessions observed within the active timeout window.",
labels,
nil,
),
peakClients: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "peak_clients"),
"Maximum concurrent active DNSTT client sessions observed.",
labels,
nil,
),
queries: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "queries_total"),
"Total DNSTT DNS queries observed.",
labels,
nil,
),
bytesIn: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "bytes_in_total"),
"Total bytes observed in DNSTT DNS queries.",
labels,
nil,
),
bytesOut: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "bytes_out_total"),
"Total bytes observed in DNSTT DNS responses.",
labels,
nil,
),
sessions: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "sessions_total"),
"Total unique DNSTT client sessions observed.",
labels,
nil,
),
}
}
// Describe sends metric descriptors to Prometheus.
func (e *Exporter) Describe(ch chan<- *prometheus.Desc) {
ch <- e.activeClients
ch <- e.peakClients
ch <- e.queries
ch <- e.bytesIn
ch <- e.bytesOut
ch <- e.sessions
}
// Collect sends current metric values to Prometheus.
func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
for _, tunnel := range e.collector.Snapshot().Tunnels {
ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(tunnel.ActiveClients), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(tunnel.PeakClients), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(tunnel.TotalQueries), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(tunnel.BytesIn), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(tunnel.BytesOut), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(tunnel.TotalSessions), tunnel.Domain)
}
}

View file

@ -0,0 +1,42 @@
package dnstt
import (
"strings"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestExporterCollectsAggregateDNSTTMetrics(t *testing.T) {
now := time.Unix(1000, 0)
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
c.RecordQuery("tunnel.example.com", "client-a", 100)
c.RecordQuery("tunnel.example.com", "client-b", 300)
c.RecordResponse("tunnel.example.com", 250)
expected := `
# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window.
# TYPE dnstt_active_clients gauge
dnstt_active_clients{domain="tunnel.example.com"} 2
# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries.
# TYPE dnstt_bytes_in_total counter
dnstt_bytes_in_total{domain="tunnel.example.com"} 400
# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses.
# TYPE dnstt_bytes_out_total counter
dnstt_bytes_out_total{domain="tunnel.example.com"} 250
# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed.
# TYPE dnstt_peak_clients gauge
dnstt_peak_clients{domain="tunnel.example.com"} 2
# HELP dnstt_queries_total Total DNSTT DNS queries observed.
# TYPE dnstt_queries_total counter
dnstt_queries_total{domain="tunnel.example.com"} 2
# HELP dnstt_sessions_total Total unique DNSTT client sessions observed.
# TYPE dnstt_sessions_total counter
dnstt_sessions_total{domain="tunnel.example.com"} 2
`
if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil {
t.Fatal(err)
}
}

60
internal/dnstt/packet.go Normal file
View file

@ -0,0 +1,60 @@
package dnstt
import (
"encoding/binary"
"strings"
)
// ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet.
func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
if len(data) < 20 || data[0]>>4 != 4 {
return
}
ihl := int(data[0]&0x0f) * 4
if ihl < 20 || len(data) < ihl {
return
}
if data[9] != 17 {
return
}
udpData := data[ihl:]
if len(udpData) < 8 {
return
}
srcPort := binary.BigEndian.Uint16(udpData[0:2])
dstPort := binary.BigEndian.Uint16(udpData[2:4])
udpLen := int(binary.BigEndian.Uint16(udpData[4:6]))
dnsPayload := udpData[8:]
if len(dnsPayload) < 12 {
return
}
if int(dstPort) == port {
domain, clientID := ExtractQuery(dnsPayload, collector.Domains())
if domain != "" {
collector.RecordQuery(domain, clientID, udpLen)
}
return
}
if int(srcPort) == port {
domain := extractQueryDomain(dnsPayload)
if domain != "" {
collector.RecordResponse(domain, udpLen)
}
}
}
func extractQueryDomain(dns []byte) string {
if len(dns) < 12 {
return ""
}
labels := parseDNSLabels(dns[12:])
if len(labels) == 0 {
return ""
}
return normalizeDomain(strings.Join(labels, "."))
}

82
internal/dnstt/parser.go Normal file
View file

@ -0,0 +1,82 @@
package dnstt
import (
"encoding/base32"
"fmt"
"strings"
)
const clientIDLen = 8
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
// ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query.
func ExtractQuery(dns []byte, domains []string) (string, string) {
if len(dns) < 12 {
return "", ""
}
labels := parseDNSLabels(dns[12:])
if len(labels) == 0 {
return "", ""
}
fullDomain := strings.ToLower(strings.Join(labels, "."))
for _, domain := range domains {
domain = normalizeDomain(domain)
if fullDomain == domain {
return domain, ""
}
suffix := "." + domain
if !strings.HasSuffix(fullDomain, suffix) {
continue
}
tunnelLabels := strings.Count(domain, ".") + 1
if len(labels) <= tunnelLabels {
continue
}
prefixLabels := labels[:len(labels)-tunnelLabels]
encoded := strings.ToUpper(strings.Join(prefixLabels, ""))
decoded := make([]byte, base32Encoding.DecodedLen(len(encoded)))
n, err := base32Encoding.Decode(decoded, []byte(encoded))
if err != nil || n < clientIDLen {
return domain, ""
}
return domain, fmt.Sprintf("%x", decoded[:clientIDLen])
}
return "", ""
}
func parseDNSLabels(data []byte) []string {
var labels []string
offset := 0
for offset < len(data) {
labelLen := int(data[offset])
if labelLen == 0 {
break
}
if labelLen&0xc0 == 0xc0 {
break
}
offset++
if offset+labelLen > len(data) {
return nil
}
labels = append(labels, string(data[offset:offset+labelLen]))
offset += labelLen
}
return labels
}
func normalizeDomain(domain string) string {
return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".")
}

View file

@ -0,0 +1,93 @@
package dnstt
import (
"encoding/base32"
"testing"
)
func TestExtractQueryDecodesDNSTTClientID(t *testing.T) {
clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb))
packet := dnsQuery(encoded, "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "0102030405060708" {
t.Fatalf("client ID = %q, want 0102030405060708", gotClientID)
}
}
func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) {
packet := dnsQuery("", "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "" {
t.Fatalf("client ID = %q, want empty", gotClientID)
}
}
func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) {
packet := dnsQuery("abcd", "other.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "" || gotClientID != "" {
t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID)
}
}
func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) {
packet := dnsQuery("not-base32-!", "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "" {
t.Fatalf("client ID = %q, want empty", gotClientID)
}
}
func dnsQuery(prefix string, domain string) []byte {
packet := []byte{
0x12, 0x34, // transaction ID
0x01, 0x00, // flags
0x00, 0x01, // questions
0x00, 0x00, // answers
0x00, 0x00, // authority
0x00, 0x00, // additional
}
labels := splitDomain(domain)
if prefix != "" {
labels = append([]string{prefix}, labels...)
}
for _, label := range labels {
packet = append(packet, byte(len(label)))
packet = append(packet, label...)
}
packet = append(packet, 0x00) // root label
packet = append(packet, 0x00, 0x10) // TXT
packet = append(packet, 0x00, 0x01) // IN
return packet
}
func splitDomain(domain string) []string {
var labels []string
start := 0
for i := 0; i <= len(domain); i++ {
if i == len(domain) || domain[i] == '.' {
labels = append(labels, domain[start:i])
start = i + 1
}
}
return labels
}