add geoip country/asn labels and ipv6
This commit is contained in:
parent
8318f9fe70
commit
4710df2523
12 changed files with 559 additions and 43 deletions
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
// OpenRawSocket creates an AF_PACKET raw socket for sniffing IPv4 packets.
|
||||
func OpenRawSocket() (int, error) {
|
||||
fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(syscall.ETH_P_IP)))
|
||||
fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(ethPAll)))
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("open raw socket: %w", err)
|
||||
}
|
||||
|
|
@ -36,7 +36,7 @@ func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
|
|||
}
|
||||
return
|
||||
}
|
||||
ProcessIPv4Packet(buf[:n], port, collector)
|
||||
ProcessPacket(buf[:n], port, collector)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -48,3 +48,5 @@ func CloseRawSocket(fd int) error {
|
|||
func htons(v uint16) uint16 {
|
||||
return (v << 8) | (v >> 8)
|
||||
}
|
||||
|
||||
const ethPAll = 0x0003
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -15,14 +17,19 @@ type Collector struct {
|
|||
domains []string
|
||||
tunnels map[string]*tunnelState
|
||||
now func() time.Time
|
||||
geo GeoResolver
|
||||
}
|
||||
|
||||
type tunnelState struct {
|
||||
series map[geoKey]*seriesState
|
||||
clients map[string]*clientState
|
||||
}
|
||||
|
||||
type seriesState struct {
|
||||
queries uint64
|
||||
bytesIn uint64
|
||||
bytesOut uint64
|
||||
peakClients int
|
||||
clients map[string]*clientState
|
||||
}
|
||||
|
||||
type clientState struct {
|
||||
|
|
@ -30,6 +37,13 @@ type clientState struct {
|
|||
lastSeen time.Time
|
||||
queries uint64
|
||||
bytesIn uint64
|
||||
firstKey geoKey
|
||||
lastKey geoKey
|
||||
}
|
||||
|
||||
type geoKey struct {
|
||||
country string
|
||||
asn string
|
||||
}
|
||||
|
||||
// Snapshot is a point-in-time copy of observed traffic.
|
||||
|
|
@ -46,6 +60,20 @@ type TunnelSnapshot struct {
|
|||
TotalQueries uint64
|
||||
BytesIn uint64
|
||||
BytesOut uint64
|
||||
Series []SeriesSnapshot
|
||||
}
|
||||
|
||||
// SeriesSnapshot is a per-label metric series for a tunnel.
|
||||
type SeriesSnapshot struct {
|
||||
Domain string
|
||||
Country string
|
||||
ASN string
|
||||
ActiveClients int
|
||||
PeakClients int
|
||||
TotalSessions int
|
||||
TotalQueries uint64
|
||||
BytesIn uint64
|
||||
BytesOut uint64
|
||||
}
|
||||
|
||||
// CollectorOption configures a Collector.
|
||||
|
|
@ -58,6 +86,13 @@ func WithNow(now func() time.Time) CollectorOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithGeoResolver enables optional GeoIP labels for resolver IP addresses.
|
||||
func WithGeoResolver(resolver GeoResolver) CollectorOption {
|
||||
return func(c *Collector) {
|
||||
c.geo = resolver
|
||||
}
|
||||
}
|
||||
|
||||
// NewCollector creates a collector for the provided DNSTT domains.
|
||||
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
|
||||
c := &Collector{
|
||||
|
|
@ -75,6 +110,7 @@ func NewCollector(domains []string, opts ...CollectorOption) *Collector {
|
|||
}
|
||||
c.domains = append(c.domains, domain)
|
||||
c.tunnels[domain] = &tunnelState{
|
||||
series: make(map[geoKey]*seriesState),
|
||||
clients: make(map[string]*clientState),
|
||||
}
|
||||
}
|
||||
|
|
@ -86,6 +122,13 @@ func NewCollector(domains []string, opts ...CollectorOption) *Collector {
|
|||
return c
|
||||
}
|
||||
|
||||
// GeoLabelNames returns the optional GeoIP label names used by this collector.
|
||||
func (c *Collector) GeoLabelNames() []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.geoLabelNamesLocked()
|
||||
}
|
||||
|
||||
// Domains returns the normalized DNSTT domains this collector recognizes.
|
||||
func (c *Collector) Domains() []string {
|
||||
c.mu.Lock()
|
||||
|
|
@ -98,6 +141,11 @@ func (c *Collector) Domains() []string {
|
|||
|
||||
// RecordQuery records an observed DNSTT DNS query.
|
||||
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
|
||||
c.RecordQueryFrom(domain, clientID, netip.Addr{}, size)
|
||||
}
|
||||
|
||||
// RecordQueryFrom records an observed DNSTT DNS query from a resolver address.
|
||||
func (c *Collector) RecordQueryFrom(domain string, clientID string, resolverIP netip.Addr, size int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
|
@ -106,9 +154,11 @@ func (c *Collector) RecordQuery(domain string, clientID string, size int) {
|
|||
return
|
||||
}
|
||||
|
||||
tunnel.queries++
|
||||
key := c.geoKeyLocked(resolverIP)
|
||||
series := ensureSeries(tunnel, key)
|
||||
series.queries++
|
||||
if size > 0 {
|
||||
tunnel.bytesIn += uint64(size)
|
||||
series.bytesIn += uint64(size)
|
||||
}
|
||||
|
||||
if clientID == "" {
|
||||
|
|
@ -118,23 +168,26 @@ func (c *Collector) RecordQuery(domain string, clientID string, size int) {
|
|||
now := c.now()
|
||||
client, exists := tunnel.clients[clientID]
|
||||
if !exists {
|
||||
client = &clientState{firstSeen: now}
|
||||
client = &clientState{firstSeen: now, firstKey: key}
|
||||
tunnel.clients[clientID] = client
|
||||
}
|
||||
client.lastSeen = now
|
||||
client.lastKey = key
|
||||
client.queries++
|
||||
if size > 0 {
|
||||
client.bytesIn += uint64(size)
|
||||
}
|
||||
|
||||
active := activeClients(tunnel, now)
|
||||
if active > tunnel.peakClients {
|
||||
tunnel.peakClients = active
|
||||
}
|
||||
updatePeaks(tunnel, now)
|
||||
}
|
||||
|
||||
// RecordResponse records an observed DNSTT DNS response.
|
||||
func (c *Collector) RecordResponse(domain string, size int) {
|
||||
c.RecordResponseFrom(domain, netip.Addr{}, size)
|
||||
}
|
||||
|
||||
// RecordResponseFrom records an observed DNSTT DNS response to a resolver address.
|
||||
func (c *Collector) RecordResponseFrom(domain string, resolverIP netip.Addr, size int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
|
@ -142,8 +195,10 @@ func (c *Collector) RecordResponse(domain string, size int) {
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
key := c.geoKeyLocked(resolverIP)
|
||||
series := ensureSeries(tunnel, key)
|
||||
if size > 0 {
|
||||
tunnel.bytesOut += uint64(size)
|
||||
series.bytesOut += uint64(size)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -156,20 +211,21 @@ func (c *Collector) Snapshot() Snapshot {
|
|||
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
|
||||
for _, domain := range c.domains {
|
||||
tunnel := c.tunnels[domain]
|
||||
active := activeClients(tunnel, now)
|
||||
if active > tunnel.peakClients {
|
||||
tunnel.peakClients = active
|
||||
}
|
||||
updatePeaks(tunnel, now)
|
||||
series := c.seriesSnapshotsLocked(domain, tunnel, now)
|
||||
|
||||
snapshot.Tunnels[domain] = TunnelSnapshot{
|
||||
Domain: domain,
|
||||
ActiveClients: active,
|
||||
PeakClients: tunnel.peakClients,
|
||||
TotalSessions: len(tunnel.clients),
|
||||
TotalQueries: tunnel.queries,
|
||||
BytesIn: tunnel.bytesIn,
|
||||
BytesOut: tunnel.bytesOut,
|
||||
tunnelSnapshot := TunnelSnapshot{Domain: domain, Series: series}
|
||||
for _, s := range series {
|
||||
tunnelSnapshot.ActiveClients += s.ActiveClients
|
||||
tunnelSnapshot.TotalSessions += s.TotalSessions
|
||||
tunnelSnapshot.TotalQueries += s.TotalQueries
|
||||
tunnelSnapshot.BytesIn += s.BytesIn
|
||||
tunnelSnapshot.BytesOut += s.BytesOut
|
||||
if s.PeakClients > tunnelSnapshot.PeakClients {
|
||||
tunnelSnapshot.PeakClients = s.PeakClients
|
||||
}
|
||||
}
|
||||
snapshot.Tunnels[domain] = tunnelSnapshot
|
||||
}
|
||||
|
||||
return snapshot
|
||||
|
|
@ -188,12 +244,103 @@ func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func activeClients(tunnel *tunnelState, now time.Time) int {
|
||||
active := 0
|
||||
func ensureSeries(tunnel *tunnelState, key geoKey) *seriesState {
|
||||
series, ok := tunnel.series[key]
|
||||
if !ok {
|
||||
series = &seriesState{}
|
||||
tunnel.series[key] = series
|
||||
}
|
||||
return series
|
||||
}
|
||||
|
||||
func updatePeaks(tunnel *tunnelState, now time.Time) {
|
||||
activeByKey := make(map[geoKey]int)
|
||||
for _, client := range tunnel.clients {
|
||||
if now.Sub(client.lastSeen) < ClientTimeout {
|
||||
active++
|
||||
activeByKey[client.lastKey]++
|
||||
}
|
||||
}
|
||||
for key, active := range activeByKey {
|
||||
series := ensureSeries(tunnel, key)
|
||||
if active > series.peakClients {
|
||||
series.peakClients = active
|
||||
}
|
||||
}
|
||||
return active
|
||||
}
|
||||
|
||||
func (c *Collector) seriesSnapshotsLocked(domain string, tunnel *tunnelState, now time.Time) []SeriesSnapshot {
|
||||
activeByKey := make(map[geoKey]int)
|
||||
sessionsByKey := make(map[geoKey]int)
|
||||
for _, client := range tunnel.clients {
|
||||
sessionsByKey[client.firstKey]++
|
||||
if now.Sub(client.lastSeen) < ClientTimeout {
|
||||
activeByKey[client.lastKey]++
|
||||
}
|
||||
}
|
||||
|
||||
keys := make([]geoKey, 0, len(tunnel.series))
|
||||
for key := range tunnel.series {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
if keys[i].country != keys[j].country {
|
||||
return keys[i].country < keys[j].country
|
||||
}
|
||||
return keys[i].asn < keys[j].asn
|
||||
})
|
||||
|
||||
series := make([]SeriesSnapshot, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
state := tunnel.series[key]
|
||||
series = append(series, SeriesSnapshot{
|
||||
Domain: domain,
|
||||
Country: key.country,
|
||||
ASN: key.asn,
|
||||
ActiveClients: activeByKey[key],
|
||||
PeakClients: state.peakClients,
|
||||
TotalSessions: sessionsByKey[key],
|
||||
TotalQueries: state.queries,
|
||||
BytesIn: state.bytesIn,
|
||||
BytesOut: state.bytesOut,
|
||||
})
|
||||
}
|
||||
return series
|
||||
}
|
||||
|
||||
func (c *Collector) geoLabelNamesLocked() []string {
|
||||
if c.geo == nil {
|
||||
return nil
|
||||
}
|
||||
names := c.geo.LabelNames()
|
||||
out := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
if name == "country" || name == "asn" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *Collector) geoKeyLocked(addr netip.Addr) geoKey {
|
||||
var key geoKey
|
||||
if c.geo == nil {
|
||||
return key
|
||||
}
|
||||
|
||||
labels := c.geo.Lookup(addr)
|
||||
for _, name := range c.geoLabelNamesLocked() {
|
||||
switch name {
|
||||
case "country":
|
||||
key.country = labels.Country
|
||||
if key.country == "" {
|
||||
key.country = UnknownCountry
|
||||
}
|
||||
case "asn":
|
||||
key.asn = labels.ASN
|
||||
if key.asn == "" {
|
||||
key.asn = UnknownASN
|
||||
}
|
||||
}
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ type Exporter struct {
|
|||
|
||||
// NewExporter creates a Prometheus collector for DNSTT metrics.
|
||||
func NewExporter(collector *Collector) *Exporter {
|
||||
labels := []string{"domain"}
|
||||
labels := append([]string{"domain"}, collector.GeoLabelNames()...)
|
||||
return &Exporter{
|
||||
collector: collector,
|
||||
activeClients: prometheus.NewDesc(
|
||||
|
|
@ -73,11 +73,27 @@ func (e *Exporter) Describe(ch chan<- *prometheus.Desc) {
|
|||
// Collect sends current metric values to Prometheus.
|
||||
func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
|
||||
for _, tunnel := range e.collector.Snapshot().Tunnels {
|
||||
ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(tunnel.ActiveClients), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(tunnel.PeakClients), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(tunnel.TotalQueries), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(tunnel.BytesIn), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(tunnel.BytesOut), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(tunnel.TotalSessions), tunnel.Domain)
|
||||
for _, series := range tunnel.Series {
|
||||
labels := e.labelValues(series)
|
||||
ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(series.ActiveClients), labels...)
|
||||
ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(series.PeakClients), labels...)
|
||||
ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(series.TotalQueries), labels...)
|
||||
ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(series.BytesIn), labels...)
|
||||
ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(series.BytesOut), labels...)
|
||||
ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(series.TotalSessions), labels...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Exporter) labelValues(series SeriesSnapshot) []string {
|
||||
values := []string{series.Domain}
|
||||
for _, label := range e.collector.GeoLabelNames() {
|
||||
switch label {
|
||||
case "country":
|
||||
values = append(values, series.Country)
|
||||
case "asn":
|
||||
values = append(values, series.ASN)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -40,3 +41,101 @@ dnstt_sessions_total{domain="tunnel.example.com"} 2
|
|||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExporterCollectsGeoIPCountryAndASNLabels(t *testing.T) {
|
||||
now := time.Unix(1000, 0)
|
||||
resolverIP := netip.MustParseAddr("2001:db8::53")
|
||||
c := NewCollector(
|
||||
[]string{"tunnel.example.com"},
|
||||
WithNow(func() time.Time { return now }),
|
||||
WithGeoResolver(fakeGeoResolver{
|
||||
labelNames: []string{"country", "asn"},
|
||||
labels: map[netip.Addr]GeoLabels{
|
||||
resolverIP: {Country: "DE", ASN: "3320"},
|
||||
},
|
||||
}),
|
||||
)
|
||||
c.RecordQueryFrom("tunnel.example.com", "client-a", resolverIP, 100)
|
||||
|
||||
expected := `
|
||||
# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window.
|
||||
# TYPE dnstt_active_clients gauge
|
||||
dnstt_active_clients{asn="3320",country="DE",domain="tunnel.example.com"} 1
|
||||
# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries.
|
||||
# TYPE dnstt_bytes_in_total counter
|
||||
dnstt_bytes_in_total{asn="3320",country="DE",domain="tunnel.example.com"} 100
|
||||
# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses.
|
||||
# TYPE dnstt_bytes_out_total counter
|
||||
dnstt_bytes_out_total{asn="3320",country="DE",domain="tunnel.example.com"} 0
|
||||
# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed.
|
||||
# TYPE dnstt_peak_clients gauge
|
||||
dnstt_peak_clients{asn="3320",country="DE",domain="tunnel.example.com"} 1
|
||||
# HELP dnstt_queries_total Total DNSTT DNS queries observed.
|
||||
# TYPE dnstt_queries_total counter
|
||||
dnstt_queries_total{asn="3320",country="DE",domain="tunnel.example.com"} 1
|
||||
# HELP dnstt_sessions_total Total unique DNSTT client sessions observed.
|
||||
# TYPE dnstt_sessions_total counter
|
||||
dnstt_sessions_total{asn="3320",country="DE",domain="tunnel.example.com"} 1
|
||||
`
|
||||
|
||||
if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExporterCanCollectASNWithoutCountry(t *testing.T) {
|
||||
now := time.Unix(1000, 0)
|
||||
resolverIP := netip.MustParseAddr("192.0.2.53")
|
||||
c := NewCollector(
|
||||
[]string{"tunnel.example.com"},
|
||||
WithNow(func() time.Time { return now }),
|
||||
WithGeoResolver(fakeGeoResolver{
|
||||
labelNames: []string{"asn"},
|
||||
labels: map[netip.Addr]GeoLabels{
|
||||
resolverIP: {ASN: "15169"},
|
||||
},
|
||||
}),
|
||||
)
|
||||
c.RecordQueryFrom("tunnel.example.com", "client-a", resolverIP, 100)
|
||||
|
||||
expected := `
|
||||
# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window.
|
||||
# TYPE dnstt_active_clients gauge
|
||||
dnstt_active_clients{asn="15169",domain="tunnel.example.com"} 1
|
||||
# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries.
|
||||
# TYPE dnstt_bytes_in_total counter
|
||||
dnstt_bytes_in_total{asn="15169",domain="tunnel.example.com"} 100
|
||||
# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses.
|
||||
# TYPE dnstt_bytes_out_total counter
|
||||
dnstt_bytes_out_total{asn="15169",domain="tunnel.example.com"} 0
|
||||
# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed.
|
||||
# TYPE dnstt_peak_clients gauge
|
||||
dnstt_peak_clients{asn="15169",domain="tunnel.example.com"} 1
|
||||
# HELP dnstt_queries_total Total DNSTT DNS queries observed.
|
||||
# TYPE dnstt_queries_total counter
|
||||
dnstt_queries_total{asn="15169",domain="tunnel.example.com"} 1
|
||||
# HELP dnstt_sessions_total Total unique DNSTT client sessions observed.
|
||||
# TYPE dnstt_sessions_total counter
|
||||
dnstt_sessions_total{asn="15169",domain="tunnel.example.com"} 1
|
||||
`
|
||||
|
||||
if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeGeoResolver struct {
|
||||
labelNames []string
|
||||
labels map[netip.Addr]GeoLabels
|
||||
}
|
||||
|
||||
func (f fakeGeoResolver) Lookup(addr netip.Addr) GeoLabels {
|
||||
if labels, ok := f.labels[addr]; ok {
|
||||
return labels
|
||||
}
|
||||
return GeoLabels{Country: UnknownCountry, ASN: UnknownASN}
|
||||
}
|
||||
|
||||
func (f fakeGeoResolver) LabelNames() []string {
|
||||
return f.labelNames
|
||||
}
|
||||
|
|
|
|||
108
internal/dnstt/geoip.go
Normal file
108
internal/dnstt/geoip.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
geoip2 "github.com/oschwald/geoip2-golang/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// UnknownCountry is used when country lookup is enabled but no country is found.
|
||||
UnknownCountry = "ZZ"
|
||||
// UnknownASN is used when ASN lookup is enabled but no ASN is found.
|
||||
UnknownASN = "0"
|
||||
)
|
||||
|
||||
// GeoLabels are optional labels derived from a resolver IP address.
|
||||
type GeoLabels struct {
|
||||
Country string
|
||||
ASN string
|
||||
}
|
||||
|
||||
// GeoResolver resolves optional GeoIP labels for resolver IP addresses.
|
||||
type GeoResolver interface {
|
||||
Lookup(netip.Addr) GeoLabels
|
||||
LabelNames() []string
|
||||
}
|
||||
|
||||
// GeoIPResolver uses optional MaxMind Country and ASN databases.
|
||||
type GeoIPResolver struct {
|
||||
countryDB *geoip2.Reader
|
||||
asnDB *geoip2.Reader
|
||||
}
|
||||
|
||||
// OpenGeoIPResolver opens optional MaxMind Country and ASN databases.
|
||||
func OpenGeoIPResolver(countryDatabase string, asnDatabase string) (*GeoIPResolver, error) {
|
||||
resolver := &GeoIPResolver{}
|
||||
if countryDatabase != "" {
|
||||
db, err := geoip2.Open(countryDatabase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open country database: %w", err)
|
||||
}
|
||||
resolver.countryDB = db
|
||||
}
|
||||
if asnDatabase != "" {
|
||||
db, err := geoip2.Open(asnDatabase)
|
||||
if err != nil {
|
||||
resolver.Close()
|
||||
return nil, fmt.Errorf("open ASN database: %w", err)
|
||||
}
|
||||
resolver.asnDB = db
|
||||
}
|
||||
return resolver, nil
|
||||
}
|
||||
|
||||
// Close closes any open MaxMind databases.
|
||||
func (r *GeoIPResolver) Close() {
|
||||
if r.countryDB != nil {
|
||||
r.countryDB.Close()
|
||||
}
|
||||
if r.asnDB != nil {
|
||||
r.asnDB.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// LabelNames returns the optional Prometheus labels enabled by configured databases.
|
||||
func (r *GeoIPResolver) LabelNames() []string {
|
||||
var labels []string
|
||||
if r.countryDB != nil {
|
||||
labels = append(labels, "country")
|
||||
}
|
||||
if r.asnDB != nil {
|
||||
labels = append(labels, "asn")
|
||||
}
|
||||
return labels
|
||||
}
|
||||
|
||||
// Lookup returns optional GeoIP labels for a resolver IP address.
|
||||
func (r *GeoIPResolver) Lookup(addr netip.Addr) GeoLabels {
|
||||
labels := GeoLabels{}
|
||||
if !addr.IsValid() {
|
||||
if r.countryDB != nil {
|
||||
labels.Country = UnknownCountry
|
||||
}
|
||||
if r.asnDB != nil {
|
||||
labels.ASN = UnknownASN
|
||||
}
|
||||
return labels
|
||||
}
|
||||
|
||||
if r.countryDB != nil {
|
||||
labels.Country = UnknownCountry
|
||||
record, err := r.countryDB.Country(addr)
|
||||
if err == nil && record.HasData() && record.Country.ISOCode != "" {
|
||||
labels.Country = record.Country.ISOCode
|
||||
}
|
||||
}
|
||||
if r.asnDB != nil {
|
||||
labels.ASN = UnknownASN
|
||||
record, err := r.asnDB.ASN(addr)
|
||||
if err == nil && record.HasData() && record.AutonomousSystemNumber != 0 {
|
||||
labels.ASN = strconv.FormatUint(uint64(record.AutonomousSystemNumber), 10)
|
||||
}
|
||||
}
|
||||
|
||||
return labels
|
||||
}
|
||||
|
|
@ -2,9 +2,23 @@ package dnstt
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProcessPacket records DNSTT DNS traffic from a raw IP packet.
|
||||
func ProcessPacket(data []byte, port int, collector *Collector) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
switch data[0] >> 4 {
|
||||
case 4:
|
||||
ProcessIPv4Packet(data, port, collector)
|
||||
case 6:
|
||||
ProcessIPv6Packet(data, port, collector)
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet.
|
||||
func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
|
||||
if len(data) < 20 || data[0]>>4 != 4 {
|
||||
|
|
@ -18,8 +32,39 @@ func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
|
|||
if data[9] != 17 {
|
||||
return
|
||||
}
|
||||
srcIP := netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]})
|
||||
|
||||
udpData := data[ihl:]
|
||||
processUDPPayload(udpData, srcIP, dstIP, port, collector)
|
||||
}
|
||||
|
||||
// ProcessIPv6Packet records DNSTT DNS traffic from a raw IPv6 packet.
|
||||
func ProcessIPv6Packet(data []byte, port int, collector *Collector) {
|
||||
if len(data) < 40 || data[0]>>4 != 6 {
|
||||
return
|
||||
}
|
||||
if data[6] != 17 {
|
||||
return
|
||||
}
|
||||
|
||||
var srcBytes [16]byte
|
||||
copy(srcBytes[:], data[8:24])
|
||||
srcIP := netip.AddrFrom16(srcBytes)
|
||||
|
||||
var dstBytes [16]byte
|
||||
copy(dstBytes[:], data[24:40])
|
||||
dstIP := netip.AddrFrom16(dstBytes)
|
||||
|
||||
payloadLen := int(binary.BigEndian.Uint16(data[4:6]))
|
||||
if payloadLen <= 0 || 40+payloadLen > len(data) {
|
||||
return
|
||||
}
|
||||
|
||||
processUDPPayload(data[40:40+payloadLen], srcIP, dstIP, port, collector)
|
||||
}
|
||||
|
||||
func processUDPPayload(udpData []byte, srcIP netip.Addr, dstIP netip.Addr, port int, collector *Collector) {
|
||||
if len(udpData) < 8 {
|
||||
return
|
||||
}
|
||||
|
|
@ -35,7 +80,7 @@ func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
|
|||
if int(dstPort) == port {
|
||||
domain, clientID := ExtractQuery(dnsPayload, collector.Domains())
|
||||
if domain != "" {
|
||||
collector.RecordQuery(domain, clientID, udpLen)
|
||||
collector.RecordQueryFrom(domain, clientID, srcIP, udpLen)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -43,7 +88,7 @@ func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
|
|||
if int(srcPort) == port {
|
||||
domain := extractQueryDomain(dnsPayload)
|
||||
if domain != "" {
|
||||
collector.RecordResponse(domain, udpLen)
|
||||
collector.RecordResponseFrom(domain, dstIP, udpLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
56
internal/dnstt/packet_test.go
Normal file
56
internal/dnstt/packet_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestProcessIPv6PacketRecordsResolverAddress(t *testing.T) {
|
||||
now := time.Unix(1000, 0)
|
||||
resolverIP := netip.MustParseAddr("2001:db8::53")
|
||||
serverIP := netip.MustParseAddr("2001:db8::1")
|
||||
c := NewCollector(
|
||||
[]string{"tunnel.example.com"},
|
||||
WithNow(func() time.Time { return now }),
|
||||
WithGeoResolver(fakeGeoResolver{
|
||||
labelNames: []string{"asn"},
|
||||
labels: map[netip.Addr]GeoLabels{
|
||||
resolverIP: {ASN: "13335"},
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
ProcessPacket(ipv6UDPPacket(resolverIP, serverIP, 53000, 53, dnsQuery("MFRGGZDFMZTWQ2LK", "tunnel.example.com")), 53, c)
|
||||
|
||||
series := c.Snapshot().Tunnels["tunnel.example.com"].Series
|
||||
if len(series) != 1 {
|
||||
t.Fatalf("series count = %d, want 1", len(series))
|
||||
}
|
||||
if series[0].ASN != "13335" {
|
||||
t.Fatalf("ASN = %q, want 13335", series[0].ASN)
|
||||
}
|
||||
if series[0].TotalQueries != 1 {
|
||||
t.Fatalf("queries = %d, want 1", series[0].TotalQueries)
|
||||
}
|
||||
}
|
||||
|
||||
func ipv6UDPPacket(src, dst netip.Addr, srcPort, dstPort uint16, payload []byte) []byte {
|
||||
packet := make([]byte, 40+8+len(payload))
|
||||
packet[0] = 0x60
|
||||
binary.BigEndian.PutUint16(packet[4:6], uint16(8+len(payload)))
|
||||
packet[6] = 17
|
||||
packet[7] = 64
|
||||
srcBytes := src.As16()
|
||||
dstBytes := dst.As16()
|
||||
copy(packet[8:24], srcBytes[:])
|
||||
copy(packet[24:40], dstBytes[:])
|
||||
|
||||
udp := packet[40:]
|
||||
binary.BigEndian.PutUint16(udp[0:2], srcPort)
|
||||
binary.BigEndian.PutUint16(udp[2:4], dstPort)
|
||||
binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(payload)))
|
||||
copy(udp[8:], payload)
|
||||
return packet
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue