93 lines
2.5 KiB
Go
93 lines
2.5 KiB
Go
package dnstt
|
|
|
|
import (
|
|
"encoding/base32"
|
|
"testing"
|
|
)
|
|
|
|
func TestExtractQueryDecodesDNSTTClientID(t *testing.T) {
|
|
clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
|
|
encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb))
|
|
packet := dnsQuery(encoded, "tunnel.example.com")
|
|
|
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
|
|
|
if domain != "tunnel.example.com" {
|
|
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
|
}
|
|
if gotClientID != "0102030405060708" {
|
|
t.Fatalf("client ID = %q, want 0102030405060708", gotClientID)
|
|
}
|
|
}
|
|
|
|
func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) {
|
|
packet := dnsQuery("", "tunnel.example.com")
|
|
|
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
|
|
|
if domain != "tunnel.example.com" {
|
|
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
|
}
|
|
if gotClientID != "" {
|
|
t.Fatalf("client ID = %q, want empty", gotClientID)
|
|
}
|
|
}
|
|
|
|
func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) {
|
|
packet := dnsQuery("abcd", "other.example.com")
|
|
|
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
|
|
|
if domain != "" || gotClientID != "" {
|
|
t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID)
|
|
}
|
|
}
|
|
|
|
func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) {
|
|
packet := dnsQuery("not-base32-!", "tunnel.example.com")
|
|
|
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
|
|
|
if domain != "tunnel.example.com" {
|
|
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
|
}
|
|
if gotClientID != "" {
|
|
t.Fatalf("client ID = %q, want empty", gotClientID)
|
|
}
|
|
}
|
|
|
|
func dnsQuery(prefix string, domain string) []byte {
|
|
packet := []byte{
|
|
0x12, 0x34, // transaction ID
|
|
0x01, 0x00, // flags
|
|
0x00, 0x01, // questions
|
|
0x00, 0x00, // answers
|
|
0x00, 0x00, // authority
|
|
0x00, 0x00, // additional
|
|
}
|
|
|
|
labels := splitDomain(domain)
|
|
if prefix != "" {
|
|
labels = append([]string{prefix}, labels...)
|
|
}
|
|
for _, label := range labels {
|
|
packet = append(packet, byte(len(label)))
|
|
packet = append(packet, label...)
|
|
}
|
|
packet = append(packet, 0x00) // root label
|
|
packet = append(packet, 0x00, 0x10) // TXT
|
|
packet = append(packet, 0x00, 0x01) // IN
|
|
return packet
|
|
}
|
|
|
|
func splitDomain(domain string) []string {
|
|
var labels []string
|
|
start := 0
|
|
for i := 0; i <= len(domain); i++ {
|
|
if i == len(domain) || domain[i] == '.' {
|
|
labels = append(labels, domain[start:i])
|
|
start = i + 1
|
|
}
|
|
}
|
|
return labels
|
|
}
|