346 lines
7.9 KiB
Go
346 lines
7.9 KiB
Go
package dnstt
|
|
|
|
import (
|
|
"net/netip"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// ClientTimeout is how long since the last query before a session is considered inactive.
|
|
const ClientTimeout = 30 * time.Second
|
|
|
|
// Collector accumulates observed DNSTT DNS traffic.
|
|
type Collector struct {
|
|
mu sync.Mutex
|
|
domains []string
|
|
tunnels map[string]*tunnelState
|
|
now func() time.Time
|
|
geo GeoResolver
|
|
}
|
|
|
|
type tunnelState struct {
|
|
series map[geoKey]*seriesState
|
|
clients map[string]*clientState
|
|
}
|
|
|
|
type seriesState struct {
|
|
queries uint64
|
|
bytesIn uint64
|
|
bytesOut uint64
|
|
peakClients int
|
|
}
|
|
|
|
type clientState struct {
|
|
firstSeen time.Time
|
|
lastSeen time.Time
|
|
queries uint64
|
|
bytesIn uint64
|
|
firstKey geoKey
|
|
lastKey geoKey
|
|
}
|
|
|
|
type geoKey struct {
|
|
country string
|
|
asn string
|
|
}
|
|
|
|
// Snapshot is a point-in-time copy of observed traffic.
|
|
type Snapshot struct {
|
|
Tunnels map[string]TunnelSnapshot
|
|
}
|
|
|
|
// TunnelSnapshot is a point-in-time copy of one DNSTT domain.
|
|
type TunnelSnapshot struct {
|
|
Domain string
|
|
ActiveClients int
|
|
PeakClients int
|
|
TotalSessions int
|
|
TotalQueries uint64
|
|
BytesIn uint64
|
|
BytesOut uint64
|
|
Series []SeriesSnapshot
|
|
}
|
|
|
|
// SeriesSnapshot is a per-label metric series for a tunnel.
|
|
type SeriesSnapshot struct {
|
|
Domain string
|
|
Country string
|
|
ASN string
|
|
ActiveClients int
|
|
PeakClients int
|
|
TotalSessions int
|
|
TotalQueries uint64
|
|
BytesIn uint64
|
|
BytesOut uint64
|
|
}
|
|
|
|
// CollectorOption configures a Collector.
|
|
type CollectorOption func(*Collector)
|
|
|
|
// WithNow overrides the clock. It is intended for deterministic tests.
|
|
func WithNow(now func() time.Time) CollectorOption {
|
|
return func(c *Collector) {
|
|
c.now = now
|
|
}
|
|
}
|
|
|
|
// WithGeoResolver enables optional GeoIP labels for resolver IP addresses.
|
|
func WithGeoResolver(resolver GeoResolver) CollectorOption {
|
|
return func(c *Collector) {
|
|
c.geo = resolver
|
|
}
|
|
}
|
|
|
|
// NewCollector creates a collector for the provided DNSTT domains.
|
|
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
|
|
c := &Collector{
|
|
tunnels: make(map[string]*tunnelState),
|
|
now: time.Now,
|
|
}
|
|
|
|
for _, domain := range domains {
|
|
domain = normalizeDomain(domain)
|
|
if domain == "" {
|
|
continue
|
|
}
|
|
if _, exists := c.tunnels[domain]; exists {
|
|
continue
|
|
}
|
|
c.domains = append(c.domains, domain)
|
|
c.tunnels[domain] = &tunnelState{
|
|
series: make(map[geoKey]*seriesState),
|
|
clients: make(map[string]*clientState),
|
|
}
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(c)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// GeoLabelNames returns the optional GeoIP label names used by this collector.
|
|
func (c *Collector) GeoLabelNames() []string {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.geoLabelNamesLocked()
|
|
}
|
|
|
|
// Domains returns the normalized DNSTT domains this collector recognizes.
|
|
func (c *Collector) Domains() []string {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
domains := make([]string, len(c.domains))
|
|
copy(domains, c.domains)
|
|
return domains
|
|
}
|
|
|
|
// RecordQuery records an observed DNSTT DNS query.
|
|
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
|
|
c.RecordQueryFrom(domain, clientID, netip.Addr{}, size)
|
|
}
|
|
|
|
// RecordQueryFrom records an observed DNSTT DNS query from a resolver address.
|
|
func (c *Collector) RecordQueryFrom(domain string, clientID string, resolverIP netip.Addr, size int) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
tunnel, ok := c.findTunnelLocked(domain)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
key := c.geoKeyLocked(resolverIP)
|
|
series := ensureSeries(tunnel, key)
|
|
series.queries++
|
|
if size > 0 {
|
|
series.bytesIn += uint64(size)
|
|
}
|
|
|
|
if clientID == "" {
|
|
return
|
|
}
|
|
|
|
now := c.now()
|
|
client, exists := tunnel.clients[clientID]
|
|
if !exists {
|
|
client = &clientState{firstSeen: now, firstKey: key}
|
|
tunnel.clients[clientID] = client
|
|
}
|
|
client.lastSeen = now
|
|
client.lastKey = key
|
|
client.queries++
|
|
if size > 0 {
|
|
client.bytesIn += uint64(size)
|
|
}
|
|
|
|
updatePeaks(tunnel, now)
|
|
}
|
|
|
|
// RecordResponse records an observed DNSTT DNS response.
|
|
func (c *Collector) RecordResponse(domain string, size int) {
|
|
c.RecordResponseFrom(domain, netip.Addr{}, size)
|
|
}
|
|
|
|
// RecordResponseFrom records an observed DNSTT DNS response to a resolver address.
|
|
func (c *Collector) RecordResponseFrom(domain string, resolverIP netip.Addr, size int) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
tunnel, ok := c.findTunnelLocked(domain)
|
|
if !ok {
|
|
return
|
|
}
|
|
key := c.geoKeyLocked(resolverIP)
|
|
series := ensureSeries(tunnel, key)
|
|
if size > 0 {
|
|
series.bytesOut += uint64(size)
|
|
}
|
|
}
|
|
|
|
// Snapshot returns a stable copy of all current metrics.
|
|
func (c *Collector) Snapshot() Snapshot {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
now := c.now()
|
|
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
|
|
for _, domain := range c.domains {
|
|
tunnel := c.tunnels[domain]
|
|
updatePeaks(tunnel, now)
|
|
series := c.seriesSnapshotsLocked(domain, tunnel, now)
|
|
|
|
tunnelSnapshot := TunnelSnapshot{Domain: domain, Series: series}
|
|
for _, s := range series {
|
|
tunnelSnapshot.ActiveClients += s.ActiveClients
|
|
tunnelSnapshot.TotalSessions += s.TotalSessions
|
|
tunnelSnapshot.TotalQueries += s.TotalQueries
|
|
tunnelSnapshot.BytesIn += s.BytesIn
|
|
tunnelSnapshot.BytesOut += s.BytesOut
|
|
if s.PeakClients > tunnelSnapshot.PeakClients {
|
|
tunnelSnapshot.PeakClients = s.PeakClients
|
|
}
|
|
}
|
|
snapshot.Tunnels[domain] = tunnelSnapshot
|
|
}
|
|
|
|
return snapshot
|
|
}
|
|
|
|
func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
|
|
queryDomain = normalizeDomain(queryDomain)
|
|
if tunnel, ok := c.tunnels[queryDomain]; ok {
|
|
return tunnel, true
|
|
}
|
|
for _, domain := range c.domains {
|
|
if strings.HasSuffix(queryDomain, "."+domain) {
|
|
return c.tunnels[domain], true
|
|
}
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func ensureSeries(tunnel *tunnelState, key geoKey) *seriesState {
|
|
series, ok := tunnel.series[key]
|
|
if !ok {
|
|
series = &seriesState{}
|
|
tunnel.series[key] = series
|
|
}
|
|
return series
|
|
}
|
|
|
|
func updatePeaks(tunnel *tunnelState, now time.Time) {
|
|
activeByKey := make(map[geoKey]int)
|
|
for _, client := range tunnel.clients {
|
|
if now.Sub(client.lastSeen) < ClientTimeout {
|
|
activeByKey[client.lastKey]++
|
|
}
|
|
}
|
|
for key, active := range activeByKey {
|
|
series := ensureSeries(tunnel, key)
|
|
if active > series.peakClients {
|
|
series.peakClients = active
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Collector) seriesSnapshotsLocked(domain string, tunnel *tunnelState, now time.Time) []SeriesSnapshot {
|
|
activeByKey := make(map[geoKey]int)
|
|
sessionsByKey := make(map[geoKey]int)
|
|
for _, client := range tunnel.clients {
|
|
sessionsByKey[client.firstKey]++
|
|
if now.Sub(client.lastSeen) < ClientTimeout {
|
|
activeByKey[client.lastKey]++
|
|
}
|
|
}
|
|
|
|
keys := make([]geoKey, 0, len(tunnel.series))
|
|
for key := range tunnel.series {
|
|
keys = append(keys, key)
|
|
}
|
|
sort.Slice(keys, func(i, j int) bool {
|
|
if keys[i].country != keys[j].country {
|
|
return keys[i].country < keys[j].country
|
|
}
|
|
return keys[i].asn < keys[j].asn
|
|
})
|
|
|
|
series := make([]SeriesSnapshot, 0, len(keys))
|
|
for _, key := range keys {
|
|
state := tunnel.series[key]
|
|
series = append(series, SeriesSnapshot{
|
|
Domain: domain,
|
|
Country: key.country,
|
|
ASN: key.asn,
|
|
ActiveClients: activeByKey[key],
|
|
PeakClients: state.peakClients,
|
|
TotalSessions: sessionsByKey[key],
|
|
TotalQueries: state.queries,
|
|
BytesIn: state.bytesIn,
|
|
BytesOut: state.bytesOut,
|
|
})
|
|
}
|
|
return series
|
|
}
|
|
|
|
func (c *Collector) geoLabelNamesLocked() []string {
|
|
if c.geo == nil {
|
|
return nil
|
|
}
|
|
names := c.geo.LabelNames()
|
|
out := make([]string, 0, len(names))
|
|
for _, name := range names {
|
|
if name == "country" || name == "asn" {
|
|
out = append(out, name)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (c *Collector) geoKeyLocked(addr netip.Addr) geoKey {
|
|
var key geoKey
|
|
if c.geo == nil {
|
|
return key
|
|
}
|
|
|
|
labels := c.geo.Lookup(addr)
|
|
for _, name := range c.geoLabelNamesLocked() {
|
|
switch name {
|
|
case "country":
|
|
key.country = labels.Country
|
|
if key.country == "" {
|
|
key.country = UnknownCountry
|
|
}
|
|
case "asn":
|
|
key.asn = labels.ASN
|
|
if key.asn == "" {
|
|
key.asn = UnknownASN
|
|
}
|
|
}
|
|
}
|
|
return key
|
|
}
|