82 lines
1.7 KiB
Go
82 lines
1.7 KiB
Go
package dnstt
|
|
|
|
import (
|
|
"encoding/base32"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
const clientIDLen = 8
|
|
|
|
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
|
|
|
|
// ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query.
|
|
func ExtractQuery(dns []byte, domains []string) (string, string) {
|
|
if len(dns) < 12 {
|
|
return "", ""
|
|
}
|
|
|
|
labels := parseDNSLabels(dns[12:])
|
|
if len(labels) == 0 {
|
|
return "", ""
|
|
}
|
|
|
|
fullDomain := strings.ToLower(strings.Join(labels, "."))
|
|
for _, domain := range domains {
|
|
domain = normalizeDomain(domain)
|
|
if fullDomain == domain {
|
|
return domain, ""
|
|
}
|
|
|
|
suffix := "." + domain
|
|
if !strings.HasSuffix(fullDomain, suffix) {
|
|
continue
|
|
}
|
|
|
|
tunnelLabels := strings.Count(domain, ".") + 1
|
|
if len(labels) <= tunnelLabels {
|
|
continue
|
|
}
|
|
|
|
prefixLabels := labels[:len(labels)-tunnelLabels]
|
|
encoded := strings.ToUpper(strings.Join(prefixLabels, ""))
|
|
decoded := make([]byte, base32Encoding.DecodedLen(len(encoded)))
|
|
n, err := base32Encoding.Decode(decoded, []byte(encoded))
|
|
if err != nil || n < clientIDLen {
|
|
return domain, ""
|
|
}
|
|
|
|
return domain, fmt.Sprintf("%x", decoded[:clientIDLen])
|
|
}
|
|
|
|
return "", ""
|
|
}
|
|
|
|
func parseDNSLabels(data []byte) []string {
|
|
var labels []string
|
|
offset := 0
|
|
|
|
for offset < len(data) {
|
|
labelLen := int(data[offset])
|
|
if labelLen == 0 {
|
|
break
|
|
}
|
|
if labelLen&0xc0 == 0xc0 {
|
|
break
|
|
}
|
|
|
|
offset++
|
|
if offset+labelLen > len(data) {
|
|
return nil
|
|
}
|
|
|
|
labels = append(labels, string(data[offset:offset+labelLen]))
|
|
offset += labelLen
|
|
}
|
|
|
|
return labels
|
|
}
|
|
|
|
func normalizeDomain(domain string) string {
|
|
return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".")
|
|
}
|