Created
December 23, 2017 08:56
-
-
Save kuyagic/583642b62ac29bf8104892b7378cea3a to your computer and use it in GitHub Desktop.
clean dns server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import binascii | |
import ipaddress | |
import time | |
import os | |
import hjson | |
import requests | |
import sys | |
import threading | |
import argparse | |
import traceback | |
import signal | |
import socketserver | |
from dnslib import * | |
import logging | |
import logging.handlers | |
''' Server side php proxy | |
<?php | |
$a = $_SERVER['QUERY_STRING']; | |
$url = 'https://dns.google.com/resolv?'.$a; | |
$curl = curl_init(); | |
curl_setopt_array($curl, array( | |
CURLOPT_RETURNTRANSFER => 1, | |
CURLOPT_URL => $url | |
)); | |
$resp = curl_exec($curl); | |
curl_close($curl); | |
header('Content-Type:application/json') | |
echo $resp; | |
?> | |
''' | |
def _stdout(msg): | |
log_string = '%s - %s\n' % ('GGDNS', msg) | |
if os.name != 'nt': | |
global logger | |
logger.debug(log_string) | |
else: | |
sys.stdout.write(log_string) | |
def sha1hash(ss): | |
import hashlib | |
sha1 = hashlib.sha1(ss.encode()) | |
return sha1.hexdigest() | |
def query_from_google_https_dns(host, query_type='A', edns='1.2.4.8'): | |
# (example.com,CNAME,1.2.4.8) | |
url_pattern = '%s?name=%s&type=%s&edns_client_subnet=%s' | |
if cfg['use_proxy']: | |
proxy = { | |
'http': cfg['proxy'], | |
'https': cfg['proxy'] | |
} | |
else: | |
proxy = { | |
'http': None, | |
'https': None | |
} | |
query_url = url_pattern % (cfg['endpoint'], host, query_type, edns) | |
# check Cache | |
cache_key = sha1hash('%s_%s_%s' % (host, edns, query_type)) | |
cache_exists = cache.get_cache(cache_key) | |
if cache_exists is not None: | |
_stdout('cache found %s' % cache_key) | |
return cache_exists | |
resp = requests.get(query_url | |
, proxies=proxy | |
) | |
cache.set_cache(cache_key, resp.json()) | |
return resp.json() | |
pass | |
def dns_response(data, edns): | |
request = DNSRecord.parse(data) | |
reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=1, ra=1), q=request.q) | |
query_name = request.q.qname | |
query_type = request.q.qtype | |
qt = QTYPE[query_type] | |
google_result = query_from_google_https_dns(query_name, qt, edns) | |
# google_status = google_result.get('Status') | |
first_result = '' | |
for answer in google_result.get('Answer', {}): | |
if first_result == '': | |
first_result = str(answer['data']) | |
rtype = QTYPE[answer['type']] | |
zone = "%s %s %s %s" % (str(answer['name']), | |
answer['TTL'], | |
rtype, | |
str(answer['data'])) | |
if (not cfg['query_ipv6']) and rtype == 'AAAA': | |
continue | |
reply.add_answer(*RR.fromZone(zone)) | |
log = 'host=%s, edns=%s, type=%s, result=%s' % (str(query_name), edns, qt, first_result) | |
_stdout(log) | |
return reply.pack() | |
def get_wan_ip(request_ip): | |
global edns_endpoint | |
if edns_endpoint is not None: | |
return edns_endpoint | |
lan_range = [ipaddress.ip_network('0.0.0.0/8'), | |
ipaddress.ip_network('10.0.0.0/8'), | |
ipaddress.ip_network('100.64.0.0/10'), | |
ipaddress.ip_network('127.0.0.0/8'), | |
ipaddress.ip_network('169.254.0.0/16'), | |
ipaddress.ip_network('172.16.0.0/12'), | |
ipaddress.ip_network('192.168.0.0/16'), | |
ipaddress.ip_network('224.0.0.0/4'), | |
ipaddress.ip_network('240.0.0.0/4'), | |
ipaddress.ip_network('255.255.255.255/32') | |
] | |
ip_addr_obj = ipaddress.ip_address(request_ip) | |
is_lan_ip = False | |
for lan in lan_range: | |
if ip_addr_obj in lan: | |
is_lan_ip = True | |
break | |
if is_lan_ip: | |
req = requests.get('https://whois.pconline.com.cn/ipJson.jsp?json=true' | |
, proxies={'http': None, 'https': None} | |
) | |
if req.status_code == 200: | |
return req.json().get('ip', '1.2.4.8') | |
return '1.2.4.8' | |
edns_endpoint = request_ip | |
return request_ip | |
def init_config(): | |
default_config = { | |
'listen': '', | |
'enable_tcp': False, | |
'tcp_port': 5333, | |
'udp_port': 5333, | |
'use_proxy': False, | |
'cache_ttl': 86400, | |
'query_ipv6': True, | |
'proxy': 'socks5://127.0.0.1:1080', | |
'endpoint': 'https://dns.google.com/resolve' | |
} | |
cfg_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ggdns.json') | |
if not os.path.exists(cfg_file): | |
# _stdout('Default Config Used') | |
hjson.dumpJSON(default_config, open(cfg_file, 'w'), indent=True, for_json=True) | |
return default_config | |
json_object = hjson.load(open(cfg_file)) | |
# _stdout('Config file Used') | |
return { | |
'listen': json_object.get('listen', default_config['listen']), | |
'enable_tcp': json_object.get('enable_tcp', default_config['enable_tcp']), | |
'tcp_port': json_object.get('tcp_port', default_config['tcp_port']), | |
'udp_port': json_object.get('udp_port', default_config['udp_port']), | |
'use_proxy': json_object.get('use_proxy', default_config['use_proxy']), | |
'cache_ttl': json_object.get('cache_ttl', default_config['cache_ttl']), | |
'query_ipv6': json_object.get('query_ipv6', default_config['query_ipv6']), | |
'proxy': json_object.get('proxy', default_config['proxy']), | |
'endpoint': json_object.get('endpoint', default_config['endpoint']) | |
} | |
def init_log(): | |
if os.name != 'nt': | |
my_logger = logging.getLogger('GGDNS') | |
my_logger.setLevel(logging.DEBUG) | |
handler = logging.handlers.SysLogHandler(address='/dev/log') | |
my_logger.addHandler(handler) | |
my_logger.debug('Log init OK') | |
return my_logger | |
return None | |
def init_args(): | |
par = argparse.ArgumentParser() | |
par.add_argument('-l', help='DNS Server Bind Address', type=str, metavar='') | |
par.add_argument('-t', help='Enable TCP Listen', type=bool, metavar='') | |
par.add_argument('-p', help='TCP listen Port', type=int, metavar='') | |
par.add_argument('-u', help='UDP listen Port', type=int, metavar='') | |
par.add_argument('-x', help='Proxy Server', type=str, metavar='') | |
par.add_argument('-ttl', help='Cache TTL', type=int, metavar='') | |
par.add_argument('-aaaa', help='Enable AAAA Record', type=bool, metavar='') | |
par.add_argument('-e', help='Remote Endpoint Uri', type=str, metavar='') | |
return par.parse_args() | |
pass | |
def override_config_by_arg(): | |
global cfg | |
cfg = { | |
'listen': args.l if args.l is not None else cfg['listen'], | |
'enable_tcp': args.t if args.t is not None else cfg['enable_tcp'], | |
'tcp_port': args.p if args.p is not None else cfg['tcp_port'], | |
'udp_port': args.u if args.u is not None else cfg['udp_port'], | |
'use_proxy': True if args.x is not None else cfg['use_proxy'], | |
'cache_ttl': args.ttl if args.ttl is not None else cfg['cache_ttl'], | |
'query_ipv6': args.aaaa if args.aaaa is not None else cfg['query_ipv6'], | |
'proxy': args.x if args.x is not None else cfg['proxy'], | |
'endpoint': args.e if args.e is not None else cfg['endpoint'] | |
} | |
pass | |
class DictBasedCache: | |
_cache_items = dict() | |
_ttl = 60 * 60 * 24 # one day cache | |
def __init__(self, ttl=86400): | |
self._ttl = ttl | |
if ttl <= 0: | |
_stdout('Cache Disabled') | |
def get_cache(self, key, update_entity=None): | |
if self._ttl <= 0: | |
return None | |
exists = self._cache_items.get(key) | |
if exists is None: | |
self._cache_items[key] = {'time': time.time() + self._ttl, | |
'value': update_entity} | |
return update_entity | |
else: | |
if dict(exists).get('time') < time.time(): | |
_stdout('cache expired %s' % key) | |
if update_entity is None: | |
self._cache_items.pop(key) | |
return None | |
else: | |
self._cache_items[key] = {'time': time.time() + self._ttl, | |
'value': dict(exists).get('value', update_entity) | |
} | |
return update_entity | |
else: | |
return dict(exists).get('value', update_entity) | |
def set_cache(self, key, update_entity, ttl=None): | |
if self._ttl <= 0: | |
return | |
i_ttl = self._ttl if ttl is None else ttl | |
self._cache_items[key] = {'time': time.time() + i_ttl, | |
'value': update_entity | |
} | |
return update_entity | |
def purge_cache(self): | |
self._cache_items = dict() | |
class BaseRequestHandler(socketserver.BaseRequestHandler): | |
def get_data(self): | |
raise NotImplementedError | |
def send_data(self, data): | |
raise NotImplementedError | |
def handle(self): | |
try: | |
data = self.get_data() | |
washed_ip = get_wan_ip(str(self.client_address[0])) | |
# _stdout('client ip %s' % washed_ip) | |
# _stdout(data) | |
query_dns_result = dns_response(data, washed_ip) | |
self.send_data(query_dns_result) | |
except Exception: | |
pass | |
# traceback.print_exc(file=sys.stderr) | |
class TCPRequestHandler(BaseRequestHandler): | |
def get_data(self): | |
data = self.request.recv(8192).strip() | |
sz = int(binascii.hexlify(data[:2]), 16) | |
if sz < len(data) - 2: | |
raise Exception("Wrong size of TCP packet") | |
elif sz > len(data) - 2: | |
raise Exception("Too big TCP packet") | |
return data[2:] | |
def send_data(self, data): | |
sz = binascii.unhexlify(hex(len(data))[2:].zfill(4)) | |
return self.request.sendall(sz + data) | |
class UDPRequestHandler(BaseRequestHandler): | |
def get_data(self): | |
return self.request[0].strip() | |
def send_data(self, data): | |
return self.request[1].sendto(data, self.client_address) | |
class GracefulKiller: | |
kill_now = False | |
def __init__(self): | |
signal.signal(signal.SIGINT, self.exit_gracefully) | |
signal.signal(signal.SIGTERM, self.exit_gracefully) | |
def exit_gracefully(self, signum, frame): | |
self.kill_now = True | |
logger = init_log() | |
cfg = init_config() | |
args = init_args() | |
override_config_by_arg() | |
cache = DictBasedCache(cfg['cache_ttl']) | |
edns_endpoint = None | |
if __name__ == '__main__': | |
# print("Starting nameserver...") | |
killer = GracefulKiller() | |
servers = [ | |
socketserver.ThreadingUDPServer((cfg['listen'], cfg['udp_port']), UDPRequestHandler) | |
] | |
if cfg['enable_tcp']: | |
servers.append(socketserver.ThreadingTCPServer((cfg['listen'], cfg['tcp_port']), TCPRequestHandler)) | |
for s in servers: | |
thread = threading.Thread(target=s.serve_forever) # that thread will start one more thread for each request | |
thread.daemon = True # exit the server thread when the main thread terminates | |
thread.start() | |
try: | |
while 1: | |
time.sleep(1) | |
if killer.kill_now: | |
break | |
# sys.stderr.flush() | |
# sys.stdout.flush() | |
except KeyboardInterrupt: | |
pass | |
finally: | |
for s in servers: | |
s.shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment