diff --git a/.gitignore b/.gitignore index 8de5610..2a79a47 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ .direnv result - +/dnstt_exporter diff --git a/README.md b/README.md index 1a19166..b8cb1cd 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,39 @@ # dnstt_exporter Prometheus exporter for DNSTT client/session metrics. + +`dnstt_exporter` observes DNSTT DNS traffic on a local Linux host and exports +aggregate Prometheus metrics. It does not proxy, terminate, or configure DNSTT; +it passively decodes DNSTT session IDs from DNS query names. + +## Usage + +```sh +sudo dnstt_exporter \ + -dnstt.domain tunnel.example.com \ + -dnstt.port 53 \ + -web.listen-address :9713 +``` + +The exporter needs permission to open an `AF_PACKET` raw socket. Run it as root +or grant the binary `CAP_NET_RAW`. + +Metrics are served at `http://127.0.0.1:9713/metrics` by default. + +## Metrics + +All DNSTT metrics use a `domain` label: + +- `dnstt_active_clients` +- `dnstt_peak_clients` +- `dnstt_queries_total` +- `dnstt_bytes_in_total` +- `dnstt_bytes_out_total` +- `dnstt_sessions_total` + +## Development + +```sh +go test ./... +go build ./cmd/dnstt_exporter +``` diff --git a/cmd/dnstt_exporter/main.go b/cmd/dnstt_exporter/main.go new file mode 100644 index 0000000..88e0002 --- /dev/null +++ b/cmd/dnstt_exporter/main.go @@ -0,0 +1,104 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + + "guardianproject.dev/bypass-censorship/dnstt_exporter/internal/dnstt" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +type domainList []string + +func (d *domainList) String() string { + return strings.Join(*d, ",") +} + +func (d *domainList) Set(value string) error { + value = strings.TrimSpace(value) + if value == "" { + return errors.New("domain cannot be empty") + } + *d = append(*d, value) + return nil +} + +func main() { + var domains domainList + var ( + listenAddress = flag.String("web.listen-address", ":9713", "Address on which to expose metrics and web interface.") + metricsPath = flag.String("web.telemetry-path", "/metrics", "Path under which to expose metrics.") + dnsPort = flag.Int("dnstt.port", 53, "UDP port where DNSTT DNS traffic is observed.") + ) + flag.Var(&domains, "dnstt.domain", "DNSTT tunnel domain to observe. Repeat for multiple domains.") + flag.Parse() + + if len(domains) == 0 { + log.Fatal("at least one -dnstt.domain is required") + } + if *dnsPort < 1 || *dnsPort > 65535 { + log.Fatalf("-dnstt.port must be between 1 and 65535, got %d", *dnsPort) + } + + collector := dnstt.NewCollector(domains) + fd, err := dnstt.OpenRawSocket() + if err != nil { + log.Fatalf("failed to open raw socket: %v", err) + } + defer dnstt.CloseRawSocket(fd) + + stop := make(chan struct{}) + go dnstt.CaptureLoop(fd, *dnsPort, collector, stop) + + registry := prometheus.NewRegistry() + registry.MustRegister(dnstt.NewExporter(collector)) + registry.MustRegister(collectors.NewGoCollector()) + registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + + mux := http.NewServeMux() + mux.Handle(*metricsPath, promhttp.HandlerFor(registry, promhttp.HandlerOpts{})) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + fmt.Fprintf(w, ` +DNSTT Exporter + +

DNSTT Exporter

+

Metrics

+ + +`, *metricsPath) + }) + + errCh := make(chan error, 1) + server := &http.Server{Addr: *listenAddress, Handler: mux} + go func() { + log.Printf("listening on %s", *listenAddress) + errCh <- server.ListenAndServe() + }() + + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) + + select { + case sig := <-signalCh: + log.Printf("received %s, shutting down", sig) + close(stop) + case err := <-errCh: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("server error: %v", err) + } + } +} diff --git a/flake.nix b/flake.nix index fd6dbf5..d370277 100644 --- a/flake.nix +++ b/flake.nix @@ -11,9 +11,27 @@ forAllSystems = fn: nixpkgs.lib.genAttrs systems (system: fn nixpkgs.legacyPackages.${system}); in { - #packages = forAllSystems (pkgs: { - # default = pkgs.callPackage ./package.nix { }; - #}); + packages = forAllSystems (pkgs: { + default = pkgs.buildGoModule { + pname = "dnstt_exporter"; + version = "0.1.0"; + src = ./.; + vendorHash = "sha256-+olXxVW9VcvapOuJGas70RIfJgMzUv6mndzJi6Apw+s="; + + subPackages = [ "cmd/dnstt_exporter" ]; + }; + }); + + checks = forAllSystems (pkgs: { + tests = self.packages.${pkgs.stdenv.hostPlatform.system}.default.overrideAttrs (_: { + doCheck = true; + checkPhase = '' + runHook preCheck + go test ./... + runHook postCheck + ''; + }); + }); #checks = forAllSystems ( # pkgs: diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..4872657 --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module guardianproject.dev/bypass-censorship/dnstt_exporter + +go 1.25 + +require github.com/prometheus/client_golang v1.23.2 + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/sys v0.35.0 // indirect + google.golang.org/protobuf v1.36.8 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d6b8ca9 --- /dev/null +++ b/go.sum @@ -0,0 +1,46 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/dnstt/capture_linux.go b/internal/dnstt/capture_linux.go new file mode 100644 index 0000000..b416033 --- /dev/null +++ b/internal/dnstt/capture_linux.go @@ -0,0 +1,50 @@ +//go:build linux + +package dnstt + +import ( + "fmt" + "syscall" +) + +// OpenRawSocket creates an AF_PACKET raw socket for sniffing IPv4 packets. +func OpenRawSocket() (int, error) { + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(syscall.ETH_P_IP))) + if err != nil { + return -1, fmt.Errorf("open raw socket: %w", err) + } + + tv := syscall.Timeval{Sec: 0, Usec: 500000} + _ = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv) + return fd, nil +} + +// CaptureLoop records matching packets until stop is closed. +func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) { + buf := make([]byte, 65535) + for { + select { + case <-stop: + return + default: + } + + n, _, err := syscall.Recvfrom(fd, buf, 0) + if err != nil { + if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK || err == syscall.EINTR { + continue + } + return + } + ProcessIPv4Packet(buf[:n], port, collector) + } +} + +// CloseRawSocket closes a socket opened by OpenRawSocket. +func CloseRawSocket(fd int) error { + return syscall.Close(fd) +} + +func htons(v uint16) uint16 { + return (v << 8) | (v >> 8) +} diff --git a/internal/dnstt/capture_unsupported.go b/internal/dnstt/capture_unsupported.go new file mode 100644 index 0000000..3eb64ce --- /dev/null +++ b/internal/dnstt/capture_unsupported.go @@ -0,0 +1,20 @@ +//go:build !linux + +package dnstt + +import "fmt" + +// OpenRawSocket is only implemented on Linux. +func OpenRawSocket() (int, error) { + return -1, fmt.Errorf("raw packet capture is only supported on Linux") +} + +// CaptureLoop is only implemented on Linux. +func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) { + <-stop +} + +// CloseRawSocket is only implemented on Linux. +func CloseRawSocket(fd int) error { + return nil +} diff --git a/internal/dnstt/collector.go b/internal/dnstt/collector.go new file mode 100644 index 0000000..b6207dc --- /dev/null +++ b/internal/dnstt/collector.go @@ -0,0 +1,199 @@ +package dnstt + +import ( + "strings" + "sync" + "time" +) + +// ClientTimeout is how long since the last query before a session is considered inactive. +const ClientTimeout = 30 * time.Second + +// Collector accumulates observed DNSTT DNS traffic. +type Collector struct { + mu sync.Mutex + domains []string + tunnels map[string]*tunnelState + now func() time.Time +} + +type tunnelState struct { + queries uint64 + bytesIn uint64 + bytesOut uint64 + peakClients int + clients map[string]*clientState +} + +type clientState struct { + firstSeen time.Time + lastSeen time.Time + queries uint64 + bytesIn uint64 +} + +// Snapshot is a point-in-time copy of observed traffic. +type Snapshot struct { + Tunnels map[string]TunnelSnapshot +} + +// TunnelSnapshot is a point-in-time copy of one DNSTT domain. +type TunnelSnapshot struct { + Domain string + ActiveClients int + PeakClients int + TotalSessions int + TotalQueries uint64 + BytesIn uint64 + BytesOut uint64 +} + +// CollectorOption configures a Collector. +type CollectorOption func(*Collector) + +// WithNow overrides the clock. It is intended for deterministic tests. +func WithNow(now func() time.Time) CollectorOption { + return func(c *Collector) { + c.now = now + } +} + +// NewCollector creates a collector for the provided DNSTT domains. +func NewCollector(domains []string, opts ...CollectorOption) *Collector { + c := &Collector{ + tunnels: make(map[string]*tunnelState), + now: time.Now, + } + + for _, domain := range domains { + domain = normalizeDomain(domain) + if domain == "" { + continue + } + if _, exists := c.tunnels[domain]; exists { + continue + } + c.domains = append(c.domains, domain) + c.tunnels[domain] = &tunnelState{ + clients: make(map[string]*clientState), + } + } + + for _, opt := range opts { + opt(c) + } + + return c +} + +// Domains returns the normalized DNSTT domains this collector recognizes. +func (c *Collector) Domains() []string { + c.mu.Lock() + defer c.mu.Unlock() + + domains := make([]string, len(c.domains)) + copy(domains, c.domains) + return domains +} + +// RecordQuery records an observed DNSTT DNS query. +func (c *Collector) RecordQuery(domain string, clientID string, size int) { + c.mu.Lock() + defer c.mu.Unlock() + + tunnel, ok := c.findTunnelLocked(domain) + if !ok { + return + } + + tunnel.queries++ + if size > 0 { + tunnel.bytesIn += uint64(size) + } + + if clientID == "" { + return + } + + now := c.now() + client, exists := tunnel.clients[clientID] + if !exists { + client = &clientState{firstSeen: now} + tunnel.clients[clientID] = client + } + client.lastSeen = now + client.queries++ + if size > 0 { + client.bytesIn += uint64(size) + } + + active := activeClients(tunnel, now) + if active > tunnel.peakClients { + tunnel.peakClients = active + } +} + +// RecordResponse records an observed DNSTT DNS response. +func (c *Collector) RecordResponse(domain string, size int) { + c.mu.Lock() + defer c.mu.Unlock() + + tunnel, ok := c.findTunnelLocked(domain) + if !ok { + return + } + if size > 0 { + tunnel.bytesOut += uint64(size) + } +} + +// Snapshot returns a stable copy of all current metrics. +func (c *Collector) Snapshot() Snapshot { + c.mu.Lock() + defer c.mu.Unlock() + + now := c.now() + snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))} + for _, domain := range c.domains { + tunnel := c.tunnels[domain] + active := activeClients(tunnel, now) + if active > tunnel.peakClients { + tunnel.peakClients = active + } + + snapshot.Tunnels[domain] = TunnelSnapshot{ + Domain: domain, + ActiveClients: active, + PeakClients: tunnel.peakClients, + TotalSessions: len(tunnel.clients), + TotalQueries: tunnel.queries, + BytesIn: tunnel.bytesIn, + BytesOut: tunnel.bytesOut, + } + } + + return snapshot +} + +func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) { + queryDomain = normalizeDomain(queryDomain) + if tunnel, ok := c.tunnels[queryDomain]; ok { + return tunnel, true + } + for _, domain := range c.domains { + if strings.HasSuffix(queryDomain, "."+domain) { + return c.tunnels[domain], true + } + } + return nil, false +} + +func activeClients(tunnel *tunnelState, now time.Time) int { + active := 0 + for _, client := range tunnel.clients { + if now.Sub(client.lastSeen) < ClientTimeout { + active++ + } + } + return active +} diff --git a/internal/dnstt/collector_test.go b/internal/dnstt/collector_test.go new file mode 100644 index 0000000..d46c923 --- /dev/null +++ b/internal/dnstt/collector_test.go @@ -0,0 +1,61 @@ +package dnstt + +import ( + "testing" + "time" +) + +func TestCollectorTracksActiveAndPeakClients(t *testing.T) { + now := time.Unix(1000, 0) + c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now })) + + c.RecordQuery("tunnel.example.com", "client-a", 120) + c.RecordQuery("tunnel.example.com", "client-b", 80) + c.RecordResponse("tunnel.example.com", 200) + + snapshot := c.Snapshot() + tunnel := snapshot.Tunnels["tunnel.example.com"] + if tunnel.ActiveClients != 2 { + t.Fatalf("active clients = %d, want 2", tunnel.ActiveClients) + } + if tunnel.PeakClients != 2 { + t.Fatalf("peak clients = %d, want 2", tunnel.PeakClients) + } + if tunnel.TotalSessions != 2 { + t.Fatalf("total sessions = %d, want 2", tunnel.TotalSessions) + } + if tunnel.TotalQueries != 2 { + t.Fatalf("queries = %d, want 2", tunnel.TotalQueries) + } + if tunnel.BytesIn != 200 { + t.Fatalf("bytes in = %d, want 200", tunnel.BytesIn) + } + if tunnel.BytesOut != 200 { + t.Fatalf("bytes out = %d, want 200", tunnel.BytesOut) + } + + now = now.Add(ClientTimeout + time.Second) + snapshot = c.Snapshot() + tunnel = snapshot.Tunnels["tunnel.example.com"] + if tunnel.ActiveClients != 0 { + t.Fatalf("active clients after timeout = %d, want 0", tunnel.ActiveClients) + } + if tunnel.PeakClients != 2 { + t.Fatalf("peak clients after timeout = %d, want 2", tunnel.PeakClients) + } +} + +func TestCollectorMatchesSubdomainsToRegisteredTunnel(t *testing.T) { + now := time.Unix(1000, 0) + c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now })) + + c.RecordQuery("abcd.tunnel.example.com", "client-a", 120) + + tunnel := c.Snapshot().Tunnels["tunnel.example.com"] + if tunnel.TotalQueries != 1 { + t.Fatalf("queries = %d, want 1", tunnel.TotalQueries) + } + if tunnel.ActiveClients != 1 { + t.Fatalf("active clients = %d, want 1", tunnel.ActiveClients) + } +} diff --git a/internal/dnstt/exporter.go b/internal/dnstt/exporter.go new file mode 100644 index 0000000..a36019c --- /dev/null +++ b/internal/dnstt/exporter.go @@ -0,0 +1,83 @@ +package dnstt + +import "github.com/prometheus/client_golang/prometheus" + +const namespace = "dnstt" + +// Exporter exposes aggregate DNSTT traffic metrics from a Collector. +type Exporter struct { + collector *Collector + + activeClients *prometheus.Desc + peakClients *prometheus.Desc + queries *prometheus.Desc + bytesIn *prometheus.Desc + bytesOut *prometheus.Desc + sessions *prometheus.Desc +} + +// NewExporter creates a Prometheus collector for DNSTT metrics. +func NewExporter(collector *Collector) *Exporter { + labels := []string{"domain"} + return &Exporter{ + collector: collector, + activeClients: prometheus.NewDesc( + prometheus.BuildFQName(namespace, "", "active_clients"), + "Number of DNSTT client sessions observed within the active timeout window.", + labels, + nil, + ), + peakClients: prometheus.NewDesc( + prometheus.BuildFQName(namespace, "", "peak_clients"), + "Maximum concurrent active DNSTT client sessions observed.", + labels, + nil, + ), + queries: prometheus.NewDesc( + prometheus.BuildFQName(namespace, "", "queries_total"), + "Total DNSTT DNS queries observed.", + labels, + nil, + ), + bytesIn: prometheus.NewDesc( + prometheus.BuildFQName(namespace, "", "bytes_in_total"), + "Total bytes observed in DNSTT DNS queries.", + labels, + nil, + ), + bytesOut: prometheus.NewDesc( + prometheus.BuildFQName(namespace, "", "bytes_out_total"), + "Total bytes observed in DNSTT DNS responses.", + labels, + nil, + ), + sessions: prometheus.NewDesc( + prometheus.BuildFQName(namespace, "", "sessions_total"), + "Total unique DNSTT client sessions observed.", + labels, + nil, + ), + } +} + +// Describe sends metric descriptors to Prometheus. +func (e *Exporter) Describe(ch chan<- *prometheus.Desc) { + ch <- e.activeClients + ch <- e.peakClients + ch <- e.queries + ch <- e.bytesIn + ch <- e.bytesOut + ch <- e.sessions +} + +// Collect sends current metric values to Prometheus. +func (e *Exporter) Collect(ch chan<- prometheus.Metric) { + for _, tunnel := range e.collector.Snapshot().Tunnels { + ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(tunnel.ActiveClients), tunnel.Domain) + ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(tunnel.PeakClients), tunnel.Domain) + ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(tunnel.TotalQueries), tunnel.Domain) + ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(tunnel.BytesIn), tunnel.Domain) + ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(tunnel.BytesOut), tunnel.Domain) + ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(tunnel.TotalSessions), tunnel.Domain) + } +} diff --git a/internal/dnstt/exporter_test.go b/internal/dnstt/exporter_test.go new file mode 100644 index 0000000..c899464 --- /dev/null +++ b/internal/dnstt/exporter_test.go @@ -0,0 +1,42 @@ +package dnstt + +import ( + "strings" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus/testutil" +) + +func TestExporterCollectsAggregateDNSTTMetrics(t *testing.T) { + now := time.Unix(1000, 0) + c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now })) + c.RecordQuery("tunnel.example.com", "client-a", 100) + c.RecordQuery("tunnel.example.com", "client-b", 300) + c.RecordResponse("tunnel.example.com", 250) + + expected := ` +# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window. +# TYPE dnstt_active_clients gauge +dnstt_active_clients{domain="tunnel.example.com"} 2 +# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries. +# TYPE dnstt_bytes_in_total counter +dnstt_bytes_in_total{domain="tunnel.example.com"} 400 +# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses. +# TYPE dnstt_bytes_out_total counter +dnstt_bytes_out_total{domain="tunnel.example.com"} 250 +# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed. +# TYPE dnstt_peak_clients gauge +dnstt_peak_clients{domain="tunnel.example.com"} 2 +# HELP dnstt_queries_total Total DNSTT DNS queries observed. +# TYPE dnstt_queries_total counter +dnstt_queries_total{domain="tunnel.example.com"} 2 +# HELP dnstt_sessions_total Total unique DNSTT client sessions observed. +# TYPE dnstt_sessions_total counter +dnstt_sessions_total{domain="tunnel.example.com"} 2 +` + + if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil { + t.Fatal(err) + } +} diff --git a/internal/dnstt/packet.go b/internal/dnstt/packet.go new file mode 100644 index 0000000..4b2acd1 --- /dev/null +++ b/internal/dnstt/packet.go @@ -0,0 +1,60 @@ +package dnstt + +import ( + "encoding/binary" + "strings" +) + +// ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet. +func ProcessIPv4Packet(data []byte, port int, collector *Collector) { + if len(data) < 20 || data[0]>>4 != 4 { + return + } + + ihl := int(data[0]&0x0f) * 4 + if ihl < 20 || len(data) < ihl { + return + } + if data[9] != 17 { + return + } + + udpData := data[ihl:] + if len(udpData) < 8 { + return + } + + srcPort := binary.BigEndian.Uint16(udpData[0:2]) + dstPort := binary.BigEndian.Uint16(udpData[2:4]) + udpLen := int(binary.BigEndian.Uint16(udpData[4:6])) + dnsPayload := udpData[8:] + if len(dnsPayload) < 12 { + return + } + + if int(dstPort) == port { + domain, clientID := ExtractQuery(dnsPayload, collector.Domains()) + if domain != "" { + collector.RecordQuery(domain, clientID, udpLen) + } + return + } + + if int(srcPort) == port { + domain := extractQueryDomain(dnsPayload) + if domain != "" { + collector.RecordResponse(domain, udpLen) + } + } +} + +func extractQueryDomain(dns []byte) string { + if len(dns) < 12 { + return "" + } + labels := parseDNSLabels(dns[12:]) + if len(labels) == 0 { + return "" + } + return normalizeDomain(strings.Join(labels, ".")) +} diff --git a/internal/dnstt/parser.go b/internal/dnstt/parser.go new file mode 100644 index 0000000..0447a27 --- /dev/null +++ b/internal/dnstt/parser.go @@ -0,0 +1,82 @@ +package dnstt + +import ( + "encoding/base32" + "fmt" + "strings" +) + +const clientIDLen = 8 + +var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding) + +// ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query. +func ExtractQuery(dns []byte, domains []string) (string, string) { + if len(dns) < 12 { + return "", "" + } + + labels := parseDNSLabels(dns[12:]) + if len(labels) == 0 { + return "", "" + } + + fullDomain := strings.ToLower(strings.Join(labels, ".")) + for _, domain := range domains { + domain = normalizeDomain(domain) + if fullDomain == domain { + return domain, "" + } + + suffix := "." + domain + if !strings.HasSuffix(fullDomain, suffix) { + continue + } + + tunnelLabels := strings.Count(domain, ".") + 1 + if len(labels) <= tunnelLabels { + continue + } + + prefixLabels := labels[:len(labels)-tunnelLabels] + encoded := strings.ToUpper(strings.Join(prefixLabels, "")) + decoded := make([]byte, base32Encoding.DecodedLen(len(encoded))) + n, err := base32Encoding.Decode(decoded, []byte(encoded)) + if err != nil || n < clientIDLen { + return domain, "" + } + + return domain, fmt.Sprintf("%x", decoded[:clientIDLen]) + } + + return "", "" +} + +func parseDNSLabels(data []byte) []string { + var labels []string + offset := 0 + + for offset < len(data) { + labelLen := int(data[offset]) + if labelLen == 0 { + break + } + if labelLen&0xc0 == 0xc0 { + break + } + + offset++ + if offset+labelLen > len(data) { + return nil + } + + labels = append(labels, string(data[offset:offset+labelLen])) + offset += labelLen + } + + return labels +} + +func normalizeDomain(domain string) string { + return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".") +} diff --git a/internal/dnstt/parser_test.go b/internal/dnstt/parser_test.go new file mode 100644 index 0000000..02fc8c2 --- /dev/null +++ b/internal/dnstt/parser_test.go @@ -0,0 +1,93 @@ +package dnstt + +import ( + "encoding/base32" + "testing" +) + +func TestExtractQueryDecodesDNSTTClientID(t *testing.T) { + clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb)) + packet := dnsQuery(encoded, "tunnel.example.com") + + domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) + + if domain != "tunnel.example.com" { + t.Fatalf("domain = %q, want tunnel.example.com", domain) + } + if gotClientID != "0102030405060708" { + t.Fatalf("client ID = %q, want 0102030405060708", gotClientID) + } +} + +func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) { + packet := dnsQuery("", "tunnel.example.com") + + domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) + + if domain != "tunnel.example.com" { + t.Fatalf("domain = %q, want tunnel.example.com", domain) + } + if gotClientID != "" { + t.Fatalf("client ID = %q, want empty", gotClientID) + } +} + +func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) { + packet := dnsQuery("abcd", "other.example.com") + + domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) + + if domain != "" || gotClientID != "" { + t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID) + } +} + +func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) { + packet := dnsQuery("not-base32-!", "tunnel.example.com") + + domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) + + if domain != "tunnel.example.com" { + t.Fatalf("domain = %q, want tunnel.example.com", domain) + } + if gotClientID != "" { + t.Fatalf("client ID = %q, want empty", gotClientID) + } +} + +func dnsQuery(prefix string, domain string) []byte { + packet := []byte{ + 0x12, 0x34, // transaction ID + 0x01, 0x00, // flags + 0x00, 0x01, // questions + 0x00, 0x00, // answers + 0x00, 0x00, // authority + 0x00, 0x00, // additional + } + + labels := splitDomain(domain) + if prefix != "" { + labels = append([]string{prefix}, labels...) + } + for _, label := range labels { + packet = append(packet, byte(len(label))) + packet = append(packet, label...) + } + packet = append(packet, 0x00) // root label + packet = append(packet, 0x00, 0x10) // TXT + packet = append(packet, 0x00, 0x01) // IN + return packet +} + +func splitDomain(domain string) []string { + var labels []string + start := 0 + for i := 0; i <= len(domain); i++ { + if i == len(domain) || domain[i] == '.' { + labels = append(labels, domain[start:i]) + start = i + 1 + } + } + return labels +}