-
-
Save kokizzu/5f4178b7674de9e3eea25933b9c29cfa to your computer and use it in GitHub Desktop.
DNS over HTTPS server resolver under 300 lines of clean Go code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"context" | |
"errors" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"net/http" | |
"os" | |
"os/signal" | |
"sync" | |
"syscall" | |
"time" | |
"github.com/miekg/dns" | |
) | |
func main() { | |
ctx, cancel := context.WithCancel(context.Background()) | |
defer cancel() | |
server, err := NewServer(ctx, Cloudflare()) | |
if err != nil { | |
log.Println(err) | |
return | |
} | |
stopped := make(chan struct{}) | |
go server.Run(ctx, stopped) | |
signals := make(chan os.Signal, 1) | |
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) | |
select { | |
case <-stopped: | |
case signal := <-signals: | |
log.Println("OS signal received, shutting down: ", signal) | |
cancel() | |
select { | |
case <-stopped: // wait for the server to be stopped | |
case <-signals: | |
// exit on second signal | |
return | |
} | |
} | |
} | |
type Provider struct { | |
plainDNS string | |
fqdn string | |
url string | |
} | |
func Cloudflare() Provider { | |
return Provider{ | |
plainDNS: "1.1.1.1:53", | |
fqdn: "cloudflare-dns.com.", | |
url: "https://cloudflare-dns.com/dns-query", | |
} | |
} | |
type Server interface { | |
Run(ctx context.Context, stopped chan<- struct{}) | |
} | |
type server struct { | |
dnsServer dns.Server | |
} | |
var ErrResolveDOHFqdn = errors.New("cannot resolve DNS over HTTPS fqdn") | |
func NewServer(ctx context.Context, provider Provider) (s Server, err error) { | |
dnsClient := new(dns.Client) | |
message := new(dns.Msg) | |
message.SetQuestion(provider.fqdn, dns.TypeA) | |
hostIPv4, _, errIPv4 := dnsClient.Exchange(message, provider.plainDNS) | |
if err != nil { | |
log.Printf("cannot obtain IPv4 address for %s: %s\n", provider.fqdn, errIPv4) | |
} else { | |
log.Printf("resolved %s to IPv4 %s", provider.fqdn, hostIPv4.Answer[0].(*dns.A).A) | |
} | |
message.SetQuestion(provider.fqdn, dns.TypeAAAA) | |
hostIPv6, _, errIPv6 := dnsClient.Exchange(message, provider.plainDNS) | |
if err != nil { | |
log.Printf("cannot obtain IPv6 address for %s: %s\n", provider.fqdn, errIPv6) | |
} else { | |
log.Printf("resolved %s to IPv6 %s", provider.fqdn, hostIPv6.Answer[0].(*dns.AAAA).AAAA) | |
} | |
if errIPv4 != nil && errIPv6 != nil { | |
return nil, fmt.Errorf("%w: %s", ErrResolveDOHFqdn, provider.fqdn) | |
} | |
const httpTimeout = 5 * time.Second | |
return &server{ | |
dnsServer: dns.Server{ | |
Addr: ":53", | |
Net: "udp", | |
Handler: newDNSHandler(ctx, httpTimeout, provider, hostIPv4, hostIPv6), | |
}, | |
}, nil | |
} | |
func (s *server) Run(ctx context.Context, stopped chan<- struct{}) { | |
defer close(stopped) | |
go func() { // shutdown goroutine | |
<-ctx.Done() | |
const graceTime = 100 * time.Millisecond | |
ctx, cancel := context.WithTimeout(context.Background(), graceTime) | |
defer cancel() | |
if err := s.dnsServer.ShutdownContext(ctx); err != nil { | |
log.Println("DNS server shutdown error: ", err) | |
} | |
}() | |
log.Println("DNS server listening on :53") | |
if err := s.dnsServer.ListenAndServe(); err != nil { | |
log.Println("DNS server crashed: ", err) | |
} | |
log.Println("DNS server stopped") | |
} | |
func newDNSHandler(ctx context.Context, httpTimeout time.Duration, | |
provider Provider, hostIPv4, hostIPv6 *dns.Msg) dns.Handler { | |
client := &http.Client{ | |
Timeout: httpTimeout, | |
} | |
httpBufferPool := &sync.Pool{ | |
New: func() interface{} { | |
return bytes.NewBuffer(nil) | |
}, | |
} | |
const udpPacketMaxSize = 512 | |
udpBufferPool := &sync.Pool{ | |
New: func() interface{} { | |
return make([]byte, udpPacketMaxSize) | |
}, | |
} | |
return &dnsHandler{ | |
ctx: ctx, | |
provider: provider, | |
hostIPv4: hostIPv4, | |
hostIPv6: hostIPv6, | |
client: client, | |
httpBufferPool: httpBufferPool, | |
udpBufferPool: udpBufferPool, | |
} | |
} | |
type dnsHandler struct { | |
ctx context.Context | |
provider Provider | |
hostIPv4 *dns.Msg | |
hostIPv6 *dns.Msg | |
client *http.Client | |
httpBufferPool *sync.Pool | |
udpBufferPool *sync.Pool | |
} | |
func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |
if h.askingForProviderIP(r) { | |
// Answer back with IP address obtained at start | |
h.replyWithProviderIP(w, r) | |
return | |
} | |
buffer := h.udpBufferPool.Get().([]byte) | |
// no need to reset buffer as wire takes care of cutting it down | |
wire, err := r.PackBuffer(buffer) | |
if err != nil { | |
log.Printf("cannot pack message to wire format: %s\n", err) | |
dns.HandleFailed(w, r) | |
return | |
} | |
// It's fine to copy the slice headers as long as we keep | |
// the underlying array of bytes. | |
h.udpBufferPool.Put(buffer) //nolint:staticcheck | |
respWire, err := h.requestHTTP(h.ctx, wire) | |
if err != nil { | |
log.Printf("HTTP request failed: %s\n", err) | |
dns.HandleFailed(w, r) | |
return | |
} | |
message := new(dns.Msg) | |
if err := message.Unpack(respWire); err != nil { | |
log.Printf("cannot unpack message from wireformat: %s\n", err) | |
dns.HandleFailed(w, r) | |
return | |
} | |
message.SetReply(r) | |
if err := w.WriteMsg(message); err != nil { | |
log.Println("write dns message error: ", err) | |
} | |
} | |
func (h *dnsHandler) askingForProviderIP(r *dns.Msg) bool { | |
return len(r.Question) > 0 && r.Question[0].Name == h.provider.fqdn && | |
(r.Question[0].Qtype == dns.TypeA || r.Question[0].Qtype == dns.TypeAAAA) | |
} | |
func (h *dnsHandler) replyWithProviderIP(w dns.ResponseWriter, r *dns.Msg) { | |
host := h.hostIPv4 | |
if r.Question[0].Qtype == dns.TypeAAAA { | |
host = h.hostIPv6 | |
} | |
host.SetReply(r) | |
if err := w.WriteMsg(host); err != nil { | |
log.Println("write dns message error: ", err) | |
} | |
} | |
var ( | |
ErrHTTPStatus = errors.New("bad HTTP status") | |
) | |
func (h *dnsHandler) requestHTTP(ctx context.Context, wire []byte) (respWire []byte, err error) { | |
buffer := h.httpBufferPool.Get().(*bytes.Buffer) | |
buffer.Reset() | |
defer h.httpBufferPool.Put(buffer) | |
_, err = buffer.Write(wire) | |
if err != nil { | |
return nil, err | |
} | |
request, err := http.NewRequestWithContext(ctx, http.MethodPost, h.provider.url, buffer) | |
if err != nil { | |
return nil, err | |
} | |
const contentTypeUDPWireFormat = "application/dns-udpwireformat" | |
request.Header.Set("Content-Type", contentTypeUDPWireFormat) | |
response, err := h.client.Do(request) | |
if err != nil { | |
return nil, err | |
} | |
defer response.Body.Close() | |
if response.StatusCode != http.StatusOK { | |
return nil, fmt.Errorf("%w: %s", ErrHTTPStatus, response.Status) | |
} | |
respWire, err = ioutil.ReadAll(response.Body) | |
if err != nil { | |
return nil, err | |
} | |
if err := response.Body.Close(); err != nil { | |
return nil, err | |
} | |
return respWire, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment