Skip to content

Instantly share code, notes, and snippets.

@iKunalChhabra
Created September 4, 2025 12:44
Show Gist options
  • Save iKunalChhabra/8974f835af9aa1954d6b30b907cbdf45 to your computer and use it in GitHub Desktop.
Save iKunalChhabra/8974f835af9aa1954d6b30b907cbdf45 to your computer and use it in GitHub Desktop.
A simple DNS Server in C++
#include <iostream>
#include <cstring>
#include <ranges>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <vector>
struct DNSHeaderFlags
{
// Query/Response Indicator - 1 for a reply packet, 0 for a question packet.
uint16_t QR : 1;
// Operation Code - Specifies the kind of query in a message.
uint16_t OPCODE : 4;
// Authoritative Answer - 1 if the responding server "owns" the domain queried, i.e., it's authoritative.
uint16_t AA : 1;
// Truncation - 1 if the message is larger than 512 bytes. Always 0 in UDP responses.
uint16_t TC : 1;
// Recursion Desired - Sender sets this to 1 if the server should recursively resolve this query, 0 otherwise.
uint16_t RD : 1;
// Recursion Available - Server sets this to 1 to indicate that recursion is available.
uint16_t RA : 1;
// Reserved - Used by DNSSEC queries. At inception, it was reserved for future use.
uint16_t Z : 3;
// Response Code - Response code indicating the status of the response.
uint16_t RCODE : 4;
};
struct DNSHeader
{
// Packet Identifier - A random ID assigned to query packets. Response packets must reply with the same ID.
uint16_t ID;
// flags
DNSHeaderFlags FLAGS;
// Question Count - Number of questions in the Question section.
uint16_t QDCOUNT;
// Answer Record Count - Number of records in the Answer section.
uint16_t ANCOUNT;
// Authority Record Count - Number of records in the Authority section.
uint16_t NSCOUNT;
// Additional Record Count - Number of records in the Additional section.
uint16_t ARCOUNT;
};
struct DNSResponse
{
DNSHeader header;
std::vector<uint8_t> question;
std::vector<uint8_t> answers;
};
struct DNSQuestion
{
std::vector<std::string> labels;
uint16_t qtype = 0;
uint16_t qclass = 0;
};
class DNSServer
{
int udpSocket = -1;
int upstreamSocket = -1;
const int port = 2053;
sockaddr_in resolverAddr{};
static sockaddr_in parseResolver(const std::string& addr) {
const auto pos = addr.find(':');
if (pos == std::string::npos) throw std::invalid_argument("resolver must be ip:port");
const std::string ip = addr.substr(0, pos);
const int port = std::stoi(addr.substr(pos + 1));
sockaddr_in sa{};
sa.sin_family = AF_INET;
sa.sin_port = htons(static_cast<uint16_t>(port));
if (inet_pton(AF_INET, ip.c_str(), &sa.sin_addr) != 1) {
throw std::invalid_argument("invalid resolver ip");
}
return sa;
}
public:
~DNSServer()
{
if (udpSocket != -1) close(udpSocket);
if (upstreamSocket != -1) close(upstreamSocket);
}
explicit DNSServer(const std::string& resolver) : resolverAddr(parseResolver(resolver)) {
createSocket();
setSocketOptions();
bindSocket();
upstreamSocket = socket(AF_INET, SOCK_DGRAM, 0);
if (upstreamSocket == -1) {
throw std::runtime_error("Upstream socket creation failed");
}
constexpr timeval tv{.tv_sec = 2, .tv_usec = 0};
setsockopt(upstreamSocket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
}
void createSocket()
{
udpSocket = socket(AF_INET, SOCK_DGRAM, 0);
if (udpSocket == -1)
{
throw std::runtime_error("Socket creation failed");
}
}
void setSocketOptions() const
{
constexpr int reuse = 1;
if (setsockopt(udpSocket, SOL_SOCKET, SO_REUSEPORT, &reuse, sizeof(reuse)) < 0)
{
throw std::runtime_error("Socket options setting failed");
}
}
void bindSocket() const
{
sockaddr_in serv_addr = {
.sin_family = AF_INET,
.sin_port = htons(port),
.sin_addr = {htonl(INADDR_ANY)},
};
if (bind(udpSocket, reinterpret_cast<struct sockaddr*>(&serv_addr), sizeof(serv_addr)) != 0)
{
throw std::runtime_error("Socket bind failed");
}
}
static uint16_t flagsToInt16(const DNSHeaderFlags& inFlags)
{
uint16_t flags = 0;
flags |= (inFlags.QR & 0x1) << 15;
flags |= (inFlags.OPCODE & 0xF) << 11;
flags |= (inFlags.AA & 0x1) << 10;
flags |= (inFlags.TC & 0x1) << 9;
flags |= (inFlags.RD & 0x1) << 8;
flags |= (inFlags.RA & 0x1) << 7;
flags |= (inFlags.Z & 0x7) << 4;
flags |= (inFlags.RCODE & 0xF);
return flags;
}
static void push16(std::vector<uint8_t>& data, const uint16_t v) {
data.push_back(static_cast<uint8_t>((v >> 8) & 0xFF)); // high byte
data.push_back(static_cast<uint8_t>(v & 0xFF)); // low byte
}
static void push32(std::vector<uint8_t>& data, uint32_t v) {
data.push_back(static_cast<uint8_t>((v >> 24) & 0xFF));
data.push_back(static_cast<uint8_t>((v >> 16) & 0xFF));
data.push_back(static_cast<uint8_t>((v >> 8) & 0xFF));
data.push_back(static_cast<uint8_t>(v & 0xFF));
}
static void writeHeader(std::vector<uint8_t>& out, const DNSHeader& h) {
push16(out, h.ID);
push16(out, flagsToInt16(h.FLAGS));
push16(out, h.QDCOUNT);
push16(out, h.ANCOUNT);
push16(out, h.NSCOUNT);
push16(out, h.ARCOUNT);
}
static std::vector<uint8_t> extractAnswersFromUpstream(const std::vector<uint8_t>& resp) {
if (resp.size() < 12) return {};
auto rd16 = [](const uint8_t* p){ return (static_cast<uint16_t>(p[0])<<8) | p[1]; };
const uint8_t* buf = resp.data();
const uint16_t qd = rd16(buf + 4);
// Walk past QDCOUNT questions (QNAME + 4 bytes)
size_t off = 12;
for (int i = 0; i < qd; ++i) {
while (off < resp.size() && resp[off] != 0) {
const uint8_t len = resp[off];
// handle compression if any (11xxxxxx)
if ((len & 0xC0) == 0xC0) { off += 2; break; }
off += 1 + len;
}
if (off < resp.size()) off += 1; // zero terminator
off += 4; // QTYPE+QCLASS
}
// Remaining bytes are the answer+authority+additional (tester guarantees only answers)
std::vector<uint8_t> answers;
if (off <= resp.size()) answers.insert(answers.end(), resp.begin()+off, resp.end());
return answers;
}
[[nodiscard]] std::vector<uint8_t> forwardSingleQuestion(const DNSHeader& origH, const DNSQuestion& q) const {
// Build a minimal query with one question
DNSHeaderFlags f = origH.FLAGS; // preserve RD, OPCODE, etc.
f.QR = 0; f.RCODE = 0; // ensure it's a query
const DNSHeader h{
.ID = origH.ID,
.FLAGS = f,
.QDCOUNT = 1,
.ANCOUNT = 0,
.NSCOUNT = 0,
.ARCOUNT = 0
};
std::vector<uint8_t> packet;
writeHeader(packet, h);
auto qb = buildQuestionSection(q);
packet.insert(packet.end(), qb.begin(), qb.end());
// Send it to resolver
if (sendto(upstreamSocket, packet.data(), packet.size(), 0,
reinterpret_cast<const sockaddr*>(&resolverAddr), sizeof(resolverAddr)) < 0) {
throw std::runtime_error("failed to send to resolver");
}
// Receive response
uint8_t buf[512];
sockaddr_in from{}; socklen_t fromLen = sizeof(from);
const ssize_t n = recvfrom(upstreamSocket, buf, sizeof(buf), 0,
reinterpret_cast<sockaddr*>(&from), &fromLen);
if (n < 0) throw std::runtime_error("resolver recv timeout/failure");
return {buf, buf + n};
}
static std::vector<uint8_t> parse_ip(const std::string& value, const std::string& delimiter)
{
std::vector<uint8_t> parts;
for (auto v : std::views::split(value, delimiter))
{
std::string t(v.begin(), v.end());
const int num = std::stoi(t); // Convert string -> integer
if (num < 0 || num > 255) // Validate byte range
throw std::out_of_range("Invalid IP segment: " + t);
parts.push_back(static_cast<uint8_t>(num));
}
return parts;
}
static std::vector<uint8_t> buildAnswerSection(const DNSQuestion& q)
{
std::vector<uint8_t> answers;
// Name
for (const auto& label : q.labels)
{
answers.push_back(static_cast<uint8_t>(label.size()));
for (auto& c : label) answers.push_back(static_cast<uint8_t>(c));
}
answers.push_back(0x00); // end of Name
// Type
push16(answers, q.qtype);
// Class
push16(answers, q.qclass);
// TTL
push32(answers, 60);
// Length
push16(answers, 4);
// Data
const std::string ip = "8.8.8.8";
auto ip_parts = parse_ip(ip, ".");
for (const auto& v : ip_parts) answers.push_back(v);
return answers;
}
static std::vector<uint8_t> buildQuestionSection(const DNSQuestion& q)
{
std::vector<uint8_t> questions;
// QNAME
for (const auto& label : q.labels)
{
questions.push_back(static_cast<uint8_t>(label.size()));
for (auto& c : label) questions.push_back(static_cast<uint8_t>(c));
}
questions.push_back(0x00); // end of QNAME
// QTYPE (0x0001 = A record)
push16(questions, q.qtype);
// QCLASS (0x0001 = IN / Internet)
push16(questions, q.qclass);
return questions;
}
static std::vector<uint8_t> buildResponse(const DNSHeader& h, const std::vector<DNSQuestion>& qs, const DNSServer* self)
{
// If the opcode is not a standard QUERY (0), respond with NOTIMP (RCODE=4)
if (h.FLAGS.OPCODE != 0) {
DNSHeaderFlags flags = h.FLAGS;
flags.QR = 1; // this is a response
flags.RCODE = 4; // NOTIMP
DNSHeader outH{
.ID = h.ID,
.FLAGS = flags,
.QDCOUNT = static_cast<uint16_t>(qs.size()),
.ANCOUNT = 0,
.NSCOUNT = 0,
.ARCOUNT = 0
};
std::vector<uint8_t> response;
response.reserve(512);
writeHeader(response, outH);
// Echo back the original questions (uncompressed)
for (const auto& q : qs) {
auto qb = buildQuestionSection(q);
response.insert(response.end(), qb.begin(), qb.end());
}
return response;
}
// Collect answers by forwarding each question individually
std::vector<uint8_t> mergedAnswers;
uint16_t totalAN = 0;
for (const auto& q : qs) {
auto upstreamResp = self->forwardSingleQuestion(h, q);
// Count answers from upstream header
auto rd16 = [](const uint8_t* p){ return (static_cast<uint16_t>(p[0])<<8) | p[1]; };
const uint16_t an = rd16(upstreamResp.data() + 6); // ANCOUNT at bytes 6..7
totalAN = static_cast<uint16_t>(totalAN + an);
auto ansBytes = extractAnswersFromUpstream(upstreamResp);
mergedAnswers.insert(mergedAnswers.end(), ansBytes.begin(), ansBytes.end());
}
// Prepare header: reply with same ID, copy flags (QR=1)
DNSHeaderFlags flags = h.FLAGS;
flags.QR = 1; flags.RCODE = 0; // assume success (upstream guarantees answer)
DNSHeader outH{
.ID = h.ID,
.FLAGS = flags,
.QDCOUNT = static_cast<uint16_t>(qs.size()),
.ANCOUNT = totalAN,
.NSCOUNT = 0,
.ARCOUNT = 0
};
std::vector<uint8_t> response;
response.reserve(512);
writeHeader(response, outH);
// Append original questions (uncompressed)
for (const auto& q : qs) {
auto qb = buildQuestionSection(q);
response.insert(response.end(), qb.begin(), qb.end());
}
// Append merged answers
response.insert(response.end(), mergedAnswers.begin(), mergedAnswers.end());
return response;
}
static std::vector<DNSQuestion> parseQuestion(const uint8_t* buffer, const DNSHeader h)
{
auto read_u16 = [](const uint8_t* p) -> uint16_t {
return (static_cast<uint16_t>(p[0]) << 8) | p[1];
};
std::vector<DNSQuestion> questions;
size_t offset = 12; // start of Question section
for (int qi = 0; qi < h.QDCOUNT; ++qi) {
DNSQuestion q{};
std::vector<std::string> labels;
// Decode QNAME (with compression)
size_t idx = offset;
bool jumped = false; // whether we've followed a pointer
size_t name_end = 0; // where QNAME ends in the main stream
while (true) {
const uint8_t len = buffer[idx];
// pointer: 11xxxxxx xxxxxxxx
if ((len & 0xC0) == 0xC0) {
const uint16_t ptr = ((len & 0x3F) << 8) | buffer[idx + 1];
if (!jumped) {
name_end = idx + 2; // QNAME ends here in the main stream
jumped = true;
}
idx = ptr; // jump to the pointed name
continue;
}
// end of name
if (len == 0) {
if (!jumped) name_end = idx + 1; // account for the zero byte
idx += 1;
break;
}
// normal label
idx += 1;
labels.emplace_back(reinterpret_cast<const char*>(buffer + idx), len);
idx += len;
}
q.labels = std::move(labels);
// Read QTYPE/QCLASS at the end of the main-stream name bytes
const uint16_t qtype = read_u16(buffer + name_end + 0);
const uint16_t qclass = read_u16(buffer + name_end + 2);
q.qtype = qtype;
q.qclass = qclass;
// Advance the overall offset for the *next* question
offset = name_end + 4;
questions.push_back(std::move(q));
}
return questions;
}
static DNSHeader parseHeaderSection(uint8_t* buffer)
{
auto parsed = reinterpret_cast<uint16_t*>(buffer);
const uint16_t ID = ntohs(*(parsed+0));
const uint16_t rawFlag = ntohs(*(parsed+1));
DNSHeaderFlags FLAGS{};
FLAGS.QR = (rawFlag >> 15) & 0x1;
FLAGS.OPCODE = (rawFlag >> 11) & 0xF;
FLAGS.AA = (rawFlag >> 10) & 0x1;
FLAGS.TC = (rawFlag >> 9) & 0x1;
FLAGS.RD = (rawFlag >> 8) & 0x1;
FLAGS.RA = (rawFlag >> 7) & 0x1;
FLAGS.Z = (rawFlag >> 4) & 0x7;
FLAGS.RCODE = (rawFlag >> 0) & 0xF;
const uint16_t QDCOUNT = ntohs(*(parsed+2));
const uint16_t ANCOUNT = ntohs(*(parsed+3));
const uint16_t NSCOUNT = ntohs(*(parsed+4));
const uint16_t ARCOUNT = ntohs(*(parsed+5));
const DNSHeader h = {
ID, FLAGS, QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT
};
return h;
}
void handleRequest() const
{
// define buffer
uint8_t buffer[512];
// define client
sockaddr_in clientAddress{};
socklen_t clientAddrLen = sizeof(clientAddress);
// read request
const ssize_t bytesRead = recvfrom(udpSocket, buffer, sizeof(buffer), 0,
reinterpret_cast<struct sockaddr*>(&clientAddress), &clientAddrLen);
const DNSHeader h = parseHeaderSection(buffer);
const std::vector<DNSQuestion> qs = parseQuestion(buffer, h);
if (bytesRead == -1)
{
std::cout << "No data received\n";
}
std::cout << "Received " << bytesRead << " bytes" << std::endl;
const auto response = buildResponse(h, qs, this);
if (sendto(udpSocket, response.data(), response.size(), 0, reinterpret_cast<sockaddr*>(&clientAddress),
sizeof(clientAddress)) == -1)
{
perror("Failed to send response");
}
}
[[noreturn]] void startServer() const
{
while (true)
{
handleRequest();
}
}
};
int main(const int argc, char** argv)
{
std::string resolver = "1.1.1.1:53"; // default for local testing
for (int i = 1; i + 1 < argc; ++i) {
if (std::string(argv[i]) == "--resolver") {
resolver = argv[i+1];
}
}
std::cout << "Starting DNS server (forwarder) using resolver " << resolver << "\n";
const DNSServer server(resolver);
server.startServer();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment