Created
December 13, 2017 05:40
-
-
Save siddontang/fa81a59e7234e9960f8514785af69e1e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(fnbox)] | |
use std::usize; | |
use std::time::{Duration, Instant}; | |
use std::sync::{Arc, Condvar, Mutex}; | |
use std::thread::{Builder, JoinHandle}; | |
use std::marker::PhantomData; | |
use std::boxed::FnBox; | |
use std::collections::VecDeque; | |
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; | |
use std::fmt::Write; | |
use std::sync::mpsc::{channel, Sender}; | |
use std::thread; | |
use std::env; | |
pub const DEFAULT_TASKS_PER_TICK: usize = 10000; | |
const DEFAULT_QUEUE_CAPACITY: usize = 1000; | |
const DEFAULT_THREAD_COUNT: usize = 1; | |
const NAP_SECS: u64 = 1; | |
const QUEUE_MAX_CAPACITY: usize = 8 * DEFAULT_QUEUE_CAPACITY; | |
pub trait Context: Send { | |
fn on_task_started(&mut self) {} | |
fn on_task_finished(&mut self) {} | |
fn on_tick(&mut self) {} | |
} | |
#[derive(Default)] | |
pub struct DefaultContext; | |
impl Context for DefaultContext {} | |
pub trait ContextFactory<Ctx: Context> { | |
fn create(&self) -> Ctx; | |
} | |
pub struct DefaultContextFactory; | |
impl<C: Context + Default> ContextFactory<C> for DefaultContextFactory { | |
fn create(&self) -> C { | |
C::default() | |
} | |
} | |
pub struct Task<C> { | |
task: Box<FnBox(&mut C) + Send>, | |
} | |
impl<C: Context> Task<C> { | |
fn new<F>(job: F) -> Task<C> | |
where | |
for<'r> F: FnOnce(&'r mut C) + Send + 'static, | |
{ | |
Task { | |
task: Box::new(job), | |
} | |
} | |
} | |
// First in first out queue. | |
pub struct FifoQueue<C> { | |
queue: VecDeque<Task<C>>, | |
} | |
impl<C: Context> FifoQueue<C> { | |
fn new() -> FifoQueue<C> { | |
FifoQueue { | |
queue: VecDeque::with_capacity(DEFAULT_QUEUE_CAPACITY), | |
} | |
} | |
fn push(&mut self, task: Task<C>) { | |
self.queue.push_back(task); | |
} | |
fn pop(&mut self) -> Option<Task<C>> { | |
let task = self.queue.pop_front(); | |
if self.queue.is_empty() && self.queue.capacity() > QUEUE_MAX_CAPACITY { | |
self.queue = VecDeque::with_capacity(DEFAULT_QUEUE_CAPACITY); | |
} | |
task | |
} | |
} | |
pub struct ThreadPoolBuilder<C, F> { | |
name: String, | |
thread_count: usize, | |
tasks_per_tick: usize, | |
stack_size: Option<usize>, | |
factory: F, | |
_ctx: PhantomData<C>, | |
} | |
impl<C: Context + Default + 'static> ThreadPoolBuilder<C, DefaultContextFactory> { | |
pub fn with_default_factory(name: String) -> ThreadPoolBuilder<C, DefaultContextFactory> { | |
ThreadPoolBuilder::new(name, DefaultContextFactory) | |
} | |
} | |
impl<C: Context + 'static, F: ContextFactory<C>> ThreadPoolBuilder<C, F> { | |
pub fn new(name: String, factory: F) -> ThreadPoolBuilder<C, F> { | |
ThreadPoolBuilder { | |
name: name, | |
thread_count: DEFAULT_THREAD_COUNT, | |
tasks_per_tick: DEFAULT_TASKS_PER_TICK, | |
stack_size: None, | |
factory: factory, | |
_ctx: PhantomData, | |
} | |
} | |
pub fn thread_count(mut self, count: usize) -> ThreadPoolBuilder<C, F> { | |
self.thread_count = count; | |
self | |
} | |
pub fn tasks_per_tick(mut self, count: usize) -> ThreadPoolBuilder<C, F> { | |
self.tasks_per_tick = count; | |
self | |
} | |
pub fn stack_size(mut self, size: usize) -> ThreadPoolBuilder<C, F> { | |
self.stack_size = Some(size); | |
self | |
} | |
pub fn build(self) -> ThreadPool<C> { | |
ThreadPool::new( | |
self.name, | |
self.thread_count, | |
self.tasks_per_tick, | |
self.stack_size, | |
self.factory, | |
) | |
} | |
} | |
struct ScheduleState<Ctx> { | |
queue: FifoQueue<Ctx>, | |
stopped: bool, | |
} | |
/// `ThreadPool` is used to execute tasks in parallel. | |
/// Each task would be pushed into the pool, and when a thread | |
/// is ready to process a task, it will get a task from the pool | |
/// according to the `ScheduleQueue` provided in initialization. | |
pub struct ThreadPool<Ctx> { | |
state: Arc<(Mutex<ScheduleState<Ctx>>, Condvar)>, | |
threads: Vec<JoinHandle<()>>, | |
task_count: Arc<AtomicUsize>, | |
} | |
impl<Ctx> ThreadPool<Ctx> | |
where | |
Ctx: Context + 'static, | |
{ | |
fn new<C: ContextFactory<Ctx>>( | |
name: String, | |
num_threads: usize, | |
tasks_per_tick: usize, | |
stack_size: Option<usize>, | |
f: C, | |
) -> ThreadPool<Ctx> { | |
assert!(num_threads >= 1); | |
let state = ScheduleState { | |
queue: FifoQueue::new(), | |
stopped: false, | |
}; | |
let state = Arc::new((Mutex::new(state), Condvar::new())); | |
let mut threads = Vec::with_capacity(num_threads); | |
let task_count = Arc::new(AtomicUsize::new(0)); | |
// Threadpool threads | |
for _ in 0..num_threads { | |
let state = state.clone(); | |
let task_num = task_count.clone(); | |
let ctx = f.create(); | |
let mut tb = Builder::new().name(name.clone()); | |
if let Some(stack_size) = stack_size { | |
tb = tb.stack_size(stack_size); | |
} | |
let thread = tb.spawn(move || { | |
let mut worker = Worker::new(state, task_num, tasks_per_tick, ctx); | |
worker.run(); | |
}).unwrap(); | |
threads.push(thread); | |
} | |
ThreadPool { | |
state: state, | |
threads: threads, | |
task_count: task_count, | |
} | |
} | |
pub fn execute<F>(&self, job: F) | |
where | |
F: FnOnce(&mut Ctx) + Send + 'static, | |
Ctx: Context, | |
{ | |
let task = Task::new(job); | |
let &(ref lock, ref cvar) = &*self.state; | |
{ | |
let mut state = lock.lock().unwrap(); | |
if state.stopped { | |
return; | |
} | |
state.queue.push(task); | |
cvar.notify_one(); | |
} | |
self.task_count.fetch_add(1, AtomicOrdering::SeqCst); | |
} | |
#[inline] | |
pub fn get_task_count(&self) -> usize { | |
self.task_count.load(AtomicOrdering::SeqCst) | |
} | |
pub fn stop(&mut self) -> Result<(), String> { | |
let &(ref lock, ref cvar) = &*self.state; | |
{ | |
let mut state = lock.lock().unwrap(); | |
state.stopped = true; | |
cvar.notify_all(); | |
} | |
let mut err_msg = String::new(); | |
for t in self.threads.drain(..) { | |
if let Err(e) = t.join() { | |
write!(&mut err_msg, "Failed to join thread with err: {:?};", e).unwrap(); | |
} | |
} | |
if !err_msg.is_empty() { | |
return Err(err_msg); | |
} | |
Ok(()) | |
} | |
} | |
// Each thread has a worker. | |
struct Worker<C> { | |
state: Arc<(Mutex<ScheduleState<C>>, Condvar)>, | |
task_count: Arc<AtomicUsize>, | |
tasks_per_tick: usize, | |
task_counter: usize, | |
ctx: C, | |
} | |
impl<C> Worker<C> | |
where | |
C: Context, | |
{ | |
fn new( | |
state: Arc<(Mutex<ScheduleState<C>>, Condvar)>, | |
task_count: Arc<AtomicUsize>, | |
tasks_per_tick: usize, | |
ctx: C, | |
) -> Worker<C> { | |
Worker { | |
state: state, | |
task_count: task_count, | |
tasks_per_tick: tasks_per_tick, | |
task_counter: 0, | |
ctx: ctx, | |
} | |
} | |
fn next_task(&mut self) -> Option<Task<C>> { | |
let &(ref lock, ref cvar) = &*self.state; | |
let mut state = lock.lock().unwrap(); | |
let mut timeout = Some(Duration::from_secs(NAP_SECS)); | |
loop { | |
if state.stopped { | |
return None; | |
} | |
match state.queue.pop() { | |
Some(t) => { | |
self.task_counter += 1; | |
return Some(t); | |
} | |
None => { | |
state = match timeout { | |
Some(t) => cvar.wait_timeout(state, t).unwrap().0, | |
None => { | |
self.task_counter = 0; | |
self.ctx.on_tick(); | |
cvar.wait(state).unwrap() | |
} | |
}; | |
timeout = None; | |
} | |
} | |
} | |
} | |
fn run(&mut self) { | |
loop { | |
let task = match self.next_task() { | |
None => return, | |
Some(t) => t, | |
}; | |
self.ctx.on_task_started(); | |
(task.task).call_box((&mut self.ctx,)); | |
self.ctx.on_task_finished(); | |
self.task_count.fetch_sub(1, AtomicOrdering::SeqCst); | |
if self.task_counter == self.tasks_per_tick { | |
self.task_counter = 0; | |
self.ctx.on_tick(); | |
} | |
} | |
} | |
} | |
fn main() { | |
let mut thread_count: usize = 8; | |
let args: Vec<String> = env::args().collect(); | |
if args.len() == 2 { | |
thread_count = args[1].parse().unwrap(); | |
} | |
println!("Using {} threads", thread_count); | |
let mut task_pool = ThreadPoolBuilder::with_default_factory(format!("test")).thread_count(thread_count).build(); | |
let (tx, rx) = channel(); | |
thread::spawn(move || { | |
let mut total_counts = 0; | |
let mut t = Instant::now(); | |
let interval = Duration::from_secs(1); | |
while let Ok(index) = rx.recv() { | |
total_counts += 1; | |
if t.elapsed() >= interval { | |
t = Instant::now(); | |
println!("{} QPS", total_counts); | |
total_counts = 0; | |
} | |
} | |
}); | |
let t = Instant::now(); | |
let run_time = Duration::from_secs(10); | |
loop { | |
let tx1 = tx.clone(); | |
task_pool.execute(move |_:&mut DefaultContext|{ | |
tx1.send(0).unwrap(); | |
}); | |
if t.elapsed() >= run_time { | |
return; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment