add geoip country/asn labels and ipv6

This commit is contained in:
Abel Luck 2026-05-05 13:57:12 +02:00
parent 8318f9fe70
commit 4710df2523
12 changed files with 559 additions and 43 deletions

View file

@ -1,6 +1,8 @@
package dnstt
import (
"net/netip"
"sort"
"strings"
"sync"
"time"
@ -15,14 +17,19 @@ type Collector struct {
domains []string
tunnels map[string]*tunnelState
now func() time.Time
geo GeoResolver
}
type tunnelState struct {
series map[geoKey]*seriesState
clients map[string]*clientState
}
type seriesState struct {
queries uint64
bytesIn uint64
bytesOut uint64
peakClients int
clients map[string]*clientState
}
type clientState struct {
@ -30,6 +37,13 @@ type clientState struct {
lastSeen time.Time
queries uint64
bytesIn uint64
firstKey geoKey
lastKey geoKey
}
type geoKey struct {
country string
asn string
}
// Snapshot is a point-in-time copy of observed traffic.
@ -46,6 +60,20 @@ type TunnelSnapshot struct {
TotalQueries uint64
BytesIn uint64
BytesOut uint64
Series []SeriesSnapshot
}
// SeriesSnapshot is a per-label metric series for a tunnel.
type SeriesSnapshot struct {
Domain string
Country string
ASN string
ActiveClients int
PeakClients int
TotalSessions int
TotalQueries uint64
BytesIn uint64
BytesOut uint64
}
// CollectorOption configures a Collector.
@ -58,6 +86,13 @@ func WithNow(now func() time.Time) CollectorOption {
}
}
// WithGeoResolver enables optional GeoIP labels for resolver IP addresses.
func WithGeoResolver(resolver GeoResolver) CollectorOption {
return func(c *Collector) {
c.geo = resolver
}
}
// NewCollector creates a collector for the provided DNSTT domains.
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
c := &Collector{
@ -75,6 +110,7 @@ func NewCollector(domains []string, opts ...CollectorOption) *Collector {
}
c.domains = append(c.domains, domain)
c.tunnels[domain] = &tunnelState{
series: make(map[geoKey]*seriesState),
clients: make(map[string]*clientState),
}
}
@ -86,6 +122,13 @@ func NewCollector(domains []string, opts ...CollectorOption) *Collector {
return c
}
// GeoLabelNames returns the optional GeoIP label names used by this collector.
func (c *Collector) GeoLabelNames() []string {
c.mu.Lock()
defer c.mu.Unlock()
return c.geoLabelNamesLocked()
}
// Domains returns the normalized DNSTT domains this collector recognizes.
func (c *Collector) Domains() []string {
c.mu.Lock()
@ -98,6 +141,11 @@ func (c *Collector) Domains() []string {
// RecordQuery records an observed DNSTT DNS query.
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
c.RecordQueryFrom(domain, clientID, netip.Addr{}, size)
}
// RecordQueryFrom records an observed DNSTT DNS query from a resolver address.
func (c *Collector) RecordQueryFrom(domain string, clientID string, resolverIP netip.Addr, size int) {
c.mu.Lock()
defer c.mu.Unlock()
@ -106,9 +154,11 @@ func (c *Collector) RecordQuery(domain string, clientID string, size int) {
return
}
tunnel.queries++
key := c.geoKeyLocked(resolverIP)
series := ensureSeries(tunnel, key)
series.queries++
if size > 0 {
tunnel.bytesIn += uint64(size)
series.bytesIn += uint64(size)
}
if clientID == "" {
@ -118,23 +168,26 @@ func (c *Collector) RecordQuery(domain string, clientID string, size int) {
now := c.now()
client, exists := tunnel.clients[clientID]
if !exists {
client = &clientState{firstSeen: now}
client = &clientState{firstSeen: now, firstKey: key}
tunnel.clients[clientID] = client
}
client.lastSeen = now
client.lastKey = key
client.queries++
if size > 0 {
client.bytesIn += uint64(size)
}
active := activeClients(tunnel, now)
if active > tunnel.peakClients {
tunnel.peakClients = active
}
updatePeaks(tunnel, now)
}
// RecordResponse records an observed DNSTT DNS response.
func (c *Collector) RecordResponse(domain string, size int) {
c.RecordResponseFrom(domain, netip.Addr{}, size)
}
// RecordResponseFrom records an observed DNSTT DNS response to a resolver address.
func (c *Collector) RecordResponseFrom(domain string, resolverIP netip.Addr, size int) {
c.mu.Lock()
defer c.mu.Unlock()
@ -142,8 +195,10 @@ func (c *Collector) RecordResponse(domain string, size int) {
if !ok {
return
}
key := c.geoKeyLocked(resolverIP)
series := ensureSeries(tunnel, key)
if size > 0 {
tunnel.bytesOut += uint64(size)
series.bytesOut += uint64(size)
}
}
@ -156,20 +211,21 @@ func (c *Collector) Snapshot() Snapshot {
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
for _, domain := range c.domains {
tunnel := c.tunnels[domain]
active := activeClients(tunnel, now)
if active > tunnel.peakClients {
tunnel.peakClients = active
}
updatePeaks(tunnel, now)
series := c.seriesSnapshotsLocked(domain, tunnel, now)
snapshot.Tunnels[domain] = TunnelSnapshot{
Domain: domain,
ActiveClients: active,
PeakClients: tunnel.peakClients,
TotalSessions: len(tunnel.clients),
TotalQueries: tunnel.queries,
BytesIn: tunnel.bytesIn,
BytesOut: tunnel.bytesOut,
tunnelSnapshot := TunnelSnapshot{Domain: domain, Series: series}
for _, s := range series {
tunnelSnapshot.ActiveClients += s.ActiveClients
tunnelSnapshot.TotalSessions += s.TotalSessions
tunnelSnapshot.TotalQueries += s.TotalQueries
tunnelSnapshot.BytesIn += s.BytesIn
tunnelSnapshot.BytesOut += s.BytesOut
if s.PeakClients > tunnelSnapshot.PeakClients {
tunnelSnapshot.PeakClients = s.PeakClients
}
}
snapshot.Tunnels[domain] = tunnelSnapshot
}
return snapshot
@ -188,12 +244,103 @@ func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
return nil, false
}
func activeClients(tunnel *tunnelState, now time.Time) int {
active := 0
func ensureSeries(tunnel *tunnelState, key geoKey) *seriesState {
series, ok := tunnel.series[key]
if !ok {
series = &seriesState{}
tunnel.series[key] = series
}
return series
}
func updatePeaks(tunnel *tunnelState, now time.Time) {
activeByKey := make(map[geoKey]int)
for _, client := range tunnel.clients {
if now.Sub(client.lastSeen) < ClientTimeout {
active++
activeByKey[client.lastKey]++
}
}
for key, active := range activeByKey {
series := ensureSeries(tunnel, key)
if active > series.peakClients {
series.peakClients = active
}
}
return active
}
func (c *Collector) seriesSnapshotsLocked(domain string, tunnel *tunnelState, now time.Time) []SeriesSnapshot {
activeByKey := make(map[geoKey]int)
sessionsByKey := make(map[geoKey]int)
for _, client := range tunnel.clients {
sessionsByKey[client.firstKey]++
if now.Sub(client.lastSeen) < ClientTimeout {
activeByKey[client.lastKey]++
}
}
keys := make([]geoKey, 0, len(tunnel.series))
for key := range tunnel.series {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].country != keys[j].country {
return keys[i].country < keys[j].country
}
return keys[i].asn < keys[j].asn
})
series := make([]SeriesSnapshot, 0, len(keys))
for _, key := range keys {
state := tunnel.series[key]
series = append(series, SeriesSnapshot{
Domain: domain,
Country: key.country,
ASN: key.asn,
ActiveClients: activeByKey[key],
PeakClients: state.peakClients,
TotalSessions: sessionsByKey[key],
TotalQueries: state.queries,
BytesIn: state.bytesIn,
BytesOut: state.bytesOut,
})
}
return series
}
func (c *Collector) geoLabelNamesLocked() []string {
if c.geo == nil {
return nil
}
names := c.geo.LabelNames()
out := make([]string, 0, len(names))
for _, name := range names {
if name == "country" || name == "asn" {
out = append(out, name)
}
}
return out
}
func (c *Collector) geoKeyLocked(addr netip.Addr) geoKey {
var key geoKey
if c.geo == nil {
return key
}
labels := c.geo.Lookup(addr)
for _, name := range c.geoLabelNamesLocked() {
switch name {
case "country":
key.country = labels.Country
if key.country == "" {
key.country = UnknownCountry
}
case "asn":
key.asn = labels.ASN
if key.asn == "" {
key.asn = UnknownASN
}
}
}
return key
}