Last active
August 28, 2024 15:57
-
-
Save jerrylususu/43b07e5849d14d68a72e8f99ed3f4c80 to your computer and use it in GitHub Desktop.
custom tcp protocol client with conn pool by claude
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// MyProtocol.h | |
#pragma once | |
#include <array> | |
#include <vector> | |
#include <string> | |
#include <memory> | |
#include <chrono> | |
#include <mutex> | |
#include <queue> | |
#include <cstring> | |
#include <cstdint> | |
#include <arpa/inet.h> | |
#include <sys/epoll.h> | |
constexpr uint32_t MAGIC_NUMBER = 0x12345678; | |
enum class ErrorCode { | |
SUCCESS = 0, | |
CONNECTION_FAILED, | |
SEND_FAILED, | |
RECEIVE_FAILED, | |
TIMEOUT, | |
INVALID_MAGIC_NUMBER, | |
NO_AVAILABLE_CONNECTIONS, | |
EPOLL_ERROR | |
}; | |
struct Request { | |
std::vector<uint8_t> payload; | |
}; | |
struct Response { | |
std::vector<uint8_t> payload; | |
}; | |
class Connection { | |
public: | |
Connection(const std::string& host, int port); | |
~Connection(); | |
ErrorCode connect(); | |
void disconnect(); | |
ErrorCode send(const Request& req); | |
ErrorCode receive(Response& resp); | |
int getSockfd() const { return sockfd_; } | |
private: | |
int sockfd_; | |
std::string host_; | |
int port_; | |
bool connected_; | |
}; | |
class ConnectionPool { | |
public: | |
ConnectionPool(const std::string& host, int port, size_t pool_size); | |
std::shared_ptr<Connection> getConnection(ErrorCode& ec); | |
void releaseConnection(std::shared_ptr<Connection> conn); | |
private: | |
std::string host_; | |
int port_; | |
std::queue<std::shared_ptr<Connection>> pool_; | |
std::mutex mutex_; | |
}; | |
class MyProtocolClient { | |
public: | |
MyProtocolClient(const std::string& host, int port, size_t pool_size = 5); | |
~MyProtocolClient(); | |
ErrorCode sendAndReceive(const Request& req, Response& resp, std::chrono::milliseconds timeout = std::chrono::seconds(30)); | |
private: | |
ConnectionPool pool_; | |
int epoll_fd_; | |
}; | |
// MyProtocol.cpp | |
#include "MyProtocol.h" | |
#include <sys/socket.h> | |
#include <netinet/in.h> | |
#include <unistd.h> | |
#include <fcntl.h> | |
Connection::Connection(const std::string& host, int port) | |
: host_(host), port_(port), connected_(false), sockfd_(-1) {} | |
Connection::~Connection() { | |
disconnect(); | |
} | |
ErrorCode Connection::connect() { | |
sockfd_ = socket(AF_INET, SOCK_STREAM, 0); | |
if (sockfd_ == -1) { | |
return ErrorCode::CONNECTION_FAILED; | |
} | |
sockaddr_in server_addr{}; | |
server_addr.sin_family = AF_INET; | |
server_addr.sin_port = htons(port_); | |
if (inet_pton(AF_INET, host_.c_str(), &server_addr.sin_addr) <= 0) { | |
return ErrorCode::CONNECTION_FAILED; | |
} | |
if (::connect(sockfd_, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { | |
return ErrorCode::CONNECTION_FAILED; | |
} | |
connected_ = true; | |
return ErrorCode::SUCCESS; | |
} | |
void Connection::disconnect() { | |
if (connected_) { | |
close(sockfd_); | |
connected_ = false; | |
} | |
} | |
ErrorCode Connection::send(const Request& req) { | |
if (!connected_) { | |
return ErrorCode::CONNECTION_FAILED; | |
} | |
std::array<uint8_t, 8> header; | |
uint32_t magic = htonl(MAGIC_NUMBER); | |
uint32_t length = htonl(static_cast<uint32_t>(req.payload.size())); | |
std::memcpy(header.data(), &magic, 4); | |
std::memcpy(header.data() + 4, &length, 4); | |
if (::send(sockfd_, header.data(), header.size(), 0) == -1) { | |
return ErrorCode::SEND_FAILED; | |
} | |
if (::send(sockfd_, req.payload.data(), req.payload.size(), 0) == -1) { | |
return ErrorCode::SEND_FAILED; | |
} | |
return ErrorCode::SUCCESS; | |
} | |
ErrorCode Connection::receive(Response& resp) { | |
if (!connected_) { | |
return ErrorCode::CONNECTION_FAILED; | |
} | |
std::array<uint8_t, 8> header; | |
ssize_t bytes_received = recv(sockfd_, header.data(), header.size(), 0); | |
if (bytes_received != 8) { | |
return ErrorCode::RECEIVE_FAILED; | |
} | |
uint32_t received_magic; | |
uint32_t payload_length; | |
std::memcpy(&received_magic, header.data(), 4); | |
std::memcpy(&payload_length, header.data() + 4, 4); | |
received_magic = ntohl(received_magic); | |
payload_length = ntohl(payload_length); | |
if (received_magic != MAGIC_NUMBER) { | |
return ErrorCode::INVALID_MAGIC_NUMBER; | |
} | |
resp.payload.resize(payload_length); | |
size_t total_received = 0; | |
while (total_received < payload_length) { | |
bytes_received = recv(sockfd_, resp.payload.data() + total_received, | |
payload_length - total_received, 0); | |
if (bytes_received <= 0) { | |
return ErrorCode::RECEIVE_FAILED; | |
} | |
total_received += bytes_received; | |
} | |
return ErrorCode::SUCCESS; | |
} | |
ConnectionPool::ConnectionPool(const std::string& host, int port, size_t pool_size) | |
: host_(host), port_(port) { | |
for (size_t i = 0; i < pool_size; ++i) { | |
auto conn = std::make_shared<Connection>(host, port); | |
ErrorCode ec = conn->connect(); | |
if (ec == ErrorCode::SUCCESS) { | |
pool_.push(conn); | |
} | |
} | |
} | |
std::shared_ptr<Connection> ConnectionPool::getConnection(ErrorCode& ec) { | |
std::lock_guard<std::mutex> lock(mutex_); | |
if (pool_.empty()) { | |
ec = ErrorCode::NO_AVAILABLE_CONNECTIONS; | |
return nullptr; | |
} | |
auto conn = pool_.front(); | |
pool_.pop(); | |
ec = ErrorCode::SUCCESS; | |
return conn; | |
} | |
void ConnectionPool::releaseConnection(std::shared_ptr<Connection> conn) { | |
std::lock_guard<std::mutex> lock(mutex_); | |
pool_.push(conn); | |
} | |
MyProtocolClient::MyProtocolClient(const std::string& host, int port, size_t pool_size) | |
: pool_(host, port, pool_size) { | |
epoll_fd_ = epoll_create1(0); | |
if (epoll_fd_ == -1) { | |
// Handle epoll creation error | |
} | |
} | |
MyProtocolClient::~MyProtocolClient() { | |
if (epoll_fd_ != -1) { | |
close(epoll_fd_); | |
} | |
} | |
ErrorCode MyProtocolClient::sendAndReceive(const Request& req, Response& resp, std::chrono::milliseconds timeout) { | |
ErrorCode ec; | |
auto conn = pool_.getConnection(ec); | |
if (ec != ErrorCode::SUCCESS) { | |
return ec; | |
} | |
int flags = fcntl(conn->getSockfd(), F_GETFL, 0); | |
fcntl(conn->getSockfd(), F_SETFL, flags | O_NONBLOCK); | |
ec = conn->send(req); | |
if (ec != ErrorCode::SUCCESS) { | |
pool_.releaseConnection(conn); | |
return ec; | |
} | |
epoll_event ev; | |
ev.events = EPOLLIN; | |
ev.data.fd = conn->getSockfd(); | |
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, conn->getSockfd(), &ev) == -1) { | |
pool_.releaseConnection(conn); | |
return ErrorCode::EPOLL_ERROR; | |
} | |
epoll_event events[1]; | |
int n = epoll_wait(epoll_fd_, events, 1, timeout.count()); | |
if (n == 0) { | |
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, conn->getSockfd(), nullptr); | |
pool_.releaseConnection(conn); | |
return ErrorCode::TIMEOUT; | |
} else if (n == -1) { | |
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, conn->getSockfd(), nullptr); | |
pool_.releaseConnection(conn); | |
return ErrorCode::EPOLL_ERROR; | |
} | |
ec = conn->receive(resp); | |
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, conn->getSockfd(), nullptr); | |
pool_.releaseConnection(conn); | |
return ec; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <unordered_map> | |
#include <mutex> | |
#include <condition_variable> | |
#include <chrono> | |
#include <thread> | |
#include <atomic> | |
#include <memory> | |
#include <cstring> | |
#include <arpa/inet.h> | |
#include <sys/socket.h> | |
#include <unistd.h> | |
// 错误码枚举 | |
enum ErrorCode { | |
SUCCESS = 0, | |
SOCKET_CREATE_ERROR, | |
CONNECT_ERROR, | |
SEND_ERROR, | |
RECEIVE_ERROR, | |
POOL_FULL_ERROR, | |
CONNECTION_TIMEOUT_ERROR | |
}; | |
// TCP连接类 | |
class TCPConnection { | |
private: | |
int sock; | |
std::string ip; | |
int port; | |
std::chrono::steady_clock::time_point lastUsed; | |
public: | |
TCPConnection(const std::string& ip, int port) : ip(ip), port(port), sock(-1) { | |
updateLastUsed(); | |
} | |
~TCPConnection() { | |
if (sock != -1) { | |
close(sock); | |
} | |
} | |
ErrorCode connect() { | |
sock = socket(AF_INET, SOCK_STREAM, 0); | |
if (sock == -1) { | |
return SOCKET_CREATE_ERROR; | |
} | |
struct sockaddr_in server; | |
server.sin_addr.s_addr = inet_addr(ip.c_str()); | |
server.sin_family = AF_INET; | |
server.sin_port = htons(port); | |
if (::connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { | |
close(sock); | |
sock = -1; | |
return CONNECT_ERROR; | |
} | |
return SUCCESS; | |
} | |
ErrorCode sendAndReceive(const std::vector<uint8_t>& requestData, std::vector<uint8_t>& responseData, size_t expectedSize) { | |
// Send request | |
if (::send(sock, requestData.data(), requestData.size(), 0) < 0) { | |
return SEND_ERROR; | |
} | |
// Receive response | |
responseData.resize(expectedSize); | |
size_t totalReceived = 0; | |
while (totalReceived < expectedSize) { | |
int received = recv(sock, responseData.data() + totalReceived, expectedSize - totalReceived, 0); | |
if (received <= 0) { | |
return RECEIVE_ERROR; | |
} | |
totalReceived += received; | |
} | |
return SUCCESS; | |
} | |
void updateLastUsed() { | |
lastUsed = std::chrono::steady_clock::now(); | |
} | |
bool isExpired(const std::chrono::seconds& timeout) const { | |
auto now = std::chrono::steady_clock::now(); | |
return std::chrono::duration_cast<std::chrono::seconds>(now - lastUsed) > timeout; | |
} | |
std::string getKey() const { | |
return ip + ":" + std::to_string(port); | |
} | |
}; | |
// 线程安全的连接池 | |
class TCPConnectionPool { | |
private: | |
struct PoolArea { | |
std::unordered_map<std::string, std::vector<std::shared_ptr<TCPConnection>>> connections; | |
std::mutex mtx; | |
}; | |
PoolArea areas[2]; | |
std::atomic<int> activeArea{0}; | |
size_t maxConnectionsPerEndpoint; | |
std::chrono::seconds connectionTimeout; | |
std::atomic<bool> running; | |
std::thread cleanupThread; | |
void cleanupExpiredConnections() { | |
while (running) { | |
std::this_thread::sleep_for(std::chrono::seconds(30)); // 每30秒检查一次 | |
int currentArea = activeArea.load(); | |
int inactiveArea = 1 - currentArea; | |
// 复制活跃区域到非活跃区域 | |
{ | |
std::lock_guard<std::mutex> lockActive(areas[currentArea].mtx); | |
std::lock_guard<std::mutex> lockInactive(areas[inactiveArea].mtx); | |
areas[inactiveArea].connections = areas[currentArea].connections; | |
} | |
// 清理非活跃区域中的过期连接 | |
{ | |
std::lock_guard<std::mutex> lock(areas[inactiveArea].mtx); | |
for (auto it = areas[inactiveArea].connections.begin(); it != areas[inactiveArea].connections.end();) { | |
auto& connections = it->second; | |
connections.erase( | |
std::remove_if(connections.begin(), connections.end(), | |
[this](const std::shared_ptr<TCPConnection>& conn) { | |
return conn->isExpired(connectionTimeout); | |
}), | |
connections.end() | |
); | |
if (connections.empty()) { | |
it = areas[inactiveArea].connections.erase(it); | |
} else { | |
++it; | |
} | |
} | |
} | |
// 切换活跃区域 | |
activeArea.store(inactiveArea); | |
} | |
} | |
public: | |
TCPConnectionPool(size_t maxPerEndpoint = 10, std::chrono::seconds timeout = std::chrono::seconds(60)) | |
: maxConnectionsPerEndpoint(maxPerEndpoint), connectionTimeout(timeout), running(true) { | |
cleanupThread = std::thread(&TCPConnectionPool::cleanupExpiredConnections, this); | |
} | |
~TCPConnectionPool() { | |
running = false; | |
if (cleanupThread.joinable()) { | |
cleanupThread.join(); | |
} | |
} | |
ErrorCode getConnection(const std::string& ip, int port, std::shared_ptr<TCPConnection>& conn) { | |
int currentArea = activeArea.load(); | |
std::unique_lock<std::mutex> lock(areas[currentArea].mtx); | |
std::string key = ip + ":" + std::to_string(port); | |
auto& connections = areas[currentArea].connections[key]; | |
// 尝试找到一个可用的连接 | |
for (auto it = connections.begin(); it != connections.end(); ++it) { | |
if (!(*it)->isExpired(connectionTimeout)) { | |
conn = *it; | |
connections.erase(it); | |
conn->updateLastUsed(); | |
return SUCCESS; | |
} | |
} | |
// 如果没有可用连接且未达到上限,创建新连接 | |
if (connections.size() < maxConnectionsPerEndpoint) { | |
conn = std::make_shared<TCPConnection>(ip, port); | |
ErrorCode result = conn->connect(); | |
if (result == SUCCESS) { | |
return SUCCESS; | |
} | |
return result; | |
} | |
return POOL_FULL_ERROR; | |
} | |
void returnConnection(std::shared_ptr<TCPConnection> conn) { | |
int currentArea = activeArea.load(); | |
std::lock_guard<std::mutex> lock(areas[currentArea].mtx); | |
std::string key = conn->getKey(); | |
if (areas[currentArea].connections[key].size() < maxConnectionsPerEndpoint) { | |
areas[currentArea].connections[key].push_back(conn); | |
} | |
} | |
// 单例接口 | |
static TCPConnectionPool& getInstance() { | |
static TCPConnectionPool instance; | |
return instance; | |
} | |
}; | |
// 自定义协议客户端 | |
class MyProtocolClient { | |
private: | |
TCPConnectionPool& pool; | |
public: | |
MyProtocolClient() : pool(TCPConnectionPool::getInstance()) {} | |
ErrorCode sendRequestAndReceiveResponse(const std::string& ip, int port, | |
const std::vector<uint8_t>& requestData, | |
std::vector<uint8_t>& responseData, | |
size_t expectedSize) { | |
std::shared_ptr<TCPConnection> conn; | |
ErrorCode result = pool.getConnection(ip, port, conn); | |
if (result != SUCCESS) { | |
return result; | |
} | |
result = conn->sendAndReceive(requestData, responseData, expectedSize); | |
pool.returnConnection(conn); | |
return result; | |
} | |
}; | |
// 使用示例 | |
int main() { | |
MyProtocolClient client1, client2; | |
// 模拟多线程环境 | |
auto task = [](MyProtocolClient& client, const std::string& message) { | |
for (int i = 0; i < 10; ++i) { | |
std::vector<uint8_t> requestData(message.begin(), message.end()); | |
std::vector<uint8_t> responseData; | |
ErrorCode result = client.sendRequestAndReceiveResponse("127.0.0.1", 8080, requestData, responseData, 10); | |
if (result == SUCCESS) { | |
std::string responseStr(responseData.begin(), responseData.end()); | |
std::cout << "Request: " << message << ", Response: " << responseStr << std::endl; | |
} else { | |
std::cout << "Request-Response failed for: " << message << ", Error code: " << result << std::endl; | |
} | |
std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |
} | |
}; | |
std::thread t1(task, std::ref(client1), "Hello from client1"); | |
std::thread t2(task, std::ref(client2), "Hello from client2"); | |
t1.join(); | |
t2.join(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment