initial working version

This commit is contained in:
Abel Luck 2026-05-05 13:43:02 +02:00
parent 4d8b83cbb6
commit 8318f9fe70
15 changed files with 917 additions and 4 deletions

View file

@ -0,0 +1,50 @@
//go:build linux
package dnstt
import (
"fmt"
"syscall"
)
// OpenRawSocket creates an AF_PACKET raw socket for sniffing IPv4 packets.
func OpenRawSocket() (int, error) {
fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(syscall.ETH_P_IP)))
if err != nil {
return -1, fmt.Errorf("open raw socket: %w", err)
}
tv := syscall.Timeval{Sec: 0, Usec: 500000}
_ = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv)
return fd, nil
}
// CaptureLoop records matching packets until stop is closed.
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
buf := make([]byte, 65535)
for {
select {
case <-stop:
return
default:
}
n, _, err := syscall.Recvfrom(fd, buf, 0)
if err != nil {
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK || err == syscall.EINTR {
continue
}
return
}
ProcessIPv4Packet(buf[:n], port, collector)
}
}
// CloseRawSocket closes a socket opened by OpenRawSocket.
func CloseRawSocket(fd int) error {
return syscall.Close(fd)
}
func htons(v uint16) uint16 {
return (v << 8) | (v >> 8)
}

View file

@ -0,0 +1,20 @@
//go:build !linux
package dnstt
import "fmt"
// OpenRawSocket is only implemented on Linux.
func OpenRawSocket() (int, error) {
return -1, fmt.Errorf("raw packet capture is only supported on Linux")
}
// CaptureLoop is only implemented on Linux.
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
<-stop
}
// CloseRawSocket is only implemented on Linux.
func CloseRawSocket(fd int) error {
return nil
}

199
internal/dnstt/collector.go Normal file
View file

@ -0,0 +1,199 @@
package dnstt
import (
"strings"
"sync"
"time"
)
// ClientTimeout is how long since the last query before a session is considered inactive.
const ClientTimeout = 30 * time.Second
// Collector accumulates observed DNSTT DNS traffic.
type Collector struct {
mu sync.Mutex
domains []string
tunnels map[string]*tunnelState
now func() time.Time
}
type tunnelState struct {
queries uint64
bytesIn uint64
bytesOut uint64
peakClients int
clients map[string]*clientState
}
type clientState struct {
firstSeen time.Time
lastSeen time.Time
queries uint64
bytesIn uint64
}
// Snapshot is a point-in-time copy of observed traffic.
type Snapshot struct {
Tunnels map[string]TunnelSnapshot
}
// TunnelSnapshot is a point-in-time copy of one DNSTT domain.
type TunnelSnapshot struct {
Domain string
ActiveClients int
PeakClients int
TotalSessions int
TotalQueries uint64
BytesIn uint64
BytesOut uint64
}
// CollectorOption configures a Collector.
type CollectorOption func(*Collector)
// WithNow overrides the clock. It is intended for deterministic tests.
func WithNow(now func() time.Time) CollectorOption {
return func(c *Collector) {
c.now = now
}
}
// NewCollector creates a collector for the provided DNSTT domains.
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
c := &Collector{
tunnels: make(map[string]*tunnelState),
now: time.Now,
}
for _, domain := range domains {
domain = normalizeDomain(domain)
if domain == "" {
continue
}
if _, exists := c.tunnels[domain]; exists {
continue
}
c.domains = append(c.domains, domain)
c.tunnels[domain] = &tunnelState{
clients: make(map[string]*clientState),
}
}
for _, opt := range opts {
opt(c)
}
return c
}
// Domains returns the normalized DNSTT domains this collector recognizes.
func (c *Collector) Domains() []string {
c.mu.Lock()
defer c.mu.Unlock()
domains := make([]string, len(c.domains))
copy(domains, c.domains)
return domains
}
// RecordQuery records an observed DNSTT DNS query.
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
c.mu.Lock()
defer c.mu.Unlock()
tunnel, ok := c.findTunnelLocked(domain)
if !ok {
return
}
tunnel.queries++
if size > 0 {
tunnel.bytesIn += uint64(size)
}
if clientID == "" {
return
}
now := c.now()
client, exists := tunnel.clients[clientID]
if !exists {
client = &clientState{firstSeen: now}
tunnel.clients[clientID] = client
}
client.lastSeen = now
client.queries++
if size > 0 {
client.bytesIn += uint64(size)
}
active := activeClients(tunnel, now)
if active > tunnel.peakClients {
tunnel.peakClients = active
}
}
// RecordResponse records an observed DNSTT DNS response.
func (c *Collector) RecordResponse(domain string, size int) {
c.mu.Lock()
defer c.mu.Unlock()
tunnel, ok := c.findTunnelLocked(domain)
if !ok {
return
}
if size > 0 {
tunnel.bytesOut += uint64(size)
}
}
// Snapshot returns a stable copy of all current metrics.
func (c *Collector) Snapshot() Snapshot {
c.mu.Lock()
defer c.mu.Unlock()
now := c.now()
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
for _, domain := range c.domains {
tunnel := c.tunnels[domain]
active := activeClients(tunnel, now)
if active > tunnel.peakClients {
tunnel.peakClients = active
}
snapshot.Tunnels[domain] = TunnelSnapshot{
Domain: domain,
ActiveClients: active,
PeakClients: tunnel.peakClients,
TotalSessions: len(tunnel.clients),
TotalQueries: tunnel.queries,
BytesIn: tunnel.bytesIn,
BytesOut: tunnel.bytesOut,
}
}
return snapshot
}
func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
queryDomain = normalizeDomain(queryDomain)
if tunnel, ok := c.tunnels[queryDomain]; ok {
return tunnel, true
}
for _, domain := range c.domains {
if strings.HasSuffix(queryDomain, "."+domain) {
return c.tunnels[domain], true
}
}
return nil, false
}
func activeClients(tunnel *tunnelState, now time.Time) int {
active := 0
for _, client := range tunnel.clients {
if now.Sub(client.lastSeen) < ClientTimeout {
active++
}
}
return active
}

View file

@ -0,0 +1,61 @@
package dnstt
import (
"testing"
"time"
)
func TestCollectorTracksActiveAndPeakClients(t *testing.T) {
now := time.Unix(1000, 0)
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
c.RecordQuery("tunnel.example.com", "client-a", 120)
c.RecordQuery("tunnel.example.com", "client-b", 80)
c.RecordResponse("tunnel.example.com", 200)
snapshot := c.Snapshot()
tunnel := snapshot.Tunnels["tunnel.example.com"]
if tunnel.ActiveClients != 2 {
t.Fatalf("active clients = %d, want 2", tunnel.ActiveClients)
}
if tunnel.PeakClients != 2 {
t.Fatalf("peak clients = %d, want 2", tunnel.PeakClients)
}
if tunnel.TotalSessions != 2 {
t.Fatalf("total sessions = %d, want 2", tunnel.TotalSessions)
}
if tunnel.TotalQueries != 2 {
t.Fatalf("queries = %d, want 2", tunnel.TotalQueries)
}
if tunnel.BytesIn != 200 {
t.Fatalf("bytes in = %d, want 200", tunnel.BytesIn)
}
if tunnel.BytesOut != 200 {
t.Fatalf("bytes out = %d, want 200", tunnel.BytesOut)
}
now = now.Add(ClientTimeout + time.Second)
snapshot = c.Snapshot()
tunnel = snapshot.Tunnels["tunnel.example.com"]
if tunnel.ActiveClients != 0 {
t.Fatalf("active clients after timeout = %d, want 0", tunnel.ActiveClients)
}
if tunnel.PeakClients != 2 {
t.Fatalf("peak clients after timeout = %d, want 2", tunnel.PeakClients)
}
}
func TestCollectorMatchesSubdomainsToRegisteredTunnel(t *testing.T) {
now := time.Unix(1000, 0)
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
c.RecordQuery("abcd.tunnel.example.com", "client-a", 120)
tunnel := c.Snapshot().Tunnels["tunnel.example.com"]
if tunnel.TotalQueries != 1 {
t.Fatalf("queries = %d, want 1", tunnel.TotalQueries)
}
if tunnel.ActiveClients != 1 {
t.Fatalf("active clients = %d, want 1", tunnel.ActiveClients)
}
}

View file

@ -0,0 +1,83 @@
package dnstt
import "github.com/prometheus/client_golang/prometheus"
const namespace = "dnstt"
// Exporter exposes aggregate DNSTT traffic metrics from a Collector.
type Exporter struct {
collector *Collector
activeClients *prometheus.Desc
peakClients *prometheus.Desc
queries *prometheus.Desc
bytesIn *prometheus.Desc
bytesOut *prometheus.Desc
sessions *prometheus.Desc
}
// NewExporter creates a Prometheus collector for DNSTT metrics.
func NewExporter(collector *Collector) *Exporter {
labels := []string{"domain"}
return &Exporter{
collector: collector,
activeClients: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "active_clients"),
"Number of DNSTT client sessions observed within the active timeout window.",
labels,
nil,
),
peakClients: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "peak_clients"),
"Maximum concurrent active DNSTT client sessions observed.",
labels,
nil,
),
queries: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "queries_total"),
"Total DNSTT DNS queries observed.",
labels,
nil,
),
bytesIn: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "bytes_in_total"),
"Total bytes observed in DNSTT DNS queries.",
labels,
nil,
),
bytesOut: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "bytes_out_total"),
"Total bytes observed in DNSTT DNS responses.",
labels,
nil,
),
sessions: prometheus.NewDesc(
prometheus.BuildFQName(namespace, "", "sessions_total"),
"Total unique DNSTT client sessions observed.",
labels,
nil,
),
}
}
// Describe sends metric descriptors to Prometheus.
func (e *Exporter) Describe(ch chan<- *prometheus.Desc) {
ch <- e.activeClients
ch <- e.peakClients
ch <- e.queries
ch <- e.bytesIn
ch <- e.bytesOut
ch <- e.sessions
}
// Collect sends current metric values to Prometheus.
func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
for _, tunnel := range e.collector.Snapshot().Tunnels {
ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(tunnel.ActiveClients), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(tunnel.PeakClients), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(tunnel.TotalQueries), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(tunnel.BytesIn), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(tunnel.BytesOut), tunnel.Domain)
ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(tunnel.TotalSessions), tunnel.Domain)
}
}

View file

@ -0,0 +1,42 @@
package dnstt
import (
"strings"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestExporterCollectsAggregateDNSTTMetrics(t *testing.T) {
now := time.Unix(1000, 0)
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
c.RecordQuery("tunnel.example.com", "client-a", 100)
c.RecordQuery("tunnel.example.com", "client-b", 300)
c.RecordResponse("tunnel.example.com", 250)
expected := `
# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window.
# TYPE dnstt_active_clients gauge
dnstt_active_clients{domain="tunnel.example.com"} 2
# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries.
# TYPE dnstt_bytes_in_total counter
dnstt_bytes_in_total{domain="tunnel.example.com"} 400
# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses.
# TYPE dnstt_bytes_out_total counter
dnstt_bytes_out_total{domain="tunnel.example.com"} 250
# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed.
# TYPE dnstt_peak_clients gauge
dnstt_peak_clients{domain="tunnel.example.com"} 2
# HELP dnstt_queries_total Total DNSTT DNS queries observed.
# TYPE dnstt_queries_total counter
dnstt_queries_total{domain="tunnel.example.com"} 2
# HELP dnstt_sessions_total Total unique DNSTT client sessions observed.
# TYPE dnstt_sessions_total counter
dnstt_sessions_total{domain="tunnel.example.com"} 2
`
if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil {
t.Fatal(err)
}
}

60
internal/dnstt/packet.go Normal file
View file

@ -0,0 +1,60 @@
package dnstt
import (
"encoding/binary"
"strings"
)
// ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet.
func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
if len(data) < 20 || data[0]>>4 != 4 {
return
}
ihl := int(data[0]&0x0f) * 4
if ihl < 20 || len(data) < ihl {
return
}
if data[9] != 17 {
return
}
udpData := data[ihl:]
if len(udpData) < 8 {
return
}
srcPort := binary.BigEndian.Uint16(udpData[0:2])
dstPort := binary.BigEndian.Uint16(udpData[2:4])
udpLen := int(binary.BigEndian.Uint16(udpData[4:6]))
dnsPayload := udpData[8:]
if len(dnsPayload) < 12 {
return
}
if int(dstPort) == port {
domain, clientID := ExtractQuery(dnsPayload, collector.Domains())
if domain != "" {
collector.RecordQuery(domain, clientID, udpLen)
}
return
}
if int(srcPort) == port {
domain := extractQueryDomain(dnsPayload)
if domain != "" {
collector.RecordResponse(domain, udpLen)
}
}
}
func extractQueryDomain(dns []byte) string {
if len(dns) < 12 {
return ""
}
labels := parseDNSLabels(dns[12:])
if len(labels) == 0 {
return ""
}
return normalizeDomain(strings.Join(labels, "."))
}

82
internal/dnstt/parser.go Normal file
View file

@ -0,0 +1,82 @@
package dnstt
import (
"encoding/base32"
"fmt"
"strings"
)
const clientIDLen = 8
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
// ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query.
func ExtractQuery(dns []byte, domains []string) (string, string) {
if len(dns) < 12 {
return "", ""
}
labels := parseDNSLabels(dns[12:])
if len(labels) == 0 {
return "", ""
}
fullDomain := strings.ToLower(strings.Join(labels, "."))
for _, domain := range domains {
domain = normalizeDomain(domain)
if fullDomain == domain {
return domain, ""
}
suffix := "." + domain
if !strings.HasSuffix(fullDomain, suffix) {
continue
}
tunnelLabels := strings.Count(domain, ".") + 1
if len(labels) <= tunnelLabels {
continue
}
prefixLabels := labels[:len(labels)-tunnelLabels]
encoded := strings.ToUpper(strings.Join(prefixLabels, ""))
decoded := make([]byte, base32Encoding.DecodedLen(len(encoded)))
n, err := base32Encoding.Decode(decoded, []byte(encoded))
if err != nil || n < clientIDLen {
return domain, ""
}
return domain, fmt.Sprintf("%x", decoded[:clientIDLen])
}
return "", ""
}
func parseDNSLabels(data []byte) []string {
var labels []string
offset := 0
for offset < len(data) {
labelLen := int(data[offset])
if labelLen == 0 {
break
}
if labelLen&0xc0 == 0xc0 {
break
}
offset++
if offset+labelLen > len(data) {
return nil
}
labels = append(labels, string(data[offset:offset+labelLen]))
offset += labelLen
}
return labels
}
func normalizeDomain(domain string) string {
return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".")
}

View file

@ -0,0 +1,93 @@
package dnstt
import (
"encoding/base32"
"testing"
)
func TestExtractQueryDecodesDNSTTClientID(t *testing.T) {
clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb))
packet := dnsQuery(encoded, "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "0102030405060708" {
t.Fatalf("client ID = %q, want 0102030405060708", gotClientID)
}
}
func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) {
packet := dnsQuery("", "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "" {
t.Fatalf("client ID = %q, want empty", gotClientID)
}
}
func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) {
packet := dnsQuery("abcd", "other.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "" || gotClientID != "" {
t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID)
}
}
func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) {
packet := dnsQuery("not-base32-!", "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "" {
t.Fatalf("client ID = %q, want empty", gotClientID)
}
}
func dnsQuery(prefix string, domain string) []byte {
packet := []byte{
0x12, 0x34, // transaction ID
0x01, 0x00, // flags
0x00, 0x01, // questions
0x00, 0x00, // answers
0x00, 0x00, // authority
0x00, 0x00, // additional
}
labels := splitDomain(domain)
if prefix != "" {
labels = append([]string{prefix}, labels...)
}
for _, label := range labels {
packet = append(packet, byte(len(label)))
packet = append(packet, label...)
}
packet = append(packet, 0x00) // root label
packet = append(packet, 0x00, 0x10) // TXT
packet = append(packet, 0x00, 0x01) // IN
return packet
}
func splitDomain(domain string) []string {
var labels []string
start := 0
for i := 0; i <= len(domain); i++ {
if i == len(domain) || domain[i] == '.' {
labels = append(labels, domain[start:i])
start = i + 1
}
}
return labels
}