Last active
July 15, 2017 23:26
-
-
Save bartvm/f87965f902a17c3a9e80b5bfafa3fc97 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp | |
index f1e09b0e..d78c03e8 100644 | |
--- a/torch/csrc/autograd/engine.cpp | |
+++ b/torch/csrc/autograd/engine.cpp | |
@@ -110,7 +110,7 @@ Engine::~Engine() = default; | |
auto Engine::thread_main(std::shared_ptr<ReadyQueue> queue, int device) -> void { | |
THInferNumThreads(); | |
AutoGPU guard(device); | |
- while (1) { | |
+ while (!exit.back().load()) { | |
FunctionTask task = queue->pop_back(); | |
if (!task.base->has_error.load()) { | |
try { | |
@@ -124,6 +124,7 @@ auto Engine::thread_main(std::shared_ptr<ReadyQueue> queue, int device) -> void | |
task.base->not_done.notify_all(); | |
} | |
} | |
+ exit.pop_back(); | |
} | |
auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void { | |
@@ -299,7 +300,7 @@ auto Engine::execute(const function_list& input_roots, | |
variable_list& inputs, | |
bool keep_graph, | |
const callback_map& callbacks) -> void { | |
- std::call_once(start_threads_flag, &Engine::start_threads, this); | |
+ start_threads(); | |
// Callbacks are only valid for the duration of this run and should always be cleared | |
ClearCallbacks _cb_guard(post_callbacks, post_callbacks_lock); | |
@@ -351,6 +352,8 @@ auto Engine::execute(const function_list& input_roots, | |
post_callbacks[i](); | |
cb_lock.lock(); | |
} | |
+ | |
+ exit.back().store(true); | |
} | |
void Engine::queue_callback(std::function<void()> callback) { | |
@@ -371,6 +374,7 @@ auto Engine::start_threads() -> void { | |
num_devices = 0; | |
} | |
#endif | |
+ exit.emplace_back(false); | |
int num_threads = num_devices + 1; | |
ready_queues = std::vector<std::shared_ptr<ReadyQueue>>(num_threads); | |
for (int i = 0; i < num_threads; ++i) { | |
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h | |
index a0308f7d..1722dd92 100644 | |
--- a/torch/csrc/autograd/engine.h | |
+++ b/torch/csrc/autograd/engine.h | |
@@ -4,6 +4,7 @@ | |
// to "root" variables (variables created by the user with requires_grad=True). | |
#include <Python.h> | |
+#include <atomic> | |
#include <deque> | |
#include <memory> | |
#include <unordered_map> | |
@@ -55,8 +56,8 @@ protected: | |
virtual void thread_main(std::shared_ptr<ReadyQueue> queue, int device); | |
virtual void thread_on_exception(FunctionTask& task, std::exception& e); | |
- std::once_flag start_threads_flag; | |
std::vector<std::shared_ptr<ReadyQueue>> ready_queues; | |
+ std::deque<std::atomic_bool> exit; | |
std::vector<std::function<void()>> post_callbacks; | |
std::mutex post_callbacks_lock; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment