Skip to content

Instantly share code, notes, and snippets.

@eahydra
Last active July 26, 2019 09:38
Show Gist options
  • Save eahydra/5141947 to your computer and use it in GitHub Desktop.
Save eahydra/5141947 to your computer and use it in GitHub Desktop.
a task thread pool base on Windows IOCP. Support asynchronous task with closure and timer.
#include "task_thread_pool.hpp"
#include <set>
#include <vector>
#include <tuple>
#include <algorithm>
#include <Windows.h>
#include <assert.h>
#include <mutex>
namespace base {
namespace detail {
typedef std::recursive_mutex lock_t;
typedef std::lock_guard<lock_t> guard_t;
struct iocp_context_t {
typedef std::function<void(DWORD, DWORD)> callback_t;
iocp_context_t() {
memset(&overlapped, 0, sizeof(overlapped));
}
void handler(DWORD status, DWORD bytes_transfered, OVERLAPPED* over_lapped) {
if (callback != nullptr) {
callback(status, bytes_transfered);
}
iocp_context_t* context = reinterpret_cast<iocp_context_t *>(over_lapped);
delete context;
}
OVERLAPPED overlapped;
callback_t callback;
};
class iocp_t {
public:
typedef std::function<unsigned int()> time_getter_t;
typedef std::function<void()> timeout_processor_t;
iocp_t() { }
~iocp_t() {}
bool start(int count, time_getter_t&& time_getter, timeout_processor_t&& timeout_processor) {
if (count <=0) {
SYSTEM_INFO system_info;
GetSystemInfo(&system_info);
count = system_info.dwNumberOfProcessors;
}
InterlockedExchange(&shutdown_, 0);
port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, count);
if (port_handle_ == nullptr) {
return false;
}
std::tuple<HANDLE, time_getter_t, timeout_processor_t>* param = new std::tuple<HANDLE, time_getter_t, timeout_processor_t>(port_handle_, time_getter, timeout_processor);
for (int i = 0; i < count; i++) {
HANDLE thread_handle = CreateThread(nullptr, 0, iocp_t::thread_proc, param, 0, nullptr);
if (thread_handle != nullptr) {
threads_.push_back(thread_handle);
}
}
return threads_.size() != 0;
}
void stop() {
if (shutdown_)
return;
InterlockedExchange(&shutdown_, 1);
for (size_t i = 0; i < threads_.size(); i++) {
PostQueuedCompletionStatus(port_handle_, 0, 0, nullptr);
}
WaitForMultipleObjects(threads_.size(), &threads_[0], FALSE, INFINITE);
std::for_each(threads_.begin(), threads_.end(), std::ptr_fun(CloseHandle));
threads_.clear();
CloseHandle(port_handle_);
port_handle_ = nullptr;
}
template <typename handler_t>
void post_task(handler_t&& handler)
{
iocp_context_t* context = new iocp_context_t;
context->callback = [handler](DWORD, DWORD){
handler();
};
if (!PostQueuedCompletionStatus(port_handle_, 0, 0, reinterpret_cast<OVERLAPPED *>(context))) {
delete context;
}
}
private:
static DWORD CALLBACK thread_proc(PVOID param)
{
auto config = static_cast<std::tuple<HANDLE, time_getter_t, timeout_processor_t>*>(param);
while (true) {
DWORD number_bytes = 0;
ULONG_PTR completion_key = 0;
DWORD status = ERROR_SUCCESS;
OVERLAPPED* over_lapped = nullptr;
unsigned int wait_time = (std::get<1>(*config))();
SetLastError(ERROR_SUCCESS);
BOOL result = GetQueuedCompletionStatus(
std::get<0>(*config), &number_bytes, &completion_key, &over_lapped, wait_time);
if (InterlockedCompareExchange(&shutdown_, 1, 1)) {
break;
}
status = GetLastError();
if (!result && status == WAIT_TIMEOUT) {
(std::get<2>(*config))();
continue;
}
if ((!result || over_lapped == nullptr) &&
(status == ERROR_ABANDONED_WAIT_0 ||
status == ERROR_INVALID_HANDLE))
{
assert(false);
break;
}
iocp_context_t* context = reinterpret_cast<iocp_context_t *>(over_lapped);
context->handler(status, number_bytes, over_lapped);
}
return 1;
}
private:
HANDLE port_handle_;
std::vector<HANDLE> threads_;
static long shutdown_;
};
long iocp_t::shutdown_ = 0;
class timer_queue_t {
public:
struct timer_t
{
typedef std::function<void(task_thread_pool_t::handle_t&)> on_timer_t;
unsigned long long sequence_id_;
unsigned int duration_;// 时钟间隔, 单位: 毫秒
unsigned int start_; // 时钟开始时间, 单位: 毫秒
on_timer_t on_timer_; // 时钟回调函数
long is_running_;
template <typename handler_t>
timer_t(unsigned int duration, handler_t&& handler)
: duration_(duration), start_(GetTickCount())
, on_timer_(std::forward<on_timer_t>(handler))
, is_running_(0) {}
};
typedef std::shared_ptr<timer_t> timer_ptr_t;
public:
timer_queue_t() : sequence_id_(0) { }
~timer_queue_t() { }
template <typename handler_t>
timer_ptr_t set_timer(handler_t&& handler, unsigned int duration) {
timer_ptr_t timer_ptr = std::make_shared<timer_t>(duration, std::move(handler));
guard_t guard(lock_);
timer_ptr->sequence_id_ = sequence_id_++;
timers_.insert(timer_ptr);
return timer_ptr;
}
void kill_timer(timer_ptr_t& timer_ptr) {
guard_t guard(lock_);
erase(timer_ptr.get());
timer_ptr.reset();
}
unsigned int timeout_next() {
unsigned int current = GetTickCount();
guard_t guard(lock_);
for (auto iter = timers_.begin(); iter != timers_.end(); ++iter) {
auto timer = *iter;
int ret_value = get_wait_time(timer->start_, timer->duration_, current);
if (ret_value > 0) {
return ret_value;
}
}
return INFINITE;
}
bool peek(timer_ptr_t& timer) {
unsigned int current = GetTickCount();
guard_t guard(lock_);
if (timers_.empty()) {
return false;
}
auto first_timer = *timers_.begin();
int ret_value = get_wait_time(first_timer->start_, first_timer->duration_, current);
if (ret_value <= 0) {
timer = first_timer;
timers_.erase(timers_.begin());
first_timer->start_ += first_timer->duration_;
timers_.insert(first_timer);
return true;
}
return false;
}
void do_all_timer() {
for (;;) {
timer_ptr_t timer_ptr;
if (peek(timer_ptr) && timer_ptr != nullptr) {
// 防止多线程同时执行同一个timer_t
if (InterlockedCompareExchange(&timer_ptr->is_running_, 1, 0) == 0) {
if (timer_ptr->on_timer_ != nullptr) {
timer_ptr->on_timer_(std::static_pointer_cast<void>(timer_ptr));
}
InterlockedExchange(&timer_ptr->is_running_ , 0);
}
} else {
break;
}
}
}
void clear() {
std::set<timer_ptr_t, less_timer_t> timers;
for (;;)
{
guard_t guard(lock_);
timers.swap(timers_);
break;
}
timers.clear();
}
private:
void erase(const timer_t* timer_ptr) {
guard_t guard(lock_);
for (auto iter = timers_.begin(); iter != timers_.end(); iter++) {
if ((*iter).get() == timer_ptr) {
timers_.erase(iter);
break;
}
}
}
private:
// 由于GetTickCount会溢出,需要校准
static int get_wait_time(
unsigned int start, unsigned int duration, unsigned int current) {
unsigned int alarm = start + duration;
// 未溢出
if (alarm >= start && current >= start) {
return (int)(alarm - current);
} else if (alarm >= start && current < start) {
// 当前时间溢出
return (int)(alarm - (0x100000000 + current));
} else if (alarm < start && current >= start) {
// 报警时间溢出
return (int)((unsigned long long)start + (unsigned long long)duration - current);
} else /* if (alarm < start && current < start)*/ {
// 报警时间、当前时间都溢出
return (int)(alarm - current);
}
}
struct less_timer_t : public std::binary_function<timer_ptr_t, timer_ptr_t, bool> {
bool operator()(const timer_ptr_t& left, const timer_ptr_t& right) const {
unsigned int current = GetTickCount();
int left_correct = get_wait_time(left->start_, left->duration_, current);
int right_correct = get_wait_time(right->start_, right->duration_, current);
return ((left_correct < right_correct) ||
(left_correct == right_correct && left->sequence_id_ < right->sequence_id_));
}
};
private:
lock_t lock_;
unsigned long long sequence_id_;
std::set<timer_ptr_t, less_timer_t> timers_;
};
} // namespace detail
struct task_thread_pool_t::impl
{
detail::iocp_t iocp_;
detail::timer_queue_t timer_queue_;
bool start(int num = 0) {
return iocp_.start(
num,
[&]()->unsigned int{
return timer_queue_.timeout_next();
},
[&]() {
timer_queue_.do_all_timer();
});
}
void stop() {
iocp_.stop();
timer_queue_.clear();
}
template <typename handler_t>
void post_task(handler_t&& handler)
{
iocp_.post_task(std::move(handler));
}
template <typename handler_t>
task_thread_pool_t::handle_t set_timer(handler_t&& handler, unsigned int millisecond)
{
auto handle = timer_queue_.set_timer(std::move(handler), millisecond);
post_task([](){ /* DO NOTHING, Just to process timer*/});
return std::move(handle);
}
void kill_timer(handle_t handle)
{
timer_queue_.kill_timer(std::static_pointer_cast<detail::timer_queue_t::timer_t>(handle));
}
};
task_thread_pool_t::task_thread_pool_t()
: impl_ptr_(new impl)
{
}
task_thread_pool_t::~task_thread_pool_t()
{
}
bool task_thread_pool_t::start(int num /* = 0 */) {
return impl_ptr_->start(num);
}
void task_thread_pool_t::stop()
{
impl_ptr_->stop();
}
void task_thread_pool_t::post_task(closure_t&& closure)
{
impl_ptr_->post_task(std::move(closure));
}
task_thread_pool_t::handle_t task_thread_pool_t::set_timer(unsigned int millisecond, on_timer_t&& closure)
{
return impl_ptr_->set_timer(std::move(closure), millisecond);
}
void task_thread_pool_t::kill_timer(handle_t handle)
{
impl_ptr_->kill_timer(handle);
}
} // namespace base
#ifndef TASK_THREAD_POOL_HPP_
#define TASK_THREAD_POOL_HPP_
#include <memory>
#include <functional>
namespace base {
class task_thread_pool_t {
public:
typedef std::shared_ptr<void> handle_t;
typedef std::function<void()> closure_t;
typedef std::function<void(handle_t&)> on_timer_t;
task_thread_pool_t();
~task_thread_pool_t();
bool start(int num = 0);
void stop();
void post_task(closure_t&& closure);
handle_t set_timer(unsigned int millisecond, on_timer_t&& closure);
void kill_timer(handle_t handle);
private:
struct impl;
std::unique_ptr<impl> impl_ptr_;
};
} // namespace base
#endif // TASK_THREAD_POOL_HPP_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment