initial working version
This commit is contained in:
parent
4d8b83cbb6
commit
8318f9fe70
15 changed files with 917 additions and 4 deletions
50
internal/dnstt/capture_linux.go
Normal file
50
internal/dnstt/capture_linux.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
//go:build linux
|
||||
|
||||
package dnstt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// OpenRawSocket creates an AF_PACKET raw socket for sniffing IPv4 packets.
|
||||
func OpenRawSocket() (int, error) {
|
||||
fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(syscall.ETH_P_IP)))
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("open raw socket: %w", err)
|
||||
}
|
||||
|
||||
tv := syscall.Timeval{Sec: 0, Usec: 500000}
|
||||
_ = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv)
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
// CaptureLoop records matching packets until stop is closed.
|
||||
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, _, err := syscall.Recvfrom(fd, buf, 0)
|
||||
if err != nil {
|
||||
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK || err == syscall.EINTR {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
ProcessIPv4Packet(buf[:n], port, collector)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseRawSocket closes a socket opened by OpenRawSocket.
|
||||
func CloseRawSocket(fd int) error {
|
||||
return syscall.Close(fd)
|
||||
}
|
||||
|
||||
func htons(v uint16) uint16 {
|
||||
return (v << 8) | (v >> 8)
|
||||
}
|
||||
20
internal/dnstt/capture_unsupported.go
Normal file
20
internal/dnstt/capture_unsupported.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
//go:build !linux
|
||||
|
||||
package dnstt
|
||||
|
||||
import "fmt"
|
||||
|
||||
// OpenRawSocket is only implemented on Linux.
|
||||
func OpenRawSocket() (int, error) {
|
||||
return -1, fmt.Errorf("raw packet capture is only supported on Linux")
|
||||
}
|
||||
|
||||
// CaptureLoop is only implemented on Linux.
|
||||
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
|
||||
<-stop
|
||||
}
|
||||
|
||||
// CloseRawSocket is only implemented on Linux.
|
||||
func CloseRawSocket(fd int) error {
|
||||
return nil
|
||||
}
|
||||
199
internal/dnstt/collector.go
Normal file
199
internal/dnstt/collector.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientTimeout is how long since the last query before a session is considered inactive.
|
||||
const ClientTimeout = 30 * time.Second
|
||||
|
||||
// Collector accumulates observed DNSTT DNS traffic.
|
||||
type Collector struct {
|
||||
mu sync.Mutex
|
||||
domains []string
|
||||
tunnels map[string]*tunnelState
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type tunnelState struct {
|
||||
queries uint64
|
||||
bytesIn uint64
|
||||
bytesOut uint64
|
||||
peakClients int
|
||||
clients map[string]*clientState
|
||||
}
|
||||
|
||||
type clientState struct {
|
||||
firstSeen time.Time
|
||||
lastSeen time.Time
|
||||
queries uint64
|
||||
bytesIn uint64
|
||||
}
|
||||
|
||||
// Snapshot is a point-in-time copy of observed traffic.
|
||||
type Snapshot struct {
|
||||
Tunnels map[string]TunnelSnapshot
|
||||
}
|
||||
|
||||
// TunnelSnapshot is a point-in-time copy of one DNSTT domain.
|
||||
type TunnelSnapshot struct {
|
||||
Domain string
|
||||
ActiveClients int
|
||||
PeakClients int
|
||||
TotalSessions int
|
||||
TotalQueries uint64
|
||||
BytesIn uint64
|
||||
BytesOut uint64
|
||||
}
|
||||
|
||||
// CollectorOption configures a Collector.
|
||||
type CollectorOption func(*Collector)
|
||||
|
||||
// WithNow overrides the clock. It is intended for deterministic tests.
|
||||
func WithNow(now func() time.Time) CollectorOption {
|
||||
return func(c *Collector) {
|
||||
c.now = now
|
||||
}
|
||||
}
|
||||
|
||||
// NewCollector creates a collector for the provided DNSTT domains.
|
||||
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
|
||||
c := &Collector{
|
||||
tunnels: make(map[string]*tunnelState),
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
domain = normalizeDomain(domain)
|
||||
if domain == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := c.tunnels[domain]; exists {
|
||||
continue
|
||||
}
|
||||
c.domains = append(c.domains, domain)
|
||||
c.tunnels[domain] = &tunnelState{
|
||||
clients: make(map[string]*clientState),
|
||||
}
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Domains returns the normalized DNSTT domains this collector recognizes.
|
||||
func (c *Collector) Domains() []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
domains := make([]string, len(c.domains))
|
||||
copy(domains, c.domains)
|
||||
return domains
|
||||
}
|
||||
|
||||
// RecordQuery records an observed DNSTT DNS query.
|
||||
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
tunnel, ok := c.findTunnelLocked(domain)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tunnel.queries++
|
||||
if size > 0 {
|
||||
tunnel.bytesIn += uint64(size)
|
||||
}
|
||||
|
||||
if clientID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
now := c.now()
|
||||
client, exists := tunnel.clients[clientID]
|
||||
if !exists {
|
||||
client = &clientState{firstSeen: now}
|
||||
tunnel.clients[clientID] = client
|
||||
}
|
||||
client.lastSeen = now
|
||||
client.queries++
|
||||
if size > 0 {
|
||||
client.bytesIn += uint64(size)
|
||||
}
|
||||
|
||||
active := activeClients(tunnel, now)
|
||||
if active > tunnel.peakClients {
|
||||
tunnel.peakClients = active
|
||||
}
|
||||
}
|
||||
|
||||
// RecordResponse records an observed DNSTT DNS response.
|
||||
func (c *Collector) RecordResponse(domain string, size int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
tunnel, ok := c.findTunnelLocked(domain)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if size > 0 {
|
||||
tunnel.bytesOut += uint64(size)
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot returns a stable copy of all current metrics.
|
||||
func (c *Collector) Snapshot() Snapshot {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := c.now()
|
||||
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
|
||||
for _, domain := range c.domains {
|
||||
tunnel := c.tunnels[domain]
|
||||
active := activeClients(tunnel, now)
|
||||
if active > tunnel.peakClients {
|
||||
tunnel.peakClients = active
|
||||
}
|
||||
|
||||
snapshot.Tunnels[domain] = TunnelSnapshot{
|
||||
Domain: domain,
|
||||
ActiveClients: active,
|
||||
PeakClients: tunnel.peakClients,
|
||||
TotalSessions: len(tunnel.clients),
|
||||
TotalQueries: tunnel.queries,
|
||||
BytesIn: tunnel.bytesIn,
|
||||
BytesOut: tunnel.bytesOut,
|
||||
}
|
||||
}
|
||||
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
|
||||
queryDomain = normalizeDomain(queryDomain)
|
||||
if tunnel, ok := c.tunnels[queryDomain]; ok {
|
||||
return tunnel, true
|
||||
}
|
||||
for _, domain := range c.domains {
|
||||
if strings.HasSuffix(queryDomain, "."+domain) {
|
||||
return c.tunnels[domain], true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func activeClients(tunnel *tunnelState, now time.Time) int {
|
||||
active := 0
|
||||
for _, client := range tunnel.clients {
|
||||
if now.Sub(client.lastSeen) < ClientTimeout {
|
||||
active++
|
||||
}
|
||||
}
|
||||
return active
|
||||
}
|
||||
61
internal/dnstt/collector_test.go
Normal file
61
internal/dnstt/collector_test.go
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCollectorTracksActiveAndPeakClients(t *testing.T) {
|
||||
now := time.Unix(1000, 0)
|
||||
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
|
||||
|
||||
c.RecordQuery("tunnel.example.com", "client-a", 120)
|
||||
c.RecordQuery("tunnel.example.com", "client-b", 80)
|
||||
c.RecordResponse("tunnel.example.com", 200)
|
||||
|
||||
snapshot := c.Snapshot()
|
||||
tunnel := snapshot.Tunnels["tunnel.example.com"]
|
||||
if tunnel.ActiveClients != 2 {
|
||||
t.Fatalf("active clients = %d, want 2", tunnel.ActiveClients)
|
||||
}
|
||||
if tunnel.PeakClients != 2 {
|
||||
t.Fatalf("peak clients = %d, want 2", tunnel.PeakClients)
|
||||
}
|
||||
if tunnel.TotalSessions != 2 {
|
||||
t.Fatalf("total sessions = %d, want 2", tunnel.TotalSessions)
|
||||
}
|
||||
if tunnel.TotalQueries != 2 {
|
||||
t.Fatalf("queries = %d, want 2", tunnel.TotalQueries)
|
||||
}
|
||||
if tunnel.BytesIn != 200 {
|
||||
t.Fatalf("bytes in = %d, want 200", tunnel.BytesIn)
|
||||
}
|
||||
if tunnel.BytesOut != 200 {
|
||||
t.Fatalf("bytes out = %d, want 200", tunnel.BytesOut)
|
||||
}
|
||||
|
||||
now = now.Add(ClientTimeout + time.Second)
|
||||
snapshot = c.Snapshot()
|
||||
tunnel = snapshot.Tunnels["tunnel.example.com"]
|
||||
if tunnel.ActiveClients != 0 {
|
||||
t.Fatalf("active clients after timeout = %d, want 0", tunnel.ActiveClients)
|
||||
}
|
||||
if tunnel.PeakClients != 2 {
|
||||
t.Fatalf("peak clients after timeout = %d, want 2", tunnel.PeakClients)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectorMatchesSubdomainsToRegisteredTunnel(t *testing.T) {
|
||||
now := time.Unix(1000, 0)
|
||||
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
|
||||
|
||||
c.RecordQuery("abcd.tunnel.example.com", "client-a", 120)
|
||||
|
||||
tunnel := c.Snapshot().Tunnels["tunnel.example.com"]
|
||||
if tunnel.TotalQueries != 1 {
|
||||
t.Fatalf("queries = %d, want 1", tunnel.TotalQueries)
|
||||
}
|
||||
if tunnel.ActiveClients != 1 {
|
||||
t.Fatalf("active clients = %d, want 1", tunnel.ActiveClients)
|
||||
}
|
||||
}
|
||||
83
internal/dnstt/exporter.go
Normal file
83
internal/dnstt/exporter.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
package dnstt
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
const namespace = "dnstt"
|
||||
|
||||
// Exporter exposes aggregate DNSTT traffic metrics from a Collector.
|
||||
type Exporter struct {
|
||||
collector *Collector
|
||||
|
||||
activeClients *prometheus.Desc
|
||||
peakClients *prometheus.Desc
|
||||
queries *prometheus.Desc
|
||||
bytesIn *prometheus.Desc
|
||||
bytesOut *prometheus.Desc
|
||||
sessions *prometheus.Desc
|
||||
}
|
||||
|
||||
// NewExporter creates a Prometheus collector for DNSTT metrics.
|
||||
func NewExporter(collector *Collector) *Exporter {
|
||||
labels := []string{"domain"}
|
||||
return &Exporter{
|
||||
collector: collector,
|
||||
activeClients: prometheus.NewDesc(
|
||||
prometheus.BuildFQName(namespace, "", "active_clients"),
|
||||
"Number of DNSTT client sessions observed within the active timeout window.",
|
||||
labels,
|
||||
nil,
|
||||
),
|
||||
peakClients: prometheus.NewDesc(
|
||||
prometheus.BuildFQName(namespace, "", "peak_clients"),
|
||||
"Maximum concurrent active DNSTT client sessions observed.",
|
||||
labels,
|
||||
nil,
|
||||
),
|
||||
queries: prometheus.NewDesc(
|
||||
prometheus.BuildFQName(namespace, "", "queries_total"),
|
||||
"Total DNSTT DNS queries observed.",
|
||||
labels,
|
||||
nil,
|
||||
),
|
||||
bytesIn: prometheus.NewDesc(
|
||||
prometheus.BuildFQName(namespace, "", "bytes_in_total"),
|
||||
"Total bytes observed in DNSTT DNS queries.",
|
||||
labels,
|
||||
nil,
|
||||
),
|
||||
bytesOut: prometheus.NewDesc(
|
||||
prometheus.BuildFQName(namespace, "", "bytes_out_total"),
|
||||
"Total bytes observed in DNSTT DNS responses.",
|
||||
labels,
|
||||
nil,
|
||||
),
|
||||
sessions: prometheus.NewDesc(
|
||||
prometheus.BuildFQName(namespace, "", "sessions_total"),
|
||||
"Total unique DNSTT client sessions observed.",
|
||||
labels,
|
||||
nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// Describe sends metric descriptors to Prometheus.
|
||||
func (e *Exporter) Describe(ch chan<- *prometheus.Desc) {
|
||||
ch <- e.activeClients
|
||||
ch <- e.peakClients
|
||||
ch <- e.queries
|
||||
ch <- e.bytesIn
|
||||
ch <- e.bytesOut
|
||||
ch <- e.sessions
|
||||
}
|
||||
|
||||
// Collect sends current metric values to Prometheus.
|
||||
func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
|
||||
for _, tunnel := range e.collector.Snapshot().Tunnels {
|
||||
ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(tunnel.ActiveClients), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(tunnel.PeakClients), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(tunnel.TotalQueries), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(tunnel.BytesIn), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(tunnel.BytesOut), tunnel.Domain)
|
||||
ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(tunnel.TotalSessions), tunnel.Domain)
|
||||
}
|
||||
}
|
||||
42
internal/dnstt/exporter_test.go
Normal file
42
internal/dnstt/exporter_test.go
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
)
|
||||
|
||||
func TestExporterCollectsAggregateDNSTTMetrics(t *testing.T) {
|
||||
now := time.Unix(1000, 0)
|
||||
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
|
||||
c.RecordQuery("tunnel.example.com", "client-a", 100)
|
||||
c.RecordQuery("tunnel.example.com", "client-b", 300)
|
||||
c.RecordResponse("tunnel.example.com", 250)
|
||||
|
||||
expected := `
|
||||
# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window.
|
||||
# TYPE dnstt_active_clients gauge
|
||||
dnstt_active_clients{domain="tunnel.example.com"} 2
|
||||
# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries.
|
||||
# TYPE dnstt_bytes_in_total counter
|
||||
dnstt_bytes_in_total{domain="tunnel.example.com"} 400
|
||||
# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses.
|
||||
# TYPE dnstt_bytes_out_total counter
|
||||
dnstt_bytes_out_total{domain="tunnel.example.com"} 250
|
||||
# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed.
|
||||
# TYPE dnstt_peak_clients gauge
|
||||
dnstt_peak_clients{domain="tunnel.example.com"} 2
|
||||
# HELP dnstt_queries_total Total DNSTT DNS queries observed.
|
||||
# TYPE dnstt_queries_total counter
|
||||
dnstt_queries_total{domain="tunnel.example.com"} 2
|
||||
# HELP dnstt_sessions_total Total unique DNSTT client sessions observed.
|
||||
# TYPE dnstt_sessions_total counter
|
||||
dnstt_sessions_total{domain="tunnel.example.com"} 2
|
||||
`
|
||||
|
||||
if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
60
internal/dnstt/packet.go
Normal file
60
internal/dnstt/packet.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet.
|
||||
func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
|
||||
if len(data) < 20 || data[0]>>4 != 4 {
|
||||
return
|
||||
}
|
||||
|
||||
ihl := int(data[0]&0x0f) * 4
|
||||
if ihl < 20 || len(data) < ihl {
|
||||
return
|
||||
}
|
||||
if data[9] != 17 {
|
||||
return
|
||||
}
|
||||
|
||||
udpData := data[ihl:]
|
||||
if len(udpData) < 8 {
|
||||
return
|
||||
}
|
||||
|
||||
srcPort := binary.BigEndian.Uint16(udpData[0:2])
|
||||
dstPort := binary.BigEndian.Uint16(udpData[2:4])
|
||||
udpLen := int(binary.BigEndian.Uint16(udpData[4:6]))
|
||||
dnsPayload := udpData[8:]
|
||||
if len(dnsPayload) < 12 {
|
||||
return
|
||||
}
|
||||
|
||||
if int(dstPort) == port {
|
||||
domain, clientID := ExtractQuery(dnsPayload, collector.Domains())
|
||||
if domain != "" {
|
||||
collector.RecordQuery(domain, clientID, udpLen)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if int(srcPort) == port {
|
||||
domain := extractQueryDomain(dnsPayload)
|
||||
if domain != "" {
|
||||
collector.RecordResponse(domain, udpLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractQueryDomain(dns []byte) string {
|
||||
if len(dns) < 12 {
|
||||
return ""
|
||||
}
|
||||
labels := parseDNSLabels(dns[12:])
|
||||
if len(labels) == 0 {
|
||||
return ""
|
||||
}
|
||||
return normalizeDomain(strings.Join(labels, "."))
|
||||
}
|
||||
82
internal/dnstt/parser.go
Normal file
82
internal/dnstt/parser.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"encoding/base32"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const clientIDLen = 8
|
||||
|
||||
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
|
||||
// ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query.
|
||||
func ExtractQuery(dns []byte, domains []string) (string, string) {
|
||||
if len(dns) < 12 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
labels := parseDNSLabels(dns[12:])
|
||||
if len(labels) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
fullDomain := strings.ToLower(strings.Join(labels, "."))
|
||||
for _, domain := range domains {
|
||||
domain = normalizeDomain(domain)
|
||||
if fullDomain == domain {
|
||||
return domain, ""
|
||||
}
|
||||
|
||||
suffix := "." + domain
|
||||
if !strings.HasSuffix(fullDomain, suffix) {
|
||||
continue
|
||||
}
|
||||
|
||||
tunnelLabels := strings.Count(domain, ".") + 1
|
||||
if len(labels) <= tunnelLabels {
|
||||
continue
|
||||
}
|
||||
|
||||
prefixLabels := labels[:len(labels)-tunnelLabels]
|
||||
encoded := strings.ToUpper(strings.Join(prefixLabels, ""))
|
||||
decoded := make([]byte, base32Encoding.DecodedLen(len(encoded)))
|
||||
n, err := base32Encoding.Decode(decoded, []byte(encoded))
|
||||
if err != nil || n < clientIDLen {
|
||||
return domain, ""
|
||||
}
|
||||
|
||||
return domain, fmt.Sprintf("%x", decoded[:clientIDLen])
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func parseDNSLabels(data []byte) []string {
|
||||
var labels []string
|
||||
offset := 0
|
||||
|
||||
for offset < len(data) {
|
||||
labelLen := int(data[offset])
|
||||
if labelLen == 0 {
|
||||
break
|
||||
}
|
||||
if labelLen&0xc0 == 0xc0 {
|
||||
break
|
||||
}
|
||||
|
||||
offset++
|
||||
if offset+labelLen > len(data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
labels = append(labels, string(data[offset:offset+labelLen]))
|
||||
offset += labelLen
|
||||
}
|
||||
|
||||
return labels
|
||||
}
|
||||
|
||||
func normalizeDomain(domain string) string {
|
||||
return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".")
|
||||
}
|
||||
93
internal/dnstt/parser_test.go
Normal file
93
internal/dnstt/parser_test.go
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
package dnstt
|
||||
|
||||
import (
|
||||
"encoding/base32"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractQueryDecodesDNSTTClientID(t *testing.T) {
|
||||
clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
|
||||
encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb))
|
||||
packet := dnsQuery(encoded, "tunnel.example.com")
|
||||
|
||||
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||
|
||||
if domain != "tunnel.example.com" {
|
||||
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
||||
}
|
||||
if gotClientID != "0102030405060708" {
|
||||
t.Fatalf("client ID = %q, want 0102030405060708", gotClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) {
|
||||
packet := dnsQuery("", "tunnel.example.com")
|
||||
|
||||
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||
|
||||
if domain != "tunnel.example.com" {
|
||||
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
||||
}
|
||||
if gotClientID != "" {
|
||||
t.Fatalf("client ID = %q, want empty", gotClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) {
|
||||
packet := dnsQuery("abcd", "other.example.com")
|
||||
|
||||
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||
|
||||
if domain != "" || gotClientID != "" {
|
||||
t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) {
|
||||
packet := dnsQuery("not-base32-!", "tunnel.example.com")
|
||||
|
||||
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||
|
||||
if domain != "tunnel.example.com" {
|
||||
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
||||
}
|
||||
if gotClientID != "" {
|
||||
t.Fatalf("client ID = %q, want empty", gotClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func dnsQuery(prefix string, domain string) []byte {
|
||||
packet := []byte{
|
||||
0x12, 0x34, // transaction ID
|
||||
0x01, 0x00, // flags
|
||||
0x00, 0x01, // questions
|
||||
0x00, 0x00, // answers
|
||||
0x00, 0x00, // authority
|
||||
0x00, 0x00, // additional
|
||||
}
|
||||
|
||||
labels := splitDomain(domain)
|
||||
if prefix != "" {
|
||||
labels = append([]string{prefix}, labels...)
|
||||
}
|
||||
for _, label := range labels {
|
||||
packet = append(packet, byte(len(label)))
|
||||
packet = append(packet, label...)
|
||||
}
|
||||
packet = append(packet, 0x00) // root label
|
||||
packet = append(packet, 0x00, 0x10) // TXT
|
||||
packet = append(packet, 0x00, 0x01) // IN
|
||||
return packet
|
||||
}
|
||||
|
||||
func splitDomain(domain string) []string {
|
||||
var labels []string
|
||||
start := 0
|
||||
for i := 0; i <= len(domain); i++ {
|
||||
if i == len(domain) || domain[i] == '.' {
|
||||
labels = append(labels, domain[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return labels
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue