Created
September 4, 2025 12:44
-
-
Save iKunalChhabra/8974f835af9aa1954d6b30b907cbdf45 to your computer and use it in GitHub Desktop.
A simple DNS Server in C++
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <iostream> | |
| #include <cstring> | |
| #include <ranges> | |
| #include <sys/socket.h> | |
| #include <netinet/in.h> | |
| #include <arpa/inet.h> | |
| #include <unistd.h> | |
| #include <vector> | |
| struct DNSHeaderFlags | |
| { | |
| // Query/Response Indicator - 1 for a reply packet, 0 for a question packet. | |
| uint16_t QR : 1; | |
| // Operation Code - Specifies the kind of query in a message. | |
| uint16_t OPCODE : 4; | |
| // Authoritative Answer - 1 if the responding server "owns" the domain queried, i.e., it's authoritative. | |
| uint16_t AA : 1; | |
| // Truncation - 1 if the message is larger than 512 bytes. Always 0 in UDP responses. | |
| uint16_t TC : 1; | |
| // Recursion Desired - Sender sets this to 1 if the server should recursively resolve this query, 0 otherwise. | |
| uint16_t RD : 1; | |
| // Recursion Available - Server sets this to 1 to indicate that recursion is available. | |
| uint16_t RA : 1; | |
| // Reserved - Used by DNSSEC queries. At inception, it was reserved for future use. | |
| uint16_t Z : 3; | |
| // Response Code - Response code indicating the status of the response. | |
| uint16_t RCODE : 4; | |
| }; | |
| struct DNSHeader | |
| { | |
| // Packet Identifier - A random ID assigned to query packets. Response packets must reply with the same ID. | |
| uint16_t ID; | |
| // flags | |
| DNSHeaderFlags FLAGS; | |
| // Question Count - Number of questions in the Question section. | |
| uint16_t QDCOUNT; | |
| // Answer Record Count - Number of records in the Answer section. | |
| uint16_t ANCOUNT; | |
| // Authority Record Count - Number of records in the Authority section. | |
| uint16_t NSCOUNT; | |
| // Additional Record Count - Number of records in the Additional section. | |
| uint16_t ARCOUNT; | |
| }; | |
| struct DNSResponse | |
| { | |
| DNSHeader header; | |
| std::vector<uint8_t> question; | |
| std::vector<uint8_t> answers; | |
| }; | |
| struct DNSQuestion | |
| { | |
| std::vector<std::string> labels; | |
| uint16_t qtype = 0; | |
| uint16_t qclass = 0; | |
| }; | |
| class DNSServer | |
| { | |
| int udpSocket = -1; | |
| int upstreamSocket = -1; | |
| const int port = 2053; | |
| sockaddr_in resolverAddr{}; | |
| static sockaddr_in parseResolver(const std::string& addr) { | |
| const auto pos = addr.find(':'); | |
| if (pos == std::string::npos) throw std::invalid_argument("resolver must be ip:port"); | |
| const std::string ip = addr.substr(0, pos); | |
| const int port = std::stoi(addr.substr(pos + 1)); | |
| sockaddr_in sa{}; | |
| sa.sin_family = AF_INET; | |
| sa.sin_port = htons(static_cast<uint16_t>(port)); | |
| if (inet_pton(AF_INET, ip.c_str(), &sa.sin_addr) != 1) { | |
| throw std::invalid_argument("invalid resolver ip"); | |
| } | |
| return sa; | |
| } | |
| public: | |
| ~DNSServer() | |
| { | |
| if (udpSocket != -1) close(udpSocket); | |
| if (upstreamSocket != -1) close(upstreamSocket); | |
| } | |
| explicit DNSServer(const std::string& resolver) : resolverAddr(parseResolver(resolver)) { | |
| createSocket(); | |
| setSocketOptions(); | |
| bindSocket(); | |
| upstreamSocket = socket(AF_INET, SOCK_DGRAM, 0); | |
| if (upstreamSocket == -1) { | |
| throw std::runtime_error("Upstream socket creation failed"); | |
| } | |
| constexpr timeval tv{.tv_sec = 2, .tv_usec = 0}; | |
| setsockopt(upstreamSocket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); | |
| } | |
| void createSocket() | |
| { | |
| udpSocket = socket(AF_INET, SOCK_DGRAM, 0); | |
| if (udpSocket == -1) | |
| { | |
| throw std::runtime_error("Socket creation failed"); | |
| } | |
| } | |
| void setSocketOptions() const | |
| { | |
| constexpr int reuse = 1; | |
| if (setsockopt(udpSocket, SOL_SOCKET, SO_REUSEPORT, &reuse, sizeof(reuse)) < 0) | |
| { | |
| throw std::runtime_error("Socket options setting failed"); | |
| } | |
| } | |
| void bindSocket() const | |
| { | |
| sockaddr_in serv_addr = { | |
| .sin_family = AF_INET, | |
| .sin_port = htons(port), | |
| .sin_addr = {htonl(INADDR_ANY)}, | |
| }; | |
| if (bind(udpSocket, reinterpret_cast<struct sockaddr*>(&serv_addr), sizeof(serv_addr)) != 0) | |
| { | |
| throw std::runtime_error("Socket bind failed"); | |
| } | |
| } | |
| static uint16_t flagsToInt16(const DNSHeaderFlags& inFlags) | |
| { | |
| uint16_t flags = 0; | |
| flags |= (inFlags.QR & 0x1) << 15; | |
| flags |= (inFlags.OPCODE & 0xF) << 11; | |
| flags |= (inFlags.AA & 0x1) << 10; | |
| flags |= (inFlags.TC & 0x1) << 9; | |
| flags |= (inFlags.RD & 0x1) << 8; | |
| flags |= (inFlags.RA & 0x1) << 7; | |
| flags |= (inFlags.Z & 0x7) << 4; | |
| flags |= (inFlags.RCODE & 0xF); | |
| return flags; | |
| } | |
| static void push16(std::vector<uint8_t>& data, const uint16_t v) { | |
| data.push_back(static_cast<uint8_t>((v >> 8) & 0xFF)); // high byte | |
| data.push_back(static_cast<uint8_t>(v & 0xFF)); // low byte | |
| } | |
| static void push32(std::vector<uint8_t>& data, uint32_t v) { | |
| data.push_back(static_cast<uint8_t>((v >> 24) & 0xFF)); | |
| data.push_back(static_cast<uint8_t>((v >> 16) & 0xFF)); | |
| data.push_back(static_cast<uint8_t>((v >> 8) & 0xFF)); | |
| data.push_back(static_cast<uint8_t>(v & 0xFF)); | |
| } | |
| static void writeHeader(std::vector<uint8_t>& out, const DNSHeader& h) { | |
| push16(out, h.ID); | |
| push16(out, flagsToInt16(h.FLAGS)); | |
| push16(out, h.QDCOUNT); | |
| push16(out, h.ANCOUNT); | |
| push16(out, h.NSCOUNT); | |
| push16(out, h.ARCOUNT); | |
| } | |
| static std::vector<uint8_t> extractAnswersFromUpstream(const std::vector<uint8_t>& resp) { | |
| if (resp.size() < 12) return {}; | |
| auto rd16 = [](const uint8_t* p){ return (static_cast<uint16_t>(p[0])<<8) | p[1]; }; | |
| const uint8_t* buf = resp.data(); | |
| const uint16_t qd = rd16(buf + 4); | |
| // Walk past QDCOUNT questions (QNAME + 4 bytes) | |
| size_t off = 12; | |
| for (int i = 0; i < qd; ++i) { | |
| while (off < resp.size() && resp[off] != 0) { | |
| const uint8_t len = resp[off]; | |
| // handle compression if any (11xxxxxx) | |
| if ((len & 0xC0) == 0xC0) { off += 2; break; } | |
| off += 1 + len; | |
| } | |
| if (off < resp.size()) off += 1; // zero terminator | |
| off += 4; // QTYPE+QCLASS | |
| } | |
| // Remaining bytes are the answer+authority+additional (tester guarantees only answers) | |
| std::vector<uint8_t> answers; | |
| if (off <= resp.size()) answers.insert(answers.end(), resp.begin()+off, resp.end()); | |
| return answers; | |
| } | |
| [[nodiscard]] std::vector<uint8_t> forwardSingleQuestion(const DNSHeader& origH, const DNSQuestion& q) const { | |
| // Build a minimal query with one question | |
| DNSHeaderFlags f = origH.FLAGS; // preserve RD, OPCODE, etc. | |
| f.QR = 0; f.RCODE = 0; // ensure it's a query | |
| const DNSHeader h{ | |
| .ID = origH.ID, | |
| .FLAGS = f, | |
| .QDCOUNT = 1, | |
| .ANCOUNT = 0, | |
| .NSCOUNT = 0, | |
| .ARCOUNT = 0 | |
| }; | |
| std::vector<uint8_t> packet; | |
| writeHeader(packet, h); | |
| auto qb = buildQuestionSection(q); | |
| packet.insert(packet.end(), qb.begin(), qb.end()); | |
| // Send it to resolver | |
| if (sendto(upstreamSocket, packet.data(), packet.size(), 0, | |
| reinterpret_cast<const sockaddr*>(&resolverAddr), sizeof(resolverAddr)) < 0) { | |
| throw std::runtime_error("failed to send to resolver"); | |
| } | |
| // Receive response | |
| uint8_t buf[512]; | |
| sockaddr_in from{}; socklen_t fromLen = sizeof(from); | |
| const ssize_t n = recvfrom(upstreamSocket, buf, sizeof(buf), 0, | |
| reinterpret_cast<sockaddr*>(&from), &fromLen); | |
| if (n < 0) throw std::runtime_error("resolver recv timeout/failure"); | |
| return {buf, buf + n}; | |
| } | |
| static std::vector<uint8_t> parse_ip(const std::string& value, const std::string& delimiter) | |
| { | |
| std::vector<uint8_t> parts; | |
| for (auto v : std::views::split(value, delimiter)) | |
| { | |
| std::string t(v.begin(), v.end()); | |
| const int num = std::stoi(t); // Convert string -> integer | |
| if (num < 0 || num > 255) // Validate byte range | |
| throw std::out_of_range("Invalid IP segment: " + t); | |
| parts.push_back(static_cast<uint8_t>(num)); | |
| } | |
| return parts; | |
| } | |
| static std::vector<uint8_t> buildAnswerSection(const DNSQuestion& q) | |
| { | |
| std::vector<uint8_t> answers; | |
| // Name | |
| for (const auto& label : q.labels) | |
| { | |
| answers.push_back(static_cast<uint8_t>(label.size())); | |
| for (auto& c : label) answers.push_back(static_cast<uint8_t>(c)); | |
| } | |
| answers.push_back(0x00); // end of Name | |
| // Type | |
| push16(answers, q.qtype); | |
| // Class | |
| push16(answers, q.qclass); | |
| // TTL | |
| push32(answers, 60); | |
| // Length | |
| push16(answers, 4); | |
| // Data | |
| const std::string ip = "8.8.8.8"; | |
| auto ip_parts = parse_ip(ip, "."); | |
| for (const auto& v : ip_parts) answers.push_back(v); | |
| return answers; | |
| } | |
| static std::vector<uint8_t> buildQuestionSection(const DNSQuestion& q) | |
| { | |
| std::vector<uint8_t> questions; | |
| // QNAME | |
| for (const auto& label : q.labels) | |
| { | |
| questions.push_back(static_cast<uint8_t>(label.size())); | |
| for (auto& c : label) questions.push_back(static_cast<uint8_t>(c)); | |
| } | |
| questions.push_back(0x00); // end of QNAME | |
| // QTYPE (0x0001 = A record) | |
| push16(questions, q.qtype); | |
| // QCLASS (0x0001 = IN / Internet) | |
| push16(questions, q.qclass); | |
| return questions; | |
| } | |
| static std::vector<uint8_t> buildResponse(const DNSHeader& h, const std::vector<DNSQuestion>& qs, const DNSServer* self) | |
| { | |
| // If the opcode is not a standard QUERY (0), respond with NOTIMP (RCODE=4) | |
| if (h.FLAGS.OPCODE != 0) { | |
| DNSHeaderFlags flags = h.FLAGS; | |
| flags.QR = 1; // this is a response | |
| flags.RCODE = 4; // NOTIMP | |
| DNSHeader outH{ | |
| .ID = h.ID, | |
| .FLAGS = flags, | |
| .QDCOUNT = static_cast<uint16_t>(qs.size()), | |
| .ANCOUNT = 0, | |
| .NSCOUNT = 0, | |
| .ARCOUNT = 0 | |
| }; | |
| std::vector<uint8_t> response; | |
| response.reserve(512); | |
| writeHeader(response, outH); | |
| // Echo back the original questions (uncompressed) | |
| for (const auto& q : qs) { | |
| auto qb = buildQuestionSection(q); | |
| response.insert(response.end(), qb.begin(), qb.end()); | |
| } | |
| return response; | |
| } | |
| // Collect answers by forwarding each question individually | |
| std::vector<uint8_t> mergedAnswers; | |
| uint16_t totalAN = 0; | |
| for (const auto& q : qs) { | |
| auto upstreamResp = self->forwardSingleQuestion(h, q); | |
| // Count answers from upstream header | |
| auto rd16 = [](const uint8_t* p){ return (static_cast<uint16_t>(p[0])<<8) | p[1]; }; | |
| const uint16_t an = rd16(upstreamResp.data() + 6); // ANCOUNT at bytes 6..7 | |
| totalAN = static_cast<uint16_t>(totalAN + an); | |
| auto ansBytes = extractAnswersFromUpstream(upstreamResp); | |
| mergedAnswers.insert(mergedAnswers.end(), ansBytes.begin(), ansBytes.end()); | |
| } | |
| // Prepare header: reply with same ID, copy flags (QR=1) | |
| DNSHeaderFlags flags = h.FLAGS; | |
| flags.QR = 1; flags.RCODE = 0; // assume success (upstream guarantees answer) | |
| DNSHeader outH{ | |
| .ID = h.ID, | |
| .FLAGS = flags, | |
| .QDCOUNT = static_cast<uint16_t>(qs.size()), | |
| .ANCOUNT = totalAN, | |
| .NSCOUNT = 0, | |
| .ARCOUNT = 0 | |
| }; | |
| std::vector<uint8_t> response; | |
| response.reserve(512); | |
| writeHeader(response, outH); | |
| // Append original questions (uncompressed) | |
| for (const auto& q : qs) { | |
| auto qb = buildQuestionSection(q); | |
| response.insert(response.end(), qb.begin(), qb.end()); | |
| } | |
| // Append merged answers | |
| response.insert(response.end(), mergedAnswers.begin(), mergedAnswers.end()); | |
| return response; | |
| } | |
| static std::vector<DNSQuestion> parseQuestion(const uint8_t* buffer, const DNSHeader h) | |
| { | |
| auto read_u16 = [](const uint8_t* p) -> uint16_t { | |
| return (static_cast<uint16_t>(p[0]) << 8) | p[1]; | |
| }; | |
| std::vector<DNSQuestion> questions; | |
| size_t offset = 12; // start of Question section | |
| for (int qi = 0; qi < h.QDCOUNT; ++qi) { | |
| DNSQuestion q{}; | |
| std::vector<std::string> labels; | |
| // Decode QNAME (with compression) | |
| size_t idx = offset; | |
| bool jumped = false; // whether we've followed a pointer | |
| size_t name_end = 0; // where QNAME ends in the main stream | |
| while (true) { | |
| const uint8_t len = buffer[idx]; | |
| // pointer: 11xxxxxx xxxxxxxx | |
| if ((len & 0xC0) == 0xC0) { | |
| const uint16_t ptr = ((len & 0x3F) << 8) | buffer[idx + 1]; | |
| if (!jumped) { | |
| name_end = idx + 2; // QNAME ends here in the main stream | |
| jumped = true; | |
| } | |
| idx = ptr; // jump to the pointed name | |
| continue; | |
| } | |
| // end of name | |
| if (len == 0) { | |
| if (!jumped) name_end = idx + 1; // account for the zero byte | |
| idx += 1; | |
| break; | |
| } | |
| // normal label | |
| idx += 1; | |
| labels.emplace_back(reinterpret_cast<const char*>(buffer + idx), len); | |
| idx += len; | |
| } | |
| q.labels = std::move(labels); | |
| // Read QTYPE/QCLASS at the end of the main-stream name bytes | |
| const uint16_t qtype = read_u16(buffer + name_end + 0); | |
| const uint16_t qclass = read_u16(buffer + name_end + 2); | |
| q.qtype = qtype; | |
| q.qclass = qclass; | |
| // Advance the overall offset for the *next* question | |
| offset = name_end + 4; | |
| questions.push_back(std::move(q)); | |
| } | |
| return questions; | |
| } | |
| static DNSHeader parseHeaderSection(uint8_t* buffer) | |
| { | |
| auto parsed = reinterpret_cast<uint16_t*>(buffer); | |
| const uint16_t ID = ntohs(*(parsed+0)); | |
| const uint16_t rawFlag = ntohs(*(parsed+1)); | |
| DNSHeaderFlags FLAGS{}; | |
| FLAGS.QR = (rawFlag >> 15) & 0x1; | |
| FLAGS.OPCODE = (rawFlag >> 11) & 0xF; | |
| FLAGS.AA = (rawFlag >> 10) & 0x1; | |
| FLAGS.TC = (rawFlag >> 9) & 0x1; | |
| FLAGS.RD = (rawFlag >> 8) & 0x1; | |
| FLAGS.RA = (rawFlag >> 7) & 0x1; | |
| FLAGS.Z = (rawFlag >> 4) & 0x7; | |
| FLAGS.RCODE = (rawFlag >> 0) & 0xF; | |
| const uint16_t QDCOUNT = ntohs(*(parsed+2)); | |
| const uint16_t ANCOUNT = ntohs(*(parsed+3)); | |
| const uint16_t NSCOUNT = ntohs(*(parsed+4)); | |
| const uint16_t ARCOUNT = ntohs(*(parsed+5)); | |
| const DNSHeader h = { | |
| ID, FLAGS, QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT | |
| }; | |
| return h; | |
| } | |
| void handleRequest() const | |
| { | |
| // define buffer | |
| uint8_t buffer[512]; | |
| // define client | |
| sockaddr_in clientAddress{}; | |
| socklen_t clientAddrLen = sizeof(clientAddress); | |
| // read request | |
| const ssize_t bytesRead = recvfrom(udpSocket, buffer, sizeof(buffer), 0, | |
| reinterpret_cast<struct sockaddr*>(&clientAddress), &clientAddrLen); | |
| const DNSHeader h = parseHeaderSection(buffer); | |
| const std::vector<DNSQuestion> qs = parseQuestion(buffer, h); | |
| if (bytesRead == -1) | |
| { | |
| std::cout << "No data received\n"; | |
| } | |
| std::cout << "Received " << bytesRead << " bytes" << std::endl; | |
| const auto response = buildResponse(h, qs, this); | |
| if (sendto(udpSocket, response.data(), response.size(), 0, reinterpret_cast<sockaddr*>(&clientAddress), | |
| sizeof(clientAddress)) == -1) | |
| { | |
| perror("Failed to send response"); | |
| } | |
| } | |
| [[noreturn]] void startServer() const | |
| { | |
| while (true) | |
| { | |
| handleRequest(); | |
| } | |
| } | |
| }; | |
| int main(const int argc, char** argv) | |
| { | |
| std::string resolver = "1.1.1.1:53"; // default for local testing | |
| for (int i = 1; i + 1 < argc; ++i) { | |
| if (std::string(argv[i]) == "--resolver") { | |
| resolver = argv[i+1]; | |
| } | |
| } | |
| std::cout << "Starting DNS server (forwarder) using resolver " << resolver << "\n"; | |
| const DNSServer server(resolver); | |
| server.startServer(); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment