Skip to content

Instantly share code, notes, and snippets.

@kokizzu
Forked from qdm12/main.go
Created March 18, 2021 18:46
Show Gist options
  • Save kokizzu/5f4178b7674de9e3eea25933b9c29cfa to your computer and use it in GitHub Desktop.
Save kokizzu/5f4178b7674de9e3eea25933b9c29cfa to your computer and use it in GitHub Desktop.
DNS over HTTPS server resolver under 300 lines of clean Go code
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
)
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server, err := NewServer(ctx, Cloudflare())
if err != nil {
log.Println(err)
return
}
stopped := make(chan struct{})
go server.Run(ctx, stopped)
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
select {
case <-stopped:
case signal := <-signals:
log.Println("OS signal received, shutting down: ", signal)
cancel()
select {
case <-stopped: // wait for the server to be stopped
case <-signals:
// exit on second signal
return
}
}
}
type Provider struct {
plainDNS string
fqdn string
url string
}
func Cloudflare() Provider {
return Provider{
plainDNS: "1.1.1.1:53",
fqdn: "cloudflare-dns.com.",
url: "https://cloudflare-dns.com/dns-query",
}
}
type Server interface {
Run(ctx context.Context, stopped chan<- struct{})
}
type server struct {
dnsServer dns.Server
}
var ErrResolveDOHFqdn = errors.New("cannot resolve DNS over HTTPS fqdn")
func NewServer(ctx context.Context, provider Provider) (s Server, err error) {
dnsClient := new(dns.Client)
message := new(dns.Msg)
message.SetQuestion(provider.fqdn, dns.TypeA)
hostIPv4, _, errIPv4 := dnsClient.Exchange(message, provider.plainDNS)
if err != nil {
log.Printf("cannot obtain IPv4 address for %s: %s\n", provider.fqdn, errIPv4)
} else {
log.Printf("resolved %s to IPv4 %s", provider.fqdn, hostIPv4.Answer[0].(*dns.A).A)
}
message.SetQuestion(provider.fqdn, dns.TypeAAAA)
hostIPv6, _, errIPv6 := dnsClient.Exchange(message, provider.plainDNS)
if err != nil {
log.Printf("cannot obtain IPv6 address for %s: %s\n", provider.fqdn, errIPv6)
} else {
log.Printf("resolved %s to IPv6 %s", provider.fqdn, hostIPv6.Answer[0].(*dns.AAAA).AAAA)
}
if errIPv4 != nil && errIPv6 != nil {
return nil, fmt.Errorf("%w: %s", ErrResolveDOHFqdn, provider.fqdn)
}
const httpTimeout = 5 * time.Second
return &server{
dnsServer: dns.Server{
Addr: ":53",
Net: "udp",
Handler: newDNSHandler(ctx, httpTimeout, provider, hostIPv4, hostIPv6),
},
}, nil
}
func (s *server) Run(ctx context.Context, stopped chan<- struct{}) {
defer close(stopped)
go func() { // shutdown goroutine
<-ctx.Done()
const graceTime = 100 * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), graceTime)
defer cancel()
if err := s.dnsServer.ShutdownContext(ctx); err != nil {
log.Println("DNS server shutdown error: ", err)
}
}()
log.Println("DNS server listening on :53")
if err := s.dnsServer.ListenAndServe(); err != nil {
log.Println("DNS server crashed: ", err)
}
log.Println("DNS server stopped")
}
func newDNSHandler(ctx context.Context, httpTimeout time.Duration,
provider Provider, hostIPv4, hostIPv6 *dns.Msg) dns.Handler {
client := &http.Client{
Timeout: httpTimeout,
}
httpBufferPool := &sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(nil)
},
}
const udpPacketMaxSize = 512
udpBufferPool := &sync.Pool{
New: func() interface{} {
return make([]byte, udpPacketMaxSize)
},
}
return &dnsHandler{
ctx: ctx,
provider: provider,
hostIPv4: hostIPv4,
hostIPv6: hostIPv6,
client: client,
httpBufferPool: httpBufferPool,
udpBufferPool: udpBufferPool,
}
}
type dnsHandler struct {
ctx context.Context
provider Provider
hostIPv4 *dns.Msg
hostIPv6 *dns.Msg
client *http.Client
httpBufferPool *sync.Pool
udpBufferPool *sync.Pool
}
func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if h.askingForProviderIP(r) {
// Answer back with IP address obtained at start
h.replyWithProviderIP(w, r)
return
}
buffer := h.udpBufferPool.Get().([]byte)
// no need to reset buffer as wire takes care of cutting it down
wire, err := r.PackBuffer(buffer)
if err != nil {
log.Printf("cannot pack message to wire format: %s\n", err)
dns.HandleFailed(w, r)
return
}
// It's fine to copy the slice headers as long as we keep
// the underlying array of bytes.
h.udpBufferPool.Put(buffer) //nolint:staticcheck
respWire, err := h.requestHTTP(h.ctx, wire)
if err != nil {
log.Printf("HTTP request failed: %s\n", err)
dns.HandleFailed(w, r)
return
}
message := new(dns.Msg)
if err := message.Unpack(respWire); err != nil {
log.Printf("cannot unpack message from wireformat: %s\n", err)
dns.HandleFailed(w, r)
return
}
message.SetReply(r)
if err := w.WriteMsg(message); err != nil {
log.Println("write dns message error: ", err)
}
}
func (h *dnsHandler) askingForProviderIP(r *dns.Msg) bool {
return len(r.Question) > 0 && r.Question[0].Name == h.provider.fqdn &&
(r.Question[0].Qtype == dns.TypeA || r.Question[0].Qtype == dns.TypeAAAA)
}
func (h *dnsHandler) replyWithProviderIP(w dns.ResponseWriter, r *dns.Msg) {
host := h.hostIPv4
if r.Question[0].Qtype == dns.TypeAAAA {
host = h.hostIPv6
}
host.SetReply(r)
if err := w.WriteMsg(host); err != nil {
log.Println("write dns message error: ", err)
}
}
var (
ErrHTTPStatus = errors.New("bad HTTP status")
)
func (h *dnsHandler) requestHTTP(ctx context.Context, wire []byte) (respWire []byte, err error) {
buffer := h.httpBufferPool.Get().(*bytes.Buffer)
buffer.Reset()
defer h.httpBufferPool.Put(buffer)
_, err = buffer.Write(wire)
if err != nil {
return nil, err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, h.provider.url, buffer)
if err != nil {
return nil, err
}
const contentTypeUDPWireFormat = "application/dns-udpwireformat"
request.Header.Set("Content-Type", contentTypeUDPWireFormat)
response, err := h.client.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %s", ErrHTTPStatus, response.Status)
}
respWire, err = ioutil.ReadAll(response.Body)
if err != nil {
return nil, err
}
if err := response.Body.Close(); err != nil {
return nil, err
}
return respWire, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment