dnstt_exporter/internal/dnstt/parser.go

83 lines
1.7 KiB
Go
Raw Normal View History

2026-05-05 13:43:02 +02:00
package dnstt
import (
"encoding/base32"
"fmt"
"strings"
)
const clientIDLen = 8
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
// ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query.
func ExtractQuery(dns []byte, domains []string) (string, string) {
if len(dns) < 12 {
return "", ""
}
labels := parseDNSLabels(dns[12:])
if len(labels) == 0 {
return "", ""
}
fullDomain := strings.ToLower(strings.Join(labels, "."))
for _, domain := range domains {
domain = normalizeDomain(domain)
if fullDomain == domain {
return domain, ""
}
suffix := "." + domain
if !strings.HasSuffix(fullDomain, suffix) {
continue
}
tunnelLabels := strings.Count(domain, ".") + 1
if len(labels) <= tunnelLabels {
continue
}
prefixLabels := labels[:len(labels)-tunnelLabels]
encoded := strings.ToUpper(strings.Join(prefixLabels, ""))
decoded := make([]byte, base32Encoding.DecodedLen(len(encoded)))
n, err := base32Encoding.Decode(decoded, []byte(encoded))
if err != nil || n < clientIDLen {
return domain, ""
}
return domain, fmt.Sprintf("%x", decoded[:clientIDLen])
}
return "", ""
}
func parseDNSLabels(data []byte) []string {
var labels []string
offset := 0
for offset < len(data) {
labelLen := int(data[offset])
if labelLen == 0 {
break
}
if labelLen&0xc0 == 0xc0 {
break
}
offset++
if offset+labelLen > len(data) {
return nil
}
labels = append(labels, string(data[offset:offset+labelLen]))
offset += labelLen
}
return labels
}
func normalizeDomain(domain string) string {
return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".")
}