Last active
November 1, 2017 14:39
-
-
Save takumin/1c1eb6eb4380f20d1b2ce5b6e1e88e34 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"encoding/json" | |
"flag" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"net" | |
"net/http" | |
"net/url" | |
"os" | |
"os/signal" | |
"runtime" | |
"strings" | |
"sync" | |
"syscall" | |
"time" | |
"github.com/miekg/dns" | |
) | |
var ( | |
listen string | |
dnssec bool | |
maxWorkers int | |
maxQueues int | |
dispatch *dispatcher | |
client *http.Client = &http.Client{ | |
Transport: &http.Transport{ | |
DialContext: (&net.Dialer{ | |
Timeout: 300 * time.Second, | |
KeepAlive: 300 * time.Second, | |
DualStack: true, | |
}).DialContext, | |
Proxy: http.ProxyFromEnvironment, | |
MaxIdleConns: 10, | |
IdleConnTimeout: 300 * time.Second, | |
TLSHandshakeTimeout: 10 * time.Second, | |
ExpectContinueTimeout: 1 * time.Second, | |
}, | |
} | |
) | |
type ( | |
dnsResponseJSON struct { | |
Status uint32 `json:"Status"` | |
TC bool `json:"TC"` | |
RD bool `json:"RD"` | |
RA bool `json:"RA"` | |
AD bool `json:"AD"` | |
CD bool `json:"CD"` | |
Question []dnsResponseQuestionJSON `json:"Question"` | |
Answer []dnsResponseAnswerJSON `json:"Answer"` | |
Authority []dnsResponseAnswerJSON `json:"Authority"` | |
Additional []dnsResponseAnswerJSON `json:"Additional"` | |
Subnet string `json:"edns_client_subnet"` | |
Comment string `json:"Comment"` | |
} | |
dnsResponseQuestionJSON struct { | |
Name string `json:"name"` | |
Type uint16 `json:"type"` | |
} | |
dnsResponseAnswerJSON struct { | |
Name string `json:"name"` | |
Type uint16 `json:"type"` | |
TTL uint32 `json:"TTL"` | |
Data string `json:"data"` | |
} | |
response struct { | |
writer dns.ResponseWriter | |
msg *dns.Msg | |
} | |
job struct { | |
proc func(context.Context) | |
ctx context.Context | |
} | |
dispatcher struct { | |
queue chan *job | |
wg sync.WaitGroup | |
ctx context.Context | |
cancel context.CancelFunc | |
} | |
) | |
func newDispatcher() *dispatcher { | |
ctx, cancel := context.WithCancel(context.Background()) | |
d := &dispatcher{ | |
queue: make(chan *job, maxQueues), | |
ctx: ctx, | |
cancel: cancel, | |
} | |
return d | |
} | |
func (d *dispatcher) Add(proc func(context.Context)) { | |
d.queue <- &job{proc: proc, ctx: d.ctx} | |
} | |
func (d *dispatcher) AddWithContext(proc func(context.Context), ctx context.Context) { | |
d.queue <- &job{proc: proc, ctx: ctx} | |
} | |
func (d *dispatcher) Context() context.Context { | |
return d.ctx | |
} | |
func (d *dispatcher) Start() { | |
d.wg.Add(maxWorkers) | |
for i := 0; i < maxWorkers; i++ { | |
go func() { | |
defer d.wg.Done() | |
for j := range d.queue { | |
j.proc(j.ctx) | |
} | |
}() | |
} | |
} | |
func (d *dispatcher) Stop() { | |
close(d.queue) | |
d.wg.Wait() | |
d.cancel() | |
} | |
func (d *dispatcher) StopImmediately() { | |
d.cancel() | |
close(d.queue) | |
d.wg.Wait() | |
} | |
func main() { | |
flag.StringVar(&listen, "listen", ":5553", "listen address") | |
flag.BoolVar(&dnssec, "dnssec", true, "enable DNSSEC") | |
flag.IntVar(&maxWorkers, "max_workers", runtime.NumCPU(), "Max Workers") | |
flag.IntVar(&maxQueues, "max_queues", 10000, "Max Queues") | |
flag.VisitAll(func(f *flag.Flag) { | |
if s := os.Getenv(strings.ToUpper(f.Name)); s != "" { | |
f.Value.Set(s) | |
} | |
}) | |
flag.Parse() | |
dispatch = newDispatcher() | |
dispatch.Start() | |
dns.HandleFunc(".", handleDnsRequest) | |
go func() { | |
udp := &dns.Server{Addr: listen, Net: "udp"} | |
err := udp.ListenAndServe() | |
if err != nil { | |
log.Fatal("Failed to set udp listener\n", err.Error()) | |
} | |
}() | |
go func() { | |
tcp := &dns.Server{Addr: listen, Net: "tcp"} | |
err := tcp.ListenAndServe() | |
if err != nil { | |
log.Fatal("Failed to set tcp listener\n", err.Error()) | |
} | |
}() | |
sig := make(chan os.Signal) | |
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) | |
for { | |
select { | |
case s := <-sig: | |
log.Fatalf("Signal (%d) received, stopping\n", s) | |
dispatch.Stop() | |
} | |
} | |
} | |
func handleDnsRequest(w dns.ResponseWriter, m *dns.Msg) { | |
dispatch.AddWithContext(requestDnsOverHttps, context.WithValue(dispatch.Context(), "response", &response{ | |
writer: w, | |
msg: m, | |
})) | |
} | |
func requestDnsOverHttps(ctx context.Context) { | |
response := ctx.Value("response").(*response) | |
req, err := http.NewRequest(http.MethodGet, "https://dns.google.com/resolve", nil) | |
if err != nil { | |
log.Fatal("Failed to http new request\n", err.Error()) | |
} | |
val := url.Values{} | |
val.Add("name", response.msg.Question[0].Name) | |
val.Add("type", fmt.Sprint(response.msg.Question[0].Qtype)) | |
if dnssec { | |
val.Add("cd", "0") | |
} else { | |
val.Add("cd", "1") | |
} | |
val.Add("edns_client_subnet", "0.0.0.0/0") | |
val.Add("random_padding", "") | |
req.URL.RawQuery = val.Encode() | |
res, err := client.Do(req) | |
if err != nil { | |
log.Print("Failed to get response\n", err.Error()) | |
return | |
} | |
defer res.Body.Close() | |
msg := new(dns.Msg) | |
msg.SetReply(response.msg) | |
if res.StatusCode == 200 { | |
body, err := ioutil.ReadAll(res.Body) | |
if err != nil { | |
log.Print("Failed to read response\n", err.Error()) | |
return | |
} | |
var dat dnsResponseJSON | |
err = json.Unmarshal(body, &dat) | |
if err != nil { | |
log.Print("Failed to unmarshal\n", err.Error()) | |
return | |
} | |
rrs := make([]dns.RR, 0, len(dat.Answer)) | |
for _, v := range dat.Answer { | |
rr, err := dns.NewRR(v.Name + " " + fmt.Sprint(v.TTL) + " IN " + dns.TypeToString[v.Type] + " " + v.Data) | |
if err != nil { | |
log.Print("Failed to new response\n", err.Error()) | |
return | |
} | |
rrs = append(rrs, rr) | |
} | |
msg.Answer = rrs | |
} | |
response.writer.WriteMsg(msg) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment