Skip to content

Instantly share code, notes, and snippets.

@jerrylususu
Last active August 28, 2024 15:57
Show Gist options
  • Save jerrylususu/43b07e5849d14d68a72e8f99ed3f4c80 to your computer and use it in GitHub Desktop.
Save jerrylususu/43b07e5849d14d68a72e8f99ed3f4c80 to your computer and use it in GitHub Desktop.
custom tcp protocol client with conn pool by claude
// MyProtocol.h
#pragma once
#include <array>
#include <vector>
#include <string>
#include <memory>
#include <chrono>
#include <mutex>
#include <queue>
#include <cstring>
#include <cstdint>
#include <arpa/inet.h>
#include <sys/epoll.h>
constexpr uint32_t MAGIC_NUMBER = 0x12345678;
enum class ErrorCode {
SUCCESS = 0,
CONNECTION_FAILED,
SEND_FAILED,
RECEIVE_FAILED,
TIMEOUT,
INVALID_MAGIC_NUMBER,
NO_AVAILABLE_CONNECTIONS,
EPOLL_ERROR
};
struct Request {
std::vector<uint8_t> payload;
};
struct Response {
std::vector<uint8_t> payload;
};
class Connection {
public:
Connection(const std::string& host, int port);
~Connection();
ErrorCode connect();
void disconnect();
ErrorCode send(const Request& req);
ErrorCode receive(Response& resp);
int getSockfd() const { return sockfd_; }
private:
int sockfd_;
std::string host_;
int port_;
bool connected_;
};
class ConnectionPool {
public:
ConnectionPool(const std::string& host, int port, size_t pool_size);
std::shared_ptr<Connection> getConnection(ErrorCode& ec);
void releaseConnection(std::shared_ptr<Connection> conn);
private:
std::string host_;
int port_;
std::queue<std::shared_ptr<Connection>> pool_;
std::mutex mutex_;
};
class MyProtocolClient {
public:
MyProtocolClient(const std::string& host, int port, size_t pool_size = 5);
~MyProtocolClient();
ErrorCode sendAndReceive(const Request& req, Response& resp, std::chrono::milliseconds timeout = std::chrono::seconds(30));
private:
ConnectionPool pool_;
int epoll_fd_;
};
// MyProtocol.cpp
#include "MyProtocol.h"
#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <fcntl.h>
Connection::Connection(const std::string& host, int port)
: host_(host), port_(port), connected_(false), sockfd_(-1) {}
Connection::~Connection() {
disconnect();
}
ErrorCode Connection::connect() {
sockfd_ = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd_ == -1) {
return ErrorCode::CONNECTION_FAILED;
}
sockaddr_in server_addr{};
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port_);
if (inet_pton(AF_INET, host_.c_str(), &server_addr.sin_addr) <= 0) {
return ErrorCode::CONNECTION_FAILED;
}
if (::connect(sockfd_, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
return ErrorCode::CONNECTION_FAILED;
}
connected_ = true;
return ErrorCode::SUCCESS;
}
void Connection::disconnect() {
if (connected_) {
close(sockfd_);
connected_ = false;
}
}
ErrorCode Connection::send(const Request& req) {
if (!connected_) {
return ErrorCode::CONNECTION_FAILED;
}
std::array<uint8_t, 8> header;
uint32_t magic = htonl(MAGIC_NUMBER);
uint32_t length = htonl(static_cast<uint32_t>(req.payload.size()));
std::memcpy(header.data(), &magic, 4);
std::memcpy(header.data() + 4, &length, 4);
if (::send(sockfd_, header.data(), header.size(), 0) == -1) {
return ErrorCode::SEND_FAILED;
}
if (::send(sockfd_, req.payload.data(), req.payload.size(), 0) == -1) {
return ErrorCode::SEND_FAILED;
}
return ErrorCode::SUCCESS;
}
ErrorCode Connection::receive(Response& resp) {
if (!connected_) {
return ErrorCode::CONNECTION_FAILED;
}
std::array<uint8_t, 8> header;
ssize_t bytes_received = recv(sockfd_, header.data(), header.size(), 0);
if (bytes_received != 8) {
return ErrorCode::RECEIVE_FAILED;
}
uint32_t received_magic;
uint32_t payload_length;
std::memcpy(&received_magic, header.data(), 4);
std::memcpy(&payload_length, header.data() + 4, 4);
received_magic = ntohl(received_magic);
payload_length = ntohl(payload_length);
if (received_magic != MAGIC_NUMBER) {
return ErrorCode::INVALID_MAGIC_NUMBER;
}
resp.payload.resize(payload_length);
size_t total_received = 0;
while (total_received < payload_length) {
bytes_received = recv(sockfd_, resp.payload.data() + total_received,
payload_length - total_received, 0);
if (bytes_received <= 0) {
return ErrorCode::RECEIVE_FAILED;
}
total_received += bytes_received;
}
return ErrorCode::SUCCESS;
}
ConnectionPool::ConnectionPool(const std::string& host, int port, size_t pool_size)
: host_(host), port_(port) {
for (size_t i = 0; i < pool_size; ++i) {
auto conn = std::make_shared<Connection>(host, port);
ErrorCode ec = conn->connect();
if (ec == ErrorCode::SUCCESS) {
pool_.push(conn);
}
}
}
std::shared_ptr<Connection> ConnectionPool::getConnection(ErrorCode& ec) {
std::lock_guard<std::mutex> lock(mutex_);
if (pool_.empty()) {
ec = ErrorCode::NO_AVAILABLE_CONNECTIONS;
return nullptr;
}
auto conn = pool_.front();
pool_.pop();
ec = ErrorCode::SUCCESS;
return conn;
}
void ConnectionPool::releaseConnection(std::shared_ptr<Connection> conn) {
std::lock_guard<std::mutex> lock(mutex_);
pool_.push(conn);
}
MyProtocolClient::MyProtocolClient(const std::string& host, int port, size_t pool_size)
: pool_(host, port, pool_size) {
epoll_fd_ = epoll_create1(0);
if (epoll_fd_ == -1) {
// Handle epoll creation error
}
}
MyProtocolClient::~MyProtocolClient() {
if (epoll_fd_ != -1) {
close(epoll_fd_);
}
}
ErrorCode MyProtocolClient::sendAndReceive(const Request& req, Response& resp, std::chrono::milliseconds timeout) {
ErrorCode ec;
auto conn = pool_.getConnection(ec);
if (ec != ErrorCode::SUCCESS) {
return ec;
}
int flags = fcntl(conn->getSockfd(), F_GETFL, 0);
fcntl(conn->getSockfd(), F_SETFL, flags | O_NONBLOCK);
ec = conn->send(req);
if (ec != ErrorCode::SUCCESS) {
pool_.releaseConnection(conn);
return ec;
}
epoll_event ev;
ev.events = EPOLLIN;
ev.data.fd = conn->getSockfd();
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, conn->getSockfd(), &ev) == -1) {
pool_.releaseConnection(conn);
return ErrorCode::EPOLL_ERROR;
}
epoll_event events[1];
int n = epoll_wait(epoll_fd_, events, 1, timeout.count());
if (n == 0) {
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, conn->getSockfd(), nullptr);
pool_.releaseConnection(conn);
return ErrorCode::TIMEOUT;
} else if (n == -1) {
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, conn->getSockfd(), nullptr);
pool_.releaseConnection(conn);
return ErrorCode::EPOLL_ERROR;
}
ec = conn->receive(resp);
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, conn->getSockfd(), nullptr);
pool_.releaseConnection(conn);
return ec;
}
#include <iostream>
#include <vector>
#include <unordered_map>
#include <mutex>
#include <condition_variable>
#include <chrono>
#include <thread>
#include <atomic>
#include <memory>
#include <cstring>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <unistd.h>
// 错误码枚举
enum ErrorCode {
SUCCESS = 0,
SOCKET_CREATE_ERROR,
CONNECT_ERROR,
SEND_ERROR,
RECEIVE_ERROR,
POOL_FULL_ERROR,
CONNECTION_TIMEOUT_ERROR
};
// TCP连接类
class TCPConnection {
private:
int sock;
std::string ip;
int port;
std::chrono::steady_clock::time_point lastUsed;
public:
TCPConnection(const std::string& ip, int port) : ip(ip), port(port), sock(-1) {
updateLastUsed();
}
~TCPConnection() {
if (sock != -1) {
close(sock);
}
}
ErrorCode connect() {
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == -1) {
return SOCKET_CREATE_ERROR;
}
struct sockaddr_in server;
server.sin_addr.s_addr = inet_addr(ip.c_str());
server.sin_family = AF_INET;
server.sin_port = htons(port);
if (::connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) {
close(sock);
sock = -1;
return CONNECT_ERROR;
}
return SUCCESS;
}
ErrorCode sendAndReceive(const std::vector<uint8_t>& requestData, std::vector<uint8_t>& responseData, size_t expectedSize) {
// Send request
if (::send(sock, requestData.data(), requestData.size(), 0) < 0) {
return SEND_ERROR;
}
// Receive response
responseData.resize(expectedSize);
size_t totalReceived = 0;
while (totalReceived < expectedSize) {
int received = recv(sock, responseData.data() + totalReceived, expectedSize - totalReceived, 0);
if (received <= 0) {
return RECEIVE_ERROR;
}
totalReceived += received;
}
return SUCCESS;
}
void updateLastUsed() {
lastUsed = std::chrono::steady_clock::now();
}
bool isExpired(const std::chrono::seconds& timeout) const {
auto now = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::seconds>(now - lastUsed) > timeout;
}
std::string getKey() const {
return ip + ":" + std::to_string(port);
}
};
// 线程安全的连接池
class TCPConnectionPool {
private:
struct PoolArea {
std::unordered_map<std::string, std::vector<std::shared_ptr<TCPConnection>>> connections;
std::mutex mtx;
};
PoolArea areas[2];
std::atomic<int> activeArea{0};
size_t maxConnectionsPerEndpoint;
std::chrono::seconds connectionTimeout;
std::atomic<bool> running;
std::thread cleanupThread;
void cleanupExpiredConnections() {
while (running) {
std::this_thread::sleep_for(std::chrono::seconds(30)); // 每30秒检查一次
int currentArea = activeArea.load();
int inactiveArea = 1 - currentArea;
// 复制活跃区域到非活跃区域
{
std::lock_guard<std::mutex> lockActive(areas[currentArea].mtx);
std::lock_guard<std::mutex> lockInactive(areas[inactiveArea].mtx);
areas[inactiveArea].connections = areas[currentArea].connections;
}
// 清理非活跃区域中的过期连接
{
std::lock_guard<std::mutex> lock(areas[inactiveArea].mtx);
for (auto it = areas[inactiveArea].connections.begin(); it != areas[inactiveArea].connections.end();) {
auto& connections = it->second;
connections.erase(
std::remove_if(connections.begin(), connections.end(),
[this](const std::shared_ptr<TCPConnection>& conn) {
return conn->isExpired(connectionTimeout);
}),
connections.end()
);
if (connections.empty()) {
it = areas[inactiveArea].connections.erase(it);
} else {
++it;
}
}
}
// 切换活跃区域
activeArea.store(inactiveArea);
}
}
public:
TCPConnectionPool(size_t maxPerEndpoint = 10, std::chrono::seconds timeout = std::chrono::seconds(60))
: maxConnectionsPerEndpoint(maxPerEndpoint), connectionTimeout(timeout), running(true) {
cleanupThread = std::thread(&TCPConnectionPool::cleanupExpiredConnections, this);
}
~TCPConnectionPool() {
running = false;
if (cleanupThread.joinable()) {
cleanupThread.join();
}
}
ErrorCode getConnection(const std::string& ip, int port, std::shared_ptr<TCPConnection>& conn) {
int currentArea = activeArea.load();
std::unique_lock<std::mutex> lock(areas[currentArea].mtx);
std::string key = ip + ":" + std::to_string(port);
auto& connections = areas[currentArea].connections[key];
// 尝试找到一个可用的连接
for (auto it = connections.begin(); it != connections.end(); ++it) {
if (!(*it)->isExpired(connectionTimeout)) {
conn = *it;
connections.erase(it);
conn->updateLastUsed();
return SUCCESS;
}
}
// 如果没有可用连接且未达到上限,创建新连接
if (connections.size() < maxConnectionsPerEndpoint) {
conn = std::make_shared<TCPConnection>(ip, port);
ErrorCode result = conn->connect();
if (result == SUCCESS) {
return SUCCESS;
}
return result;
}
return POOL_FULL_ERROR;
}
void returnConnection(std::shared_ptr<TCPConnection> conn) {
int currentArea = activeArea.load();
std::lock_guard<std::mutex> lock(areas[currentArea].mtx);
std::string key = conn->getKey();
if (areas[currentArea].connections[key].size() < maxConnectionsPerEndpoint) {
areas[currentArea].connections[key].push_back(conn);
}
}
// 单例接口
static TCPConnectionPool& getInstance() {
static TCPConnectionPool instance;
return instance;
}
};
// 自定义协议客户端
class MyProtocolClient {
private:
TCPConnectionPool& pool;
public:
MyProtocolClient() : pool(TCPConnectionPool::getInstance()) {}
ErrorCode sendRequestAndReceiveResponse(const std::string& ip, int port,
const std::vector<uint8_t>& requestData,
std::vector<uint8_t>& responseData,
size_t expectedSize) {
std::shared_ptr<TCPConnection> conn;
ErrorCode result = pool.getConnection(ip, port, conn);
if (result != SUCCESS) {
return result;
}
result = conn->sendAndReceive(requestData, responseData, expectedSize);
pool.returnConnection(conn);
return result;
}
};
// 使用示例
int main() {
MyProtocolClient client1, client2;
// 模拟多线程环境
auto task = [](MyProtocolClient& client, const std::string& message) {
for (int i = 0; i < 10; ++i) {
std::vector<uint8_t> requestData(message.begin(), message.end());
std::vector<uint8_t> responseData;
ErrorCode result = client.sendRequestAndReceiveResponse("127.0.0.1", 8080, requestData, responseData, 10);
if (result == SUCCESS) {
std::string responseStr(responseData.begin(), responseData.end());
std::cout << "Request: " << message << ", Response: " << responseStr << std::endl;
} else {
std::cout << "Request-Response failed for: " << message << ", Error code: " << result << std::endl;
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
};
std::thread t1(task, std::ref(client1), "Hello from client1");
std::thread t2(task, std::ref(client2), "Hello from client2");
t1.join();
t2.join();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment