dnstt_exporter/internal/dnstt/collector.go

346 lines
7.9 KiB
Go

package dnstt
import (
"net/netip"
"sort"
"strings"
"sync"
"time"
)
// ClientTimeout is how long since the last query before a session is considered inactive.
const ClientTimeout = 30 * time.Second
// Collector accumulates observed DNSTT DNS traffic.
type Collector struct {
mu sync.Mutex
domains []string
tunnels map[string]*tunnelState
now func() time.Time
geo GeoResolver
}
type tunnelState struct {
series map[geoKey]*seriesState
clients map[string]*clientState
}
type seriesState struct {
queries uint64
bytesIn uint64
bytesOut uint64
peakClients int
}
type clientState struct {
firstSeen time.Time
lastSeen time.Time
queries uint64
bytesIn uint64
firstKey geoKey
lastKey geoKey
}
type geoKey struct {
country string
asn string
}
// Snapshot is a point-in-time copy of observed traffic.
type Snapshot struct {
Tunnels map[string]TunnelSnapshot
}
// TunnelSnapshot is a point-in-time copy of one DNSTT domain.
type TunnelSnapshot struct {
Domain string
ActiveClients int
PeakClients int
TotalSessions int
TotalQueries uint64
BytesIn uint64
BytesOut uint64
Series []SeriesSnapshot
}
// SeriesSnapshot is a per-label metric series for a tunnel.
type SeriesSnapshot struct {
Domain string
Country string
ASN string
ActiveClients int
PeakClients int
TotalSessions int
TotalQueries uint64
BytesIn uint64
BytesOut uint64
}
// CollectorOption configures a Collector.
type CollectorOption func(*Collector)
// WithNow overrides the clock. It is intended for deterministic tests.
func WithNow(now func() time.Time) CollectorOption {
return func(c *Collector) {
c.now = now
}
}
// WithGeoResolver enables optional GeoIP labels for resolver IP addresses.
func WithGeoResolver(resolver GeoResolver) CollectorOption {
return func(c *Collector) {
c.geo = resolver
}
}
// NewCollector creates a collector for the provided DNSTT domains.
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
c := &Collector{
tunnels: make(map[string]*tunnelState),
now: time.Now,
}
for _, domain := range domains {
domain = normalizeDomain(domain)
if domain == "" {
continue
}
if _, exists := c.tunnels[domain]; exists {
continue
}
c.domains = append(c.domains, domain)
c.tunnels[domain] = &tunnelState{
series: make(map[geoKey]*seriesState),
clients: make(map[string]*clientState),
}
}
for _, opt := range opts {
opt(c)
}
return c
}
// GeoLabelNames returns the optional GeoIP label names used by this collector.
func (c *Collector) GeoLabelNames() []string {
c.mu.Lock()
defer c.mu.Unlock()
return c.geoLabelNamesLocked()
}
// Domains returns the normalized DNSTT domains this collector recognizes.
func (c *Collector) Domains() []string {
c.mu.Lock()
defer c.mu.Unlock()
domains := make([]string, len(c.domains))
copy(domains, c.domains)
return domains
}
// RecordQuery records an observed DNSTT DNS query.
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
c.RecordQueryFrom(domain, clientID, netip.Addr{}, size)
}
// RecordQueryFrom records an observed DNSTT DNS query from a resolver address.
func (c *Collector) RecordQueryFrom(domain string, clientID string, resolverIP netip.Addr, size int) {
c.mu.Lock()
defer c.mu.Unlock()
tunnel, ok := c.findTunnelLocked(domain)
if !ok {
return
}
key := c.geoKeyLocked(resolverIP)
series := ensureSeries(tunnel, key)
series.queries++
if size > 0 {
series.bytesIn += uint64(size)
}
if clientID == "" {
return
}
now := c.now()
client, exists := tunnel.clients[clientID]
if !exists {
client = &clientState{firstSeen: now, firstKey: key}
tunnel.clients[clientID] = client
}
client.lastSeen = now
client.lastKey = key
client.queries++
if size > 0 {
client.bytesIn += uint64(size)
}
updatePeaks(tunnel, now)
}
// RecordResponse records an observed DNSTT DNS response.
func (c *Collector) RecordResponse(domain string, size int) {
c.RecordResponseFrom(domain, netip.Addr{}, size)
}
// RecordResponseFrom records an observed DNSTT DNS response to a resolver address.
func (c *Collector) RecordResponseFrom(domain string, resolverIP netip.Addr, size int) {
c.mu.Lock()
defer c.mu.Unlock()
tunnel, ok := c.findTunnelLocked(domain)
if !ok {
return
}
key := c.geoKeyLocked(resolverIP)
series := ensureSeries(tunnel, key)
if size > 0 {
series.bytesOut += uint64(size)
}
}
// Snapshot returns a stable copy of all current metrics.
func (c *Collector) Snapshot() Snapshot {
c.mu.Lock()
defer c.mu.Unlock()
now := c.now()
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
for _, domain := range c.domains {
tunnel := c.tunnels[domain]
updatePeaks(tunnel, now)
series := c.seriesSnapshotsLocked(domain, tunnel, now)
tunnelSnapshot := TunnelSnapshot{Domain: domain, Series: series}
for _, s := range series {
tunnelSnapshot.ActiveClients += s.ActiveClients
tunnelSnapshot.TotalSessions += s.TotalSessions
tunnelSnapshot.TotalQueries += s.TotalQueries
tunnelSnapshot.BytesIn += s.BytesIn
tunnelSnapshot.BytesOut += s.BytesOut
if s.PeakClients > tunnelSnapshot.PeakClients {
tunnelSnapshot.PeakClients = s.PeakClients
}
}
snapshot.Tunnels[domain] = tunnelSnapshot
}
return snapshot
}
func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
queryDomain = normalizeDomain(queryDomain)
if tunnel, ok := c.tunnels[queryDomain]; ok {
return tunnel, true
}
for _, domain := range c.domains {
if strings.HasSuffix(queryDomain, "."+domain) {
return c.tunnels[domain], true
}
}
return nil, false
}
func ensureSeries(tunnel *tunnelState, key geoKey) *seriesState {
series, ok := tunnel.series[key]
if !ok {
series = &seriesState{}
tunnel.series[key] = series
}
return series
}
func updatePeaks(tunnel *tunnelState, now time.Time) {
activeByKey := make(map[geoKey]int)
for _, client := range tunnel.clients {
if now.Sub(client.lastSeen) < ClientTimeout {
activeByKey[client.lastKey]++
}
}
for key, active := range activeByKey {
series := ensureSeries(tunnel, key)
if active > series.peakClients {
series.peakClients = active
}
}
}
func (c *Collector) seriesSnapshotsLocked(domain string, tunnel *tunnelState, now time.Time) []SeriesSnapshot {
activeByKey := make(map[geoKey]int)
sessionsByKey := make(map[geoKey]int)
for _, client := range tunnel.clients {
sessionsByKey[client.firstKey]++
if now.Sub(client.lastSeen) < ClientTimeout {
activeByKey[client.lastKey]++
}
}
keys := make([]geoKey, 0, len(tunnel.series))
for key := range tunnel.series {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].country != keys[j].country {
return keys[i].country < keys[j].country
}
return keys[i].asn < keys[j].asn
})
series := make([]SeriesSnapshot, 0, len(keys))
for _, key := range keys {
state := tunnel.series[key]
series = append(series, SeriesSnapshot{
Domain: domain,
Country: key.country,
ASN: key.asn,
ActiveClients: activeByKey[key],
PeakClients: state.peakClients,
TotalSessions: sessionsByKey[key],
TotalQueries: state.queries,
BytesIn: state.bytesIn,
BytesOut: state.bytesOut,
})
}
return series
}
func (c *Collector) geoLabelNamesLocked() []string {
if c.geo == nil {
return nil
}
names := c.geo.LabelNames()
out := make([]string, 0, len(names))
for _, name := range names {
if name == "country" || name == "asn" {
out = append(out, name)
}
}
return out
}
func (c *Collector) geoKeyLocked(addr netip.Addr) geoKey {
var key geoKey
if c.geo == nil {
return key
}
labels := c.geo.Lookup(addr)
for _, name := range c.geoLabelNamesLocked() {
switch name {
case "country":
key.country = labels.Country
if key.country == "" {
key.country = UnknownCountry
}
case "asn":
key.asn = labels.ASN
if key.asn == "" {
key.asn = UnknownASN
}
}
}
return key
}