Last active
July 26, 2019 09:38
-
-
Save eahydra/5141947 to your computer and use it in GitHub Desktop.
a task thread pool base on Windows IOCP. Support asynchronous task with closure and timer.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "task_thread_pool.hpp" | |
#include <set> | |
#include <vector> | |
#include <tuple> | |
#include <algorithm> | |
#include <Windows.h> | |
#include <assert.h> | |
#include <mutex> | |
namespace base { | |
namespace detail { | |
typedef std::recursive_mutex lock_t; | |
typedef std::lock_guard<lock_t> guard_t; | |
struct iocp_context_t { | |
typedef std::function<void(DWORD, DWORD)> callback_t; | |
iocp_context_t() { | |
memset(&overlapped, 0, sizeof(overlapped)); | |
} | |
void handler(DWORD status, DWORD bytes_transfered, OVERLAPPED* over_lapped) { | |
if (callback != nullptr) { | |
callback(status, bytes_transfered); | |
} | |
iocp_context_t* context = reinterpret_cast<iocp_context_t *>(over_lapped); | |
delete context; | |
} | |
OVERLAPPED overlapped; | |
callback_t callback; | |
}; | |
class iocp_t { | |
public: | |
typedef std::function<unsigned int()> time_getter_t; | |
typedef std::function<void()> timeout_processor_t; | |
iocp_t() { } | |
~iocp_t() {} | |
bool start(int count, time_getter_t&& time_getter, timeout_processor_t&& timeout_processor) { | |
if (count <=0) { | |
SYSTEM_INFO system_info; | |
GetSystemInfo(&system_info); | |
count = system_info.dwNumberOfProcessors; | |
} | |
InterlockedExchange(&shutdown_, 0); | |
port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, count); | |
if (port_handle_ == nullptr) { | |
return false; | |
} | |
std::tuple<HANDLE, time_getter_t, timeout_processor_t>* param = new std::tuple<HANDLE, time_getter_t, timeout_processor_t>(port_handle_, time_getter, timeout_processor); | |
for (int i = 0; i < count; i++) { | |
HANDLE thread_handle = CreateThread(nullptr, 0, iocp_t::thread_proc, param, 0, nullptr); | |
if (thread_handle != nullptr) { | |
threads_.push_back(thread_handle); | |
} | |
} | |
return threads_.size() != 0; | |
} | |
void stop() { | |
if (shutdown_) | |
return; | |
InterlockedExchange(&shutdown_, 1); | |
for (size_t i = 0; i < threads_.size(); i++) { | |
PostQueuedCompletionStatus(port_handle_, 0, 0, nullptr); | |
} | |
WaitForMultipleObjects(threads_.size(), &threads_[0], FALSE, INFINITE); | |
std::for_each(threads_.begin(), threads_.end(), std::ptr_fun(CloseHandle)); | |
threads_.clear(); | |
CloseHandle(port_handle_); | |
port_handle_ = nullptr; | |
} | |
template <typename handler_t> | |
void post_task(handler_t&& handler) | |
{ | |
iocp_context_t* context = new iocp_context_t; | |
context->callback = [handler](DWORD, DWORD){ | |
handler(); | |
}; | |
if (!PostQueuedCompletionStatus(port_handle_, 0, 0, reinterpret_cast<OVERLAPPED *>(context))) { | |
delete context; | |
} | |
} | |
private: | |
static DWORD CALLBACK thread_proc(PVOID param) | |
{ | |
auto config = static_cast<std::tuple<HANDLE, time_getter_t, timeout_processor_t>*>(param); | |
while (true) { | |
DWORD number_bytes = 0; | |
ULONG_PTR completion_key = 0; | |
DWORD status = ERROR_SUCCESS; | |
OVERLAPPED* over_lapped = nullptr; | |
unsigned int wait_time = (std::get<1>(*config))(); | |
SetLastError(ERROR_SUCCESS); | |
BOOL result = GetQueuedCompletionStatus( | |
std::get<0>(*config), &number_bytes, &completion_key, &over_lapped, wait_time); | |
if (InterlockedCompareExchange(&shutdown_, 1, 1)) { | |
break; | |
} | |
status = GetLastError(); | |
if (!result && status == WAIT_TIMEOUT) { | |
(std::get<2>(*config))(); | |
continue; | |
} | |
if ((!result || over_lapped == nullptr) && | |
(status == ERROR_ABANDONED_WAIT_0 || | |
status == ERROR_INVALID_HANDLE)) | |
{ | |
assert(false); | |
break; | |
} | |
iocp_context_t* context = reinterpret_cast<iocp_context_t *>(over_lapped); | |
context->handler(status, number_bytes, over_lapped); | |
} | |
return 1; | |
} | |
private: | |
HANDLE port_handle_; | |
std::vector<HANDLE> threads_; | |
static long shutdown_; | |
}; | |
long iocp_t::shutdown_ = 0; | |
class timer_queue_t { | |
public: | |
struct timer_t | |
{ | |
typedef std::function<void(task_thread_pool_t::handle_t&)> on_timer_t; | |
unsigned long long sequence_id_; | |
unsigned int duration_;// 时钟间隔, 单位: 毫秒 | |
unsigned int start_; // 时钟开始时间, 单位: 毫秒 | |
on_timer_t on_timer_; // 时钟回调函数 | |
long is_running_; | |
template <typename handler_t> | |
timer_t(unsigned int duration, handler_t&& handler) | |
: duration_(duration), start_(GetTickCount()) | |
, on_timer_(std::forward<on_timer_t>(handler)) | |
, is_running_(0) {} | |
}; | |
typedef std::shared_ptr<timer_t> timer_ptr_t; | |
public: | |
timer_queue_t() : sequence_id_(0) { } | |
~timer_queue_t() { } | |
template <typename handler_t> | |
timer_ptr_t set_timer(handler_t&& handler, unsigned int duration) { | |
timer_ptr_t timer_ptr = std::make_shared<timer_t>(duration, std::move(handler)); | |
guard_t guard(lock_); | |
timer_ptr->sequence_id_ = sequence_id_++; | |
timers_.insert(timer_ptr); | |
return timer_ptr; | |
} | |
void kill_timer(timer_ptr_t& timer_ptr) { | |
guard_t guard(lock_); | |
erase(timer_ptr.get()); | |
timer_ptr.reset(); | |
} | |
unsigned int timeout_next() { | |
unsigned int current = GetTickCount(); | |
guard_t guard(lock_); | |
for (auto iter = timers_.begin(); iter != timers_.end(); ++iter) { | |
auto timer = *iter; | |
int ret_value = get_wait_time(timer->start_, timer->duration_, current); | |
if (ret_value > 0) { | |
return ret_value; | |
} | |
} | |
return INFINITE; | |
} | |
bool peek(timer_ptr_t& timer) { | |
unsigned int current = GetTickCount(); | |
guard_t guard(lock_); | |
if (timers_.empty()) { | |
return false; | |
} | |
auto first_timer = *timers_.begin(); | |
int ret_value = get_wait_time(first_timer->start_, first_timer->duration_, current); | |
if (ret_value <= 0) { | |
timer = first_timer; | |
timers_.erase(timers_.begin()); | |
first_timer->start_ += first_timer->duration_; | |
timers_.insert(first_timer); | |
return true; | |
} | |
return false; | |
} | |
void do_all_timer() { | |
for (;;) { | |
timer_ptr_t timer_ptr; | |
if (peek(timer_ptr) && timer_ptr != nullptr) { | |
// 防止多线程同时执行同一个timer_t | |
if (InterlockedCompareExchange(&timer_ptr->is_running_, 1, 0) == 0) { | |
if (timer_ptr->on_timer_ != nullptr) { | |
timer_ptr->on_timer_(std::static_pointer_cast<void>(timer_ptr)); | |
} | |
InterlockedExchange(&timer_ptr->is_running_ , 0); | |
} | |
} else { | |
break; | |
} | |
} | |
} | |
void clear() { | |
std::set<timer_ptr_t, less_timer_t> timers; | |
for (;;) | |
{ | |
guard_t guard(lock_); | |
timers.swap(timers_); | |
break; | |
} | |
timers.clear(); | |
} | |
private: | |
void erase(const timer_t* timer_ptr) { | |
guard_t guard(lock_); | |
for (auto iter = timers_.begin(); iter != timers_.end(); iter++) { | |
if ((*iter).get() == timer_ptr) { | |
timers_.erase(iter); | |
break; | |
} | |
} | |
} | |
private: | |
// 由于GetTickCount会溢出,需要校准 | |
static int get_wait_time( | |
unsigned int start, unsigned int duration, unsigned int current) { | |
unsigned int alarm = start + duration; | |
// 未溢出 | |
if (alarm >= start && current >= start) { | |
return (int)(alarm - current); | |
} else if (alarm >= start && current < start) { | |
// 当前时间溢出 | |
return (int)(alarm - (0x100000000 + current)); | |
} else if (alarm < start && current >= start) { | |
// 报警时间溢出 | |
return (int)((unsigned long long)start + (unsigned long long)duration - current); | |
} else /* if (alarm < start && current < start)*/ { | |
// 报警时间、当前时间都溢出 | |
return (int)(alarm - current); | |
} | |
} | |
struct less_timer_t : public std::binary_function<timer_ptr_t, timer_ptr_t, bool> { | |
bool operator()(const timer_ptr_t& left, const timer_ptr_t& right) const { | |
unsigned int current = GetTickCount(); | |
int left_correct = get_wait_time(left->start_, left->duration_, current); | |
int right_correct = get_wait_time(right->start_, right->duration_, current); | |
return ((left_correct < right_correct) || | |
(left_correct == right_correct && left->sequence_id_ < right->sequence_id_)); | |
} | |
}; | |
private: | |
lock_t lock_; | |
unsigned long long sequence_id_; | |
std::set<timer_ptr_t, less_timer_t> timers_; | |
}; | |
} // namespace detail | |
struct task_thread_pool_t::impl | |
{ | |
detail::iocp_t iocp_; | |
detail::timer_queue_t timer_queue_; | |
bool start(int num = 0) { | |
return iocp_.start( | |
num, | |
[&]()->unsigned int{ | |
return timer_queue_.timeout_next(); | |
}, | |
[&]() { | |
timer_queue_.do_all_timer(); | |
}); | |
} | |
void stop() { | |
iocp_.stop(); | |
timer_queue_.clear(); | |
} | |
template <typename handler_t> | |
void post_task(handler_t&& handler) | |
{ | |
iocp_.post_task(std::move(handler)); | |
} | |
template <typename handler_t> | |
task_thread_pool_t::handle_t set_timer(handler_t&& handler, unsigned int millisecond) | |
{ | |
auto handle = timer_queue_.set_timer(std::move(handler), millisecond); | |
post_task([](){ /* DO NOTHING, Just to process timer*/}); | |
return std::move(handle); | |
} | |
void kill_timer(handle_t handle) | |
{ | |
timer_queue_.kill_timer(std::static_pointer_cast<detail::timer_queue_t::timer_t>(handle)); | |
} | |
}; | |
task_thread_pool_t::task_thread_pool_t() | |
: impl_ptr_(new impl) | |
{ | |
} | |
task_thread_pool_t::~task_thread_pool_t() | |
{ | |
} | |
bool task_thread_pool_t::start(int num /* = 0 */) { | |
return impl_ptr_->start(num); | |
} | |
void task_thread_pool_t::stop() | |
{ | |
impl_ptr_->stop(); | |
} | |
void task_thread_pool_t::post_task(closure_t&& closure) | |
{ | |
impl_ptr_->post_task(std::move(closure)); | |
} | |
task_thread_pool_t::handle_t task_thread_pool_t::set_timer(unsigned int millisecond, on_timer_t&& closure) | |
{ | |
return impl_ptr_->set_timer(std::move(closure), millisecond); | |
} | |
void task_thread_pool_t::kill_timer(handle_t handle) | |
{ | |
impl_ptr_->kill_timer(handle); | |
} | |
} // namespace base |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef TASK_THREAD_POOL_HPP_ | |
#define TASK_THREAD_POOL_HPP_ | |
#include <memory> | |
#include <functional> | |
namespace base { | |
class task_thread_pool_t { | |
public: | |
typedef std::shared_ptr<void> handle_t; | |
typedef std::function<void()> closure_t; | |
typedef std::function<void(handle_t&)> on_timer_t; | |
task_thread_pool_t(); | |
~task_thread_pool_t(); | |
bool start(int num = 0); | |
void stop(); | |
void post_task(closure_t&& closure); | |
handle_t set_timer(unsigned int millisecond, on_timer_t&& closure); | |
void kill_timer(handle_t handle); | |
private: | |
struct impl; | |
std::unique_ptr<impl> impl_ptr_; | |
}; | |
} // namespace base | |
#endif // TASK_THREAD_POOL_HPP_ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment