From 4710df25233d3ba639587b1203c53e0c09d0571c Mon Sep 17 00:00:00 2001 From: Abel Luck Date: Tue, 5 May 2026 13:57:12 +0200 Subject: [PATCH] add geoip country/asn labels and ipv6 --- README.md | 25 +++- cmd/dnstt_exporter/main.go | 14 ++- flake.nix | 2 +- go.mod | 8 +- go.sum | 8 +- internal/dnstt/capture_linux.go | 6 +- internal/dnstt/collector.go | 197 ++++++++++++++++++++++++++++---- internal/dnstt/exporter.go | 30 +++-- internal/dnstt/exporter_test.go | 99 ++++++++++++++++ internal/dnstt/geoip.go | 108 +++++++++++++++++ internal/dnstt/packet.go | 49 +++++++- internal/dnstt/packet_test.go | 56 +++++++++ 12 files changed, 559 insertions(+), 43 deletions(-) create mode 100644 internal/dnstt/geoip.go create mode 100644 internal/dnstt/packet_test.go diff --git a/README.md b/README.md index b8cb1cd..cf3219b 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ it passively decodes DNSTT session IDs from DNS query names. sudo dnstt_exporter \ -dnstt.domain tunnel.example.com \ -dnstt.port 53 \ + -geoip.country-database /path/to/GeoLite2-Country.mmdb \ + -geoip.asn-database /path/to/GeoLite2-ASN.mmdb \ -web.listen-address :9713 ``` @@ -20,9 +22,30 @@ or grant the binary `CAP_NET_RAW`. Metrics are served at `http://127.0.0.1:9713/metrics` by default. +## How It Works + +`dnstt_exporter` opens a Linux `AF_PACKET` raw socket and passively watches UDP +DNS traffic on the configured DNSTT port. It parses IPv4 and IPv6 packets, +matches DNS query names against the configured DNSTT domain, and decodes the +DNSTT session ID from the query-name prefix. + +The exporter treats a session as active when it has seen a query for that +session within the last 30 seconds. Peak client counts are the highest active +session counts observed since the exporter started. + +GeoIP labels are based on the resolver address seen by the server. For incoming +queries this is the packet source address; for outgoing responses it is the +packet destination address. This may be a recursive resolver such as an ISP DNS +server, Cloudflare, Google, or Quad9, not the original DNSTT client. + +The exporter does not run `dnstt-server`, proxy traffic, terminate DNSTT, or +decrypt tunnel payloads. + ## Metrics -All DNSTT metrics use a `domain` label: +All DNSTT metrics use a `domain` label. If `-geoip.country-database` is set, +metrics also include `country`. If `-geoip.asn-database` is set, metrics also +include `asn`. Unmapped countries use `ZZ`; unmapped ASNs use `0`. - `dnstt_active_clients` - `dnstt_peak_clients` diff --git a/cmd/dnstt_exporter/main.go b/cmd/dnstt_exporter/main.go index 88e0002..79078b5 100644 --- a/cmd/dnstt_exporter/main.go +++ b/cmd/dnstt_exporter/main.go @@ -39,6 +39,8 @@ func main() { listenAddress = flag.String("web.listen-address", ":9713", "Address on which to expose metrics and web interface.") metricsPath = flag.String("web.telemetry-path", "/metrics", "Path under which to expose metrics.") dnsPort = flag.Int("dnstt.port", 53, "UDP port where DNSTT DNS traffic is observed.") + countryDB = flag.String("geoip.country-database", "", "Optional MaxMind GeoIP2/GeoLite2 Country database path.") + asnDB = flag.String("geoip.asn-database", "", "Optional MaxMind GeoIP2/GeoLite2 ASN database path.") ) flag.Var(&domains, "dnstt.domain", "DNSTT tunnel domain to observe. Repeat for multiple domains.") flag.Parse() @@ -50,7 +52,17 @@ func main() { log.Fatalf("-dnstt.port must be between 1 and 65535, got %d", *dnsPort) } - collector := dnstt.NewCollector(domains) + var opts []dnstt.CollectorOption + if *countryDB != "" || *asnDB != "" { + geoResolver, err := dnstt.OpenGeoIPResolver(*countryDB, *asnDB) + if err != nil { + log.Fatalf("failed to open GeoIP database: %v", err) + } + defer geoResolver.Close() + opts = append(opts, dnstt.WithGeoResolver(geoResolver)) + } + + collector := dnstt.NewCollector(domains, opts...) fd, err := dnstt.OpenRawSocket() if err != nil { log.Fatalf("failed to open raw socket: %v", err) diff --git a/flake.nix b/flake.nix index d370277..6344226 100644 --- a/flake.nix +++ b/flake.nix @@ -16,7 +16,7 @@ pname = "dnstt_exporter"; version = "0.1.0"; src = ./.; - vendorHash = "sha256-+olXxVW9VcvapOuJGas70RIfJgMzUv6mndzJi6Apw+s="; + vendorHash = "sha256-+msAQqz07XVzGjpwi8PvlA7jP4Y+j/BgSZdnIc0LrpA="; subPackages = [ "cmd/dnstt_exporter" ]; }; diff --git a/go.mod b/go.mod index 4872657..78a79af 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,10 @@ module guardianproject.dev/bypass-censorship/dnstt_exporter go 1.25 -require github.com/prometheus/client_golang v1.23.2 +require ( + github.com/oschwald/geoip2-golang/v2 v2.1.0 + github.com/prometheus/client_golang v1.23.2 +) require ( github.com/beorn7/perks v1.0.1 // indirect @@ -10,10 +13,11 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/oschwald/maxminddb-golang/v2 v2.1.1 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sys v0.38.0 // indirect google.golang.org/protobuf v1.36.8 // indirect ) diff --git a/go.sum b/go.sum index d6b8ca9..93c96c8 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/oschwald/geoip2-golang/v2 v2.1.0 h1:DjnLhNJu9WHwTrmoiQFvgmyJoczhdnm7LB23UBI2Amo= +github.com/oschwald/geoip2-golang/v2 v2.1.0/go.mod h1:qdVmcPgrTJ4q2eP9tHq/yldMTdp2VMr33uVdFbHBiBc= +github.com/oschwald/maxminddb-golang/v2 v2.1.1 h1:lA8FH0oOrM4u7mLvowq8IT6a3Q/qEnqRzLQn9eH5ojc= +github.com/oschwald/maxminddb-golang/v2 v2.1.1/go.mod h1:PLdx6PR+siSIoXqqy7C7r3SB3KZnhxWr1Dp6g0Hacl8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= @@ -35,8 +39,8 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/dnstt/capture_linux.go b/internal/dnstt/capture_linux.go index b416033..344f881 100644 --- a/internal/dnstt/capture_linux.go +++ b/internal/dnstt/capture_linux.go @@ -9,7 +9,7 @@ import ( // OpenRawSocket creates an AF_PACKET raw socket for sniffing IPv4 packets. func OpenRawSocket() (int, error) { - fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(syscall.ETH_P_IP))) + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(ethPAll))) if err != nil { return -1, fmt.Errorf("open raw socket: %w", err) } @@ -36,7 +36,7 @@ func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) { } return } - ProcessIPv4Packet(buf[:n], port, collector) + ProcessPacket(buf[:n], port, collector) } } @@ -48,3 +48,5 @@ func CloseRawSocket(fd int) error { func htons(v uint16) uint16 { return (v << 8) | (v >> 8) } + +const ethPAll = 0x0003 diff --git a/internal/dnstt/collector.go b/internal/dnstt/collector.go index b6207dc..4928c16 100644 --- a/internal/dnstt/collector.go +++ b/internal/dnstt/collector.go @@ -1,6 +1,8 @@ package dnstt import ( + "net/netip" + "sort" "strings" "sync" "time" @@ -15,14 +17,19 @@ type Collector struct { domains []string tunnels map[string]*tunnelState now func() time.Time + geo GeoResolver } type tunnelState struct { + series map[geoKey]*seriesState + clients map[string]*clientState +} + +type seriesState struct { queries uint64 bytesIn uint64 bytesOut uint64 peakClients int - clients map[string]*clientState } type clientState struct { @@ -30,6 +37,13 @@ type clientState struct { lastSeen time.Time queries uint64 bytesIn uint64 + firstKey geoKey + lastKey geoKey +} + +type geoKey struct { + country string + asn string } // Snapshot is a point-in-time copy of observed traffic. @@ -46,6 +60,20 @@ type TunnelSnapshot struct { TotalQueries uint64 BytesIn uint64 BytesOut uint64 + Series []SeriesSnapshot +} + +// SeriesSnapshot is a per-label metric series for a tunnel. +type SeriesSnapshot struct { + Domain string + Country string + ASN string + ActiveClients int + PeakClients int + TotalSessions int + TotalQueries uint64 + BytesIn uint64 + BytesOut uint64 } // CollectorOption configures a Collector. @@ -58,6 +86,13 @@ func WithNow(now func() time.Time) CollectorOption { } } +// WithGeoResolver enables optional GeoIP labels for resolver IP addresses. +func WithGeoResolver(resolver GeoResolver) CollectorOption { + return func(c *Collector) { + c.geo = resolver + } +} + // NewCollector creates a collector for the provided DNSTT domains. func NewCollector(domains []string, opts ...CollectorOption) *Collector { c := &Collector{ @@ -75,6 +110,7 @@ func NewCollector(domains []string, opts ...CollectorOption) *Collector { } c.domains = append(c.domains, domain) c.tunnels[domain] = &tunnelState{ + series: make(map[geoKey]*seriesState), clients: make(map[string]*clientState), } } @@ -86,6 +122,13 @@ func NewCollector(domains []string, opts ...CollectorOption) *Collector { return c } +// GeoLabelNames returns the optional GeoIP label names used by this collector. +func (c *Collector) GeoLabelNames() []string { + c.mu.Lock() + defer c.mu.Unlock() + return c.geoLabelNamesLocked() +} + // Domains returns the normalized DNSTT domains this collector recognizes. func (c *Collector) Domains() []string { c.mu.Lock() @@ -98,6 +141,11 @@ func (c *Collector) Domains() []string { // RecordQuery records an observed DNSTT DNS query. func (c *Collector) RecordQuery(domain string, clientID string, size int) { + c.RecordQueryFrom(domain, clientID, netip.Addr{}, size) +} + +// RecordQueryFrom records an observed DNSTT DNS query from a resolver address. +func (c *Collector) RecordQueryFrom(domain string, clientID string, resolverIP netip.Addr, size int) { c.mu.Lock() defer c.mu.Unlock() @@ -106,9 +154,11 @@ func (c *Collector) RecordQuery(domain string, clientID string, size int) { return } - tunnel.queries++ + key := c.geoKeyLocked(resolverIP) + series := ensureSeries(tunnel, key) + series.queries++ if size > 0 { - tunnel.bytesIn += uint64(size) + series.bytesIn += uint64(size) } if clientID == "" { @@ -118,23 +168,26 @@ func (c *Collector) RecordQuery(domain string, clientID string, size int) { now := c.now() client, exists := tunnel.clients[clientID] if !exists { - client = &clientState{firstSeen: now} + client = &clientState{firstSeen: now, firstKey: key} tunnel.clients[clientID] = client } client.lastSeen = now + client.lastKey = key client.queries++ if size > 0 { client.bytesIn += uint64(size) } - active := activeClients(tunnel, now) - if active > tunnel.peakClients { - tunnel.peakClients = active - } + updatePeaks(tunnel, now) } // RecordResponse records an observed DNSTT DNS response. func (c *Collector) RecordResponse(domain string, size int) { + c.RecordResponseFrom(domain, netip.Addr{}, size) +} + +// RecordResponseFrom records an observed DNSTT DNS response to a resolver address. +func (c *Collector) RecordResponseFrom(domain string, resolverIP netip.Addr, size int) { c.mu.Lock() defer c.mu.Unlock() @@ -142,8 +195,10 @@ func (c *Collector) RecordResponse(domain string, size int) { if !ok { return } + key := c.geoKeyLocked(resolverIP) + series := ensureSeries(tunnel, key) if size > 0 { - tunnel.bytesOut += uint64(size) + series.bytesOut += uint64(size) } } @@ -156,20 +211,21 @@ func (c *Collector) Snapshot() Snapshot { snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))} for _, domain := range c.domains { tunnel := c.tunnels[domain] - active := activeClients(tunnel, now) - if active > tunnel.peakClients { - tunnel.peakClients = active - } + updatePeaks(tunnel, now) + series := c.seriesSnapshotsLocked(domain, tunnel, now) - snapshot.Tunnels[domain] = TunnelSnapshot{ - Domain: domain, - ActiveClients: active, - PeakClients: tunnel.peakClients, - TotalSessions: len(tunnel.clients), - TotalQueries: tunnel.queries, - BytesIn: tunnel.bytesIn, - BytesOut: tunnel.bytesOut, + tunnelSnapshot := TunnelSnapshot{Domain: domain, Series: series} + for _, s := range series { + tunnelSnapshot.ActiveClients += s.ActiveClients + tunnelSnapshot.TotalSessions += s.TotalSessions + tunnelSnapshot.TotalQueries += s.TotalQueries + tunnelSnapshot.BytesIn += s.BytesIn + tunnelSnapshot.BytesOut += s.BytesOut + if s.PeakClients > tunnelSnapshot.PeakClients { + tunnelSnapshot.PeakClients = s.PeakClients + } } + snapshot.Tunnels[domain] = tunnelSnapshot } return snapshot @@ -188,12 +244,103 @@ func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) { return nil, false } -func activeClients(tunnel *tunnelState, now time.Time) int { - active := 0 +func ensureSeries(tunnel *tunnelState, key geoKey) *seriesState { + series, ok := tunnel.series[key] + if !ok { + series = &seriesState{} + tunnel.series[key] = series + } + return series +} + +func updatePeaks(tunnel *tunnelState, now time.Time) { + activeByKey := make(map[geoKey]int) for _, client := range tunnel.clients { if now.Sub(client.lastSeen) < ClientTimeout { - active++ + activeByKey[client.lastKey]++ + } + } + for key, active := range activeByKey { + series := ensureSeries(tunnel, key) + if active > series.peakClients { + series.peakClients = active } } - return active +} + +func (c *Collector) seriesSnapshotsLocked(domain string, tunnel *tunnelState, now time.Time) []SeriesSnapshot { + activeByKey := make(map[geoKey]int) + sessionsByKey := make(map[geoKey]int) + for _, client := range tunnel.clients { + sessionsByKey[client.firstKey]++ + if now.Sub(client.lastSeen) < ClientTimeout { + activeByKey[client.lastKey]++ + } + } + + keys := make([]geoKey, 0, len(tunnel.series)) + for key := range tunnel.series { + keys = append(keys, key) + } + sort.Slice(keys, func(i, j int) bool { + if keys[i].country != keys[j].country { + return keys[i].country < keys[j].country + } + return keys[i].asn < keys[j].asn + }) + + series := make([]SeriesSnapshot, 0, len(keys)) + for _, key := range keys { + state := tunnel.series[key] + series = append(series, SeriesSnapshot{ + Domain: domain, + Country: key.country, + ASN: key.asn, + ActiveClients: activeByKey[key], + PeakClients: state.peakClients, + TotalSessions: sessionsByKey[key], + TotalQueries: state.queries, + BytesIn: state.bytesIn, + BytesOut: state.bytesOut, + }) + } + return series +} + +func (c *Collector) geoLabelNamesLocked() []string { + if c.geo == nil { + return nil + } + names := c.geo.LabelNames() + out := make([]string, 0, len(names)) + for _, name := range names { + if name == "country" || name == "asn" { + out = append(out, name) + } + } + return out +} + +func (c *Collector) geoKeyLocked(addr netip.Addr) geoKey { + var key geoKey + if c.geo == nil { + return key + } + + labels := c.geo.Lookup(addr) + for _, name := range c.geoLabelNamesLocked() { + switch name { + case "country": + key.country = labels.Country + if key.country == "" { + key.country = UnknownCountry + } + case "asn": + key.asn = labels.ASN + if key.asn == "" { + key.asn = UnknownASN + } + } + } + return key } diff --git a/internal/dnstt/exporter.go b/internal/dnstt/exporter.go index a36019c..deacb6c 100644 --- a/internal/dnstt/exporter.go +++ b/internal/dnstt/exporter.go @@ -18,7 +18,7 @@ type Exporter struct { // NewExporter creates a Prometheus collector for DNSTT metrics. func NewExporter(collector *Collector) *Exporter { - labels := []string{"domain"} + labels := append([]string{"domain"}, collector.GeoLabelNames()...) return &Exporter{ collector: collector, activeClients: prometheus.NewDesc( @@ -73,11 +73,27 @@ func (e *Exporter) Describe(ch chan<- *prometheus.Desc) { // Collect sends current metric values to Prometheus. func (e *Exporter) Collect(ch chan<- prometheus.Metric) { for _, tunnel := range e.collector.Snapshot().Tunnels { - ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(tunnel.ActiveClients), tunnel.Domain) - ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(tunnel.PeakClients), tunnel.Domain) - ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(tunnel.TotalQueries), tunnel.Domain) - ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(tunnel.BytesIn), tunnel.Domain) - ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(tunnel.BytesOut), tunnel.Domain) - ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(tunnel.TotalSessions), tunnel.Domain) + for _, series := range tunnel.Series { + labels := e.labelValues(series) + ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(series.ActiveClients), labels...) + ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(series.PeakClients), labels...) + ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(series.TotalQueries), labels...) + ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(series.BytesIn), labels...) + ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(series.BytesOut), labels...) + ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(series.TotalSessions), labels...) + } } } + +func (e *Exporter) labelValues(series SeriesSnapshot) []string { + values := []string{series.Domain} + for _, label := range e.collector.GeoLabelNames() { + switch label { + case "country": + values = append(values, series.Country) + case "asn": + values = append(values, series.ASN) + } + } + return values +} diff --git a/internal/dnstt/exporter_test.go b/internal/dnstt/exporter_test.go index c899464..b8b1c16 100644 --- a/internal/dnstt/exporter_test.go +++ b/internal/dnstt/exporter_test.go @@ -1,6 +1,7 @@ package dnstt import ( + "net/netip" "strings" "testing" "time" @@ -40,3 +41,101 @@ dnstt_sessions_total{domain="tunnel.example.com"} 2 t.Fatal(err) } } + +func TestExporterCollectsGeoIPCountryAndASNLabels(t *testing.T) { + now := time.Unix(1000, 0) + resolverIP := netip.MustParseAddr("2001:db8::53") + c := NewCollector( + []string{"tunnel.example.com"}, + WithNow(func() time.Time { return now }), + WithGeoResolver(fakeGeoResolver{ + labelNames: []string{"country", "asn"}, + labels: map[netip.Addr]GeoLabels{ + resolverIP: {Country: "DE", ASN: "3320"}, + }, + }), + ) + c.RecordQueryFrom("tunnel.example.com", "client-a", resolverIP, 100) + + expected := ` +# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window. +# TYPE dnstt_active_clients gauge +dnstt_active_clients{asn="3320",country="DE",domain="tunnel.example.com"} 1 +# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries. +# TYPE dnstt_bytes_in_total counter +dnstt_bytes_in_total{asn="3320",country="DE",domain="tunnel.example.com"} 100 +# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses. +# TYPE dnstt_bytes_out_total counter +dnstt_bytes_out_total{asn="3320",country="DE",domain="tunnel.example.com"} 0 +# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed. +# TYPE dnstt_peak_clients gauge +dnstt_peak_clients{asn="3320",country="DE",domain="tunnel.example.com"} 1 +# HELP dnstt_queries_total Total DNSTT DNS queries observed. +# TYPE dnstt_queries_total counter +dnstt_queries_total{asn="3320",country="DE",domain="tunnel.example.com"} 1 +# HELP dnstt_sessions_total Total unique DNSTT client sessions observed. +# TYPE dnstt_sessions_total counter +dnstt_sessions_total{asn="3320",country="DE",domain="tunnel.example.com"} 1 +` + + if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil { + t.Fatal(err) + } +} + +func TestExporterCanCollectASNWithoutCountry(t *testing.T) { + now := time.Unix(1000, 0) + resolverIP := netip.MustParseAddr("192.0.2.53") + c := NewCollector( + []string{"tunnel.example.com"}, + WithNow(func() time.Time { return now }), + WithGeoResolver(fakeGeoResolver{ + labelNames: []string{"asn"}, + labels: map[netip.Addr]GeoLabels{ + resolverIP: {ASN: "15169"}, + }, + }), + ) + c.RecordQueryFrom("tunnel.example.com", "client-a", resolverIP, 100) + + expected := ` +# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window. +# TYPE dnstt_active_clients gauge +dnstt_active_clients{asn="15169",domain="tunnel.example.com"} 1 +# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries. +# TYPE dnstt_bytes_in_total counter +dnstt_bytes_in_total{asn="15169",domain="tunnel.example.com"} 100 +# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses. +# TYPE dnstt_bytes_out_total counter +dnstt_bytes_out_total{asn="15169",domain="tunnel.example.com"} 0 +# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed. +# TYPE dnstt_peak_clients gauge +dnstt_peak_clients{asn="15169",domain="tunnel.example.com"} 1 +# HELP dnstt_queries_total Total DNSTT DNS queries observed. +# TYPE dnstt_queries_total counter +dnstt_queries_total{asn="15169",domain="tunnel.example.com"} 1 +# HELP dnstt_sessions_total Total unique DNSTT client sessions observed. +# TYPE dnstt_sessions_total counter +dnstt_sessions_total{asn="15169",domain="tunnel.example.com"} 1 +` + + if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil { + t.Fatal(err) + } +} + +type fakeGeoResolver struct { + labelNames []string + labels map[netip.Addr]GeoLabels +} + +func (f fakeGeoResolver) Lookup(addr netip.Addr) GeoLabels { + if labels, ok := f.labels[addr]; ok { + return labels + } + return GeoLabels{Country: UnknownCountry, ASN: UnknownASN} +} + +func (f fakeGeoResolver) LabelNames() []string { + return f.labelNames +} diff --git a/internal/dnstt/geoip.go b/internal/dnstt/geoip.go new file mode 100644 index 0000000..77e161c --- /dev/null +++ b/internal/dnstt/geoip.go @@ -0,0 +1,108 @@ +package dnstt + +import ( + "fmt" + "net/netip" + "strconv" + + geoip2 "github.com/oschwald/geoip2-golang/v2" +) + +const ( + // UnknownCountry is used when country lookup is enabled but no country is found. + UnknownCountry = "ZZ" + // UnknownASN is used when ASN lookup is enabled but no ASN is found. + UnknownASN = "0" +) + +// GeoLabels are optional labels derived from a resolver IP address. +type GeoLabels struct { + Country string + ASN string +} + +// GeoResolver resolves optional GeoIP labels for resolver IP addresses. +type GeoResolver interface { + Lookup(netip.Addr) GeoLabels + LabelNames() []string +} + +// GeoIPResolver uses optional MaxMind Country and ASN databases. +type GeoIPResolver struct { + countryDB *geoip2.Reader + asnDB *geoip2.Reader +} + +// OpenGeoIPResolver opens optional MaxMind Country and ASN databases. +func OpenGeoIPResolver(countryDatabase string, asnDatabase string) (*GeoIPResolver, error) { + resolver := &GeoIPResolver{} + if countryDatabase != "" { + db, err := geoip2.Open(countryDatabase) + if err != nil { + return nil, fmt.Errorf("open country database: %w", err) + } + resolver.countryDB = db + } + if asnDatabase != "" { + db, err := geoip2.Open(asnDatabase) + if err != nil { + resolver.Close() + return nil, fmt.Errorf("open ASN database: %w", err) + } + resolver.asnDB = db + } + return resolver, nil +} + +// Close closes any open MaxMind databases. +func (r *GeoIPResolver) Close() { + if r.countryDB != nil { + r.countryDB.Close() + } + if r.asnDB != nil { + r.asnDB.Close() + } +} + +// LabelNames returns the optional Prometheus labels enabled by configured databases. +func (r *GeoIPResolver) LabelNames() []string { + var labels []string + if r.countryDB != nil { + labels = append(labels, "country") + } + if r.asnDB != nil { + labels = append(labels, "asn") + } + return labels +} + +// Lookup returns optional GeoIP labels for a resolver IP address. +func (r *GeoIPResolver) Lookup(addr netip.Addr) GeoLabels { + labels := GeoLabels{} + if !addr.IsValid() { + if r.countryDB != nil { + labels.Country = UnknownCountry + } + if r.asnDB != nil { + labels.ASN = UnknownASN + } + return labels + } + + if r.countryDB != nil { + labels.Country = UnknownCountry + record, err := r.countryDB.Country(addr) + if err == nil && record.HasData() && record.Country.ISOCode != "" { + labels.Country = record.Country.ISOCode + } + } + if r.asnDB != nil { + labels.ASN = UnknownASN + record, err := r.asnDB.ASN(addr) + if err == nil && record.HasData() && record.AutonomousSystemNumber != 0 { + labels.ASN = strconv.FormatUint(uint64(record.AutonomousSystemNumber), 10) + } + } + + return labels +} diff --git a/internal/dnstt/packet.go b/internal/dnstt/packet.go index 4b2acd1..9a9822a 100644 --- a/internal/dnstt/packet.go +++ b/internal/dnstt/packet.go @@ -2,9 +2,23 @@ package dnstt import ( "encoding/binary" + "net/netip" "strings" ) +// ProcessPacket records DNSTT DNS traffic from a raw IP packet. +func ProcessPacket(data []byte, port int, collector *Collector) { + if len(data) == 0 { + return + } + switch data[0] >> 4 { + case 4: + ProcessIPv4Packet(data, port, collector) + case 6: + ProcessIPv6Packet(data, port, collector) + } +} + // ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet. func ProcessIPv4Packet(data []byte, port int, collector *Collector) { if len(data) < 20 || data[0]>>4 != 4 { @@ -18,8 +32,39 @@ func ProcessIPv4Packet(data []byte, port int, collector *Collector) { if data[9] != 17 { return } + srcIP := netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}) + dstIP := netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}) udpData := data[ihl:] + processUDPPayload(udpData, srcIP, dstIP, port, collector) +} + +// ProcessIPv6Packet records DNSTT DNS traffic from a raw IPv6 packet. +func ProcessIPv6Packet(data []byte, port int, collector *Collector) { + if len(data) < 40 || data[0]>>4 != 6 { + return + } + if data[6] != 17 { + return + } + + var srcBytes [16]byte + copy(srcBytes[:], data[8:24]) + srcIP := netip.AddrFrom16(srcBytes) + + var dstBytes [16]byte + copy(dstBytes[:], data[24:40]) + dstIP := netip.AddrFrom16(dstBytes) + + payloadLen := int(binary.BigEndian.Uint16(data[4:6])) + if payloadLen <= 0 || 40+payloadLen > len(data) { + return + } + + processUDPPayload(data[40:40+payloadLen], srcIP, dstIP, port, collector) +} + +func processUDPPayload(udpData []byte, srcIP netip.Addr, dstIP netip.Addr, port int, collector *Collector) { if len(udpData) < 8 { return } @@ -35,7 +80,7 @@ func ProcessIPv4Packet(data []byte, port int, collector *Collector) { if int(dstPort) == port { domain, clientID := ExtractQuery(dnsPayload, collector.Domains()) if domain != "" { - collector.RecordQuery(domain, clientID, udpLen) + collector.RecordQueryFrom(domain, clientID, srcIP, udpLen) } return } @@ -43,7 +88,7 @@ func ProcessIPv4Packet(data []byte, port int, collector *Collector) { if int(srcPort) == port { domain := extractQueryDomain(dnsPayload) if domain != "" { - collector.RecordResponse(domain, udpLen) + collector.RecordResponseFrom(domain, dstIP, udpLen) } } } diff --git a/internal/dnstt/packet_test.go b/internal/dnstt/packet_test.go new file mode 100644 index 0000000..670f44b --- /dev/null +++ b/internal/dnstt/packet_test.go @@ -0,0 +1,56 @@ +package dnstt + +import ( + "encoding/binary" + "net/netip" + "testing" + "time" +) + +func TestProcessIPv6PacketRecordsResolverAddress(t *testing.T) { + now := time.Unix(1000, 0) + resolverIP := netip.MustParseAddr("2001:db8::53") + serverIP := netip.MustParseAddr("2001:db8::1") + c := NewCollector( + []string{"tunnel.example.com"}, + WithNow(func() time.Time { return now }), + WithGeoResolver(fakeGeoResolver{ + labelNames: []string{"asn"}, + labels: map[netip.Addr]GeoLabels{ + resolverIP: {ASN: "13335"}, + }, + }), + ) + + ProcessPacket(ipv6UDPPacket(resolverIP, serverIP, 53000, 53, dnsQuery("MFRGGZDFMZTWQ2LK", "tunnel.example.com")), 53, c) + + series := c.Snapshot().Tunnels["tunnel.example.com"].Series + if len(series) != 1 { + t.Fatalf("series count = %d, want 1", len(series)) + } + if series[0].ASN != "13335" { + t.Fatalf("ASN = %q, want 13335", series[0].ASN) + } + if series[0].TotalQueries != 1 { + t.Fatalf("queries = %d, want 1", series[0].TotalQueries) + } +} + +func ipv6UDPPacket(src, dst netip.Addr, srcPort, dstPort uint16, payload []byte) []byte { + packet := make([]byte, 40+8+len(payload)) + packet[0] = 0x60 + binary.BigEndian.PutUint16(packet[4:6], uint16(8+len(payload))) + packet[6] = 17 + packet[7] = 64 + srcBytes := src.As16() + dstBytes := dst.As16() + copy(packet[8:24], srcBytes[:]) + copy(packet[24:40], dstBytes[:]) + + udp := packet[40:] + binary.BigEndian.PutUint16(udp[0:2], srcPort) + binary.BigEndian.PutUint16(udp[2:4], dstPort) + binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(payload))) + copy(udp[8:], payload) + return packet +}