Skip to content

Instantly share code, notes, and snippets.

@takumin
Last active November 1, 2017 14:39
Show Gist options
  • Save takumin/1c1eb6eb4380f20d1b2ce5b6e1e88e34 to your computer and use it in GitHub Desktop.
Save takumin/1c1eb6eb4380f20d1b2ce5b6e1e88e34 to your computer and use it in GitHub Desktop.
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"runtime"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
)
var (
listen string
dnssec bool
maxWorkers int
maxQueues int
dispatch *dispatcher
client *http.Client = &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 300 * time.Second,
KeepAlive: 300 * time.Second,
DualStack: true,
}).DialContext,
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: 10,
IdleConnTimeout: 300 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
)
type (
dnsResponseJSON struct {
Status uint32 `json:"Status"`
TC bool `json:"TC"`
RD bool `json:"RD"`
RA bool `json:"RA"`
AD bool `json:"AD"`
CD bool `json:"CD"`
Question []dnsResponseQuestionJSON `json:"Question"`
Answer []dnsResponseAnswerJSON `json:"Answer"`
Authority []dnsResponseAnswerJSON `json:"Authority"`
Additional []dnsResponseAnswerJSON `json:"Additional"`
Subnet string `json:"edns_client_subnet"`
Comment string `json:"Comment"`
}
dnsResponseQuestionJSON struct {
Name string `json:"name"`
Type uint16 `json:"type"`
}
dnsResponseAnswerJSON struct {
Name string `json:"name"`
Type uint16 `json:"type"`
TTL uint32 `json:"TTL"`
Data string `json:"data"`
}
response struct {
writer dns.ResponseWriter
msg *dns.Msg
}
job struct {
proc func(context.Context)
ctx context.Context
}
dispatcher struct {
queue chan *job
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
)
func newDispatcher() *dispatcher {
ctx, cancel := context.WithCancel(context.Background())
d := &dispatcher{
queue: make(chan *job, maxQueues),
ctx: ctx,
cancel: cancel,
}
return d
}
func (d *dispatcher) Add(proc func(context.Context)) {
d.queue <- &job{proc: proc, ctx: d.ctx}
}
func (d *dispatcher) AddWithContext(proc func(context.Context), ctx context.Context) {
d.queue <- &job{proc: proc, ctx: ctx}
}
func (d *dispatcher) Context() context.Context {
return d.ctx
}
func (d *dispatcher) Start() {
d.wg.Add(maxWorkers)
for i := 0; i < maxWorkers; i++ {
go func() {
defer d.wg.Done()
for j := range d.queue {
j.proc(j.ctx)
}
}()
}
}
func (d *dispatcher) Stop() {
close(d.queue)
d.wg.Wait()
d.cancel()
}
func (d *dispatcher) StopImmediately() {
d.cancel()
close(d.queue)
d.wg.Wait()
}
func main() {
flag.StringVar(&listen, "listen", ":5553", "listen address")
flag.BoolVar(&dnssec, "dnssec", true, "enable DNSSEC")
flag.IntVar(&maxWorkers, "max_workers", runtime.NumCPU(), "Max Workers")
flag.IntVar(&maxQueues, "max_queues", 10000, "Max Queues")
flag.VisitAll(func(f *flag.Flag) {
if s := os.Getenv(strings.ToUpper(f.Name)); s != "" {
f.Value.Set(s)
}
})
flag.Parse()
dispatch = newDispatcher()
dispatch.Start()
dns.HandleFunc(".", handleDnsRequest)
go func() {
udp := &dns.Server{Addr: listen, Net: "udp"}
err := udp.ListenAndServe()
if err != nil {
log.Fatal("Failed to set udp listener\n", err.Error())
}
}()
go func() {
tcp := &dns.Server{Addr: listen, Net: "tcp"}
err := tcp.ListenAndServe()
if err != nil {
log.Fatal("Failed to set tcp listener\n", err.Error())
}
}()
sig := make(chan os.Signal)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
for {
select {
case s := <-sig:
log.Fatalf("Signal (%d) received, stopping\n", s)
dispatch.Stop()
}
}
}
func handleDnsRequest(w dns.ResponseWriter, m *dns.Msg) {
dispatch.AddWithContext(requestDnsOverHttps, context.WithValue(dispatch.Context(), "response", &response{
writer: w,
msg: m,
}))
}
func requestDnsOverHttps(ctx context.Context) {
response := ctx.Value("response").(*response)
req, err := http.NewRequest(http.MethodGet, "https://dns.google.com/resolve", nil)
if err != nil {
log.Fatal("Failed to http new request\n", err.Error())
}
val := url.Values{}
val.Add("name", response.msg.Question[0].Name)
val.Add("type", fmt.Sprint(response.msg.Question[0].Qtype))
if dnssec {
val.Add("cd", "0")
} else {
val.Add("cd", "1")
}
val.Add("edns_client_subnet", "0.0.0.0/0")
val.Add("random_padding", "")
req.URL.RawQuery = val.Encode()
res, err := client.Do(req)
if err != nil {
log.Print("Failed to get response\n", err.Error())
return
}
defer res.Body.Close()
msg := new(dns.Msg)
msg.SetReply(response.msg)
if res.StatusCode == 200 {
body, err := ioutil.ReadAll(res.Body)
if err != nil {
log.Print("Failed to read response\n", err.Error())
return
}
var dat dnsResponseJSON
err = json.Unmarshal(body, &dat)
if err != nil {
log.Print("Failed to unmarshal\n", err.Error())
return
}
rrs := make([]dns.RR, 0, len(dat.Answer))
for _, v := range dat.Answer {
rr, err := dns.NewRR(v.Name + " " + fmt.Sprint(v.TTL) + " IN " + dns.TypeToString[v.Type] + " " + v.Data)
if err != nil {
log.Print("Failed to new response\n", err.Error())
return
}
rrs = append(rrs, rr)
}
msg.Answer = rrs
}
response.writer.WriteMsg(msg)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment