Skip to content

Instantly share code, notes, and snippets.

@liweitianux
Last active October 28, 2024 02:42
Show Gist options
  • Save liweitianux/fcdce02c6d9a0df1eca67766f4089bb2 to your computer and use it in GitHub Desktop.
Save liweitianux/fcdce02c6d9a0df1eca67766f4089bb2 to your computer and use it in GitHub Desktop.
Simple DNS query inspector (UDP+TCP+DoT+DoH)
// SPDX-License-Identifier: MIT
//
// Simple DNS server that inspects the query and reflects it back.
//
// Weitian LI
// 2024-10-23
//
package main
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"time"
"github.com/miekg/dns"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
const (
maxQuerySize = 1024 // bytes
readTimeout = 5 * time.Second
writeTimeout = 5 * time.Second
handshakeTimeout = 5 * time.Second
dohURI = "/dns-query"
dohContentType = "application/dns-message"
)
func main() {
addr := flag.String("addr", "0.0.0.0", "listen address")
port := flag.Int("port", 53, "UDP port")
tcpPort := flag.Int("tcp-port", 53, "TCP port")
dotPort := flag.Int("dot-port", 0, "DoT (DNS-over-TLS) port")
dotCert := flag.String("dot-cert", "", "TLS certificate for DoT")
dotKey := flag.String("dot-key", "", "TLS private key for DoT")
dohPort := flag.Int("doh-port", 0, "DoH (DNS-over-HTTPS) port")
dohCert := flag.String("doh-cert", "", "TLS certificate for DoH")
dohKey := flag.String("doh-key", "", "TLS private key for DoH")
flag.Parse()
go serveUDP(*addr, *port)
if *tcpPort > 0 {
go serveTCP(*addr, *tcpPort)
}
if *dotPort > 0 {
if *dotCert == "" || *dotKey == "" {
log.Fatalf("ERROR: -dot-cert and -dot-key required")
}
cert, err := tls.LoadX509KeyPair(*dotCert, *dotKey)
if err != nil {
log.Fatalf("ERROR: failed to load DoT cert/key: %v", err)
}
config := &tls.Config{
Certificates: []tls.Certificate{cert},
GetConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
log.Printf("TLS connection from %s with ServerName=[%s]",
chi.Conn.RemoteAddr(), chi.ServerName)
return nil, nil
},
}
go serveDoT(*addr, *dotPort, config)
}
if *dohPort > 0 {
var config *tls.Config
if *dotCert != "" && *dotKey != "" {
cert, err := tls.LoadX509KeyPair(*dohCert, *dohKey)
if err != nil {
log.Fatalf("ERROR: failed to load DoH cert/key: %v", err)
}
config = &tls.Config{
Certificates: []tls.Certificate{cert},
GetConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
log.Printf("TLS connection from %s with ServerName=[%s]",
chi.Conn.RemoteAddr(), chi.ServerName)
return nil, nil
},
}
}
go serveDoH(*addr, *dohPort, config)
}
select {}
}
func serveUDP(addr string, port int) {
laddr := fmt.Sprintf("%s:%d", addr, port)
pc, err := net.ListenPacket("udp", laddr)
if err != nil {
log.Fatal(err)
}
defer pc.Close()
log.Printf("UDP serving at: %s", laddr)
for {
buf := make([]byte, maxQuerySize)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
log.Printf("ERROR: ReadFrom() failed: %v", err)
continue
}
log.Printf("UDP query from: %s", addr)
go handleUDP(pc, addr, buf[:n])
}
}
func handleUDP(pc net.PacketConn, addr net.Addr, buf []byte) {
resp := handleQuery(buf)
if len(resp) > 0 {
pc.WriteTo(resp, addr)
}
}
func serveTCP(addr string, port int) {
laddr := fmt.Sprintf("%s:%d", addr, port)
l, err := net.Listen("tcp", laddr)
if err != nil {
log.Fatal(err)
}
defer l.Close()
log.Printf("TCP serving at: %s", laddr)
for {
conn, err := l.Accept()
if err != nil {
log.Printf("ERROR: Accept() failed: %v", err)
continue
}
log.Printf("TCP query from: %s", conn.RemoteAddr())
go handleTCP(conn)
}
}
func handleTCP(c net.Conn) {
defer c.Close()
c.SetReadDeadline(time.Now().Add(readTimeout))
// read packet length
lbuf := make([]byte, 2)
_, err := io.ReadFull(c, lbuf)
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("ERROR: failed to read packet length: %v", err)
return
}
// read packet content
l := int(lbuf[0])<<8 | int(lbuf[1])
mbuf := make([]byte, l)
_, err = io.ReadFull(c, mbuf)
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("ERROR: failed to read packet content: %v", err)
return
}
resp := handleQuery(mbuf)
if len(resp) == 0 {
return
}
l = len(resp)
lbuf = []byte{byte(l >> 8), byte(l)}
buf := append(lbuf, resp...)
c.SetWriteDeadline(time.Now().Add(writeTimeout))
n, err := c.Write(buf)
if err != nil {
log.Printf("ERROR: failed to write response: %v", err)
} else if n != len(buf) {
log.Printf("WARN: response write incomplete: n=%d, l=%d", n, len(buf))
}
}
func serveDoT(addr string, port int, config *tls.Config) {
laddr := fmt.Sprintf("%s:%d", addr, port)
l, err := tls.Listen("tcp", laddr, config)
if err != nil {
log.Fatal(err)
}
defer l.Close()
log.Printf("DoT serving at: %s", laddr)
for {
conn, err := l.Accept()
if err != nil {
log.Printf("ERROR: Accept() failed: %v", err)
continue
}
log.Printf("DoT query from: %s", conn.RemoteAddr())
// NOTE: Manually call handshake so that we can get the
// connection state.
tconn := conn.(*tls.Conn)
ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout)
defer cancel()
if err := tconn.HandshakeContext(ctx); err != nil {
log.Printf("ERROR: handshake failed: %v", err)
continue
}
cs := tconn.ConnectionState()
fmt.Printf("TLS: Version=%s, CipherSuite=%s, ServerName=%s, ALPN=%s\n",
tls.VersionName(cs.Version), tls.CipherSuiteName(cs.CipherSuite),
cs.ServerName, cs.NegotiatedProtocol)
go handleTCP(conn)
}
}
func serveDoH(addr string, port int, config *tls.Config) {
mux := http.NewServeMux()
mux.HandleFunc(dohURI, handleDoH)
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", addr, port),
Handler: mux,
}
if config != nil {
log.Printf("DoH serving at: https://%s%s", server.Addr, dohURI)
server.TLSConfig = config
server.ListenAndServeTLS("", "")
} else {
log.Printf("DoH serving at: http://%s%s", server.Addr, dohURI)
// NOTE: "dig +http-plain" will use H2C (i.e., HTTP/2 cleartext)
// prior knowledge to perform the request, so need to support
// H2C here. In this case, both "prior knowledge" and
// "upgrade" modes are supported.
// Credit: https://medium.com/@thrawn01/http-2-cleartext-h2c-client-example-in-go-8167c7a4181e
h2s := &http2.Server{}
if err := http2.ConfigureServer(server, h2s); err != nil {
log.Fatal(err)
}
server.Handler = h2c.NewHandler(server.Handler, h2s)
server.ListenAndServe()
}
}
func handleDoH(w http.ResponseWriter, r *http.Request) {
log.Printf("DoH query from: %s", r.RemoteAddr)
var query []byte
switch r.Method {
case http.MethodGet:
v := r.FormValue("dns")
if v == "" {
http.Error(w, "400 bad request: dns missing", http.StatusBadRequest)
return
}
b, err := base64.RawURLEncoding.DecodeString(v)
if err != nil || len(b) == 0 {
http.Error(w, "400 bad request: dns invalid", http.StatusBadRequest)
return
}
query = b
case http.MethodPost:
if r.Header.Get("Content-Type") != dohContentType {
http.Error(w, "400 bad request: content-type invalid", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil || len(body) == 0 {
http.Error(w, "400 bad request: body", http.StatusBadRequest)
return
}
query = body
default:
http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
return
}
if cs := r.TLS; cs != nil {
fmt.Printf("TLS: Version=%s, CipherSuite=%s, ServerName=%s, ALPN=%s\n",
tls.VersionName(cs.Version), tls.CipherSuiteName(cs.CipherSuite),
cs.ServerName, cs.NegotiatedProtocol)
}
fmt.Printf("Request: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
fmt.Printf("%s %s %s\n", r.Method, r.RequestURI, r.Proto)
fmt.Printf("Host: %s\n", r.Host)
for k, vlist := range r.Header {
for _, v := range vlist {
fmt.Printf("%s: %s\n", k, v)
}
}
fmt.Printf("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
resp := handleQuery(query)
if len(resp) == 0 {
http.Error(w, "400 bad request: response", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", dohContentType)
w.WriteHeader(http.StatusOK)
w.Write(resp)
}
func handleQuery(query []byte) []byte {
if len(query) == 0 {
log.Printf("WARN: empty query")
return nil
}
msg := dns.Msg{}
if err := msg.Unpack(query); err != nil {
log.Printf("ERROR: failed to parse query: %v", err)
// The input can be invalid DNS, so don't try to reply.
return nil
}
fmt.Printf("Query message: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
fmt.Printf("%s", msg.String())
fmt.Printf("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
// Reflect the query
resp := query
resp[2] |= 0x80 // Set QR bit
return resp
}
@liweitianux
Copy link
Author

Example DoH output:

2024/10/28 10:36:10 TLS connection from 127.0.0.1:55181 with ServerName=[localhost]
2024/10/28 10:36:10 DoH query from: 127.0.0.1:55181
TLS: Version=TLS 1.3, CipherSuite=TLS_AES_128_GCM_SHA256, ServerName=localhost, ALPN=h2
Request: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
POST /dns-query HTTP/2.0
Host: localhost
Content-Length: 128
Accept: application/dns-message
Content-Type: application/dns-message
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Query message: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
;; opcode: QUERY, status: NOERROR, id: 0
;; flags: rd ad; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 1

;; OPT PSEUDOSECTION:
; EDNS: version 0; flags:; udp: 1232
; SUBNET: 1.2.3.4/32/0
; PADDING: 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

;; QUESTION SECTION:
;www.example.com.	IN	 A
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Client command:
kdig @localhost -p 8053 www.example.com +subnet=1.2.3.4/24 +https

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment