199 lines
4.2 KiB
Go
199 lines
4.2 KiB
Go
package dnstt
|
|
|
|
import (
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// ClientTimeout is how long since the last query before a session is considered inactive.
|
|
const ClientTimeout = 30 * time.Second
|
|
|
|
// Collector accumulates observed DNSTT DNS traffic.
|
|
type Collector struct {
|
|
mu sync.Mutex
|
|
domains []string
|
|
tunnels map[string]*tunnelState
|
|
now func() time.Time
|
|
}
|
|
|
|
type tunnelState struct {
|
|
queries uint64
|
|
bytesIn uint64
|
|
bytesOut uint64
|
|
peakClients int
|
|
clients map[string]*clientState
|
|
}
|
|
|
|
type clientState struct {
|
|
firstSeen time.Time
|
|
lastSeen time.Time
|
|
queries uint64
|
|
bytesIn uint64
|
|
}
|
|
|
|
// Snapshot is a point-in-time copy of observed traffic.
|
|
type Snapshot struct {
|
|
Tunnels map[string]TunnelSnapshot
|
|
}
|
|
|
|
// TunnelSnapshot is a point-in-time copy of one DNSTT domain.
|
|
type TunnelSnapshot struct {
|
|
Domain string
|
|
ActiveClients int
|
|
PeakClients int
|
|
TotalSessions int
|
|
TotalQueries uint64
|
|
BytesIn uint64
|
|
BytesOut uint64
|
|
}
|
|
|
|
// CollectorOption configures a Collector.
|
|
type CollectorOption func(*Collector)
|
|
|
|
// WithNow overrides the clock. It is intended for deterministic tests.
|
|
func WithNow(now func() time.Time) CollectorOption {
|
|
return func(c *Collector) {
|
|
c.now = now
|
|
}
|
|
}
|
|
|
|
// NewCollector creates a collector for the provided DNSTT domains.
|
|
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
|
|
c := &Collector{
|
|
tunnels: make(map[string]*tunnelState),
|
|
now: time.Now,
|
|
}
|
|
|
|
for _, domain := range domains {
|
|
domain = normalizeDomain(domain)
|
|
if domain == "" {
|
|
continue
|
|
}
|
|
if _, exists := c.tunnels[domain]; exists {
|
|
continue
|
|
}
|
|
c.domains = append(c.domains, domain)
|
|
c.tunnels[domain] = &tunnelState{
|
|
clients: make(map[string]*clientState),
|
|
}
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(c)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// Domains returns the normalized DNSTT domains this collector recognizes.
|
|
func (c *Collector) Domains() []string {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
domains := make([]string, len(c.domains))
|
|
copy(domains, c.domains)
|
|
return domains
|
|
}
|
|
|
|
// RecordQuery records an observed DNSTT DNS query.
|
|
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
tunnel, ok := c.findTunnelLocked(domain)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
tunnel.queries++
|
|
if size > 0 {
|
|
tunnel.bytesIn += uint64(size)
|
|
}
|
|
|
|
if clientID == "" {
|
|
return
|
|
}
|
|
|
|
now := c.now()
|
|
client, exists := tunnel.clients[clientID]
|
|
if !exists {
|
|
client = &clientState{firstSeen: now}
|
|
tunnel.clients[clientID] = client
|
|
}
|
|
client.lastSeen = now
|
|
client.queries++
|
|
if size > 0 {
|
|
client.bytesIn += uint64(size)
|
|
}
|
|
|
|
active := activeClients(tunnel, now)
|
|
if active > tunnel.peakClients {
|
|
tunnel.peakClients = active
|
|
}
|
|
}
|
|
|
|
// RecordResponse records an observed DNSTT DNS response.
|
|
func (c *Collector) RecordResponse(domain string, size int) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
tunnel, ok := c.findTunnelLocked(domain)
|
|
if !ok {
|
|
return
|
|
}
|
|
if size > 0 {
|
|
tunnel.bytesOut += uint64(size)
|
|
}
|
|
}
|
|
|
|
// Snapshot returns a stable copy of all current metrics.
|
|
func (c *Collector) Snapshot() Snapshot {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
now := c.now()
|
|
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
|
|
for _, domain := range c.domains {
|
|
tunnel := c.tunnels[domain]
|
|
active := activeClients(tunnel, now)
|
|
if active > tunnel.peakClients {
|
|
tunnel.peakClients = active
|
|
}
|
|
|
|
snapshot.Tunnels[domain] = TunnelSnapshot{
|
|
Domain: domain,
|
|
ActiveClients: active,
|
|
PeakClients: tunnel.peakClients,
|
|
TotalSessions: len(tunnel.clients),
|
|
TotalQueries: tunnel.queries,
|
|
BytesIn: tunnel.bytesIn,
|
|
BytesOut: tunnel.bytesOut,
|
|
}
|
|
}
|
|
|
|
return snapshot
|
|
}
|
|
|
|
func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
|
|
queryDomain = normalizeDomain(queryDomain)
|
|
if tunnel, ok := c.tunnels[queryDomain]; ok {
|
|
return tunnel, true
|
|
}
|
|
for _, domain := range c.domains {
|
|
if strings.HasSuffix(queryDomain, "."+domain) {
|
|
return c.tunnels[domain], true
|
|
}
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func activeClients(tunnel *tunnelState, now time.Time) int {
|
|
active := 0
|
|
for _, client := range tunnel.clients {
|
|
if now.Sub(client.lastSeen) < ClientTimeout {
|
|
active++
|
|
}
|
|
}
|
|
return active
|
|
}
|