Last active
January 28, 2020 13:00
-
-
Save riga/9542c6b0f2e97d9763eb59b4fb93e718 to your computer and use it in GitHub Desktop.
Custom CMSSW TF ThreadPools
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// custom TensorFlow ThreadPool implementations for CMSSW | |
#include "tensorflow/core/lib/core/threadpool.h" | |
#include "tbb/task_arena.h" | |
#include "tbb/task_group.h" | |
// thread pool that schedules all tasks within the caller thread | |
class NoThreadPoolImpl : public tensorflow::thread::ThreadPoolInterface { | |
public: | |
explicit NoThreadPoolImpl() : numScheduleCalled_(0) {} | |
void Schedule(std::function<void()> fn) override { | |
numScheduleCalled_ += 1; | |
fn(); | |
} | |
void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); } | |
void Cancel() override {} | |
int NumThreads() const override { return 1; } | |
int CurrentThreadId() const override { return -1; } | |
int GetNumScheduleCalled() { return numScheduleCalled_; } | |
private: | |
int numScheduleCalled_; | |
}; | |
// TBB thread pool | |
class TBBThreadPoolImpl : public tensorflow::thread::ThreadPoolInterface { | |
public: | |
// TODO: determine number of threads used by the tbb scheduler created in cmsRun | |
explicit TBBThreadPoolImpl(int nThreads = -1) : nThreads_(nThreads), numScheduleCalled_(0) {} | |
void Schedule(std::function<void()> fn) override { | |
numScheduleCalled_ += 1; | |
// use a task arena to avoid having unrelated tasks start | |
// running on this thread, which could potentially start deadlocks | |
tbb::task_arena taskArena; | |
tbb::task_group taskGroup; | |
// we are required to always call wait before destructor | |
auto doneWithTaskGroup = [&taskArena, &taskGroup](void*) { | |
taskArena.execute([&taskGroup]() { taskGroup.wait(); }); | |
}; | |
std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup)> taskGuard(&taskGroup, doneWithTaskGroup); | |
// schedule the task | |
taskArena.execute([&taskGroup, &fn] { taskGroup.run(fn); }); | |
// reset the task guard which will call wait | |
taskGuard.reset(); | |
} | |
void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); } | |
void Cancel() override {} | |
int NumThreads() const override { return nThreads_; } | |
// TODO: return a logical thread index between 0 and nThreads_ - 1, see | |
// https://gitlab.com/libeigen/eigen/blob/febd09dcc02c1429cd4abd3ddb3ed5108fcd8339/unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h#L39 | |
int CurrentThreadId() const override { return -1; } | |
int GetNumScheduleCalled() { return numScheduleCalled_; } | |
private: | |
int nThreads_; | |
int numScheduleCalled_; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment