package dnstt import ( "encoding/base32" "testing" ) func TestExtractQueryDecodesDNSTTClientID(t *testing.T) { clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb)) packet := dnsQuery(encoded, "tunnel.example.com") domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) if domain != "tunnel.example.com" { t.Fatalf("domain = %q, want tunnel.example.com", domain) } if gotClientID != "0102030405060708" { t.Fatalf("client ID = %q, want 0102030405060708", gotClientID) } } func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) { packet := dnsQuery("", "tunnel.example.com") domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) if domain != "tunnel.example.com" { t.Fatalf("domain = %q, want tunnel.example.com", domain) } if gotClientID != "" { t.Fatalf("client ID = %q, want empty", gotClientID) } } func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) { packet := dnsQuery("abcd", "other.example.com") domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) if domain != "" || gotClientID != "" { t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID) } } func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) { packet := dnsQuery("not-base32-!", "tunnel.example.com") domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"}) if domain != "tunnel.example.com" { t.Fatalf("domain = %q, want tunnel.example.com", domain) } if gotClientID != "" { t.Fatalf("client ID = %q, want empty", gotClientID) } } func dnsQuery(prefix string, domain string) []byte { packet := []byte{ 0x12, 0x34, // transaction ID 0x01, 0x00, // flags 0x00, 0x01, // questions 0x00, 0x00, // answers 0x00, 0x00, // authority 0x00, 0x00, // additional } labels := splitDomain(domain) if prefix != "" { labels = append([]string{prefix}, labels...) } for _, label := range labels { packet = append(packet, byte(len(label))) packet = append(packet, label...) } packet = append(packet, 0x00) // root label packet = append(packet, 0x00, 0x10) // TXT packet = append(packet, 0x00, 0x01) // IN return packet } func splitDomain(domain string) []string { var labels []string start := 0 for i := 0; i <= len(domain); i++ { if i == len(domain) || domain[i] == '.' { labels = append(labels, domain[start:i]) start = i + 1 } } return labels }