Last active
October 16, 2024 04:20
-
-
Save stakach/4251450ba25ebf5a30004bd58c541ef8 to your computer and use it in GitHub Desktop.
crystal lang native DNS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "socket" | |
@[Flags] | |
enum Resolve | |
IPv4 = 1 | |
IPv6 = 2 | |
ALPN = 4 | |
end | |
module DNSResolver | |
class DNSHeader | |
property id : UInt16 | |
property qr : UInt8 = 0 | |
property opcode : UInt8 = 0 | |
property aa : UInt8 = 0 | |
property tc : UInt8 = 0 | |
property rd : UInt8 = 1 | |
property ra : UInt8 = 0 | |
property z : UInt8 = 0 | |
property rcode : UInt8 = 0 | |
property qdcount : UInt16 = 1 | |
property ancount : UInt16 = 0 | |
property nscount : UInt16 = 0 | |
property arcount : UInt16 = 0 | |
def initialize(@id : UInt16) | |
end | |
def to_bytes : Bytes | |
io = IO::Memory.new | |
io.write_bytes(id, IO::ByteFormat::BigEndian) | |
io.write_byte((qr << 7) | (opcode << 3) | (aa << 2) | (tc << 1) | rd) | |
io.write_byte((ra << 7) | (z << 4) | rcode) | |
io.write_bytes(qdcount, IO::ByteFormat::BigEndian) | |
io.write_bytes(ancount, IO::ByteFormat::BigEndian) | |
io.write_bytes(nscount, IO::ByteFormat::BigEndian) | |
io.write_bytes(arcount, IO::ByteFormat::BigEndian) | |
io.to_slice | |
end | |
end | |
class DNSQuestion | |
property qname : String | |
property qtype : UInt16 | |
property qclass : UInt16 = 1 # IN class | |
def initialize(@qname : String, @qtype : UInt16) | |
end | |
def to_bytes : Bytes | |
io = IO::Memory.new | |
qname.split('.').each do |label| | |
io.write_byte(label.size.to_u8) | |
io.write(label.to_slice) | |
end | |
io.write_byte(0_u8) # Null terminator for the domain name | |
io.write_bytes(qtype, IO::ByteFormat::BigEndian) | |
io.write_bytes(qclass, IO::ByteFormat::BigEndian) | |
io.to_slice | |
end | |
end | |
struct DNSResourceRecord | |
property name : String | |
property type : UInt16 | |
property class_code : UInt16 | |
property ttl : UInt32 | |
property rdlength : UInt16 | |
property rdata : Bytes | |
property data : Hash(String, String)? | |
def initialize(@name : String, @type : UInt16, @class_code : UInt16, @ttl : UInt32, @rdlength : UInt16, @rdata : Bytes, @data : Hash(String, String)? = nil) | |
end | |
def self.read(io : IO::Memory, message : Bytes) : DNSResourceRecord | |
name = read_labels(io, message) | |
type = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
class_code = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
ttl = io.read_bytes(UInt32, IO::ByteFormat::BigEndian) | |
rdlength = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
rdata = Bytes.new(rdlength) | |
io.read(rdata) | |
data = parse_rdata(type, rdata, message) | |
DNSResourceRecord.new(name, type, class_code, ttl, rdlength, rdata, data) | |
end | |
def self.parse_rdata(type : UInt16, rdata : Bytes, message : Bytes) : Hash(String, String)? | |
case type | |
when 1 # A record | |
ip_str = rdata.map(&.to_s).join(".") | |
{"address" => ip_str} | |
when 28 # AAAA record | |
# Convert bytes to IPv6 address string | |
ip_str = rdata.each_slice(2).map { |a| ((a[0].to_u16 << 8) | a[1].to_u16).to_s(16) }.join(":") | |
{"address" => ip_str} | |
when 41 # OPT record (EDNS) | |
parse_opt_rdata(rdata) | |
when 65 # HTTPS record | |
parse_svcb_rdata(rdata, message) | |
else | |
nil | |
end | |
end | |
def self.parse_opt_rdata(rdata : Bytes) : Hash(String, String) | |
options = Hash(String, String).new | |
io = IO::Memory.new(rdata) | |
while io.pos != io.size | |
option_code = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
option_length = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
option_data = Bytes.new(option_length) | |
io.read(option_data) | |
options["opt_#{option_code}"] = option_data.hexstring | |
end | |
options | |
end | |
def self.parse_svcb_rdata(rdata : Bytes, message : Bytes) : Hash(String, String) | |
io = IO::Memory.new(rdata) | |
data = Hash(String, String).new | |
priority = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
data["priority"] = priority.to_s | |
target_name = read_labels(io, message) | |
data["target_name"] = target_name | |
# Read SvcParams | |
while io.pos != io.size | |
key = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
length = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
value = Bytes.new(length) | |
io.read(value) | |
case key | |
when 1 # alpn | |
alpn_protocols = [] of String | |
alpn_io = IO::Memory.new(value) | |
while alpn_io.pos != alpn_io.size | |
alpn_length = alpn_io.read_byte.as(UInt8) | |
alpn_value = Bytes.new(alpn_length) | |
alpn_io.read(alpn_value) | |
alpn_protocols << String.new(alpn_value) | |
end | |
data["alpn"] = alpn_protocols.join(",") | |
else | |
data["svcparam_#{key}"] = value.hexstring | |
end | |
end | |
data | |
end | |
def self.read_labels(io : IO::Memory, message : Bytes) : String | |
labels = [] of String | |
loop do | |
length = io.read_byte | |
break if length.nil? | |
if length == 0 | |
break | |
end | |
if length & 0xC0 == 0xC0 | |
# Pointer | |
pointer = ((length & 0x3F) << 8) | io.read_byte.as(UInt8) | |
labels << get_labels_from_pointer(pointer, message) | |
break | |
else | |
slice = Bytes.new(length) | |
io.read(slice) | |
labels << String.new(slice) | |
end | |
end | |
labels.join(".") | |
end | |
def self.get_labels_from_pointer(pointer : UInt16, message : Bytes) : String | |
io = IO::Memory.new(message) | |
io.pos = pointer | |
read_labels(io, message) | |
end | |
def to_s : String | |
data_str = data ? data.to_s : "Raw Data: #{rdata.to_hex}" | |
"Name: #{name}, Type: #{type}, Class: #{class_code}, TTL: #{ttl}, Data: #{data_str}" | |
end | |
end | |
struct DNSResponse | |
property id : UInt16 | |
property flags : UInt16 | |
property qdcount : UInt16 | |
property ancount : UInt16 | |
property nscount : UInt16 | |
property arcount : UInt16 | |
property questions : Array(DNSQuestion) | |
property answers : Array(DNSResourceRecord) | |
property authorities : Array(DNSResourceRecord) | |
property additionals : Array(DNSResourceRecord) | |
property rcode : UInt8 | |
def initialize(@id : UInt16, @flags : UInt16, @qdcount : UInt16, @ancount : UInt16, @nscount : UInt16, @arcount : UInt16, @questions : Array(DNSQuestion), @answers : Array(DNSResourceRecord), @authorities : Array(DNSResourceRecord), @additionals : Array(DNSResourceRecord), @rcode : UInt8) | |
end | |
end | |
def self.build_dns_query(domain : String, qtype : UInt16, id : UInt16) : Bytes | |
header = DNSHeader.new(id) | |
question = DNSQuestion.new(domain, qtype) | |
header.qdcount = 1_u16 | |
query = header.to_bytes + question.to_bytes | |
query | |
end | |
class_getter system_dns_servers : Array(String) do | |
dns_servers = [] of String | |
File.open("/etc/resolv.conf") do |file| | |
file.each_line do |line| | |
if line =~ /^\s*nameserver\s+([^\s]+)/ | |
dns_servers << $1 | |
end | |
end | |
end | |
dns_servers | |
rescue ex : Exception | |
[] of String | |
end | |
def self.resolve(domain : String, flags : Resolve, dns_servers : Iterable(String)? = nil, port : Int32 = 53) : Array(DNSResourceRecord) | |
if dns_servers.nil? | |
dns_servers = system_dns_servers | |
if dns_servers.empty? | |
dns_servers = {"1.1.1.1", "8.8.8.8"} | |
end | |
end | |
exceptions = [] of Exception | |
answers = [] of DNSResourceRecord | |
questions_asked = 0 | |
questions_answered = 0 | |
question_types = [] of UInt16 | |
if flags & Resolve::IPv4 | |
question_types << 1_u16 # A record | |
end | |
if flags & Resolve::IPv6 | |
question_types << 28_u16 # AAAA record | |
end | |
if flags & Resolve::ALPN | |
question_types << 65_u16 # HTTPS record | |
end | |
dns_servers.each do |dns_server| | |
queries = {} of UInt16 => UInt16 # id => qtype | |
questions_answered = {} of UInt16 => Bool | |
# Initialize all questions as not answered | |
question_types.each do |qtype| | |
questions_answered[qtype] = false | |
end | |
socket = UDPSocket.new | |
begin | |
socket.connect(dns_server, port) | |
socket.read_timeout = 500.milliseconds | |
# Send all queries | |
question_types.each do |qtype| | |
id = rand(UInt16::MAX) | |
queries[id] = qtype | |
query = build_dns_query(domain, qtype, id) | |
socket.send(query) | |
questions_asked += 1 | |
end | |
loop do | |
response_data, _ = socket.receive(4096) | |
dns_response = parse_dns_response(response_data.to_slice) | |
id = dns_response.id | |
qtype = queries[id]? | |
if qtype | |
# Mark this question as answered | |
questions_answered[qtype] = true | |
queries.delete(id) | |
# Check for DNS errors | |
if dns_response.rcode == 0 # No error | |
answers.concat(dns_response.answers) | |
else | |
# Handle DNS error codes (e.g., NXDOMAIN) | |
# You can add custom handling here if needed | |
end | |
end | |
# Break if all questions have been answered | |
break if questions_answered.values.all? | |
end | |
rescue ex : IO::Error | |
exceptions << ex | |
ensure | |
socket.close | |
end | |
# If we received any answers, return them | |
if answers.any? | |
return answers | |
end | |
end | |
# After trying all DNS servers, if no answers were received and exceptions occurred, re-raise the first exception | |
if exceptions.any? | |
raise exceptions.first | |
else | |
# No exceptions, but no answers received | |
return answers | |
end | |
end | |
def self.parse_dns_response(response : Bytes) : DNSResponse | |
io = IO::Memory.new(response) | |
message = response | |
# Extracting the DNS header | |
id = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
flags = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
qdcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
ancount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
nscount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
arcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
# Extracting rcode from flags | |
rcode = (flags & 0x000F).to_u8 | |
# Reading the question section | |
questions = Array(DNSQuestion).new | |
qdcount.times do | |
name = DNSResourceRecord.read_labels(io, message) | |
type = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
class_code = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) | |
question = DNSQuestion.new(name, type) | |
questions << question | |
end | |
# Reading the answer section | |
answers = Array(DNSResourceRecord).new | |
ancount.times do | |
rr = DNSResourceRecord.read(io, message) | |
answers << rr | |
end | |
# Reading the authority section | |
authorities = Array(DNSResourceRecord).new | |
nscount.times do | |
rr = DNSResourceRecord.read(io, message) | |
authorities << rr | |
end | |
# Reading the additional section | |
additionals = Array(DNSResourceRecord).new | |
arcount.times do | |
rr = DNSResourceRecord.read(io, message) | |
additionals << rr | |
end | |
DNSResponse.new(id, flags, qdcount, ancount, nscount, arcount, questions, answers, authorities, additionals, rcode) | |
end | |
end | |
# Usage Example | |
begin | |
answers = DNSResolver.resolve("www.google.com", Resolve::IPv4 | Resolve::IPv6 | Resolve::ALPN) | |
if answers.any? | |
answers.each do |answer| | |
puts answer | |
end | |
else | |
puts "No answers received." | |
end | |
rescue ex : IO::Error | |
puts "An I/O error occurred: #{ex.message}" | |
rescue ex : Exception | |
puts "An error occurred: #{ex.message}" | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment