Skip to content

Instantly share code, notes, and snippets.

@jweinst1
Last active August 13, 2025 05:49
Show Gist options
  • Save jweinst1/3457a162b00faea27f45cadd129f7ac3 to your computer and use it in GitHub Desktop.
Save jweinst1/3457a162b00faea27f45cadd129f7ac3 to your computer and use it in GitHub Desktop.
A replicated Queue in C++
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <signal.h>
#include <errno.h>
//--------system headers -------//
#include <unistd.h>
#include <sys/un.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <optional>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <map>
#include <stdexcept>
#include <string>
#include <variant>
#include <random>
#include <thread>
#include <chrono>
#include <atomic>
static uint64_t operation_id_gen() {
// Each thread gets its own random engine
thread_local std::mt19937 engine{std::random_device{}()};
std::uniform_int_distribution<uint64_t> dist(0x4, 0xe435f1243e);
return dist(engine);
}
static char* str_dupl(const char* src) {
size_t src_size = strlen(src) + 1;
char* newstr = (char*)malloc(src_size);
memcpy(newstr, src, src_size);
return newstr;
}
static void debugByteVector(const std::vector<unsigned char>& vec) {
for (const auto& byte : vec)
{
printf("%u ", byte);
}
printf("|\n");
}
static int getAndResetErrNo() {
int eResult = errno;
errno = 0;
return eResult;
}
static void resetErrNo() {
errno = 0;
}
static int errNoIsWouldBlock() {
int eResult = errno;
return eResult == EAGAIN || eResult == EWOULDBLOCK;
}
static void exitAndErrorNo(const char* lastAction) {
fprintf(stderr, "Got unexpected lastAction=%s, errno=%d\n", lastAction, errno);
exit(2);
}
static constexpr size_t getMaxSizeOfUnixPath() {
constexpr struct sockaddr_un unix_addr = {};
return sizeof(unix_addr.sun_path);
}
static const inline bool pathExists(const char* path) {
struct stat sbuf;
return stat(path, &sbuf) == 0;
}
static bool set_non_blocking(int sockfd, bool blocking) {
int flags = fcntl(sockfd, F_GETFL, 0);
flags = blocking ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK);
if (fcntl(sockfd, F_SETFL, flags)) {
return false;
}
return true;
}
static std::optional<int> create_server_socket(const char* path) {
int sfd = -1;
struct sockaddr_un unix_addr;
if (strlen(path) > getMaxSizeOfUnixPath() - 1) {
return std::nullopt;
}
if (pathExists(path)) {
remove(path);
}
memset(&unix_addr, 0, sizeof(struct sockaddr_un));
unix_addr.sun_family = AF_UNIX;
strncpy(unix_addr.sun_path, path, getMaxSizeOfUnixPath() - 1);
sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
return std::nullopt;
}
if (bind(sfd, (struct sockaddr *) &unix_addr, sizeof(unix_addr)) == -1) {
return std::nullopt;
}
if (listen(sfd, 5) == -1) {
return std::nullopt;
}
return std::make_optional<int>(sfd);
}
static std::optional<int> create_client_socket(const char* path) {
int cfd = -1;
struct sockaddr_un unix_addr;
cfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (cfd == -1) {
return std::nullopt;
}
if (strlen(path) > getMaxSizeOfUnixPath() - 1) {
// todo err
return std::nullopt;
}
memset(&unix_addr, 0, sizeof(struct sockaddr_un));
unix_addr.sun_family = AF_UNIX;
strncpy(unix_addr.sun_path, path, getMaxSizeOfUnixPath() - 1);
if (connect(cfd, (struct sockaddr *) &unix_addr, sizeof(unix_addr)) == -1) {
return std::nullopt;
}
return std::make_optional<int>(cfd);
}
class DiskBackedStringQueue {
public:
DiskBackedStringQueue(const std::string& path, size_t capacityBytes)
: fd_(-1), size_(0), data_(nullptr)
{
_path = path;
size_t totalSize = sizeof(Meta) + capacityBytes;
bool newFile = false;
fd_ = ::open(_path.c_str(), O_RDWR | O_CREAT, 0666);
if (fd_ == -1) {
throw std::runtime_error("Failed to open file=" + _path);
}
struct stat st;
if (fstat(fd_, &st) == -1) {
::close(fd_);
throw std::runtime_error("fstat failed");
}
if ((size_t)st.st_size != totalSize) {
if (ftruncate(fd_, totalSize) == -1) {
::close(fd_);
throw std::runtime_error("Failed to resize file");
}
newFile = true;
}
void* mapped = ::mmap(nullptr, totalSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
if (mapped == MAP_FAILED) {
data_ = nullptr;
::close(fd_);
throw std::runtime_error("Failed to mmap file");
}
data_ = static_cast<char*>(mapped);
size_ = totalSize;
meta_ = reinterpret_cast<Meta*>(data_);
buffer_ = reinterpret_cast<char*>(data_ + sizeof(Meta));
if (newFile) {
meta_->head = 0;
meta_->tail = 0;
meta_->capacity = capacityBytes;
} else if (meta_->capacity != capacityBytes) {
cleanup();
throw std::runtime_error("Capacity mismatch with existing file");
}
}
void deleteFile() {
unlink(_path.c_str());
}
bool push(const std::string& str) {
uint32_t len = static_cast<uint32_t>(str.size());
size_t needed = sizeof(len) + len;
if (freeSpace() < needed) {
return false; // Not enough space
}
writeBytes(reinterpret_cast<const char*>(&len), sizeof(len));
writeBytes(str.data(), len);
return true;
}
bool pop(std::string& out) {
if (empty()) {
return false;
}
uint32_t len;
readBytes(reinterpret_cast<char*>(&len), sizeof(len));
out.resize(len);
readBytes(&out[0], len);
return true;
}
bool hasItems() const {
return usedSpace() > 0;
}
bool empty() const {
return meta_->head == meta_->tail;
}
size_t freeSpace() const {
size_t used = usedSpace();
return meta_->capacity - used;
}
size_t usedSpace() const {
if (meta_->tail >= meta_->head) {
return meta_->tail - meta_->head;
} else {
return meta_->capacity - (meta_->head - meta_->tail);
}
}
void cleanup() {
//printf("Cleanup called for path=%s\n", _path.c_str());
if (data_) {
::msync(data_, size_, MS_SYNC);
::munmap(data_, size_);
data_ = nullptr;
}
if (fd_ != -1) {
::close(fd_);
fd_ = -1;
}
}
private:
struct Meta {
size_t head;
size_t tail;
size_t capacity;
};
void writeBytes(const char* src, size_t len) {
size_t tail = meta_->tail;
size_t cap = meta_->capacity;
size_t firstPart = std::min(len, cap - tail);
memcpy(buffer_ + tail, src, firstPart);
if (len > firstPart) {
memcpy(buffer_, src + firstPart, len - firstPart);
}
meta_->tail = (tail + len) % cap;
}
void readBytes(char* dest, size_t len) {
size_t head = meta_->head;
size_t cap = meta_->capacity;
size_t firstPart = std::min(len, cap - head);
memcpy(dest, buffer_ + head, firstPart);
if (len > firstPart) {
memcpy(dest + firstPart, buffer_, len - firstPart);
}
meta_->head = (head + len) % cap;
}
int fd_;
size_t size_;
char* data_;
Meta* meta_;
char* buffer_;
std::string _path;
};
static constexpr unsigned char PROT_U8 = 1;
static constexpr unsigned char PROT_U32 = 2;
static constexpr unsigned char PROT_U64 = 3;
static constexpr unsigned char PROT_STR = 4;
static constexpr unsigned char PROT_STR_LST = 5;
/**
* Variant type for the byte TCP protocol
* */
typedef std::variant<unsigned char, uint32_t, uint64_t, std::string, std::vector<std::string>> ReqItem;
/**
* Builds vector of request parts from raw bytes
* */
static void buildFromBytes(std::vector<ReqItem>& vec, const unsigned char* bytes, size_t size) {
size_t i = 0;
while (i < size) {
ReqItem elem;
if (bytes[i] == PROT_U8) {
i += 1;
unsigned char b = bytes[i];
elem = b;
i += 1;
} else if (bytes[i] == PROT_U32) {
i += 1;
uint32_t val = 0;
memcpy(&val, bytes + i, sizeof(val));
elem = val;
i += sizeof(val);
} else if (bytes[i] == PROT_U64) {
i += 1;
uint64_t val = 0;
memcpy(&val, bytes + i, sizeof(val));
elem = val;
i += sizeof(val);
} else if (bytes[i] == PROT_STR) {
i += 1;
uint32_t strSize = 0;
memcpy(&strSize, bytes + i, sizeof(strSize));
i += sizeof(strSize);
std::string strObj;
strObj.resize(strSize);
memcpy(strObj.data(), bytes + i, strSize);
elem = strObj;
i += strSize;
} else if (bytes[i] == PROT_STR_LST) {
i += 1;
uint32_t strLstSize = 0;
memcpy(&strLstSize, bytes + i, sizeof(strLstSize));
i += sizeof(strLstSize);
std::vector<std::string> strLst;
for (size_t j = 0; j < strLstSize; ++j) {
uint32_t strSize = 0;
memcpy(&strSize, bytes + i, sizeof(strSize));
i += sizeof(strSize);
std::string strObj;
strObj.resize(strSize);
memcpy(strObj.data(), bytes + i, strSize);
i += strSize;
strLst.push_back(strObj);
}
elem = strLst;
} else {
fprintf(stderr, "Unexpected byte code during serialization, %u\n", bytes[i]);
exit(2);
}
vec.push_back(elem);
}
}
/**
* Class that builds C++ types into raw bytes according to the byte protocol
* */
class ReqBuilder {
public:
void putSize() {
uint32_t sizeOfReq = _req.size() - sizeof(uint32_t);
memcpy(_req.data(), &sizeOfReq, sizeof(sizeOfReq));
}
void pushU8(unsigned char byte) {
_req.push_back(PROT_U8);
_req.push_back(byte);
}
void pushU32(uint32_t num) {
_req.push_back(PROT_U32);
size_t oldSize = _req.size();
_req.resize(oldSize + sizeof(num));
memcpy(_req.data() + oldSize, &num, sizeof(num));
}
void pushU64(uint64_t num) {
_req.push_back(PROT_U64);
size_t oldSize = _req.size();
_req.resize(oldSize + sizeof(num));
memcpy(_req.data() + oldSize, &num, sizeof(num));
}
void pushStr(const std::string& stringObj) {
_req.push_back(PROT_STR);
uint32_t strSize = stringObj.size();
size_t oldSize = _req.size();
_req.resize(oldSize + strSize + sizeof(strSize));
memcpy(_req.data() + oldSize, &strSize, sizeof(strSize));
memcpy(_req.data() + oldSize + sizeof(strSize), stringObj.data(), strSize);
}
void pushStrList(const std::vector<std::string>& stringLst) {
_req.push_back(PROT_STR_LST);
size_t oldSize = _req.size();
uint32_t strLstSize = stringLst.size();
size_t totalSize = calcSizeOfStrLst(stringLst);
_req.resize(oldSize + sizeof(strLstSize) + totalSize);
memcpy(_req.data() + oldSize, &strLstSize, sizeof(strLstSize));
size_t writePoint = oldSize + sizeof(strLstSize);
for (const auto& obj : stringLst) {
uint32_t strSize = obj.size();
memcpy(_req.data() + writePoint, &strSize, sizeof(strSize));
writePoint += sizeof(strSize);
memcpy(_req.data() + writePoint, obj.data(), strSize);
writePoint += strSize;
}
}
const std::vector<unsigned char>& getReq() const { return _req; }
const unsigned char* getReqData() const {
return _req.data() + sizeof(uint32_t);
}
size_t getReqSize() const { return _req.size() - sizeof(uint32_t); }
const unsigned char* getTotalReqData() const {
return _req.data();
}
size_t getTotalReqSize() const {
return _req.size();
}
private:
size_t calcSizeOfStrLst(const std::vector<std::string>& stringLst) {
size_t total = 0;
for (const auto& obj : stringLst) {
total += obj.size();
total += sizeof(uint32_t); // for size marker
}
return total;
}
std::vector<unsigned char> _req = {0, 0, 0, 0};
};
/**
* Represents a request sent across a socket
* */
struct ClusterRequest {
std::optional<std::string> sender;
std::optional<int> conn;
std::vector<ReqItem> req;
};
struct Member {
int fd = -1;
std::vector<ReqBuilder> pendingToBeSent;
};
enum class DequeueState {
eInProgress,
eWasEmpty,
eSuccess,
eFailed
};
struct DequeueResult {
DequeueState state = DequeueState::eInProgress;
std::optional<uint64_t> id;
std::optional<std::string> job;
};
enum class EnqueueState {
eInProgress,
eWasFull,
eSuccess
};
struct EnqueueResult {
EnqueueState state = EnqueueState::eInProgress;
std::optional<uint64_t> id;
};
struct OperationMetrics {
size_t proxyEnqueue = 0;
size_t proxyDequeue = 0;
size_t replEnqueue = 0;
size_t replDequeue = 0;
size_t blockedWrites = 0;
};
static constexpr unsigned char OPER_CONNECT = 1;
static constexpr unsigned char OPER_ENQUEUE = 2;
static constexpr unsigned char OPER_ENQUEUE_RESP = 3;
static constexpr unsigned char OPER_DEQUEUE = 4;
static constexpr unsigned char OPER_DEQUEUE_RESP = 5;
static constexpr unsigned char OPER_REPL_ENQUEUE = 6;
static constexpr unsigned char OPER_REPL_DEQUEUE = 7;
class ClusterNode {
public:
using ElectionMap = std::unordered_map<std::string, std::string>;
using DequeueMap = std::unordered_map<uint64_t, DequeueResult>;
explicit ClusterNode(const char* path, size_t queueSize){
_path = path;
initializeQueue(queueSize);
}
explicit ClusterNode(const std::string& path, size_t queueSize){
_path = path;
initializeQueue(queueSize);
}
~ClusterNode() {
if (_removeQueueFile) {
_diskQueue.value().cleanup();
_diskQueue.value().deleteFile();
}
}
void setRemoveQueueFile(bool state) {
_removeQueueFile = state;
}
void initializeQueue(size_t maxSize) {
_diskQueue.emplace(_path + ".dat", maxSize);
}
void closeConnections() {
close(_server);
for (const auto& [ key, conn ] : _members) {
close(conn.fd);
}
}
EnqueueResult proxyEnqueue(const std::string& job, uint64_t id) {
assert(!_isCaptain);
auto found = _members.find(_currentCaptain);
assert(found != _members.end());
EnqueueResult r;
r.state = EnqueueState::eInProgress;
ReqBuilder build;
build.pushU8(OPER_ENQUEUE);
build.pushU64(id);
build.pushStr(job);
build.putSize();
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
_metrics.blockedWrites++;
found->second.pendingToBeSent.push_back(build);
}
_pendingEnqueues.insert(id);
return r;
}
EnqueueResult enqueue(const std::string& job) {
if (!_isCaptain) {
// create an id for this job
uint64_t jobId = operation_id_gen();
return proxyEnqueue(job, jobId);
}
EnqueueResult r;
if(!_diskQueue->push(job)) {
r.state = EnqueueState::eWasFull;
return r;
}
replicateToAllMembers(std::make_optional<std::string>(job));
r.state = EnqueueState::eSuccess;
return r;
}
DequeueResult proxyDequeue(uint64_t id) {
assert(!_isCaptain);
auto found = _members.find(_currentCaptain);
assert(found != _members.end());
ReqBuilder build;
build.pushU8(OPER_DEQUEUE);
build.pushU64(id);
build.putSize();
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
_metrics.blockedWrites++;
found->second.pendingToBeSent.push_back(build);
}
_pendingDequeues[id] = DequeueResult{};
DequeueResult r;
r.id = std::make_optional<uint64_t>(id);
return r;
}
DequeueResult dequeue() {
if (!_isCaptain) {
uint64_t jobId = operation_id_gen();
return proxyDequeue(jobId);
}
std::string popped;
DequeueResult r;
if (!_diskQueue->pop(popped)) {
DequeueResult r;
r.state = DequeueState::eWasEmpty;
return r;
}
r.job = std::make_optional<std::string>(popped);
r.state = DequeueState::eSuccess;
replicateToAllMembers(std::nullopt);
return r;
}
std::optional<std::string> getAndPossiblyClearDequeue(uint64_t id) {
auto found = _pendingDequeues.find(id);
if (found == _pendingDequeues.end()) {
return std::nullopt;
}
if (found->second.state == DequeueState::eSuccess) {
const std::string dequeuedJob = found->second.job.value();
_pendingDequeues.erase(id);
return dequeuedJob;
} else if (found->second.state == DequeueState::eFailed ||
found->second.state == DequeueState::eWasEmpty) {
_pendingDequeues.erase(id);
}
return std::nullopt;
}
bool connectWith(const std::string& target) {
if (_members.find(target) != _members.end() || target == _path) {
return false;
}
std::optional<int> remote = create_client_socket(target.c_str());
if (!remote.has_value()) {
fprintf(stderr, "Failed to connect to %s, errno=%d", target.c_str(), getAndResetErrNo());
return false;
}
Member m;
m.fd = remote.value();
assert(set_non_blocking(m.fd, true));
_members[target] = m;
ReqBuilder build;
build.pushU8(OPER_CONNECT);
build.pushStr(_path);
build.putSize();
write(m.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
fprintf(stderr, "Got unexpected would block when sending connect with to %s\n", target.c_str());
return false;
}
return true;
}
void checkAndRetryWrites() {
for (auto& [ key, conn ] : _members) {
if (conn.pendingToBeSent.size() > 0) {
for (auto it = conn.pendingToBeSent.begin(); it != conn.pendingToBeSent.end(); ) {
write(conn.fd, it->getTotalReqData(), it->getTotalReqSize());
if (errNoIsWouldBlock()) {
// failed write
_metrics.blockedWrites++;
++it;
} else {
it = conn.pendingToBeSent.erase(it);
}
}
}
}
}
uint64_t getAndIncrementReplId() { return ++_lastRepId; }
void replicateToAllMembers(const std::optional<std::string>& job) {
if (!_isCaptain) return;
ReqBuilder build;
if (job.has_value()) {
build.pushU8(OPER_REPL_ENQUEUE);
build.pushU64(getAndIncrementReplId());
build.pushStr(job.value());
} else {
build.pushU8(OPER_REPL_DEQUEUE);
build.pushU64(getAndIncrementReplId());
}
build.putSize();
// Send to all members
for (auto& [ key, conn ] : _members) {
write(conn.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
++_metrics.blockedWrites;
conn.pendingToBeSent.push_back(build);
}
}
}
static std::vector<ClusterNode> createCluster(const std::vector<std::string>& clusterMembers,
const std::string& firstCaptain,
size_t queueSize = 50000) {
std::vector<ClusterNode> nodes;
for (const auto& memb: clusterMembers) {
ClusterNode n(memb, queueSize);
if (memb == firstCaptain) {
n._isCaptain = true;
}
n._currentCaptain = firstCaptain;
nodes.push_back(n);
}
for (auto& node: nodes) {
node.start();
}
/// now connect
for (size_t i = 0; i < nodes.size() - 1; ++i)
{
for (size_t j = i + 1; j < nodes.size(); ++j)
{
nodes[i].connectWith(nodes[j]._path);
nodes[j].doWork();
}
}
return nodes;
}
void start() {
std::optional<int> fd = create_server_socket(_path.c_str());
if (fd.has_value()) {
_server = *fd;
assert(set_non_blocking(_server, true));
_isStarted = true;
} else {
fprintf(stderr, "Cannot create server socket, errno=%d\n", getAndResetErrNo());
exit(2);
}
}
const std::string& getPath() const { return _path; }
int getServerFd() const { return _server; }
void collectRequests() {
assert(_isStarted);
struct sockaddr_un remote;
unsigned int sock_len = 0;
int incoming = accept(_server, (struct sockaddr*)&remote, &sock_len);
// todo make loop
if( incoming == -1 ) {
if(errNoIsWouldBlock()) {
resetErrNo();
} else {
exitAndErrorNo("Could not bind socket");
}
} else {
set_non_blocking(incoming, true);
_pending.push_back(incoming);
}
std::vector<int> toKeep;
// manual polling for now
for (size_t i = 0; i < _pending.size(); ++i) {
std::vector<unsigned char> bytes;
uint32_t req_size = 0;
read(_pending[i], &req_size, sizeof(req_size));
if (errNoIsWouldBlock()) {
resetErrNo();
toKeep.push_back(_pending[i]);
continue;
}
bytes.resize(req_size);
read(_pending[i], bytes.data(), req_size);
if (errNoIsWouldBlock()) {
exitAndErrorNo("Unexpected lack of body of request on pending");
}
std::vector<ReqItem> reqItems;
buildFromBytes(reqItems, bytes.data(), bytes.size());
ClusterRequest req;
req.conn = std::make_optional<int>(_pending[i]);
req.req = reqItems;
_requests.push_back(req);
}
_pending = toKeep;
for (const auto& [ key, conn ] : _members) {
while (readRequestFromMember(key, conn)) {
// todo some timing , but why not keep reading?
}
}
}
bool readRequestFromMember(const std::string& name, const Member& mem) {
std::vector<unsigned char> bytes;
uint32_t req_size = 0;
read(mem.fd, &req_size, sizeof(req_size));
if (errNoIsWouldBlock()) {
resetErrNo();
return false;
}
bytes.resize(req_size);
read(mem.fd, bytes.data(), req_size);
if (errNoIsWouldBlock()) {
exitAndErrorNo("Unexpected lack of body of request on formed member");
}
std::vector<ReqItem> reqItems;
buildFromBytes(reqItems, bytes.data(), bytes.size());
ClusterRequest req;
req.sender = std::make_optional<std::string>(name);
req.req = reqItems;
_requests.push_back(req);
return true;
}
void listMembersToVec(std::vector<std::string>& membs) {
for (const auto& [ key, conn ] : _members) {
membs.push_back(key);
}
}
void processConnectRequest(const ClusterRequest& req) {
const std::string sender = std::get<std::string>(req.req[1]);
if (!req.conn.has_value()) {
fprintf(stderr, "Unable to get socket fd for member=%s\n", sender.c_str());
exit(3);
}
auto found = _members.find(sender);
if (found != _members.end()) {
printf("Got reconnect for member=%s\n", sender.c_str());
close(found->second.fd);
found->second.fd = req.conn.value();
return;
}
Member m;
m.fd = req.conn.value();
_members[sender] = m;
}
void processEnqueueRespRequest(const ClusterRequest& req) {
if (_isCaptain) {
//error
return;
}
if (std::get_if<uint64_t>(&req.req[1]) == nullptr) {
//error bad request!!
return;
}
uint64_t id = std::get<uint64_t>(req.req[1]);
if (_pendingEnqueues.count(id) < 1) {
// error bad request!!
return;
}
_pendingEnqueues.erase(id);
return;
}
void processEnqueueRequest(const ClusterRequest& req) {
if (!_isCaptain) {
//error
return;
}
if (!req.sender.has_value()) {
fprintf(stderr, "Unexpected enqueue from unknown member!");
return;
}
auto found = _members.find(req.sender.value());
if (found == _members.end()) {
fprintf(stderr, "Enqueue from unknown member=%s\n", req.sender.value().c_str());
return;
}
std::optional<std::string> job = std::make_optional<std::string>(std::get<std::string>(req.req[2]));
_diskQueue->push(job.value());
ReqBuilder build;
build.pushU8(OPER_ENQUEUE_RESP);
build.pushU64(std::get<uint64_t>(req.req[1]));
build.putSize();
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
resetErrNo();
found->second.pendingToBeSent.push_back(build);
++_metrics.blockedWrites;
return;
}
replicateToAllMembers(job);
}
void processDequeueRespRequest(const ClusterRequest& req) {
if (_isCaptain) {
//error
return;
}
if (std::get_if<uint64_t>(&req.req[1]) == nullptr) {
//error bad request!!
return;
}
uint64_t id = std::get<uint64_t>(req.req[1]);
auto found = _pendingDequeues.find(id);
if (found == _pendingDequeues.end()) {
fprintf(stderr, "Unknown id=%llu for Dequeue\n", id);
return;
}
const unsigned char result = std::get<unsigned char>(req.req[2]);
if (result) {
found->second.state = DequeueState::eSuccess;
found->second.job = std::get<std::string>(req.req[3]);
} else {
found->second.state = DequeueState::eWasEmpty;
}
return;
}
void processDequeueRequest(const ClusterRequest& req) {
if (!_isCaptain) {
//error
return;
}
if (!req.sender.has_value()) {
fprintf(stderr, "Unexpected dequeue from unknown member!");
return;
}
bool doReplication = false;
auto found = _members.find(req.sender.value());
if (found == _members.end()) {
fprintf(stderr, "Dequeue from unknown member=%s\n", req.sender.value().c_str());
return;
}
ReqBuilder build;
build.pushU8(OPER_DEQUEUE_RESP);
build.pushU64(std::get<uint64_t>(req.req[1]));
std::string popped;
if (_diskQueue->pop(popped)) {
build.pushU8(1); // Queue had item
build.pushStr(popped);
doReplication = true;
} else {
build.pushU8(0); // Queue was empty.
}
build.putSize();
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
resetErrNo();
++_metrics.blockedWrites;
found->second.pendingToBeSent.push_back(build);
return;
}
if (doReplication) {
replicateToAllMembers(std::nullopt);
}
}
void processReplDequeueRequest(const ClusterRequest& req) {
if (!req.sender.has_value()) {
fprintf(stderr, "Got replication from unknown member!\n");
return;
}
if (req.sender.value() != _currentCaptain) {
fprintf(stderr, "Got replication from non-captain member=%s\n", req.sender.value().c_str());
return;
}
_lastRepId = std::get<uint64_t>(req.req[1]);
std::string popped;
(void)_diskQueue->pop(popped);
//printf("Processed replication for %llu from %s\n", _lastRepId,
//req.sender.value().c_str());
}
void processReplEnqueueRequest(const ClusterRequest& req) {
if (!req.sender.has_value()) {
fprintf(stderr, "Got replication from unknown member!\n");
return;
}
if (req.sender.value() != _currentCaptain) {
fprintf(stderr, "Got replication from non-captain member=%s\n", req.sender.value().c_str());
return;
}
_lastRepId = std::get<uint64_t>(req.req[1]);
const std::string gotJob = std::get<std::string>(req.req[2]);
_diskQueue->push(gotJob);
//printf("Processed replication for %llu from %s\n", _lastRepId,
//req.sender.value().c_str());
}
void processRequests() {
for (const auto& req: _requests) {
// todo error handle
const unsigned char opCode = std::get<unsigned char>(req.req[0]);
switch (opCode) {
case OPER_CONNECT:
processConnectRequest(req);
break;
case OPER_ENQUEUE:
processEnqueueRequest(req);
break;
case OPER_ENQUEUE_RESP:
processEnqueueRespRequest(req);
break;
case OPER_DEQUEUE:
processDequeueRequest(req);
break;
case OPER_DEQUEUE_RESP:
processDequeueRespRequest(req);
break;
case OPER_REPL_ENQUEUE:
processReplEnqueueRequest(req);
break;
case OPER_REPL_DEQUEUE:
processReplDequeueRequest(req);
break;
default:
fprintf(stderr, "Unknown request code %u\n", opCode);
exit(3);
}
}
_requests.clear();
}
void doWork() {
collectRequests();
processRequests();
checkAndRetryWrites();
// check state
}
const OperationMetrics& getMetrics() const { return _metrics; }
private:
friend class ClusterNodeTests;
bool _isStarted = false;
int _server = -1;
bool _isCaptain = false;
bool _removeQueueFile = false;
uint64_t _lastRepId = 0;
std::string _path;
std::string _currentCaptain;
std::vector<int> _pending;
std::unordered_set<uint64_t> _pendingEnqueues;
DequeueMap _pendingDequeues;
std::vector<ClusterRequest> _requests;
std::unordered_map<std::string, Member> _members;
std::optional<ElectionMap> _electionMap;
std::optional<DiskBackedStringQueue> _diskQueue;
OperationMetrics _metrics;
};
//------- tests ---------
static unsigned _failures = 0;
static unsigned _test_passes = 0;
static void check_cond(int cond, const char* condstr, unsigned line) {
if (!cond) {
fprintf(stderr, "Failed cond '%s' at line %u\n", condstr, line);
++_failures;
} else {
++_test_passes;
}
}
#define CHECKIT(cnd) check_cond(cnd, #cnd, __LINE__)
static void test_reqBuilder() {
ReqBuilder b;
b.pushU8(3);
std::string foo = "foo";
std::vector<std::string> foos = {"abc", "def"};
b.pushStr(foo);
b.pushStrList(foos);
b.putSize();
std::vector<ReqItem> elems;
buildFromBytes(elems, b.getReqData(), b.getReqSize());
CHECKIT(elems.size() == 3);
CHECKIT(std::get_if<unsigned char>(&elems[0]) != nullptr);
CHECKIT(std::get_if<std::string>(&elems[1]) != nullptr);
CHECKIT(std::get_if<std::vector<std::string>>(&elems[2]) != nullptr);
}
class ClusterNodeTests {
public:
static void testCreateCluster();
static void testEnqueueProxy();
static void testDequeueProxy();
static void testDiskQueue();
};
void ClusterNodeTests::testCreateCluster() {
std::vector<std::string> membs = {"foo", "bar", "sam"};
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]);
for (auto& memb : cluster) {
memb.setRemoveQueueFile(true);
}
CHECKIT(cluster[0]._members.size() == 2);
CHECKIT(cluster[1]._members.size() == 2);
CHECKIT(cluster[2]._members.size() == 2);
CHECKIT(cluster[0]._isCaptain == true);
CHECKIT(cluster[0]._members.find(cluster[1]._path) != cluster[0]._members.end());
CHECKIT(cluster[0]._members.find(cluster[2]._path) != cluster[0]._members.end());
CHECKIT(cluster[1]._members.find(cluster[0]._path) != cluster[1]._members.end());
CHECKIT(cluster[1]._members.find(cluster[2]._path) != cluster[1]._members.end());
CHECKIT(cluster[2]._members.find(cluster[0]._path) != cluster[2]._members.end());
CHECKIT(cluster[2]._members.find(cluster[1]._path) != cluster[2]._members.end());
}
void ClusterNodeTests::testEnqueueProxy() {
printf("Enqueue data starting\n");
std::vector<std::string> membs = {"foo", "bar", "sam"};
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]);
for (auto& memb : cluster) {
memb.setRemoveQueueFile(true);
}
const std::string job = "job";
CHECKIT(cluster[0]._isCaptain == true);
CHECKIT(cluster[0]._diskQueue->usedSpace() == 0);
CHECKIT(cluster[0].enqueue(job).state == EnqueueState::eSuccess);
CHECKIT(cluster[0]._diskQueue->usedSpace() == 7);
CHECKIT(cluster[0]._diskQueue->hasItems());
// Replication check
cluster[1].doWork();
cluster[2].doWork();
CHECKIT(cluster[0]._lastRepId == 1);
CHECKIT(cluster[1]._lastRepId == cluster[0]._lastRepId);
CHECKIT(cluster[2]._lastRepId == cluster[0]._lastRepId);
EnqueueResult jobid = cluster[1].enqueue(job);
CHECKIT(jobid.id.has_value());
CHECKIT(cluster[1]._pendingEnqueues.count(jobid.id.value()) == 1);
cluster[0].doWork();
CHECKIT(cluster[0]._diskQueue->usedSpace() == 14);
cluster[1].doWork();
cluster[2].doWork();
CHECKIT(cluster[1]._pendingEnqueues.size() == 0);
CHECKIT(cluster[1]._lastRepId == cluster[0]._lastRepId);
CHECKIT(cluster[2]._lastRepId == cluster[0]._lastRepId);
}
void ClusterNodeTests::testDequeueProxy() {
std::vector<std::string> membs = {"foo", "bar", "sam"};
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]);
for (auto& memb : cluster) {
memb.setRemoveQueueFile(true);
}
const std::string job = "job";
CHECKIT(cluster[0]._isCaptain == true);
CHECKIT(cluster[0].enqueue(job).state == EnqueueState::eSuccess);
CHECKIT(cluster[0]._diskQueue->usedSpace() == 7);
CHECKIT(cluster[0]._diskQueue->hasItems());
DequeueResult res = cluster[1].dequeue();
CHECKIT(res.state == DequeueState::eInProgress);
CHECKIT(res.id.has_value());
cluster[0].doWork();
CHECKIT(cluster[0]._diskQueue->usedSpace() == 0);
CHECKIT(cluster[1].getAndPossiblyClearDequeue(res.id.value()) == std::nullopt);
cluster[1].doWork();
CHECKIT(cluster[1].getAndPossiblyClearDequeue(res.id.value()).value() == job);
}
void ClusterNodeTests::testDiskQueue() {
static constexpr size_t maxSize = 500;
static const std::string jobSample = "{a job}";
static const std::string qPath = "qpath.dat";
{
DiskBackedStringQueue q(qPath, maxSize);
CHECKIT(q.usedSpace() == 0);
CHECKIT(q.push(jobSample));
std::string result;
CHECKIT(q.pop(result));
CHECKIT(result == jobSample);
CHECKIT(!q.pop(result));
CHECKIT(!q.hasItems());
CHECKIT(q.push(jobSample));
q.cleanup();
}
{
DiskBackedStringQueue q(qPath, maxSize);
std::string result;
CHECKIT(q.hasItems());
CHECKIT(q.pop(result));
CHECKIT(result == jobSample);
q.cleanup();
q.deleteFile();
}
}
class PerformanceTests {
public:
static void oneProducerTwoConsumer();
};
void PerformanceTests::oneProducerTwoConsumer() {
puts("Starting Perf Tests!");
std::vector<std::string> membs = {"foo", "bar", "sam"};
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]);
for (auto& memb : cluster) {
memb.setRemoveQueueFile(true);
}
std::atomic<bool> keepWorking(false);
auto start = std::chrono::high_resolution_clock::now();
std::thread captainThread = std::thread([&](){
const std::string myJob = "{the job to produce}";
size_t jobsQueued = 0;
while(!keepWorking.load());
while(keepWorking) {
cluster[0].enqueue(myJob);
cluster[0].doWork();
++jobsQueued;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
printf("The captain enqueued %zu jobs\n", jobsQueued);
printf("The captain had %zu blocked writes\n", cluster[0].getMetrics().blockedWrites);
});
std::thread member1Thread = std::thread([&](){
std::vector<uint64_t> deqRequests;
size_t jobsGot = 0;
while(!keepWorking.load());
while (keepWorking.load()) {
for (auto it = deqRequests.begin(); it != deqRequests.end(); ) {
std::optional<std::string> got = cluster[1].getAndPossiblyClearDequeue(*it);
if (got.has_value()) {
it = deqRequests.erase(it);
++jobsGot;
} else {
++it;
}
}
DequeueResult r = cluster[1].dequeue();
if (r.state == DequeueState::eInProgress) {
deqRequests.push_back(r.id.value());
}
cluster[1].doWork();
}
printf("The member got %zu jobs\n", jobsGot);
printf("The member had %zu blocked writes\n", cluster[1].getMetrics().blockedWrites);
});
std::thread member2Thread = std::thread([&](){
std::vector<uint64_t> deqRequests;
size_t jobsGot = 0;
while(!keepWorking.load());
while (keepWorking.load()) {
for (auto it = deqRequests.begin(); it != deqRequests.end(); ) {
std::optional<std::string> got = cluster[2].getAndPossiblyClearDequeue(*it);
if (got.has_value()) {
it = deqRequests.erase(it);
++jobsGot;
} else {
++it;
}
}
DequeueResult r = cluster[2].dequeue();
if (r.state == DequeueState::eInProgress) {
deqRequests.push_back(r.id.value());
}
cluster[2].doWork();
}
printf("The member got %zu jobs\n", jobsGot);
printf("The member had %zu blocked writes\n", cluster[2].getMetrics().blockedWrites);
});
keepWorking.store(true);
std::this_thread::sleep_for(std::chrono::seconds(1));
keepWorking.store(false);
captainThread.join();
member1Thread.join();
member2Thread.join();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
printf("Test took %lld microseconds\n", duration.count());
}
int main(int argc, char const *argv[])
{
srand(time(nullptr));
if (argc == 2 && strcmp(argv[1], "tests") == 0) {
test_reqBuilder();
ClusterNodeTests::testCreateCluster();
ClusterNodeTests::testDiskQueue();
ClusterNodeTests::testEnqueueProxy();
ClusterNodeTests::testDequeueProxy();
if (_failures > 0) {
fprintf(stderr, "Total failures %u\n", _failures);
}
printf("Total tests passed: %u\n", _test_passes);
} else if (argc == 2 && strcmp(argv[1], "perf") == 0) {
PerformanceTests::oneProducerTwoConsumer();
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment