package dnstt import ( "encoding/base32" "fmt" "strings" ) const clientIDLen = 8 var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding) // ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query. func ExtractQuery(dns []byte, domains []string) (string, string) { if len(dns) < 12 { return "", "" } labels := parseDNSLabels(dns[12:]) if len(labels) == 0 { return "", "" } fullDomain := strings.ToLower(strings.Join(labels, ".")) for _, domain := range domains { domain = normalizeDomain(domain) if fullDomain == domain { return domain, "" } suffix := "." + domain if !strings.HasSuffix(fullDomain, suffix) { continue } tunnelLabels := strings.Count(domain, ".") + 1 if len(labels) <= tunnelLabels { continue } prefixLabels := labels[:len(labels)-tunnelLabels] encoded := strings.ToUpper(strings.Join(prefixLabels, "")) decoded := make([]byte, base32Encoding.DecodedLen(len(encoded))) n, err := base32Encoding.Decode(decoded, []byte(encoded)) if err != nil || n < clientIDLen { return domain, "" } return domain, fmt.Sprintf("%x", decoded[:clientIDLen]) } return "", "" } func parseDNSLabels(data []byte) []string { var labels []string offset := 0 for offset < len(data) { labelLen := int(data[offset]) if labelLen == 0 { break } if labelLen&0xc0 == 0xc0 { break } offset++ if offset+labelLen > len(data) { return nil } labels = append(labels, string(data[offset:offset+labelLen])) offset += labelLen } return labels } func normalizeDomain(domain string) string { return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".") }