Skip to content

Instantly share code, notes, and snippets.

@riga
Last active January 28, 2020 13:00
Show Gist options
  • Save riga/9542c6b0f2e97d9763eb59b4fb93e718 to your computer and use it in GitHub Desktop.
Save riga/9542c6b0f2e97d9763eb59b4fb93e718 to your computer and use it in GitHub Desktop.
Custom CMSSW TF ThreadPools
// custom TensorFlow ThreadPool implementations for CMSSW
#include "tensorflow/core/lib/core/threadpool.h"
#include "tbb/task_arena.h"
#include "tbb/task_group.h"
// thread pool that schedules all tasks within the caller thread
class NoThreadPoolImpl : public tensorflow::thread::ThreadPoolInterface {
public:
explicit NoThreadPoolImpl() : numScheduleCalled_(0) {}
void Schedule(std::function<void()> fn) override {
numScheduleCalled_ += 1;
fn();
}
void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); }
void Cancel() override {}
int NumThreads() const override { return 1; }
int CurrentThreadId() const override { return -1; }
int GetNumScheduleCalled() { return numScheduleCalled_; }
private:
int numScheduleCalled_;
};
// TBB thread pool
class TBBThreadPoolImpl : public tensorflow::thread::ThreadPoolInterface {
public:
// TODO: determine number of threads used by the tbb scheduler created in cmsRun
explicit TBBThreadPoolImpl(int nThreads = -1) : nThreads_(nThreads), numScheduleCalled_(0) {}
void Schedule(std::function<void()> fn) override {
numScheduleCalled_ += 1;
// use a task arena to avoid having unrelated tasks start
// running on this thread, which could potentially start deadlocks
tbb::task_arena taskArena;
tbb::task_group taskGroup;
// we are required to always call wait before destructor
auto doneWithTaskGroup = [&taskArena, &taskGroup](void*) {
taskArena.execute([&taskGroup]() { taskGroup.wait(); });
};
std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup)> taskGuard(&taskGroup, doneWithTaskGroup);
// schedule the task
taskArena.execute([&taskGroup, &fn] { taskGroup.run(fn); });
// reset the task guard which will call wait
taskGuard.reset();
}
void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); }
void Cancel() override {}
int NumThreads() const override { return nThreads_; }
// TODO: return a logical thread index between 0 and nThreads_ - 1, see
// https://gitlab.com/libeigen/eigen/blob/febd09dcc02c1429cd4abd3ddb3ed5108fcd8339/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h#L39
int CurrentThreadId() const override { return -1; }
int GetNumScheduleCalled() { return numScheduleCalled_; }
private:
int nThreads_;
int numScheduleCalled_;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment