package dnstt import ( "net/netip" "sort" "strings" "sync" "time" ) // ClientTimeout is how long since the last query before a session is considered inactive. const ClientTimeout = 30 * time.Second // Collector accumulates observed DNSTT DNS traffic. type Collector struct { mu sync.Mutex domains []string tunnels map[string]*tunnelState now func() time.Time geo GeoResolver } type tunnelState struct { series map[geoKey]*seriesState clients map[string]*clientState } type seriesState struct { queries uint64 bytesIn uint64 bytesOut uint64 peakClients int } type clientState struct { firstSeen time.Time lastSeen time.Time queries uint64 bytesIn uint64 firstKey geoKey lastKey geoKey } type geoKey struct { country string asn string } // Snapshot is a point-in-time copy of observed traffic. type Snapshot struct { Tunnels map[string]TunnelSnapshot } // TunnelSnapshot is a point-in-time copy of one DNSTT domain. type TunnelSnapshot struct { Domain string ActiveClients int PeakClients int TotalSessions int TotalQueries uint64 BytesIn uint64 BytesOut uint64 Series []SeriesSnapshot } // SeriesSnapshot is a per-label metric series for a tunnel. type SeriesSnapshot struct { Domain string Country string ASN string ActiveClients int PeakClients int TotalSessions int TotalQueries uint64 BytesIn uint64 BytesOut uint64 } // CollectorOption configures a Collector. type CollectorOption func(*Collector) // WithNow overrides the clock. It is intended for deterministic tests. func WithNow(now func() time.Time) CollectorOption { return func(c *Collector) { c.now = now } } // WithGeoResolver enables optional GeoIP labels for resolver IP addresses. func WithGeoResolver(resolver GeoResolver) CollectorOption { return func(c *Collector) { c.geo = resolver } } // NewCollector creates a collector for the provided DNSTT domains. func NewCollector(domains []string, opts ...CollectorOption) *Collector { c := &Collector{ tunnels: make(map[string]*tunnelState), now: time.Now, } for _, domain := range domains { domain = normalizeDomain(domain) if domain == "" { continue } if _, exists := c.tunnels[domain]; exists { continue } c.domains = append(c.domains, domain) c.tunnels[domain] = &tunnelState{ series: make(map[geoKey]*seriesState), clients: make(map[string]*clientState), } } for _, opt := range opts { opt(c) } return c } // GeoLabelNames returns the optional GeoIP label names used by this collector. func (c *Collector) GeoLabelNames() []string { c.mu.Lock() defer c.mu.Unlock() return c.geoLabelNamesLocked() } // Domains returns the normalized DNSTT domains this collector recognizes. func (c *Collector) Domains() []string { c.mu.Lock() defer c.mu.Unlock() domains := make([]string, len(c.domains)) copy(domains, c.domains) return domains } // RecordQuery records an observed DNSTT DNS query. func (c *Collector) RecordQuery(domain string, clientID string, size int) { c.RecordQueryFrom(domain, clientID, netip.Addr{}, size) } // RecordQueryFrom records an observed DNSTT DNS query from a resolver address. func (c *Collector) RecordQueryFrom(domain string, clientID string, resolverIP netip.Addr, size int) { c.mu.Lock() defer c.mu.Unlock() tunnel, ok := c.findTunnelLocked(domain) if !ok { return } key := c.geoKeyLocked(resolverIP) series := ensureSeries(tunnel, key) series.queries++ if size > 0 { series.bytesIn += uint64(size) } if clientID == "" { return } now := c.now() client, exists := tunnel.clients[clientID] if !exists { client = &clientState{firstSeen: now, firstKey: key} tunnel.clients[clientID] = client } client.lastSeen = now client.lastKey = key client.queries++ if size > 0 { client.bytesIn += uint64(size) } updatePeaks(tunnel, now) } // RecordResponse records an observed DNSTT DNS response. func (c *Collector) RecordResponse(domain string, size int) { c.RecordResponseFrom(domain, netip.Addr{}, size) } // RecordResponseFrom records an observed DNSTT DNS response to a resolver address. func (c *Collector) RecordResponseFrom(domain string, resolverIP netip.Addr, size int) { c.mu.Lock() defer c.mu.Unlock() tunnel, ok := c.findTunnelLocked(domain) if !ok { return } key := c.geoKeyLocked(resolverIP) series := ensureSeries(tunnel, key) if size > 0 { series.bytesOut += uint64(size) } } // Snapshot returns a stable copy of all current metrics. func (c *Collector) Snapshot() Snapshot { c.mu.Lock() defer c.mu.Unlock() now := c.now() snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))} for _, domain := range c.domains { tunnel := c.tunnels[domain] updatePeaks(tunnel, now) series := c.seriesSnapshotsLocked(domain, tunnel, now) tunnelSnapshot := TunnelSnapshot{Domain: domain, Series: series} for _, s := range series { tunnelSnapshot.ActiveClients += s.ActiveClients tunnelSnapshot.TotalSessions += s.TotalSessions tunnelSnapshot.TotalQueries += s.TotalQueries tunnelSnapshot.BytesIn += s.BytesIn tunnelSnapshot.BytesOut += s.BytesOut if s.PeakClients > tunnelSnapshot.PeakClients { tunnelSnapshot.PeakClients = s.PeakClients } } snapshot.Tunnels[domain] = tunnelSnapshot } return snapshot } func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) { queryDomain = normalizeDomain(queryDomain) if tunnel, ok := c.tunnels[queryDomain]; ok { return tunnel, true } for _, domain := range c.domains { if strings.HasSuffix(queryDomain, "."+domain) { return c.tunnels[domain], true } } return nil, false } func ensureSeries(tunnel *tunnelState, key geoKey) *seriesState { series, ok := tunnel.series[key] if !ok { series = &seriesState{} tunnel.series[key] = series } return series } func updatePeaks(tunnel *tunnelState, now time.Time) { activeByKey := make(map[geoKey]int) for _, client := range tunnel.clients { if now.Sub(client.lastSeen) < ClientTimeout { activeByKey[client.lastKey]++ } } for key, active := range activeByKey { series := ensureSeries(tunnel, key) if active > series.peakClients { series.peakClients = active } } } func (c *Collector) seriesSnapshotsLocked(domain string, tunnel *tunnelState, now time.Time) []SeriesSnapshot { activeByKey := make(map[geoKey]int) sessionsByKey := make(map[geoKey]int) for _, client := range tunnel.clients { sessionsByKey[client.firstKey]++ if now.Sub(client.lastSeen) < ClientTimeout { activeByKey[client.lastKey]++ } } keys := make([]geoKey, 0, len(tunnel.series)) for key := range tunnel.series { keys = append(keys, key) } sort.Slice(keys, func(i, j int) bool { if keys[i].country != keys[j].country { return keys[i].country < keys[j].country } return keys[i].asn < keys[j].asn }) series := make([]SeriesSnapshot, 0, len(keys)) for _, key := range keys { state := tunnel.series[key] series = append(series, SeriesSnapshot{ Domain: domain, Country: key.country, ASN: key.asn, ActiveClients: activeByKey[key], PeakClients: state.peakClients, TotalSessions: sessionsByKey[key], TotalQueries: state.queries, BytesIn: state.bytesIn, BytesOut: state.bytesOut, }) } return series } func (c *Collector) geoLabelNamesLocked() []string { if c.geo == nil { return nil } names := c.geo.LabelNames() out := make([]string, 0, len(names)) for _, name := range names { if name == "country" || name == "asn" { out = append(out, name) } } return out } func (c *Collector) geoKeyLocked(addr netip.Addr) geoKey { var key geoKey if c.geo == nil { return key } labels := c.geo.Lookup(addr) for _, name := range c.geoLabelNamesLocked() { switch name { case "country": key.country = labels.Country if key.country == "" { key.country = UnknownCountry } case "asn": key.asn = labels.ASN if key.asn == "" { key.asn = UnknownASN } } } return key }