Last active
August 13, 2025 05:49
-
-
Save jweinst1/3457a162b00faea27f45cadd129f7ac3 to your computer and use it in GitHub Desktop.
A replicated Queue in C++
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <assert.h> | |
#include <signal.h> | |
#include <errno.h> | |
//--------system headers -------// | |
#include <unistd.h> | |
#include <sys/un.h> | |
#include <fcntl.h> | |
#include <sys/socket.h> | |
#include <sys/stat.h> | |
#include <sys/mman.h> | |
#include <optional> | |
#include <vector> | |
#include <unordered_map> | |
#include <unordered_set> | |
#include <map> | |
#include <stdexcept> | |
#include <string> | |
#include <variant> | |
#include <random> | |
#include <thread> | |
#include <chrono> | |
#include <atomic> | |
static uint64_t operation_id_gen() { | |
// Each thread gets its own random engine | |
thread_local std::mt19937 engine{std::random_device{}()}; | |
std::uniform_int_distribution<uint64_t> dist(0x4, 0xe435f1243e); | |
return dist(engine); | |
} | |
static char* str_dupl(const char* src) { | |
size_t src_size = strlen(src) + 1; | |
char* newstr = (char*)malloc(src_size); | |
memcpy(newstr, src, src_size); | |
return newstr; | |
} | |
static void debugByteVector(const std::vector<unsigned char>& vec) { | |
for (const auto& byte : vec) | |
{ | |
printf("%u ", byte); | |
} | |
printf("|\n"); | |
} | |
static int getAndResetErrNo() { | |
int eResult = errno; | |
errno = 0; | |
return eResult; | |
} | |
static void resetErrNo() { | |
errno = 0; | |
} | |
static int errNoIsWouldBlock() { | |
int eResult = errno; | |
return eResult == EAGAIN || eResult == EWOULDBLOCK; | |
} | |
static void exitAndErrorNo(const char* lastAction) { | |
fprintf(stderr, "Got unexpected lastAction=%s, errno=%d\n", lastAction, errno); | |
exit(2); | |
} | |
static constexpr size_t getMaxSizeOfUnixPath() { | |
constexpr struct sockaddr_un unix_addr = {}; | |
return sizeof(unix_addr.sun_path); | |
} | |
static const inline bool pathExists(const char* path) { | |
struct stat sbuf; | |
return stat(path, &sbuf) == 0; | |
} | |
static bool set_non_blocking(int sockfd, bool blocking) { | |
int flags = fcntl(sockfd, F_GETFL, 0); | |
flags = blocking ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK); | |
if (fcntl(sockfd, F_SETFL, flags)) { | |
return false; | |
} | |
return true; | |
} | |
static std::optional<int> create_server_socket(const char* path) { | |
int sfd = -1; | |
struct sockaddr_un unix_addr; | |
if (strlen(path) > getMaxSizeOfUnixPath() - 1) { | |
return std::nullopt; | |
} | |
if (pathExists(path)) { | |
remove(path); | |
} | |
memset(&unix_addr, 0, sizeof(struct sockaddr_un)); | |
unix_addr.sun_family = AF_UNIX; | |
strncpy(unix_addr.sun_path, path, getMaxSizeOfUnixPath() - 1); | |
sfd = socket(AF_UNIX, SOCK_STREAM, 0); | |
if (sfd == -1) { | |
return std::nullopt; | |
} | |
if (bind(sfd, (struct sockaddr *) &unix_addr, sizeof(unix_addr)) == -1) { | |
return std::nullopt; | |
} | |
if (listen(sfd, 5) == -1) { | |
return std::nullopt; | |
} | |
return std::make_optional<int>(sfd); | |
} | |
static std::optional<int> create_client_socket(const char* path) { | |
int cfd = -1; | |
struct sockaddr_un unix_addr; | |
cfd = socket(AF_UNIX, SOCK_STREAM, 0); | |
if (cfd == -1) { | |
return std::nullopt; | |
} | |
if (strlen(path) > getMaxSizeOfUnixPath() - 1) { | |
// todo err | |
return std::nullopt; | |
} | |
memset(&unix_addr, 0, sizeof(struct sockaddr_un)); | |
unix_addr.sun_family = AF_UNIX; | |
strncpy(unix_addr.sun_path, path, getMaxSizeOfUnixPath() - 1); | |
if (connect(cfd, (struct sockaddr *) &unix_addr, sizeof(unix_addr)) == -1) { | |
return std::nullopt; | |
} | |
return std::make_optional<int>(cfd); | |
} | |
class DiskBackedStringQueue { | |
public: | |
DiskBackedStringQueue(const std::string& path, size_t capacityBytes) | |
: fd_(-1), size_(0), data_(nullptr) | |
{ | |
_path = path; | |
size_t totalSize = sizeof(Meta) + capacityBytes; | |
bool newFile = false; | |
fd_ = ::open(_path.c_str(), O_RDWR | O_CREAT, 0666); | |
if (fd_ == -1) { | |
throw std::runtime_error("Failed to open file=" + _path); | |
} | |
struct stat st; | |
if (fstat(fd_, &st) == -1) { | |
::close(fd_); | |
throw std::runtime_error("fstat failed"); | |
} | |
if ((size_t)st.st_size != totalSize) { | |
if (ftruncate(fd_, totalSize) == -1) { | |
::close(fd_); | |
throw std::runtime_error("Failed to resize file"); | |
} | |
newFile = true; | |
} | |
void* mapped = ::mmap(nullptr, totalSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0); | |
if (mapped == MAP_FAILED) { | |
data_ = nullptr; | |
::close(fd_); | |
throw std::runtime_error("Failed to mmap file"); | |
} | |
data_ = static_cast<char*>(mapped); | |
size_ = totalSize; | |
meta_ = reinterpret_cast<Meta*>(data_); | |
buffer_ = reinterpret_cast<char*>(data_ + sizeof(Meta)); | |
if (newFile) { | |
meta_->head = 0; | |
meta_->tail = 0; | |
meta_->capacity = capacityBytes; | |
} else if (meta_->capacity != capacityBytes) { | |
cleanup(); | |
throw std::runtime_error("Capacity mismatch with existing file"); | |
} | |
} | |
void deleteFile() { | |
unlink(_path.c_str()); | |
} | |
bool push(const std::string& str) { | |
uint32_t len = static_cast<uint32_t>(str.size()); | |
size_t needed = sizeof(len) + len; | |
if (freeSpace() < needed) { | |
return false; // Not enough space | |
} | |
writeBytes(reinterpret_cast<const char*>(&len), sizeof(len)); | |
writeBytes(str.data(), len); | |
return true; | |
} | |
bool pop(std::string& out) { | |
if (empty()) { | |
return false; | |
} | |
uint32_t len; | |
readBytes(reinterpret_cast<char*>(&len), sizeof(len)); | |
out.resize(len); | |
readBytes(&out[0], len); | |
return true; | |
} | |
bool hasItems() const { | |
return usedSpace() > 0; | |
} | |
bool empty() const { | |
return meta_->head == meta_->tail; | |
} | |
size_t freeSpace() const { | |
size_t used = usedSpace(); | |
return meta_->capacity - used; | |
} | |
size_t usedSpace() const { | |
if (meta_->tail >= meta_->head) { | |
return meta_->tail - meta_->head; | |
} else { | |
return meta_->capacity - (meta_->head - meta_->tail); | |
} | |
} | |
void cleanup() { | |
//printf("Cleanup called for path=%s\n", _path.c_str()); | |
if (data_) { | |
::msync(data_, size_, MS_SYNC); | |
::munmap(data_, size_); | |
data_ = nullptr; | |
} | |
if (fd_ != -1) { | |
::close(fd_); | |
fd_ = -1; | |
} | |
} | |
private: | |
struct Meta { | |
size_t head; | |
size_t tail; | |
size_t capacity; | |
}; | |
void writeBytes(const char* src, size_t len) { | |
size_t tail = meta_->tail; | |
size_t cap = meta_->capacity; | |
size_t firstPart = std::min(len, cap - tail); | |
memcpy(buffer_ + tail, src, firstPart); | |
if (len > firstPart) { | |
memcpy(buffer_, src + firstPart, len - firstPart); | |
} | |
meta_->tail = (tail + len) % cap; | |
} | |
void readBytes(char* dest, size_t len) { | |
size_t head = meta_->head; | |
size_t cap = meta_->capacity; | |
size_t firstPart = std::min(len, cap - head); | |
memcpy(dest, buffer_ + head, firstPart); | |
if (len > firstPart) { | |
memcpy(dest + firstPart, buffer_, len - firstPart); | |
} | |
meta_->head = (head + len) % cap; | |
} | |
int fd_; | |
size_t size_; | |
char* data_; | |
Meta* meta_; | |
char* buffer_; | |
std::string _path; | |
}; | |
static constexpr unsigned char PROT_U8 = 1; | |
static constexpr unsigned char PROT_U32 = 2; | |
static constexpr unsigned char PROT_U64 = 3; | |
static constexpr unsigned char PROT_STR = 4; | |
static constexpr unsigned char PROT_STR_LST = 5; | |
/** | |
* Variant type for the byte TCP protocol | |
* */ | |
typedef std::variant<unsigned char, uint32_t, uint64_t, std::string, std::vector<std::string>> ReqItem; | |
/** | |
* Builds vector of request parts from raw bytes | |
* */ | |
static void buildFromBytes(std::vector<ReqItem>& vec, const unsigned char* bytes, size_t size) { | |
size_t i = 0; | |
while (i < size) { | |
ReqItem elem; | |
if (bytes[i] == PROT_U8) { | |
i += 1; | |
unsigned char b = bytes[i]; | |
elem = b; | |
i += 1; | |
} else if (bytes[i] == PROT_U32) { | |
i += 1; | |
uint32_t val = 0; | |
memcpy(&val, bytes + i, sizeof(val)); | |
elem = val; | |
i += sizeof(val); | |
} else if (bytes[i] == PROT_U64) { | |
i += 1; | |
uint64_t val = 0; | |
memcpy(&val, bytes + i, sizeof(val)); | |
elem = val; | |
i += sizeof(val); | |
} else if (bytes[i] == PROT_STR) { | |
i += 1; | |
uint32_t strSize = 0; | |
memcpy(&strSize, bytes + i, sizeof(strSize)); | |
i += sizeof(strSize); | |
std::string strObj; | |
strObj.resize(strSize); | |
memcpy(strObj.data(), bytes + i, strSize); | |
elem = strObj; | |
i += strSize; | |
} else if (bytes[i] == PROT_STR_LST) { | |
i += 1; | |
uint32_t strLstSize = 0; | |
memcpy(&strLstSize, bytes + i, sizeof(strLstSize)); | |
i += sizeof(strLstSize); | |
std::vector<std::string> strLst; | |
for (size_t j = 0; j < strLstSize; ++j) { | |
uint32_t strSize = 0; | |
memcpy(&strSize, bytes + i, sizeof(strSize)); | |
i += sizeof(strSize); | |
std::string strObj; | |
strObj.resize(strSize); | |
memcpy(strObj.data(), bytes + i, strSize); | |
i += strSize; | |
strLst.push_back(strObj); | |
} | |
elem = strLst; | |
} else { | |
fprintf(stderr, "Unexpected byte code during serialization, %u\n", bytes[i]); | |
exit(2); | |
} | |
vec.push_back(elem); | |
} | |
} | |
/** | |
* Class that builds C++ types into raw bytes according to the byte protocol | |
* */ | |
class ReqBuilder { | |
public: | |
void putSize() { | |
uint32_t sizeOfReq = _req.size() - sizeof(uint32_t); | |
memcpy(_req.data(), &sizeOfReq, sizeof(sizeOfReq)); | |
} | |
void pushU8(unsigned char byte) { | |
_req.push_back(PROT_U8); | |
_req.push_back(byte); | |
} | |
void pushU32(uint32_t num) { | |
_req.push_back(PROT_U32); | |
size_t oldSize = _req.size(); | |
_req.resize(oldSize + sizeof(num)); | |
memcpy(_req.data() + oldSize, &num, sizeof(num)); | |
} | |
void pushU64(uint64_t num) { | |
_req.push_back(PROT_U64); | |
size_t oldSize = _req.size(); | |
_req.resize(oldSize + sizeof(num)); | |
memcpy(_req.data() + oldSize, &num, sizeof(num)); | |
} | |
void pushStr(const std::string& stringObj) { | |
_req.push_back(PROT_STR); | |
uint32_t strSize = stringObj.size(); | |
size_t oldSize = _req.size(); | |
_req.resize(oldSize + strSize + sizeof(strSize)); | |
memcpy(_req.data() + oldSize, &strSize, sizeof(strSize)); | |
memcpy(_req.data() + oldSize + sizeof(strSize), stringObj.data(), strSize); | |
} | |
void pushStrList(const std::vector<std::string>& stringLst) { | |
_req.push_back(PROT_STR_LST); | |
size_t oldSize = _req.size(); | |
uint32_t strLstSize = stringLst.size(); | |
size_t totalSize = calcSizeOfStrLst(stringLst); | |
_req.resize(oldSize + sizeof(strLstSize) + totalSize); | |
memcpy(_req.data() + oldSize, &strLstSize, sizeof(strLstSize)); | |
size_t writePoint = oldSize + sizeof(strLstSize); | |
for (const auto& obj : stringLst) { | |
uint32_t strSize = obj.size(); | |
memcpy(_req.data() + writePoint, &strSize, sizeof(strSize)); | |
writePoint += sizeof(strSize); | |
memcpy(_req.data() + writePoint, obj.data(), strSize); | |
writePoint += strSize; | |
} | |
} | |
const std::vector<unsigned char>& getReq() const { return _req; } | |
const unsigned char* getReqData() const { | |
return _req.data() + sizeof(uint32_t); | |
} | |
size_t getReqSize() const { return _req.size() - sizeof(uint32_t); } | |
const unsigned char* getTotalReqData() const { | |
return _req.data(); | |
} | |
size_t getTotalReqSize() const { | |
return _req.size(); | |
} | |
private: | |
size_t calcSizeOfStrLst(const std::vector<std::string>& stringLst) { | |
size_t total = 0; | |
for (const auto& obj : stringLst) { | |
total += obj.size(); | |
total += sizeof(uint32_t); // for size marker | |
} | |
return total; | |
} | |
std::vector<unsigned char> _req = {0, 0, 0, 0}; | |
}; | |
/** | |
* Represents a request sent across a socket | |
* */ | |
struct ClusterRequest { | |
std::optional<std::string> sender; | |
std::optional<int> conn; | |
std::vector<ReqItem> req; | |
}; | |
struct Member { | |
int fd = -1; | |
std::vector<ReqBuilder> pendingToBeSent; | |
}; | |
enum class DequeueState { | |
eInProgress, | |
eWasEmpty, | |
eSuccess, | |
eFailed | |
}; | |
struct DequeueResult { | |
DequeueState state = DequeueState::eInProgress; | |
std::optional<uint64_t> id; | |
std::optional<std::string> job; | |
}; | |
enum class EnqueueState { | |
eInProgress, | |
eWasFull, | |
eSuccess | |
}; | |
struct EnqueueResult { | |
EnqueueState state = EnqueueState::eInProgress; | |
std::optional<uint64_t> id; | |
}; | |
struct OperationMetrics { | |
size_t proxyEnqueue = 0; | |
size_t proxyDequeue = 0; | |
size_t replEnqueue = 0; | |
size_t replDequeue = 0; | |
size_t blockedWrites = 0; | |
}; | |
static constexpr unsigned char OPER_CONNECT = 1; | |
static constexpr unsigned char OPER_ENQUEUE = 2; | |
static constexpr unsigned char OPER_ENQUEUE_RESP = 3; | |
static constexpr unsigned char OPER_DEQUEUE = 4; | |
static constexpr unsigned char OPER_DEQUEUE_RESP = 5; | |
static constexpr unsigned char OPER_REPL_ENQUEUE = 6; | |
static constexpr unsigned char OPER_REPL_DEQUEUE = 7; | |
class ClusterNode { | |
public: | |
using ElectionMap = std::unordered_map<std::string, std::string>; | |
using DequeueMap = std::unordered_map<uint64_t, DequeueResult>; | |
explicit ClusterNode(const char* path, size_t queueSize){ | |
_path = path; | |
initializeQueue(queueSize); | |
} | |
explicit ClusterNode(const std::string& path, size_t queueSize){ | |
_path = path; | |
initializeQueue(queueSize); | |
} | |
~ClusterNode() { | |
if (_removeQueueFile) { | |
_diskQueue.value().cleanup(); | |
_diskQueue.value().deleteFile(); | |
} | |
} | |
void setRemoveQueueFile(bool state) { | |
_removeQueueFile = state; | |
} | |
void initializeQueue(size_t maxSize) { | |
_diskQueue.emplace(_path + ".dat", maxSize); | |
} | |
void closeConnections() { | |
close(_server); | |
for (const auto& [ key, conn ] : _members) { | |
close(conn.fd); | |
} | |
} | |
EnqueueResult proxyEnqueue(const std::string& job, uint64_t id) { | |
assert(!_isCaptain); | |
auto found = _members.find(_currentCaptain); | |
assert(found != _members.end()); | |
EnqueueResult r; | |
r.state = EnqueueState::eInProgress; | |
ReqBuilder build; | |
build.pushU8(OPER_ENQUEUE); | |
build.pushU64(id); | |
build.pushStr(job); | |
build.putSize(); | |
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize()); | |
if (errNoIsWouldBlock()) { | |
_metrics.blockedWrites++; | |
found->second.pendingToBeSent.push_back(build); | |
} | |
_pendingEnqueues.insert(id); | |
return r; | |
} | |
EnqueueResult enqueue(const std::string& job) { | |
if (!_isCaptain) { | |
// create an id for this job | |
uint64_t jobId = operation_id_gen(); | |
return proxyEnqueue(job, jobId); | |
} | |
EnqueueResult r; | |
if(!_diskQueue->push(job)) { | |
r.state = EnqueueState::eWasFull; | |
return r; | |
} | |
replicateToAllMembers(std::make_optional<std::string>(job)); | |
r.state = EnqueueState::eSuccess; | |
return r; | |
} | |
DequeueResult proxyDequeue(uint64_t id) { | |
assert(!_isCaptain); | |
auto found = _members.find(_currentCaptain); | |
assert(found != _members.end()); | |
ReqBuilder build; | |
build.pushU8(OPER_DEQUEUE); | |
build.pushU64(id); | |
build.putSize(); | |
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize()); | |
if (errNoIsWouldBlock()) { | |
_metrics.blockedWrites++; | |
found->second.pendingToBeSent.push_back(build); | |
} | |
_pendingDequeues[id] = DequeueResult{}; | |
DequeueResult r; | |
r.id = std::make_optional<uint64_t>(id); | |
return r; | |
} | |
DequeueResult dequeue() { | |
if (!_isCaptain) { | |
uint64_t jobId = operation_id_gen(); | |
return proxyDequeue(jobId); | |
} | |
std::string popped; | |
DequeueResult r; | |
if (!_diskQueue->pop(popped)) { | |
DequeueResult r; | |
r.state = DequeueState::eWasEmpty; | |
return r; | |
} | |
r.job = std::make_optional<std::string>(popped); | |
r.state = DequeueState::eSuccess; | |
replicateToAllMembers(std::nullopt); | |
return r; | |
} | |
std::optional<std::string> getAndPossiblyClearDequeue(uint64_t id) { | |
auto found = _pendingDequeues.find(id); | |
if (found == _pendingDequeues.end()) { | |
return std::nullopt; | |
} | |
if (found->second.state == DequeueState::eSuccess) { | |
const std::string dequeuedJob = found->second.job.value(); | |
_pendingDequeues.erase(id); | |
return dequeuedJob; | |
} else if (found->second.state == DequeueState::eFailed || | |
found->second.state == DequeueState::eWasEmpty) { | |
_pendingDequeues.erase(id); | |
} | |
return std::nullopt; | |
} | |
bool connectWith(const std::string& target) { | |
if (_members.find(target) != _members.end() || target == _path) { | |
return false; | |
} | |
std::optional<int> remote = create_client_socket(target.c_str()); | |
if (!remote.has_value()) { | |
fprintf(stderr, "Failed to connect to %s, errno=%d", target.c_str(), getAndResetErrNo()); | |
return false; | |
} | |
Member m; | |
m.fd = remote.value(); | |
assert(set_non_blocking(m.fd, true)); | |
_members[target] = m; | |
ReqBuilder build; | |
build.pushU8(OPER_CONNECT); | |
build.pushStr(_path); | |
build.putSize(); | |
write(m.fd, build.getTotalReqData(), build.getTotalReqSize()); | |
if (errNoIsWouldBlock()) { | |
fprintf(stderr, "Got unexpected would block when sending connect with to %s\n", target.c_str()); | |
return false; | |
} | |
return true; | |
} | |
void checkAndRetryWrites() { | |
for (auto& [ key, conn ] : _members) { | |
if (conn.pendingToBeSent.size() > 0) { | |
for (auto it = conn.pendingToBeSent.begin(); it != conn.pendingToBeSent.end(); ) { | |
write(conn.fd, it->getTotalReqData(), it->getTotalReqSize()); | |
if (errNoIsWouldBlock()) { | |
// failed write | |
_metrics.blockedWrites++; | |
++it; | |
} else { | |
it = conn.pendingToBeSent.erase(it); | |
} | |
} | |
} | |
} | |
} | |
uint64_t getAndIncrementReplId() { return ++_lastRepId; } | |
void replicateToAllMembers(const std::optional<std::string>& job) { | |
if (!_isCaptain) return; | |
ReqBuilder build; | |
if (job.has_value()) { | |
build.pushU8(OPER_REPL_ENQUEUE); | |
build.pushU64(getAndIncrementReplId()); | |
build.pushStr(job.value()); | |
} else { | |
build.pushU8(OPER_REPL_DEQUEUE); | |
build.pushU64(getAndIncrementReplId()); | |
} | |
build.putSize(); | |
// Send to all members | |
for (auto& [ key, conn ] : _members) { | |
write(conn.fd, build.getTotalReqData(), build.getTotalReqSize()); | |
if (errNoIsWouldBlock()) { | |
++_metrics.blockedWrites; | |
conn.pendingToBeSent.push_back(build); | |
} | |
} | |
} | |
static std::vector<ClusterNode> createCluster(const std::vector<std::string>& clusterMembers, | |
const std::string& firstCaptain, | |
size_t queueSize = 50000) { | |
std::vector<ClusterNode> nodes; | |
for (const auto& memb: clusterMembers) { | |
ClusterNode n(memb, queueSize); | |
if (memb == firstCaptain) { | |
n._isCaptain = true; | |
} | |
n._currentCaptain = firstCaptain; | |
nodes.push_back(n); | |
} | |
for (auto& node: nodes) { | |
node.start(); | |
} | |
/// now connect | |
for (size_t i = 0; i < nodes.size() - 1; ++i) | |
{ | |
for (size_t j = i + 1; j < nodes.size(); ++j) | |
{ | |
nodes[i].connectWith(nodes[j]._path); | |
nodes[j].doWork(); | |
} | |
} | |
return nodes; | |
} | |
void start() { | |
std::optional<int> fd = create_server_socket(_path.c_str()); | |
if (fd.has_value()) { | |
_server = *fd; | |
assert(set_non_blocking(_server, true)); | |
_isStarted = true; | |
} else { | |
fprintf(stderr, "Cannot create server socket, errno=%d\n", getAndResetErrNo()); | |
exit(2); | |
} | |
} | |
const std::string& getPath() const { return _path; } | |
int getServerFd() const { return _server; } | |
void collectRequests() { | |
assert(_isStarted); | |
struct sockaddr_un remote; | |
unsigned int sock_len = 0; | |
int incoming = accept(_server, (struct sockaddr*)&remote, &sock_len); | |
// todo make loop | |
if( incoming == -1 ) { | |
if(errNoIsWouldBlock()) { | |
resetErrNo(); | |
} else { | |
exitAndErrorNo("Could not bind socket"); | |
} | |
} else { | |
set_non_blocking(incoming, true); | |
_pending.push_back(incoming); | |
} | |
std::vector<int> toKeep; | |
// manual polling for now | |
for (size_t i = 0; i < _pending.size(); ++i) { | |
std::vector<unsigned char> bytes; | |
uint32_t req_size = 0; | |
read(_pending[i], &req_size, sizeof(req_size)); | |
if (errNoIsWouldBlock()) { | |
resetErrNo(); | |
toKeep.push_back(_pending[i]); | |
continue; | |
} | |
bytes.resize(req_size); | |
read(_pending[i], bytes.data(), req_size); | |
if (errNoIsWouldBlock()) { | |
exitAndErrorNo("Unexpected lack of body of request on pending"); | |
} | |
std::vector<ReqItem> reqItems; | |
buildFromBytes(reqItems, bytes.data(), bytes.size()); | |
ClusterRequest req; | |
req.conn = std::make_optional<int>(_pending[i]); | |
req.req = reqItems; | |
_requests.push_back(req); | |
} | |
_pending = toKeep; | |
for (const auto& [ key, conn ] : _members) { | |
while (readRequestFromMember(key, conn)) { | |
// todo some timing , but why not keep reading? | |
} | |
} | |
} | |
bool readRequestFromMember(const std::string& name, const Member& mem) { | |
std::vector<unsigned char> bytes; | |
uint32_t req_size = 0; | |
read(mem.fd, &req_size, sizeof(req_size)); | |
if (errNoIsWouldBlock()) { | |
resetErrNo(); | |
return false; | |
} | |
bytes.resize(req_size); | |
read(mem.fd, bytes.data(), req_size); | |
if (errNoIsWouldBlock()) { | |
exitAndErrorNo("Unexpected lack of body of request on formed member"); | |
} | |
std::vector<ReqItem> reqItems; | |
buildFromBytes(reqItems, bytes.data(), bytes.size()); | |
ClusterRequest req; | |
req.sender = std::make_optional<std::string>(name); | |
req.req = reqItems; | |
_requests.push_back(req); | |
return true; | |
} | |
void listMembersToVec(std::vector<std::string>& membs) { | |
for (const auto& [ key, conn ] : _members) { | |
membs.push_back(key); | |
} | |
} | |
void processConnectRequest(const ClusterRequest& req) { | |
const std::string sender = std::get<std::string>(req.req[1]); | |
if (!req.conn.has_value()) { | |
fprintf(stderr, "Unable to get socket fd for member=%s\n", sender.c_str()); | |
exit(3); | |
} | |
auto found = _members.find(sender); | |
if (found != _members.end()) { | |
printf("Got reconnect for member=%s\n", sender.c_str()); | |
close(found->second.fd); | |
found->second.fd = req.conn.value(); | |
return; | |
} | |
Member m; | |
m.fd = req.conn.value(); | |
_members[sender] = m; | |
} | |
void processEnqueueRespRequest(const ClusterRequest& req) { | |
if (_isCaptain) { | |
//error | |
return; | |
} | |
if (std::get_if<uint64_t>(&req.req[1]) == nullptr) { | |
//error bad request!! | |
return; | |
} | |
uint64_t id = std::get<uint64_t>(req.req[1]); | |
if (_pendingEnqueues.count(id) < 1) { | |
// error bad request!! | |
return; | |
} | |
_pendingEnqueues.erase(id); | |
return; | |
} | |
void processEnqueueRequest(const ClusterRequest& req) { | |
if (!_isCaptain) { | |
//error | |
return; | |
} | |
if (!req.sender.has_value()) { | |
fprintf(stderr, "Unexpected enqueue from unknown member!"); | |
return; | |
} | |
auto found = _members.find(req.sender.value()); | |
if (found == _members.end()) { | |
fprintf(stderr, "Enqueue from unknown member=%s\n", req.sender.value().c_str()); | |
return; | |
} | |
std::optional<std::string> job = std::make_optional<std::string>(std::get<std::string>(req.req[2])); | |
_diskQueue->push(job.value()); | |
ReqBuilder build; | |
build.pushU8(OPER_ENQUEUE_RESP); | |
build.pushU64(std::get<uint64_t>(req.req[1])); | |
build.putSize(); | |
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize()); | |
if (errNoIsWouldBlock()) { | |
resetErrNo(); | |
found->second.pendingToBeSent.push_back(build); | |
++_metrics.blockedWrites; | |
return; | |
} | |
replicateToAllMembers(job); | |
} | |
void processDequeueRespRequest(const ClusterRequest& req) { | |
if (_isCaptain) { | |
//error | |
return; | |
} | |
if (std::get_if<uint64_t>(&req.req[1]) == nullptr) { | |
//error bad request!! | |
return; | |
} | |
uint64_t id = std::get<uint64_t>(req.req[1]); | |
auto found = _pendingDequeues.find(id); | |
if (found == _pendingDequeues.end()) { | |
fprintf(stderr, "Unknown id=%llu for Dequeue\n", id); | |
return; | |
} | |
const unsigned char result = std::get<unsigned char>(req.req[2]); | |
if (result) { | |
found->second.state = DequeueState::eSuccess; | |
found->second.job = std::get<std::string>(req.req[3]); | |
} else { | |
found->second.state = DequeueState::eWasEmpty; | |
} | |
return; | |
} | |
void processDequeueRequest(const ClusterRequest& req) { | |
if (!_isCaptain) { | |
//error | |
return; | |
} | |
if (!req.sender.has_value()) { | |
fprintf(stderr, "Unexpected dequeue from unknown member!"); | |
return; | |
} | |
bool doReplication = false; | |
auto found = _members.find(req.sender.value()); | |
if (found == _members.end()) { | |
fprintf(stderr, "Dequeue from unknown member=%s\n", req.sender.value().c_str()); | |
return; | |
} | |
ReqBuilder build; | |
build.pushU8(OPER_DEQUEUE_RESP); | |
build.pushU64(std::get<uint64_t>(req.req[1])); | |
std::string popped; | |
if (_diskQueue->pop(popped)) { | |
build.pushU8(1); // Queue had item | |
build.pushStr(popped); | |
doReplication = true; | |
} else { | |
build.pushU8(0); // Queue was empty. | |
} | |
build.putSize(); | |
write(found->second.fd, build.getTotalReqData(), build.getTotalReqSize()); | |
if (errNoIsWouldBlock()) { | |
resetErrNo(); | |
++_metrics.blockedWrites; | |
found->second.pendingToBeSent.push_back(build); | |
return; | |
} | |
if (doReplication) { | |
replicateToAllMembers(std::nullopt); | |
} | |
} | |
void processReplDequeueRequest(const ClusterRequest& req) { | |
if (!req.sender.has_value()) { | |
fprintf(stderr, "Got replication from unknown member!\n"); | |
return; | |
} | |
if (req.sender.value() != _currentCaptain) { | |
fprintf(stderr, "Got replication from non-captain member=%s\n", req.sender.value().c_str()); | |
return; | |
} | |
_lastRepId = std::get<uint64_t>(req.req[1]); | |
std::string popped; | |
(void)_diskQueue->pop(popped); | |
//printf("Processed replication for %llu from %s\n", _lastRepId, | |
//req.sender.value().c_str()); | |
} | |
void processReplEnqueueRequest(const ClusterRequest& req) { | |
if (!req.sender.has_value()) { | |
fprintf(stderr, "Got replication from unknown member!\n"); | |
return; | |
} | |
if (req.sender.value() != _currentCaptain) { | |
fprintf(stderr, "Got replication from non-captain member=%s\n", req.sender.value().c_str()); | |
return; | |
} | |
_lastRepId = std::get<uint64_t>(req.req[1]); | |
const std::string gotJob = std::get<std::string>(req.req[2]); | |
_diskQueue->push(gotJob); | |
//printf("Processed replication for %llu from %s\n", _lastRepId, | |
//req.sender.value().c_str()); | |
} | |
void processRequests() { | |
for (const auto& req: _requests) { | |
// todo error handle | |
const unsigned char opCode = std::get<unsigned char>(req.req[0]); | |
switch (opCode) { | |
case OPER_CONNECT: | |
processConnectRequest(req); | |
break; | |
case OPER_ENQUEUE: | |
processEnqueueRequest(req); | |
break; | |
case OPER_ENQUEUE_RESP: | |
processEnqueueRespRequest(req); | |
break; | |
case OPER_DEQUEUE: | |
processDequeueRequest(req); | |
break; | |
case OPER_DEQUEUE_RESP: | |
processDequeueRespRequest(req); | |
break; | |
case OPER_REPL_ENQUEUE: | |
processReplEnqueueRequest(req); | |
break; | |
case OPER_REPL_DEQUEUE: | |
processReplDequeueRequest(req); | |
break; | |
default: | |
fprintf(stderr, "Unknown request code %u\n", opCode); | |
exit(3); | |
} | |
} | |
_requests.clear(); | |
} | |
void doWork() { | |
collectRequests(); | |
processRequests(); | |
checkAndRetryWrites(); | |
// check state | |
} | |
const OperationMetrics& getMetrics() const { return _metrics; } | |
private: | |
friend class ClusterNodeTests; | |
bool _isStarted = false; | |
int _server = -1; | |
bool _isCaptain = false; | |
bool _removeQueueFile = false; | |
uint64_t _lastRepId = 0; | |
std::string _path; | |
std::string _currentCaptain; | |
std::vector<int> _pending; | |
std::unordered_set<uint64_t> _pendingEnqueues; | |
DequeueMap _pendingDequeues; | |
std::vector<ClusterRequest> _requests; | |
std::unordered_map<std::string, Member> _members; | |
std::optional<ElectionMap> _electionMap; | |
std::optional<DiskBackedStringQueue> _diskQueue; | |
OperationMetrics _metrics; | |
}; | |
//------- tests --------- | |
static unsigned _failures = 0; | |
static unsigned _test_passes = 0; | |
static void check_cond(int cond, const char* condstr, unsigned line) { | |
if (!cond) { | |
fprintf(stderr, "Failed cond '%s' at line %u\n", condstr, line); | |
++_failures; | |
} else { | |
++_test_passes; | |
} | |
} | |
#define CHECKIT(cnd) check_cond(cnd, #cnd, __LINE__) | |
static void test_reqBuilder() { | |
ReqBuilder b; | |
b.pushU8(3); | |
std::string foo = "foo"; | |
std::vector<std::string> foos = {"abc", "def"}; | |
b.pushStr(foo); | |
b.pushStrList(foos); | |
b.putSize(); | |
std::vector<ReqItem> elems; | |
buildFromBytes(elems, b.getReqData(), b.getReqSize()); | |
CHECKIT(elems.size() == 3); | |
CHECKIT(std::get_if<unsigned char>(&elems[0]) != nullptr); | |
CHECKIT(std::get_if<std::string>(&elems[1]) != nullptr); | |
CHECKIT(std::get_if<std::vector<std::string>>(&elems[2]) != nullptr); | |
} | |
class ClusterNodeTests { | |
public: | |
static void testCreateCluster(); | |
static void testEnqueueProxy(); | |
static void testDequeueProxy(); | |
static void testDiskQueue(); | |
}; | |
void ClusterNodeTests::testCreateCluster() { | |
std::vector<std::string> membs = {"foo", "bar", "sam"}; | |
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]); | |
for (auto& memb : cluster) { | |
memb.setRemoveQueueFile(true); | |
} | |
CHECKIT(cluster[0]._members.size() == 2); | |
CHECKIT(cluster[1]._members.size() == 2); | |
CHECKIT(cluster[2]._members.size() == 2); | |
CHECKIT(cluster[0]._isCaptain == true); | |
CHECKIT(cluster[0]._members.find(cluster[1]._path) != cluster[0]._members.end()); | |
CHECKIT(cluster[0]._members.find(cluster[2]._path) != cluster[0]._members.end()); | |
CHECKIT(cluster[1]._members.find(cluster[0]._path) != cluster[1]._members.end()); | |
CHECKIT(cluster[1]._members.find(cluster[2]._path) != cluster[1]._members.end()); | |
CHECKIT(cluster[2]._members.find(cluster[0]._path) != cluster[2]._members.end()); | |
CHECKIT(cluster[2]._members.find(cluster[1]._path) != cluster[2]._members.end()); | |
} | |
void ClusterNodeTests::testEnqueueProxy() { | |
printf("Enqueue data starting\n"); | |
std::vector<std::string> membs = {"foo", "bar", "sam"}; | |
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]); | |
for (auto& memb : cluster) { | |
memb.setRemoveQueueFile(true); | |
} | |
const std::string job = "job"; | |
CHECKIT(cluster[0]._isCaptain == true); | |
CHECKIT(cluster[0]._diskQueue->usedSpace() == 0); | |
CHECKIT(cluster[0].enqueue(job).state == EnqueueState::eSuccess); | |
CHECKIT(cluster[0]._diskQueue->usedSpace() == 7); | |
CHECKIT(cluster[0]._diskQueue->hasItems()); | |
// Replication check | |
cluster[1].doWork(); | |
cluster[2].doWork(); | |
CHECKIT(cluster[0]._lastRepId == 1); | |
CHECKIT(cluster[1]._lastRepId == cluster[0]._lastRepId); | |
CHECKIT(cluster[2]._lastRepId == cluster[0]._lastRepId); | |
EnqueueResult jobid = cluster[1].enqueue(job); | |
CHECKIT(jobid.id.has_value()); | |
CHECKIT(cluster[1]._pendingEnqueues.count(jobid.id.value()) == 1); | |
cluster[0].doWork(); | |
CHECKIT(cluster[0]._diskQueue->usedSpace() == 14); | |
cluster[1].doWork(); | |
cluster[2].doWork(); | |
CHECKIT(cluster[1]._pendingEnqueues.size() == 0); | |
CHECKIT(cluster[1]._lastRepId == cluster[0]._lastRepId); | |
CHECKIT(cluster[2]._lastRepId == cluster[0]._lastRepId); | |
} | |
void ClusterNodeTests::testDequeueProxy() { | |
std::vector<std::string> membs = {"foo", "bar", "sam"}; | |
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]); | |
for (auto& memb : cluster) { | |
memb.setRemoveQueueFile(true); | |
} | |
const std::string job = "job"; | |
CHECKIT(cluster[0]._isCaptain == true); | |
CHECKIT(cluster[0].enqueue(job).state == EnqueueState::eSuccess); | |
CHECKIT(cluster[0]._diskQueue->usedSpace() == 7); | |
CHECKIT(cluster[0]._diskQueue->hasItems()); | |
DequeueResult res = cluster[1].dequeue(); | |
CHECKIT(res.state == DequeueState::eInProgress); | |
CHECKIT(res.id.has_value()); | |
cluster[0].doWork(); | |
CHECKIT(cluster[0]._diskQueue->usedSpace() == 0); | |
CHECKIT(cluster[1].getAndPossiblyClearDequeue(res.id.value()) == std::nullopt); | |
cluster[1].doWork(); | |
CHECKIT(cluster[1].getAndPossiblyClearDequeue(res.id.value()).value() == job); | |
} | |
void ClusterNodeTests::testDiskQueue() { | |
static constexpr size_t maxSize = 500; | |
static const std::string jobSample = "{a job}"; | |
static const std::string qPath = "qpath.dat"; | |
{ | |
DiskBackedStringQueue q(qPath, maxSize); | |
CHECKIT(q.usedSpace() == 0); | |
CHECKIT(q.push(jobSample)); | |
std::string result; | |
CHECKIT(q.pop(result)); | |
CHECKIT(result == jobSample); | |
CHECKIT(!q.pop(result)); | |
CHECKIT(!q.hasItems()); | |
CHECKIT(q.push(jobSample)); | |
q.cleanup(); | |
} | |
{ | |
DiskBackedStringQueue q(qPath, maxSize); | |
std::string result; | |
CHECKIT(q.hasItems()); | |
CHECKIT(q.pop(result)); | |
CHECKIT(result == jobSample); | |
q.cleanup(); | |
q.deleteFile(); | |
} | |
} | |
class PerformanceTests { | |
public: | |
static void oneProducerTwoConsumer(); | |
}; | |
void PerformanceTests::oneProducerTwoConsumer() { | |
puts("Starting Perf Tests!"); | |
std::vector<std::string> membs = {"foo", "bar", "sam"}; | |
std::vector<ClusterNode> cluster = ClusterNode::createCluster(membs, membs[0]); | |
for (auto& memb : cluster) { | |
memb.setRemoveQueueFile(true); | |
} | |
std::atomic<bool> keepWorking(false); | |
auto start = std::chrono::high_resolution_clock::now(); | |
std::thread captainThread = std::thread([&](){ | |
const std::string myJob = "{the job to produce}"; | |
size_t jobsQueued = 0; | |
while(!keepWorking.load()); | |
while(keepWorking) { | |
cluster[0].enqueue(myJob); | |
cluster[0].doWork(); | |
++jobsQueued; | |
//std::this_thread::sleep_for(std::chrono::milliseconds(10)); | |
} | |
printf("The captain enqueued %zu jobs\n", jobsQueued); | |
printf("The captain had %zu blocked writes\n", cluster[0].getMetrics().blockedWrites); | |
}); | |
std::thread member1Thread = std::thread([&](){ | |
std::vector<uint64_t> deqRequests; | |
size_t jobsGot = 0; | |
while(!keepWorking.load()); | |
while (keepWorking.load()) { | |
for (auto it = deqRequests.begin(); it != deqRequests.end(); ) { | |
std::optional<std::string> got = cluster[1].getAndPossiblyClearDequeue(*it); | |
if (got.has_value()) { | |
it = deqRequests.erase(it); | |
++jobsGot; | |
} else { | |
++it; | |
} | |
} | |
DequeueResult r = cluster[1].dequeue(); | |
if (r.state == DequeueState::eInProgress) { | |
deqRequests.push_back(r.id.value()); | |
} | |
cluster[1].doWork(); | |
} | |
printf("The member got %zu jobs\n", jobsGot); | |
printf("The member had %zu blocked writes\n", cluster[1].getMetrics().blockedWrites); | |
}); | |
std::thread member2Thread = std::thread([&](){ | |
std::vector<uint64_t> deqRequests; | |
size_t jobsGot = 0; | |
while(!keepWorking.load()); | |
while (keepWorking.load()) { | |
for (auto it = deqRequests.begin(); it != deqRequests.end(); ) { | |
std::optional<std::string> got = cluster[2].getAndPossiblyClearDequeue(*it); | |
if (got.has_value()) { | |
it = deqRequests.erase(it); | |
++jobsGot; | |
} else { | |
++it; | |
} | |
} | |
DequeueResult r = cluster[2].dequeue(); | |
if (r.state == DequeueState::eInProgress) { | |
deqRequests.push_back(r.id.value()); | |
} | |
cluster[2].doWork(); | |
} | |
printf("The member got %zu jobs\n", jobsGot); | |
printf("The member had %zu blocked writes\n", cluster[2].getMetrics().blockedWrites); | |
}); | |
keepWorking.store(true); | |
std::this_thread::sleep_for(std::chrono::seconds(1)); | |
keepWorking.store(false); | |
captainThread.join(); | |
member1Thread.join(); | |
member2Thread.join(); | |
auto end = std::chrono::high_resolution_clock::now(); | |
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start); | |
printf("Test took %lld microseconds\n", duration.count()); | |
} | |
int main(int argc, char const *argv[]) | |
{ | |
srand(time(nullptr)); | |
if (argc == 2 && strcmp(argv[1], "tests") == 0) { | |
test_reqBuilder(); | |
ClusterNodeTests::testCreateCluster(); | |
ClusterNodeTests::testDiskQueue(); | |
ClusterNodeTests::testEnqueueProxy(); | |
ClusterNodeTests::testDequeueProxy(); | |
if (_failures > 0) { | |
fprintf(stderr, "Total failures %u\n", _failures); | |
} | |
printf("Total tests passed: %u\n", _test_passes); | |
} else if (argc == 2 && strcmp(argv[1], "perf") == 0) { | |
PerformanceTests::oneProducerTwoConsumer(); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment