dnstt_exporter/internal/dnstt/parser_test.go

94 lines
2.5 KiB
Go
Raw Normal View History

2026-05-05 13:43:02 +02:00
package dnstt
import (
"encoding/base32"
"testing"
)
func TestExtractQueryDecodesDNSTTClientID(t *testing.T) {
clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb))
packet := dnsQuery(encoded, "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "0102030405060708" {
t.Fatalf("client ID = %q, want 0102030405060708", gotClientID)
}
}
func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) {
packet := dnsQuery("", "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "" {
t.Fatalf("client ID = %q, want empty", gotClientID)
}
}
func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) {
packet := dnsQuery("abcd", "other.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "" || gotClientID != "" {
t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID)
}
}
func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) {
packet := dnsQuery("not-base32-!", "tunnel.example.com")
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
if domain != "tunnel.example.com" {
t.Fatalf("domain = %q, want tunnel.example.com", domain)
}
if gotClientID != "" {
t.Fatalf("client ID = %q, want empty", gotClientID)
}
}
func dnsQuery(prefix string, domain string) []byte {
packet := []byte{
0x12, 0x34, // transaction ID
0x01, 0x00, // flags
0x00, 0x01, // questions
0x00, 0x00, // answers
0x00, 0x00, // authority
0x00, 0x00, // additional
}
labels := splitDomain(domain)
if prefix != "" {
labels = append([]string{prefix}, labels...)
}
for _, label := range labels {
packet = append(packet, byte(len(label)))
packet = append(packet, label...)
}
packet = append(packet, 0x00) // root label
packet = append(packet, 0x00, 0x10) // TXT
packet = append(packet, 0x00, 0x01) // IN
return packet
}
func splitDomain(domain string) []string {
var labels []string
start := 0
for i := 0; i <= len(domain); i++ {
if i == len(domain) || domain[i] == '.' {
labels = append(labels, domain[start:i])
start = i + 1
}
}
return labels
}