Skip to content

Instantly share code, notes, and snippets.

@szampardi
Last active October 28, 2021 23:36
Show Gist options
  • Save szampardi/9d5595ba7e38f80a72dbc43dbfec925b to your computer and use it in GitHub Desktop.
Save szampardi/9d5595ba7e38f80a72dbc43dbfec925b to your computer and use it in GitHub Desktop.
package main
import (
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"net"
"os"
"os/signal"
"sort"
"strconv"
"strings"
"syscall"
"time"
// https://github.com/miekg/exdns/blob/master/reflect/reflect.go
"github.com/miekg/dns"
log "github.com/szampardi/msg"
)
var (
//verbose = flag.Bool("v", false, "verbose")
fqdn = flag.String("fqdn", ".xns.name.", "FQDN")
suffix string
laddr = flag.String("laddr", ":53", "bind on")
logfmt log.Format = log.Formats[log.PlainFormat] //
loglvl log.Lvl = log.LNotice //
logcolor = flag.Bool("c", false, "colorize output")
logger log.Logger
)
func init() {
flag.Func(
"F",
fmt.Sprintf("logging format (prefix) %v", logFmts()),
func(value string) error {
if v, ok := log.Formats[value]; ok {
logfmt = v
return nil
}
return fmt.Errorf("invalid format [%s] specified", value)
},
)
flag.Func(
"l",
"log level",
func(value string) error {
i, err := strconv.Atoi(value)
if err != nil {
return err
}
loglvl = log.Lvl(i)
return log.IsValidLevel(i)
},
)
flag.Parse()
*fqdn = dns.Fqdn(*fqdn)
suffix = fmt.Sprintf(".%s", *fqdn)
var err error
logger, err = log.New(logfmt.Name, time.RFC3339, *fqdn, *logcolor, loglvl, os.Stderr)
if err != nil {
panic(err)
}
}
func main() {
dns.HandleFunc(*fqdn, handleDNS)
for _, addr := range strings.Split(*laddr, ",") {
go serve(addr, "tcp")
go serve(addr, "udp")
}
sig := make(chan os.Signal)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
s := <-sig
logger.Warningf("signal (%v) received, stopping", s)
}
type dnsQu struct {
Type uint16 `json:"query_type"`
Text string `json:"query"`
Question []dns.Question `json:"question"`
}
type dnsAn struct {
Answered bool `json:"answered"`
Content []dns.RR `json:"content,omitempty"`
}
type dnsLog struct {
Client string `json:"client"`
Query dnsQu `json:"query"`
Answer dnsAn `json:"answer"`
}
func handleDNS(w dns.ResponseWriter, r *dns.Msg) {
m := &dns.Msg{}
m.SetReply(r)
if len(r.Question) == 0 {
return
}
q := r.Question[0]
t := dnsTXT([]string{w.RemoteAddr().String(), w.RemoteAddr().Network()})
goodReq := false
RR := make([]dns.RR, 1)
defer func() {
if err := w.WriteMsg(m); err != nil {
logger.Errorf("error writing response: %s", err)
}
logentry := dnsLog{
Client: fmt.Sprintf("%s://%s", w.RemoteAddr().Network(), w.RemoteAddr().String()),
Query: dnsQu{
Type: q.Qtype,
Text: q.Name,
Question: r.Question,
},
Answer: dnsAn{
Answered: goodReq,
Content: RR,
},
}
j, _ := json.Marshal(logentry)
logger.Noticef(string(j))
}()
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeTXT:
default:
logger.Errorf("invalid request type %d", q.Qtype)
m.SetRcode(m, dns.RcodeRefused)
return
}
if !strings.HasSuffix(q.Name, suffix) {
logger.Errorf("request for invalid domain %s", q.Name)
m.SetRcode(m, dns.RcodeRefused)
return
}
prompt := strings.TrimSuffix(q.Name, suffix)
split := strings.Split(prompt, ".")
if len(split) < 1 {
logger.Error("request too short")
m.SetRcode(m, dns.RcodeRefused)
return
}
switch {
case split[0] == "mirror":
host, _, _ := net.SplitHostPort(w.RemoteAddr().String())
rip := net.ParseIP(host)
if rip.To4() == nil {
RR[0] = dnsAAAA(q.Name, rip)
} else {
RR[0] = dnsA(q.Name, rip)
}
switch q.Qtype {
case dns.TypeTXT:
m.Answer = append(m.Answer, t)
m.Extra = append(m.Extra, RR...)
case dns.TypeAAAA, dns.TypeA:
m.Answer = append(m.Answer, RR...)
m.Extra = append(m.Extra, t)
}
goodReq = true
case strings.HasPrefix(split[0], "0x"):
hexIP := strings.TrimPrefix(split[0], "0x")
bip, err := hex.DecodeString(hexIP)
if err != nil {
logger.Errorf("error decoding hex %s: %s", hexIP, err)
m.SetRcode(m, dns.RcodeRefused)
return
}
prompt = string(bip)
logger.Debugf("decoded hex IP: %s", prompt)
ip := net.ParseIP(prompt)
if ip == nil {
logger.Errorf("invalid IP %s", prompt)
m.SetRcode(m, dns.RcodeRefused)
return
}
if ip.To4() == nil {
RR[0] = dnsAAAA(q.Name, ip)
} else {
RR[0] = dnsA(q.Name, ip)
}
switch q.Qtype {
case dns.TypeTXT:
m.Answer = append(m.Answer, t)
m.Extra = append(m.Extra, RR...)
case dns.TypeAAAA, dns.TypeA:
m.Answer = append(m.Answer, RR...)
m.Extra = append(m.Extra, t)
}
goodReq = true
default:
ip := net.ParseIP(prompt)
if ip == nil {
logger.Errorf("invalid IP %s", ip)
m.SetRcode(m, dns.RcodeRefused)
return
}
if ip.To4() == nil {
RR[0] = dnsAAAA(q.Name, ip)
} else {
RR[0] = dnsA(q.Name, ip)
}
switch q.Qtype {
case dns.TypeTXT:
m.Answer = append(m.Answer, t)
m.Extra = append(m.Extra, RR...)
case dns.TypeAAAA, dns.TypeA:
m.Answer = append(m.Answer, RR...)
m.Extra = append(m.Extra, t)
}
goodReq = true
}
}
func serve(laddr, net string) {
srv := &dns.Server{
Addr: laddr,
Net: net,
TsigSecret: nil,
}
logger.Noticef("starting %s listener on %s", net, laddr)
if err := srv.ListenAndServe(); err != nil {
logger.Errorf("failed to setup the %q server: %s", net, err.Error())
}
}
func dnsA(name string, rip net.IP) (rr *dns.A) {
rr = &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 1,
},
A: rip,
}
return rr
}
func dnsAAAA(name string, rip net.IP) (rr *dns.AAAA) {
rr = &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 1,
},
AAAA: rip,
}
return rr
}
func dnsTXT(s []string) *dns.TXT {
return &dns.TXT{
Hdr: dns.RR_Header{
Name: *fqdn,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: s,
}
}
func logFmts() []string {
var out []string
for f := range log.Formats {
if !strings.Contains(f, "rfc") {
out = append(out, f)
}
}
sort.Strings(out)
return out
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment