Last active
October 28, 2024 02:42
-
-
Save liweitianux/fcdce02c6d9a0df1eca67766f4089bb2 to your computer and use it in GitHub Desktop.
Simple DNS query inspector (UDP+TCP+DoT+DoH)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: MIT | |
// | |
// Simple DNS server that inspects the query and reflects it back. | |
// | |
// Weitian LI | |
// 2024-10-23 | |
// | |
package main | |
import ( | |
"context" | |
"crypto/tls" | |
"encoding/base64" | |
"errors" | |
"flag" | |
"fmt" | |
"io" | |
"log" | |
"net" | |
"net/http" | |
"time" | |
"github.com/miekg/dns" | |
"golang.org/x/net/http2" | |
"golang.org/x/net/http2/h2c" | |
) | |
const ( | |
maxQuerySize = 1024 // bytes | |
readTimeout = 5 * time.Second | |
writeTimeout = 5 * time.Second | |
handshakeTimeout = 5 * time.Second | |
dohURI = "/dns-query" | |
dohContentType = "application/dns-message" | |
) | |
func main() { | |
addr := flag.String("addr", "0.0.0.0", "listen address") | |
port := flag.Int("port", 53, "UDP port") | |
tcpPort := flag.Int("tcp-port", 53, "TCP port") | |
dotPort := flag.Int("dot-port", 0, "DoT (DNS-over-TLS) port") | |
dotCert := flag.String("dot-cert", "", "TLS certificate for DoT") | |
dotKey := flag.String("dot-key", "", "TLS private key for DoT") | |
dohPort := flag.Int("doh-port", 0, "DoH (DNS-over-HTTPS) port") | |
dohCert := flag.String("doh-cert", "", "TLS certificate for DoH") | |
dohKey := flag.String("doh-key", "", "TLS private key for DoH") | |
flag.Parse() | |
go serveUDP(*addr, *port) | |
if *tcpPort > 0 { | |
go serveTCP(*addr, *tcpPort) | |
} | |
if *dotPort > 0 { | |
if *dotCert == "" || *dotKey == "" { | |
log.Fatalf("ERROR: -dot-cert and -dot-key required") | |
} | |
cert, err := tls.LoadX509KeyPair(*dotCert, *dotKey) | |
if err != nil { | |
log.Fatalf("ERROR: failed to load DoT cert/key: %v", err) | |
} | |
config := &tls.Config{ | |
Certificates: []tls.Certificate{cert}, | |
GetConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) { | |
log.Printf("TLS connection from %s with ServerName=[%s]", | |
chi.Conn.RemoteAddr(), chi.ServerName) | |
return nil, nil | |
}, | |
} | |
go serveDoT(*addr, *dotPort, config) | |
} | |
if *dohPort > 0 { | |
var config *tls.Config | |
if *dotCert != "" && *dotKey != "" { | |
cert, err := tls.LoadX509KeyPair(*dohCert, *dohKey) | |
if err != nil { | |
log.Fatalf("ERROR: failed to load DoH cert/key: %v", err) | |
} | |
config = &tls.Config{ | |
Certificates: []tls.Certificate{cert}, | |
GetConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) { | |
log.Printf("TLS connection from %s with ServerName=[%s]", | |
chi.Conn.RemoteAddr(), chi.ServerName) | |
return nil, nil | |
}, | |
} | |
} | |
go serveDoH(*addr, *dohPort, config) | |
} | |
select {} | |
} | |
func serveUDP(addr string, port int) { | |
laddr := fmt.Sprintf("%s:%d", addr, port) | |
pc, err := net.ListenPacket("udp", laddr) | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer pc.Close() | |
log.Printf("UDP serving at: %s", laddr) | |
for { | |
buf := make([]byte, maxQuerySize) | |
n, addr, err := pc.ReadFrom(buf) | |
if err != nil { | |
log.Printf("ERROR: ReadFrom() failed: %v", err) | |
continue | |
} | |
log.Printf("UDP query from: %s", addr) | |
go handleUDP(pc, addr, buf[:n]) | |
} | |
} | |
func handleUDP(pc net.PacketConn, addr net.Addr, buf []byte) { | |
resp := handleQuery(buf) | |
if len(resp) > 0 { | |
pc.WriteTo(resp, addr) | |
} | |
} | |
func serveTCP(addr string, port int) { | |
laddr := fmt.Sprintf("%s:%d", addr, port) | |
l, err := net.Listen("tcp", laddr) | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer l.Close() | |
log.Printf("TCP serving at: %s", laddr) | |
for { | |
conn, err := l.Accept() | |
if err != nil { | |
log.Printf("ERROR: Accept() failed: %v", err) | |
continue | |
} | |
log.Printf("TCP query from: %s", conn.RemoteAddr()) | |
go handleTCP(conn) | |
} | |
} | |
func handleTCP(c net.Conn) { | |
defer c.Close() | |
c.SetReadDeadline(time.Now().Add(readTimeout)) | |
// read packet length | |
lbuf := make([]byte, 2) | |
_, err := io.ReadFull(c, lbuf) | |
if err != nil && !errors.Is(err, io.EOF) { | |
log.Printf("ERROR: failed to read packet length: %v", err) | |
return | |
} | |
// read packet content | |
l := int(lbuf[0])<<8 | int(lbuf[1]) | |
mbuf := make([]byte, l) | |
_, err = io.ReadFull(c, mbuf) | |
if err != nil && !errors.Is(err, io.EOF) { | |
log.Printf("ERROR: failed to read packet content: %v", err) | |
return | |
} | |
resp := handleQuery(mbuf) | |
if len(resp) == 0 { | |
return | |
} | |
l = len(resp) | |
lbuf = []byte{byte(l >> 8), byte(l)} | |
buf := append(lbuf, resp...) | |
c.SetWriteDeadline(time.Now().Add(writeTimeout)) | |
n, err := c.Write(buf) | |
if err != nil { | |
log.Printf("ERROR: failed to write response: %v", err) | |
} else if n != len(buf) { | |
log.Printf("WARN: response write incomplete: n=%d, l=%d", n, len(buf)) | |
} | |
} | |
func serveDoT(addr string, port int, config *tls.Config) { | |
laddr := fmt.Sprintf("%s:%d", addr, port) | |
l, err := tls.Listen("tcp", laddr, config) | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer l.Close() | |
log.Printf("DoT serving at: %s", laddr) | |
for { | |
conn, err := l.Accept() | |
if err != nil { | |
log.Printf("ERROR: Accept() failed: %v", err) | |
continue | |
} | |
log.Printf("DoT query from: %s", conn.RemoteAddr()) | |
// NOTE: Manually call handshake so that we can get the | |
// connection state. | |
tconn := conn.(*tls.Conn) | |
ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout) | |
defer cancel() | |
if err := tconn.HandshakeContext(ctx); err != nil { | |
log.Printf("ERROR: handshake failed: %v", err) | |
continue | |
} | |
cs := tconn.ConnectionState() | |
fmt.Printf("TLS: Version=%s, CipherSuite=%s, ServerName=%s, ALPN=%s\n", | |
tls.VersionName(cs.Version), tls.CipherSuiteName(cs.CipherSuite), | |
cs.ServerName, cs.NegotiatedProtocol) | |
go handleTCP(conn) | |
} | |
} | |
func serveDoH(addr string, port int, config *tls.Config) { | |
mux := http.NewServeMux() | |
mux.HandleFunc(dohURI, handleDoH) | |
server := &http.Server{ | |
Addr: fmt.Sprintf("%s:%d", addr, port), | |
Handler: mux, | |
} | |
if config != nil { | |
log.Printf("DoH serving at: https://%s%s", server.Addr, dohURI) | |
server.TLSConfig = config | |
server.ListenAndServeTLS("", "") | |
} else { | |
log.Printf("DoH serving at: http://%s%s", server.Addr, dohURI) | |
// NOTE: "dig +http-plain" will use H2C (i.e., HTTP/2 cleartext) | |
// prior knowledge to perform the request, so need to support | |
// H2C here. In this case, both "prior knowledge" and | |
// "upgrade" modes are supported. | |
// Credit: https://medium.com/@thrawn01/http-2-cleartext-h2c-client-example-in-go-8167c7a4181e | |
h2s := &http2.Server{} | |
if err := http2.ConfigureServer(server, h2s); err != nil { | |
log.Fatal(err) | |
} | |
server.Handler = h2c.NewHandler(server.Handler, h2s) | |
server.ListenAndServe() | |
} | |
} | |
func handleDoH(w http.ResponseWriter, r *http.Request) { | |
log.Printf("DoH query from: %s", r.RemoteAddr) | |
var query []byte | |
switch r.Method { | |
case http.MethodGet: | |
v := r.FormValue("dns") | |
if v == "" { | |
http.Error(w, "400 bad request: dns missing", http.StatusBadRequest) | |
return | |
} | |
b, err := base64.RawURLEncoding.DecodeString(v) | |
if err != nil || len(b) == 0 { | |
http.Error(w, "400 bad request: dns invalid", http.StatusBadRequest) | |
return | |
} | |
query = b | |
case http.MethodPost: | |
if r.Header.Get("Content-Type") != dohContentType { | |
http.Error(w, "400 bad request: content-type invalid", http.StatusBadRequest) | |
return | |
} | |
body, err := io.ReadAll(r.Body) | |
if err != nil || len(body) == 0 { | |
http.Error(w, "400 bad request: body", http.StatusBadRequest) | |
return | |
} | |
query = body | |
default: | |
http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed) | |
return | |
} | |
if cs := r.TLS; cs != nil { | |
fmt.Printf("TLS: Version=%s, CipherSuite=%s, ServerName=%s, ALPN=%s\n", | |
tls.VersionName(cs.Version), tls.CipherSuiteName(cs.CipherSuite), | |
cs.ServerName, cs.NegotiatedProtocol) | |
} | |
fmt.Printf("Request: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") | |
fmt.Printf("%s %s %s\n", r.Method, r.RequestURI, r.Proto) | |
fmt.Printf("Host: %s\n", r.Host) | |
for k, vlist := range r.Header { | |
for _, v := range vlist { | |
fmt.Printf("%s: %s\n", k, v) | |
} | |
} | |
fmt.Printf("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") | |
resp := handleQuery(query) | |
if len(resp) == 0 { | |
http.Error(w, "400 bad request: response", http.StatusBadRequest) | |
return | |
} | |
w.Header().Set("Content-Type", dohContentType) | |
w.WriteHeader(http.StatusOK) | |
w.Write(resp) | |
} | |
func handleQuery(query []byte) []byte { | |
if len(query) == 0 { | |
log.Printf("WARN: empty query") | |
return nil | |
} | |
msg := dns.Msg{} | |
if err := msg.Unpack(query); err != nil { | |
log.Printf("ERROR: failed to parse query: %v", err) | |
// The input can be invalid DNS, so don't try to reply. | |
return nil | |
} | |
fmt.Printf("Query message: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") | |
fmt.Printf("%s", msg.String()) | |
fmt.Printf("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") | |
// Reflect the query | |
resp := query | |
resp[2] |= 0x80 // Set QR bit | |
return resp | |
} |
Example DoH (H2C) output:
2024/10/24 13:19:33 DoH query from: 127.0.0.1:40985
Request: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
POST /dns-query HTTP/2.0
Host: 127.0.0.1:8053
Content-Type: application/dns-message
Accept: application/dns-message
Content-Length: 67
Cache-Control: no-cache, no-store, must-revalidate
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Query message: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
;; opcode: QUERY, status: NOERROR, id: 57428
;; flags: rd ad; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 1
;; OPT PSEUDOSECTION:
; EDNS: version 0; flags:; udp: 1232
; SUBNET: 1.2.3.0/24/0
; COOKIE: 64a6ebaff54c8cb5
;; QUESTION SECTION:
;www.example.com. IN A
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Client command:
dig @127.0.0.1 -p 8053 www.example.com +subnet=1.2.3.4/24 +http-plain
Example DoH output:
2024/10/28 10:36:10 TLS connection from 127.0.0.1:55181 with ServerName=[localhost]
2024/10/28 10:36:10 DoH query from: 127.0.0.1:55181
TLS: Version=TLS 1.3, CipherSuite=TLS_AES_128_GCM_SHA256, ServerName=localhost, ALPN=h2
Request: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
POST /dns-query HTTP/2.0
Host: localhost
Content-Length: 128
Accept: application/dns-message
Content-Type: application/dns-message
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Query message: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
;; opcode: QUERY, status: NOERROR, id: 0
;; flags: rd ad; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 1
;; OPT PSEUDOSECTION:
; EDNS: version 0; flags:; udp: 1232
; SUBNET: 1.2.3.4/32/0
; PADDING: 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
;; QUESTION SECTION:
;www.example.com. IN A
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Client command:
kdig @localhost -p 8053 www.example.com +subnet=1.2.3.4/24 +https
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example DoT output:
Client command:
kdig @localhost -p 8853 www.example.com +subnet=1.2.3.4/24 +tls