initial working version
This commit is contained in:
parent
4d8b83cbb6
commit
8318f9fe70
15 changed files with 917 additions and 4 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,3 +1,3 @@
|
||||||
.direnv
|
.direnv
|
||||||
result
|
result
|
||||||
|
/dnstt_exporter
|
||||||
|
|
|
||||||
36
README.md
36
README.md
|
|
@ -1,3 +1,39 @@
|
||||||
# dnstt_exporter
|
# dnstt_exporter
|
||||||
|
|
||||||
Prometheus exporter for DNSTT client/session metrics.
|
Prometheus exporter for DNSTT client/session metrics.
|
||||||
|
|
||||||
|
`dnstt_exporter` observes DNSTT DNS traffic on a local Linux host and exports
|
||||||
|
aggregate Prometheus metrics. It does not proxy, terminate, or configure DNSTT;
|
||||||
|
it passively decodes DNSTT session IDs from DNS query names.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sudo dnstt_exporter \
|
||||||
|
-dnstt.domain tunnel.example.com \
|
||||||
|
-dnstt.port 53 \
|
||||||
|
-web.listen-address :9713
|
||||||
|
```
|
||||||
|
|
||||||
|
The exporter needs permission to open an `AF_PACKET` raw socket. Run it as root
|
||||||
|
or grant the binary `CAP_NET_RAW`.
|
||||||
|
|
||||||
|
Metrics are served at `http://127.0.0.1:9713/metrics` by default.
|
||||||
|
|
||||||
|
## Metrics
|
||||||
|
|
||||||
|
All DNSTT metrics use a `domain` label:
|
||||||
|
|
||||||
|
- `dnstt_active_clients`
|
||||||
|
- `dnstt_peak_clients`
|
||||||
|
- `dnstt_queries_total`
|
||||||
|
- `dnstt_bytes_in_total`
|
||||||
|
- `dnstt_bytes_out_total`
|
||||||
|
- `dnstt_sessions_total`
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go test ./...
|
||||||
|
go build ./cmd/dnstt_exporter
|
||||||
|
```
|
||||||
|
|
|
||||||
104
cmd/dnstt_exporter/main.go
Normal file
104
cmd/dnstt_exporter/main.go
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"guardianproject.dev/bypass-censorship/dnstt_exporter/internal/dnstt"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type domainList []string
|
||||||
|
|
||||||
|
func (d *domainList) String() string {
|
||||||
|
return strings.Join(*d, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *domainList) Set(value string) error {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return errors.New("domain cannot be empty")
|
||||||
|
}
|
||||||
|
*d = append(*d, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var domains domainList
|
||||||
|
var (
|
||||||
|
listenAddress = flag.String("web.listen-address", ":9713", "Address on which to expose metrics and web interface.")
|
||||||
|
metricsPath = flag.String("web.telemetry-path", "/metrics", "Path under which to expose metrics.")
|
||||||
|
dnsPort = flag.Int("dnstt.port", 53, "UDP port where DNSTT DNS traffic is observed.")
|
||||||
|
)
|
||||||
|
flag.Var(&domains, "dnstt.domain", "DNSTT tunnel domain to observe. Repeat for multiple domains.")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if len(domains) == 0 {
|
||||||
|
log.Fatal("at least one -dnstt.domain is required")
|
||||||
|
}
|
||||||
|
if *dnsPort < 1 || *dnsPort > 65535 {
|
||||||
|
log.Fatalf("-dnstt.port must be between 1 and 65535, got %d", *dnsPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
collector := dnstt.NewCollector(domains)
|
||||||
|
fd, err := dnstt.OpenRawSocket()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to open raw socket: %v", err)
|
||||||
|
}
|
||||||
|
defer dnstt.CloseRawSocket(fd)
|
||||||
|
|
||||||
|
stop := make(chan struct{})
|
||||||
|
go dnstt.CaptureLoop(fd, *dnsPort, collector, stop)
|
||||||
|
|
||||||
|
registry := prometheus.NewRegistry()
|
||||||
|
registry.MustRegister(dnstt.NewExporter(collector))
|
||||||
|
registry.MustRegister(collectors.NewGoCollector())
|
||||||
|
registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle(*metricsPath, promhttp.HandlerFor(registry, promhttp.HandlerOpts{}))
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, `<html>
|
||||||
|
<head><title>DNSTT Exporter</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>DNSTT Exporter</h1>
|
||||||
|
<p><a href="%s">Metrics</a></p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`, *metricsPath)
|
||||||
|
})
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
server := &http.Server{Addr: *listenAddress, Handler: mux}
|
||||||
|
go func() {
|
||||||
|
log.Printf("listening on %s", *listenAddress)
|
||||||
|
errCh <- server.ListenAndServe()
|
||||||
|
}()
|
||||||
|
|
||||||
|
signalCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case sig := <-signalCh:
|
||||||
|
log.Printf("received %s, shutting down", sig)
|
||||||
|
close(stop)
|
||||||
|
case err := <-errCh:
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Fatalf("server error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
24
flake.nix
24
flake.nix
|
|
@ -11,9 +11,27 @@
|
||||||
forAllSystems = fn: nixpkgs.lib.genAttrs systems (system: fn nixpkgs.legacyPackages.${system});
|
forAllSystems = fn: nixpkgs.lib.genAttrs systems (system: fn nixpkgs.legacyPackages.${system});
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
#packages = forAllSystems (pkgs: {
|
packages = forAllSystems (pkgs: {
|
||||||
# default = pkgs.callPackage ./package.nix { };
|
default = pkgs.buildGoModule {
|
||||||
#});
|
pname = "dnstt_exporter";
|
||||||
|
version = "0.1.0";
|
||||||
|
src = ./.;
|
||||||
|
vendorHash = "sha256-+olXxVW9VcvapOuJGas70RIfJgMzUv6mndzJi6Apw+s=";
|
||||||
|
|
||||||
|
subPackages = [ "cmd/dnstt_exporter" ];
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
checks = forAllSystems (pkgs: {
|
||||||
|
tests = self.packages.${pkgs.stdenv.hostPlatform.system}.default.overrideAttrs (_: {
|
||||||
|
doCheck = true;
|
||||||
|
checkPhase = ''
|
||||||
|
runHook preCheck
|
||||||
|
go test ./...
|
||||||
|
runHook postCheck
|
||||||
|
'';
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
#checks = forAllSystems (
|
#checks = forAllSystems (
|
||||||
# pkgs:
|
# pkgs:
|
||||||
|
|
|
||||||
19
go.mod
Normal file
19
go.mod
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
module guardianproject.dev/bypass-censorship/dnstt_exporter
|
||||||
|
|
||||||
|
go 1.25
|
||||||
|
|
||||||
|
require github.com/prometheus/client_golang v1.23.2
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
|
github.com/prometheus/common v0.66.1 // indirect
|
||||||
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
|
golang.org/x/sys v0.35.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.8 // indirect
|
||||||
|
)
|
||||||
46
go.sum
Normal file
46
go.sum
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||||
|
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||||
|
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||||
|
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||||
|
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||||
|
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||||
|
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||||
|
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||||
|
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||||
|
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||||
|
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
50
internal/dnstt/capture_linux.go
Normal file
50
internal/dnstt/capture_linux.go
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenRawSocket creates an AF_PACKET raw socket for sniffing IPv4 packets.
|
||||||
|
func OpenRawSocket() (int, error) {
|
||||||
|
fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_DGRAM, int(htons(syscall.ETH_P_IP)))
|
||||||
|
if err != nil {
|
||||||
|
return -1, fmt.Errorf("open raw socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tv := syscall.Timeval{Sec: 0, Usec: 500000}
|
||||||
|
_ = syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &tv)
|
||||||
|
return fd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptureLoop records matching packets until stop is closed.
|
||||||
|
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, err := syscall.Recvfrom(fd, buf, 0)
|
||||||
|
if err != nil {
|
||||||
|
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK || err == syscall.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ProcessIPv4Packet(buf[:n], port, collector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseRawSocket closes a socket opened by OpenRawSocket.
|
||||||
|
func CloseRawSocket(fd int) error {
|
||||||
|
return syscall.Close(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func htons(v uint16) uint16 {
|
||||||
|
return (v << 8) | (v >> 8)
|
||||||
|
}
|
||||||
20
internal/dnstt/capture_unsupported.go
Normal file
20
internal/dnstt/capture_unsupported.go
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// OpenRawSocket is only implemented on Linux.
|
||||||
|
func OpenRawSocket() (int, error) {
|
||||||
|
return -1, fmt.Errorf("raw packet capture is only supported on Linux")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptureLoop is only implemented on Linux.
|
||||||
|
func CaptureLoop(fd int, port int, collector *Collector, stop <-chan struct{}) {
|
||||||
|
<-stop
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseRawSocket is only implemented on Linux.
|
||||||
|
func CloseRawSocket(fd int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
199
internal/dnstt/collector.go
Normal file
199
internal/dnstt/collector.go
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientTimeout is how long since the last query before a session is considered inactive.
|
||||||
|
const ClientTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// Collector accumulates observed DNSTT DNS traffic.
|
||||||
|
type Collector struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
domains []string
|
||||||
|
tunnels map[string]*tunnelState
|
||||||
|
now func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type tunnelState struct {
|
||||||
|
queries uint64
|
||||||
|
bytesIn uint64
|
||||||
|
bytesOut uint64
|
||||||
|
peakClients int
|
||||||
|
clients map[string]*clientState
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientState struct {
|
||||||
|
firstSeen time.Time
|
||||||
|
lastSeen time.Time
|
||||||
|
queries uint64
|
||||||
|
bytesIn uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot is a point-in-time copy of observed traffic.
|
||||||
|
type Snapshot struct {
|
||||||
|
Tunnels map[string]TunnelSnapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
// TunnelSnapshot is a point-in-time copy of one DNSTT domain.
|
||||||
|
type TunnelSnapshot struct {
|
||||||
|
Domain string
|
||||||
|
ActiveClients int
|
||||||
|
PeakClients int
|
||||||
|
TotalSessions int
|
||||||
|
TotalQueries uint64
|
||||||
|
BytesIn uint64
|
||||||
|
BytesOut uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// CollectorOption configures a Collector.
|
||||||
|
type CollectorOption func(*Collector)
|
||||||
|
|
||||||
|
// WithNow overrides the clock. It is intended for deterministic tests.
|
||||||
|
func WithNow(now func() time.Time) CollectorOption {
|
||||||
|
return func(c *Collector) {
|
||||||
|
c.now = now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCollector creates a collector for the provided DNSTT domains.
|
||||||
|
func NewCollector(domains []string, opts ...CollectorOption) *Collector {
|
||||||
|
c := &Collector{
|
||||||
|
tunnels: make(map[string]*tunnelState),
|
||||||
|
now: time.Now,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
domain = normalizeDomain(domain)
|
||||||
|
if domain == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := c.tunnels[domain]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.domains = append(c.domains, domain)
|
||||||
|
c.tunnels[domain] = &tunnelState{
|
||||||
|
clients: make(map[string]*clientState),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Domains returns the normalized DNSTT domains this collector recognizes.
|
||||||
|
func (c *Collector) Domains() []string {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
domains := make([]string, len(c.domains))
|
||||||
|
copy(domains, c.domains)
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordQuery records an observed DNSTT DNS query.
|
||||||
|
func (c *Collector) RecordQuery(domain string, clientID string, size int) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
tunnel, ok := c.findTunnelLocked(domain)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnel.queries++
|
||||||
|
if size > 0 {
|
||||||
|
tunnel.bytesIn += uint64(size)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := c.now()
|
||||||
|
client, exists := tunnel.clients[clientID]
|
||||||
|
if !exists {
|
||||||
|
client = &clientState{firstSeen: now}
|
||||||
|
tunnel.clients[clientID] = client
|
||||||
|
}
|
||||||
|
client.lastSeen = now
|
||||||
|
client.queries++
|
||||||
|
if size > 0 {
|
||||||
|
client.bytesIn += uint64(size)
|
||||||
|
}
|
||||||
|
|
||||||
|
active := activeClients(tunnel, now)
|
||||||
|
if active > tunnel.peakClients {
|
||||||
|
tunnel.peakClients = active
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordResponse records an observed DNSTT DNS response.
|
||||||
|
func (c *Collector) RecordResponse(domain string, size int) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
tunnel, ok := c.findTunnelLocked(domain)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if size > 0 {
|
||||||
|
tunnel.bytesOut += uint64(size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot returns a stable copy of all current metrics.
|
||||||
|
func (c *Collector) Snapshot() Snapshot {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
now := c.now()
|
||||||
|
snapshot := Snapshot{Tunnels: make(map[string]TunnelSnapshot, len(c.tunnels))}
|
||||||
|
for _, domain := range c.domains {
|
||||||
|
tunnel := c.tunnels[domain]
|
||||||
|
active := activeClients(tunnel, now)
|
||||||
|
if active > tunnel.peakClients {
|
||||||
|
tunnel.peakClients = active
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot.Tunnels[domain] = TunnelSnapshot{
|
||||||
|
Domain: domain,
|
||||||
|
ActiveClients: active,
|
||||||
|
PeakClients: tunnel.peakClients,
|
||||||
|
TotalSessions: len(tunnel.clients),
|
||||||
|
TotalQueries: tunnel.queries,
|
||||||
|
BytesIn: tunnel.bytesIn,
|
||||||
|
BytesOut: tunnel.bytesOut,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Collector) findTunnelLocked(queryDomain string) (*tunnelState, bool) {
|
||||||
|
queryDomain = normalizeDomain(queryDomain)
|
||||||
|
if tunnel, ok := c.tunnels[queryDomain]; ok {
|
||||||
|
return tunnel, true
|
||||||
|
}
|
||||||
|
for _, domain := range c.domains {
|
||||||
|
if strings.HasSuffix(queryDomain, "."+domain) {
|
||||||
|
return c.tunnels[domain], true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func activeClients(tunnel *tunnelState, now time.Time) int {
|
||||||
|
active := 0
|
||||||
|
for _, client := range tunnel.clients {
|
||||||
|
if now.Sub(client.lastSeen) < ClientTimeout {
|
||||||
|
active++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return active
|
||||||
|
}
|
||||||
61
internal/dnstt/collector_test.go
Normal file
61
internal/dnstt/collector_test.go
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCollectorTracksActiveAndPeakClients(t *testing.T) {
|
||||||
|
now := time.Unix(1000, 0)
|
||||||
|
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
|
||||||
|
|
||||||
|
c.RecordQuery("tunnel.example.com", "client-a", 120)
|
||||||
|
c.RecordQuery("tunnel.example.com", "client-b", 80)
|
||||||
|
c.RecordResponse("tunnel.example.com", 200)
|
||||||
|
|
||||||
|
snapshot := c.Snapshot()
|
||||||
|
tunnel := snapshot.Tunnels["tunnel.example.com"]
|
||||||
|
if tunnel.ActiveClients != 2 {
|
||||||
|
t.Fatalf("active clients = %d, want 2", tunnel.ActiveClients)
|
||||||
|
}
|
||||||
|
if tunnel.PeakClients != 2 {
|
||||||
|
t.Fatalf("peak clients = %d, want 2", tunnel.PeakClients)
|
||||||
|
}
|
||||||
|
if tunnel.TotalSessions != 2 {
|
||||||
|
t.Fatalf("total sessions = %d, want 2", tunnel.TotalSessions)
|
||||||
|
}
|
||||||
|
if tunnel.TotalQueries != 2 {
|
||||||
|
t.Fatalf("queries = %d, want 2", tunnel.TotalQueries)
|
||||||
|
}
|
||||||
|
if tunnel.BytesIn != 200 {
|
||||||
|
t.Fatalf("bytes in = %d, want 200", tunnel.BytesIn)
|
||||||
|
}
|
||||||
|
if tunnel.BytesOut != 200 {
|
||||||
|
t.Fatalf("bytes out = %d, want 200", tunnel.BytesOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
now = now.Add(ClientTimeout + time.Second)
|
||||||
|
snapshot = c.Snapshot()
|
||||||
|
tunnel = snapshot.Tunnels["tunnel.example.com"]
|
||||||
|
if tunnel.ActiveClients != 0 {
|
||||||
|
t.Fatalf("active clients after timeout = %d, want 0", tunnel.ActiveClients)
|
||||||
|
}
|
||||||
|
if tunnel.PeakClients != 2 {
|
||||||
|
t.Fatalf("peak clients after timeout = %d, want 2", tunnel.PeakClients)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectorMatchesSubdomainsToRegisteredTunnel(t *testing.T) {
|
||||||
|
now := time.Unix(1000, 0)
|
||||||
|
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
|
||||||
|
|
||||||
|
c.RecordQuery("abcd.tunnel.example.com", "client-a", 120)
|
||||||
|
|
||||||
|
tunnel := c.Snapshot().Tunnels["tunnel.example.com"]
|
||||||
|
if tunnel.TotalQueries != 1 {
|
||||||
|
t.Fatalf("queries = %d, want 1", tunnel.TotalQueries)
|
||||||
|
}
|
||||||
|
if tunnel.ActiveClients != 1 {
|
||||||
|
t.Fatalf("active clients = %d, want 1", tunnel.ActiveClients)
|
||||||
|
}
|
||||||
|
}
|
||||||
83
internal/dnstt/exporter.go
Normal file
83
internal/dnstt/exporter.go
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import "github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
|
const namespace = "dnstt"
|
||||||
|
|
||||||
|
// Exporter exposes aggregate DNSTT traffic metrics from a Collector.
|
||||||
|
type Exporter struct {
|
||||||
|
collector *Collector
|
||||||
|
|
||||||
|
activeClients *prometheus.Desc
|
||||||
|
peakClients *prometheus.Desc
|
||||||
|
queries *prometheus.Desc
|
||||||
|
bytesIn *prometheus.Desc
|
||||||
|
bytesOut *prometheus.Desc
|
||||||
|
sessions *prometheus.Desc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExporter creates a Prometheus collector for DNSTT metrics.
|
||||||
|
func NewExporter(collector *Collector) *Exporter {
|
||||||
|
labels := []string{"domain"}
|
||||||
|
return &Exporter{
|
||||||
|
collector: collector,
|
||||||
|
activeClients: prometheus.NewDesc(
|
||||||
|
prometheus.BuildFQName(namespace, "", "active_clients"),
|
||||||
|
"Number of DNSTT client sessions observed within the active timeout window.",
|
||||||
|
labels,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
peakClients: prometheus.NewDesc(
|
||||||
|
prometheus.BuildFQName(namespace, "", "peak_clients"),
|
||||||
|
"Maximum concurrent active DNSTT client sessions observed.",
|
||||||
|
labels,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
queries: prometheus.NewDesc(
|
||||||
|
prometheus.BuildFQName(namespace, "", "queries_total"),
|
||||||
|
"Total DNSTT DNS queries observed.",
|
||||||
|
labels,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
bytesIn: prometheus.NewDesc(
|
||||||
|
prometheus.BuildFQName(namespace, "", "bytes_in_total"),
|
||||||
|
"Total bytes observed in DNSTT DNS queries.",
|
||||||
|
labels,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
bytesOut: prometheus.NewDesc(
|
||||||
|
prometheus.BuildFQName(namespace, "", "bytes_out_total"),
|
||||||
|
"Total bytes observed in DNSTT DNS responses.",
|
||||||
|
labels,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
sessions: prometheus.NewDesc(
|
||||||
|
prometheus.BuildFQName(namespace, "", "sessions_total"),
|
||||||
|
"Total unique DNSTT client sessions observed.",
|
||||||
|
labels,
|
||||||
|
nil,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describe sends metric descriptors to Prometheus.
|
||||||
|
func (e *Exporter) Describe(ch chan<- *prometheus.Desc) {
|
||||||
|
ch <- e.activeClients
|
||||||
|
ch <- e.peakClients
|
||||||
|
ch <- e.queries
|
||||||
|
ch <- e.bytesIn
|
||||||
|
ch <- e.bytesOut
|
||||||
|
ch <- e.sessions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect sends current metric values to Prometheus.
|
||||||
|
func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
|
||||||
|
for _, tunnel := range e.collector.Snapshot().Tunnels {
|
||||||
|
ch <- prometheus.MustNewConstMetric(e.activeClients, prometheus.GaugeValue, float64(tunnel.ActiveClients), tunnel.Domain)
|
||||||
|
ch <- prometheus.MustNewConstMetric(e.peakClients, prometheus.GaugeValue, float64(tunnel.PeakClients), tunnel.Domain)
|
||||||
|
ch <- prometheus.MustNewConstMetric(e.queries, prometheus.CounterValue, float64(tunnel.TotalQueries), tunnel.Domain)
|
||||||
|
ch <- prometheus.MustNewConstMetric(e.bytesIn, prometheus.CounterValue, float64(tunnel.BytesIn), tunnel.Domain)
|
||||||
|
ch <- prometheus.MustNewConstMetric(e.bytesOut, prometheus.CounterValue, float64(tunnel.BytesOut), tunnel.Domain)
|
||||||
|
ch <- prometheus.MustNewConstMetric(e.sessions, prometheus.CounterValue, float64(tunnel.TotalSessions), tunnel.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
42
internal/dnstt/exporter_test.go
Normal file
42
internal/dnstt/exporter_test.go
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExporterCollectsAggregateDNSTTMetrics(t *testing.T) {
|
||||||
|
now := time.Unix(1000, 0)
|
||||||
|
c := NewCollector([]string{"tunnel.example.com"}, WithNow(func() time.Time { return now }))
|
||||||
|
c.RecordQuery("tunnel.example.com", "client-a", 100)
|
||||||
|
c.RecordQuery("tunnel.example.com", "client-b", 300)
|
||||||
|
c.RecordResponse("tunnel.example.com", 250)
|
||||||
|
|
||||||
|
expected := `
|
||||||
|
# HELP dnstt_active_clients Number of DNSTT client sessions observed within the active timeout window.
|
||||||
|
# TYPE dnstt_active_clients gauge
|
||||||
|
dnstt_active_clients{domain="tunnel.example.com"} 2
|
||||||
|
# HELP dnstt_bytes_in_total Total bytes observed in DNSTT DNS queries.
|
||||||
|
# TYPE dnstt_bytes_in_total counter
|
||||||
|
dnstt_bytes_in_total{domain="tunnel.example.com"} 400
|
||||||
|
# HELP dnstt_bytes_out_total Total bytes observed in DNSTT DNS responses.
|
||||||
|
# TYPE dnstt_bytes_out_total counter
|
||||||
|
dnstt_bytes_out_total{domain="tunnel.example.com"} 250
|
||||||
|
# HELP dnstt_peak_clients Maximum concurrent active DNSTT client sessions observed.
|
||||||
|
# TYPE dnstt_peak_clients gauge
|
||||||
|
dnstt_peak_clients{domain="tunnel.example.com"} 2
|
||||||
|
# HELP dnstt_queries_total Total DNSTT DNS queries observed.
|
||||||
|
# TYPE dnstt_queries_total counter
|
||||||
|
dnstt_queries_total{domain="tunnel.example.com"} 2
|
||||||
|
# HELP dnstt_sessions_total Total unique DNSTT client sessions observed.
|
||||||
|
# TYPE dnstt_sessions_total counter
|
||||||
|
dnstt_sessions_total{domain="tunnel.example.com"} 2
|
||||||
|
`
|
||||||
|
|
||||||
|
if err := testutil.CollectAndCompare(NewExporter(c), strings.NewReader(expected)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
60
internal/dnstt/packet.go
Normal file
60
internal/dnstt/packet.go
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet.
|
||||||
|
func ProcessIPv4Packet(data []byte, port int, collector *Collector) {
|
||||||
|
if len(data) < 20 || data[0]>>4 != 4 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ihl := int(data[0]&0x0f) * 4
|
||||||
|
if ihl < 20 || len(data) < ihl {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data[9] != 17 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
udpData := data[ihl:]
|
||||||
|
if len(udpData) < 8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srcPort := binary.BigEndian.Uint16(udpData[0:2])
|
||||||
|
dstPort := binary.BigEndian.Uint16(udpData[2:4])
|
||||||
|
udpLen := int(binary.BigEndian.Uint16(udpData[4:6]))
|
||||||
|
dnsPayload := udpData[8:]
|
||||||
|
if len(dnsPayload) < 12 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(dstPort) == port {
|
||||||
|
domain, clientID := ExtractQuery(dnsPayload, collector.Domains())
|
||||||
|
if domain != "" {
|
||||||
|
collector.RecordQuery(domain, clientID, udpLen)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(srcPort) == port {
|
||||||
|
domain := extractQueryDomain(dnsPayload)
|
||||||
|
if domain != "" {
|
||||||
|
collector.RecordResponse(domain, udpLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractQueryDomain(dns []byte) string {
|
||||||
|
if len(dns) < 12 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
labels := parseDNSLabels(dns[12:])
|
||||||
|
if len(labels) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return normalizeDomain(strings.Join(labels, "."))
|
||||||
|
}
|
||||||
82
internal/dnstt/parser.go
Normal file
82
internal/dnstt/parser.go
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base32"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const clientIDLen = 8
|
||||||
|
|
||||||
|
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||||
|
|
||||||
|
// ExtractQuery returns the configured DNSTT domain and session client ID from a DNS query.
|
||||||
|
func ExtractQuery(dns []byte, domains []string) (string, string) {
|
||||||
|
if len(dns) < 12 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
labels := parseDNSLabels(dns[12:])
|
||||||
|
if len(labels) == 0 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
fullDomain := strings.ToLower(strings.Join(labels, "."))
|
||||||
|
for _, domain := range domains {
|
||||||
|
domain = normalizeDomain(domain)
|
||||||
|
if fullDomain == domain {
|
||||||
|
return domain, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
suffix := "." + domain
|
||||||
|
if !strings.HasSuffix(fullDomain, suffix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelLabels := strings.Count(domain, ".") + 1
|
||||||
|
if len(labels) <= tunnelLabels {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixLabels := labels[:len(labels)-tunnelLabels]
|
||||||
|
encoded := strings.ToUpper(strings.Join(prefixLabels, ""))
|
||||||
|
decoded := make([]byte, base32Encoding.DecodedLen(len(encoded)))
|
||||||
|
n, err := base32Encoding.Decode(decoded, []byte(encoded))
|
||||||
|
if err != nil || n < clientIDLen {
|
||||||
|
return domain, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return domain, fmt.Sprintf("%x", decoded[:clientIDLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDNSLabels(data []byte) []string {
|
||||||
|
var labels []string
|
||||||
|
offset := 0
|
||||||
|
|
||||||
|
for offset < len(data) {
|
||||||
|
labelLen := int(data[offset])
|
||||||
|
if labelLen == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if labelLen&0xc0 == 0xc0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
offset++
|
||||||
|
if offset+labelLen > len(data) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
labels = append(labels, string(data[offset:offset+labelLen]))
|
||||||
|
offset += labelLen
|
||||||
|
}
|
||||||
|
|
||||||
|
return labels
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeDomain(domain string) string {
|
||||||
|
return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".")
|
||||||
|
}
|
||||||
93
internal/dnstt/parser_test.go
Normal file
93
internal/dnstt/parser_test.go
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
package dnstt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base32"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractQueryDecodesDNSTTClientID(t *testing.T) {
|
||||||
|
clientID := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
|
||||||
|
encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(append(clientID, 0xaa, 0xbb))
|
||||||
|
packet := dnsQuery(encoded, "tunnel.example.com")
|
||||||
|
|
||||||
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||||
|
|
||||||
|
if domain != "tunnel.example.com" {
|
||||||
|
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
||||||
|
}
|
||||||
|
if gotClientID != "0102030405060708" {
|
||||||
|
t.Fatalf("client ID = %q, want 0102030405060708", gotClientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractQueryMatchesBareDomainWithoutClientID(t *testing.T) {
|
||||||
|
packet := dnsQuery("", "tunnel.example.com")
|
||||||
|
|
||||||
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||||
|
|
||||||
|
if domain != "tunnel.example.com" {
|
||||||
|
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
||||||
|
}
|
||||||
|
if gotClientID != "" {
|
||||||
|
t.Fatalf("client ID = %q, want empty", gotClientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractQueryIgnoresUnregisteredDomain(t *testing.T) {
|
||||||
|
packet := dnsQuery("abcd", "other.example.com")
|
||||||
|
|
||||||
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||||
|
|
||||||
|
if domain != "" || gotClientID != "" {
|
||||||
|
t.Fatalf("got domain=%q clientID=%q, want both empty", domain, gotClientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractQueryReturnsDomainWhenPrefixIsNotClientID(t *testing.T) {
|
||||||
|
packet := dnsQuery("not-base32-!", "tunnel.example.com")
|
||||||
|
|
||||||
|
domain, gotClientID := ExtractQuery(packet, []string{"tunnel.example.com"})
|
||||||
|
|
||||||
|
if domain != "tunnel.example.com" {
|
||||||
|
t.Fatalf("domain = %q, want tunnel.example.com", domain)
|
||||||
|
}
|
||||||
|
if gotClientID != "" {
|
||||||
|
t.Fatalf("client ID = %q, want empty", gotClientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dnsQuery(prefix string, domain string) []byte {
|
||||||
|
packet := []byte{
|
||||||
|
0x12, 0x34, // transaction ID
|
||||||
|
0x01, 0x00, // flags
|
||||||
|
0x00, 0x01, // questions
|
||||||
|
0x00, 0x00, // answers
|
||||||
|
0x00, 0x00, // authority
|
||||||
|
0x00, 0x00, // additional
|
||||||
|
}
|
||||||
|
|
||||||
|
labels := splitDomain(domain)
|
||||||
|
if prefix != "" {
|
||||||
|
labels = append([]string{prefix}, labels...)
|
||||||
|
}
|
||||||
|
for _, label := range labels {
|
||||||
|
packet = append(packet, byte(len(label)))
|
||||||
|
packet = append(packet, label...)
|
||||||
|
}
|
||||||
|
packet = append(packet, 0x00) // root label
|
||||||
|
packet = append(packet, 0x00, 0x10) // TXT
|
||||||
|
packet = append(packet, 0x00, 0x01) // IN
|
||||||
|
return packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitDomain(domain string) []string {
|
||||||
|
var labels []string
|
||||||
|
start := 0
|
||||||
|
for i := 0; i <= len(domain); i++ {
|
||||||
|
if i == len(domain) || domain[i] == '.' {
|
||||||
|
labels = append(labels, domain[start:i])
|
||||||
|
start = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return labels
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue