package dnstt import ( "encoding/binary" "net/netip" "strings" ) // ProcessPacket records DNSTT DNS traffic from a raw IP packet. func ProcessPacket(data []byte, port int, collector *Collector) { if len(data) == 0 { return } switch data[0] >> 4 { case 4: ProcessIPv4Packet(data, port, collector) case 6: ProcessIPv6Packet(data, port, collector) } } // ProcessIPv4Packet records DNSTT DNS traffic from a raw IPv4 packet. func ProcessIPv4Packet(data []byte, port int, collector *Collector) { if len(data) < 20 || data[0]>>4 != 4 { return } ihl := int(data[0]&0x0f) * 4 if ihl < 20 || len(data) < ihl { return } if data[9] != 17 { return } srcIP := netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}) dstIP := netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}) udpData := data[ihl:] processUDPPayload(udpData, srcIP, dstIP, port, collector) } // ProcessIPv6Packet records DNSTT DNS traffic from a raw IPv6 packet. func ProcessIPv6Packet(data []byte, port int, collector *Collector) { if len(data) < 40 || data[0]>>4 != 6 { return } if data[6] != 17 { return } var srcBytes [16]byte copy(srcBytes[:], data[8:24]) srcIP := netip.AddrFrom16(srcBytes) var dstBytes [16]byte copy(dstBytes[:], data[24:40]) dstIP := netip.AddrFrom16(dstBytes) payloadLen := int(binary.BigEndian.Uint16(data[4:6])) if payloadLen <= 0 || 40+payloadLen > len(data) { return } processUDPPayload(data[40:40+payloadLen], srcIP, dstIP, port, collector) } func processUDPPayload(udpData []byte, srcIP netip.Addr, dstIP netip.Addr, port int, collector *Collector) { if len(udpData) < 8 { return } srcPort := binary.BigEndian.Uint16(udpData[0:2]) dstPort := binary.BigEndian.Uint16(udpData[2:4]) udpLen := int(binary.BigEndian.Uint16(udpData[4:6])) dnsPayload := udpData[8:] if len(dnsPayload) < 12 { return } if int(dstPort) == port { domain, clientID := ExtractQuery(dnsPayload, collector.Domains()) if domain != "" { collector.RecordQueryFrom(domain, clientID, srcIP, udpLen) } return } if int(srcPort) == port { domain := extractQueryDomain(dnsPayload) if domain != "" { collector.RecordResponseFrom(domain, dstIP, udpLen) } } } func extractQueryDomain(dns []byte) string { if len(dns) < 12 { return "" } labels := parseDNSLabels(dns[12:]) if len(labels) == 0 { return "" } return normalizeDomain(strings.Join(labels, ".")) }