-
-
Save goldsborough/d0a7afa162a163d954f5fed8e5a52251 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h | |
index f04b472e..7ff9a39f 100644 | |
--- a/torch/csrc/autograd/autograd.h | |
+++ b/torch/csrc/autograd/autograd.h | |
@@ -2,12 +2,14 @@ | |
#define THP_AUTOGRAD_H | |
PyObject * THPAutograd_initExtension(PyObject *_unused); | |
-bool THPAutograd_initFunctions(PyObject* module); | |
+void THPAutograd_initFunctions(); | |
namespace torch { namespace autograd { | |
void initAutogradClosureBindings(PyObject* module); | |
+PyMethodDef* python_functions(); | |
+ | |
}} | |
#include "torch/csrc/autograd/python_function.h" | |
diff --git a/torch/csrc/autograd/edge.h b/torch/csrc/autograd/edge.h | |
new file mode 100644 | |
index 00000000..e6b34911 | |
--- /dev/null | |
+++ b/torch/csrc/autograd/edge.h | |
@@ -0,0 +1,56 @@ | |
+#pragma once | |
+ | |
+#include <cstdint> | |
+#include <functional> | |
+#include <memory> | |
+ | |
+#include "torch/csrc/utils/hash.h" | |
+ | |
+namespace torch { namespace autograd { | |
+ | |
+struct Function; | |
+ | |
+/// Represents a particular input of a function. | |
+struct Edge { | |
+ Edge() noexcept : function(nullptr), input_nr(0) {} | |
+ | |
+ Edge(std::shared_ptr<Function> function_, uint32_t input_nr_) noexcept | |
+ : function(std::move(function_)), input_nr(input_nr_) {} | |
+ | |
+ /// Convenience method to test if an edge is valid. | |
+ bool is_valid() const noexcept { | |
+ return function != nullptr; | |
+ } | |
+ | |
+ // Required for use in associative containers. | |
+ bool operator==(const Edge& other) const noexcept { | |
+ return this->function == other.function && this->input_nr == other.input_nr; | |
+ } | |
+ | |
+ bool operator!=(const Edge& other) const noexcept { | |
+ return !(*this == other); | |
+ } | |
+ | |
+ /// The function this `Edge` points to. | |
+ std::shared_ptr<Function> function; | |
+ | |
+ /// The identifier of a particular input to the function. | |
+ uint32_t input_nr; | |
+}; | |
+}} // namespace torch::autograd | |
+ | |
+// The idiomatic way of enabling use of a custom type as the key of hash | |
+// containers in C++11. This method removes the requirement of having to pass | |
+// a custom hasher to std::unordered_{map, set}. | |
+// See http://en.cppreference.com/w/cpp/utility/hash for more information. | |
+namespace std { | |
+template <> | |
+struct hash<torch::autograd::Edge> { | |
+ // These type aliases are required by the standard. | |
+ using argument_type = torch::autograd::Edge; | |
+ using return_type = size_t; | |
+ return_type operator()(const argument_type& edge) const noexcept { | |
+ return torch::get_hash(edge.function, edge.input_nr); | |
+ } | |
+}; | |
+} // namespace std | |
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp | |
index 756f56e4..369ca601 100644 | |
--- a/torch/csrc/autograd/engine.cpp | |
+++ b/torch/csrc/autograd/engine.cpp | |
@@ -1,5 +1,9 @@ | |
#include "torch/csrc/autograd/engine.h" | |
+ | |
+#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/functions/basic_ops.h" | |
+#include "torch/csrc/autograd/grad_mode.h" | |
+#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/utils/auto_gpu.h" | |
#include <atomic> | |
@@ -7,6 +11,7 @@ | |
#include <cstdint> | |
#include <functional> | |
#include <iostream> | |
+#include <memory> | |
#include <mutex> | |
#include <set> | |
#include <string> | |
@@ -14,6 +19,7 @@ | |
#include <unordered_set> | |
#include <typeinfo> | |
#include <sstream> | |
+#include <queue> | |
#include <TH/TH.h> | |
#ifdef WITH_CUDA | |
@@ -48,13 +54,19 @@ struct FunctionTask { | |
, inputs(std::move(inputs)) {} | |
}; | |
+struct CompareFunctionTaskTime { | |
+ bool operator()(FunctionTask const & t1, FunctionTask const & t2) { | |
+ return t1.fn->sequence_nr() < t2.fn->sequence_nr(); | |
+ } | |
+}; | |
+ | |
struct ReadyQueue { | |
- std::deque<FunctionTask> queue; | |
+ std::priority_queue<FunctionTask, std::vector<FunctionTask>, CompareFunctionTaskTime> heap; | |
std::condition_variable not_empty; | |
std::mutex mutex; | |
- void push_front(FunctionTask item); | |
- FunctionTask pop_back(); | |
+ void push(FunctionTask item); | |
+ FunctionTask pop(); | |
}; | |
struct GraphTask { | |
@@ -64,47 +76,66 @@ struct GraphTask { | |
std::atomic_bool has_error; | |
std::atomic<uint64_t> outstanding_tasks; | |
bool keep_graph; | |
- bool has_any_work; | |
+ bool grad_mode; | |
std::mutex mutex; | |
// Notified when a task finishes executing. Check outstanding_tasks to see | |
// if all tasks are done. | |
std::condition_variable not_done; | |
- const Engine::pre_callback_map& pre_callbacks; | |
- const Engine::post_callback_map& post_callbacks; | |
std::unordered_map<Function*, InputBuffer> not_ready; | |
std::unordered_map<Function*, int> dependencies; | |
+ struct ExecInfo { | |
+ struct Capture { | |
+ Capture(int input_idx, int output_idx) : input_idx(input_idx), output_idx(output_idx) {} | |
+ int input_idx; // within Function inputs | |
+ int output_idx; // within the output vector of a GraphTask | |
+ }; | |
+ | |
+ bool should_execute() const { | |
+ return needed || captures; | |
+ } | |
+ | |
+ bool needed = false; | |
+ std::unique_ptr<std::vector<Capture>> captures; | |
+ }; | |
+ // Exec info has a bit complicated semantics. If it's empty, it means the task is | |
+ // run in a "default" mode, which means that all next_edges we encounter should | |
+ // get executed. If it's not empty, only functions that have an entry and this entry | |
+ // has needed == True should be executed. | |
+ std::unordered_map<Function*, ExecInfo> exec_info; | |
+ std::vector<Variable> captured_vars; | |
+ | |
+ void init_to_execute(Function& graph_root, const edge_list& captures); | |
+ | |
int owner; | |
- GraphTask(bool keep_graph, const Engine::pre_callback_map& pre_callbacks, const Engine::post_callback_map& post_callbacks) | |
+ GraphTask(bool keep_graph, bool grad_mode) | |
: exception() | |
, has_error(false) | |
, outstanding_tasks(0) | |
, keep_graph(keep_graph) | |
- , has_any_work(false) | |
+ , grad_mode(grad_mode) | |
, mutex() | |
, not_done() | |
- , pre_callbacks(pre_callbacks) | |
- , post_callbacks(post_callbacks) | |
, not_ready() | |
, dependencies() | |
, owner(NO_DEVICE) {} | |
}; | |
-auto ReadyQueue::push_front(FunctionTask item) -> void { | |
+auto ReadyQueue::push(FunctionTask item) -> void { | |
{ | |
std::lock_guard<std::mutex> lock(mutex); | |
++item.base->outstanding_tasks; | |
- queue.push_front(std::move(item)); | |
+ heap.push(std::move(item)); | |
} | |
not_empty.notify_one(); | |
} | |
-auto ReadyQueue::pop_back() -> FunctionTask { | |
+auto ReadyQueue::pop() -> FunctionTask { | |
std::unique_lock<std::mutex> lock(mutex); | |
- not_empty.wait(lock, [this]{ return !queue.empty(); }); | |
- auto task = std::move(queue.back()); queue.pop_back(); | |
+ not_empty.wait(lock, [this]{ return !heap.empty(); }); | |
+ auto task = std::move(const_cast<FunctionTask&>(heap.top())); heap.pop(); | |
return task; | |
} | |
@@ -138,8 +169,9 @@ auto Engine::thread_init(int device) -> void { | |
auto Engine::thread_main(GraphTask *graph_task) -> void { | |
auto queue = ready_queues[worker_device + 1]; | |
while (!graph_task || graph_task->outstanding_tasks > 0) { | |
- FunctionTask task = queue->pop_back(); | |
+ FunctionTask task = queue->pop(); | |
if (task.fn && !task.base->has_error.load()) { | |
+ GradMode::set_enabled(task.base->grad_mode); | |
try { | |
evaluate_function(task); | |
} catch (std::exception& e) { | |
@@ -166,7 +198,7 @@ auto Engine::thread_main(GraphTask *graph_task) -> void { | |
if (--task.base->outstanding_tasks == 0) { | |
// Synchronize outstanding_tasks with queue mutex | |
std::atomic_thread_fence(std::memory_order_release); | |
- ready_queue(base_owner).push_front(FunctionTask(task.base, nullptr, InputBuffer(0))); | |
+ ready_queue(base_owner).push(FunctionTask(task.base, nullptr, InputBuffer(0))); | |
} | |
} | |
} | |
@@ -182,14 +214,14 @@ auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void | |
} | |
static variable_list call_pre_hooks(Function& fn, variable_list inputs) { | |
- for (auto& hook : fn.pre_hooks) { | |
+ for (const auto& hook : fn.pre_hooks()) { | |
inputs = (*hook)(inputs); | |
} | |
return inputs; | |
} | |
static variable_list call_post_hooks(Function& fn, variable_list outputs, variable_list inputs) { | |
- for (auto& hook : fn.post_hooks) { | |
+ for (const auto& hook : fn.post_hooks()) { | |
outputs = (*hook)(outputs, inputs); | |
} | |
return outputs; | |
@@ -199,61 +231,57 @@ static variable_list call_function(FunctionTask& task) { | |
auto& fn = *task.fn; | |
auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs))); | |
- auto& pre_callbacks = task.base->pre_callbacks; | |
- for (auto it_p = pre_callbacks.equal_range(&fn); it_p.first != it_p.second; ++it_p.first) { | |
- auto& callback = it_p.first->second; | |
- if (!callback(&fn, inputs)) return variable_list(fn.next_functions.size()); | |
+ if(!task.base->keep_graph) { | |
+ fn.will_release_variables(); | |
} | |
- | |
auto outputs = fn(inputs); | |
- auto& post_callbacks = task.base->post_callbacks; | |
- for (auto it_p = post_callbacks.equal_range(&fn); it_p.first != it_p.second; ++it_p.first) { | |
- auto& callback = it_p.first->second; | |
- if (!callback(&fn, inputs, outputs)) return variable_list(fn.next_functions.size()); | |
- } | |
- | |
return call_post_hooks(fn, std::move(outputs), std::move(inputs)); | |
} | |
auto Engine::evaluate_function(FunctionTask& task) -> void { | |
+ // If exec_info is not empty, we have to instrument the execution | |
+ auto & exec_info = task.base->exec_info; | |
+ if (!exec_info.empty()) { | |
+ auto & fn_info = exec_info.at(task.fn.get()); | |
+ if (auto *capture_vec = fn_info.captures.get()) { | |
+ std::lock_guard<std::mutex> lock(task.base->mutex); | |
+ for (auto capture : *capture_vec) { | |
+ task.base->captured_vars[capture.output_idx] = task.inputs[capture.input_idx]; | |
+ } | |
+ } | |
+ if (!fn_info.needed) return; | |
+ } | |
+ | |
auto outputs = call_function(task); | |
auto& fn = *task.fn; | |
if (!task.base->keep_graph) { | |
- fn.releaseVariables(); | |
+ fn.release_variables(); | |
} | |
- if (outputs.size() != fn.next_functions.size()) { | |
+ if (outputs.size() != fn.num_outputs()) { | |
std::stringstream ss; | |
ss << "Function '" << fn.name() << "' returned an invalid number of outputs - expected "; | |
- ss << fn.next_functions.size() << ", but got " << outputs.size(); | |
+ ss << fn.num_outputs() << ", but got " << outputs.size(); | |
throw std::runtime_error(ss.str()); | |
} | |
int num_outputs = outputs.size(); | |
+ if (num_outputs == 0) return; // Don't even acquire the mutex | |
+ std::lock_guard<std::mutex> lock(task.base->mutex); | |
for (int i = 0; i < num_outputs; ++i) { | |
auto& output = outputs[i]; | |
- auto& next_fn = fn.next_functions[i].first; | |
- int input_nr = fn.next_functions[i].second; | |
- | |
- if (!next_fn) { | |
- continue; | |
- } | |
+ const auto& next = fn.next_edge(i); | |
- // Stochastic functions are placed in the ready queue by | |
- // compute_dependencies, so we have to skip them here. | |
- if (next_fn->is_stochastic || !next_fn->is_executable) { | |
- continue; | |
- } | |
+ if (!next.is_valid()) continue; | |
- std::lock_guard<std::mutex> lock(task.base->mutex); | |
// Check if the next function is ready to be computed | |
bool is_ready = false; | |
auto& dependencies = task.base->dependencies; | |
- auto it = dependencies.find(next_fn.get()); | |
+ auto it = dependencies.find(next.function.get()); | |
if (it == dependencies.end()) { | |
- auto name = next_fn->name(); | |
+ auto name = next.function->name(); | |
throw std::runtime_error(std::string("dependency not found for ") + name); | |
} else if (--it->second == 0) { | |
dependencies.erase(it); | |
@@ -261,72 +289,53 @@ auto Engine::evaluate_function(FunctionTask& task) -> void { | |
} | |
auto& not_ready = task.base->not_ready; | |
- auto not_ready_it = not_ready.find(next_fn.get()); | |
+ auto not_ready_it = not_ready.find(next.function.get()); | |
if (not_ready_it == not_ready.end()) { | |
+ // Skip functions that aren't supposed to be executed | |
+ if (!exec_info.empty()) { | |
+ auto it = exec_info.find(next.function.get()); | |
+ if (it == exec_info.end() || !it->second.should_execute()) { | |
+ continue; | |
+ } | |
+ } | |
// No buffers have been allocated for the function | |
- InputBuffer input_buffer(next_fn->num_inputs); | |
- input_buffer.add(input_nr, std::move(output)); | |
+ InputBuffer input_buffer(next.function->num_inputs()); | |
+ input_buffer.add(next.input_nr, std::move(output)); | |
if (is_ready) { | |
auto& queue = ready_queue(input_buffer.device()); | |
- queue.push_front(FunctionTask(task.base, next_fn, std::move(input_buffer))); | |
+ queue.push(FunctionTask(task.base, next.function, std::move(input_buffer))); | |
} else { | |
- not_ready.emplace(next_fn.get(), std::move(input_buffer)); | |
+ not_ready.emplace(next.function.get(), std::move(input_buffer)); | |
} | |
} else { | |
// The function already has a buffer | |
auto &input_buffer = not_ready_it->second; | |
- input_buffer.add(input_nr, std::move(output)); | |
+ input_buffer.add(next.input_nr, std::move(output)); | |
if (is_ready) { | |
auto& queue = ready_queue(input_buffer.device()); | |
- queue.push_front(FunctionTask(task.base, next_fn, std::move(input_buffer))); | |
+ queue.push(FunctionTask(task.base, next.function, std::move(input_buffer))); | |
not_ready.erase(not_ready_it); | |
} | |
} | |
} | |
} | |
-/** Finds all stochastic functions and appends them to the queue */ | |
-auto Engine::find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task) -> void { | |
- std::unordered_set<Function*> seen {graph_root}; | |
- function_queue search_queue {graph_root}; | |
- while (search_queue.size() > 0) { | |
- auto fn = search_queue.back(); search_queue.pop_back(); | |
- for (auto& next_fn_pair : fn->next_functions) { | |
- auto& next_fn = next_fn_pair.first; | |
- Function* next_ptr = next_fn.get(); | |
- if (!next_ptr) continue; | |
- if (next_ptr->is_stochastic && next_ptr->is_executable && seen.count(next_ptr) == 0) { | |
- ready_queue(-1).push_front(FunctionTask(&task, next_fn, InputBuffer(0))); | |
- queue.push_back(next_ptr); | |
- task.has_any_work = true; | |
- } | |
- if (seen.count(next_ptr) == 0) { | |
- seen.insert(next_ptr); | |
- search_queue.push_back(next_ptr); | |
- } | |
- } | |
- } | |
-} | |
- | |
-/** Computes the number of dependencies for each function which requires grad */ | |
-auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void { | |
+/* Computes the number of dependencies for each function which requires grad */ | |
+auto Engine::compute_dependencies(Function* root, GraphTask& task) -> void { | |
// Just to make sure that they will never be added to the queue again | |
- std::unordered_set<Function*> seen(queue.begin(), queue.end()); | |
+ std::unordered_set<Function*> seen; | |
+ std::vector<Function*> queue { root }; | |
// Queue contains all nodes that will start propagating gradients. | |
// We no longer have to expand functions that don't require grad. | |
auto& dependencies = task.dependencies; | |
while (queue.size() > 0) { | |
- auto fn = std::move(queue.back()); queue.pop_back(); | |
- for (auto& next_fn_pair : fn->next_functions) { | |
- Function* next_ptr = next_fn_pair.first.get(); | |
- if (!next_ptr) continue; | |
- if (!next_ptr->is_executable) continue; | |
- if (next_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already | |
- dependencies[next_ptr] += 1; | |
- if (seen.count(next_ptr) == 0) { | |
- seen.insert(next_ptr); | |
- queue.push_back(next_ptr); | |
+ auto fn = queue.back(); queue.pop_back(); | |
+ for (const auto& edge : fn->next_edges()) { | |
+ if (auto next_ptr = edge.function.get()) { | |
+ dependencies[next_ptr] += 1; | |
+ const bool was_inserted = seen.insert(next_ptr).second; | |
+ if (was_inserted) queue.push_back(next_ptr); | |
} | |
} | |
} | |
@@ -348,40 +357,25 @@ struct ClearCallbacks { | |
std::mutex& callbacks_lock; | |
}; | |
-auto Engine::execute(const function_list& input_roots, | |
+auto Engine::execute(const edge_list& input_roots, | |
const variable_list& inputs, | |
bool keep_graph, | |
- const pre_callback_map& pre_callbacks, | |
- const post_callback_map& post_callbacks) -> void { | |
+ bool create_graph, | |
+ const edge_list& outputs) -> variable_list { | |
std::call_once(start_threads_flag, &Engine::start_threads, this); | |
// Callbacks are only valid for the duration of this run and should always be cleared | |
ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock); | |
- GraphTask graph_task(keep_graph, pre_callbacks, post_callbacks); | |
- | |
+ GraphTask graph_task(keep_graph, create_graph); | |
std::unique_lock<std::mutex> lock(graph_task.mutex); | |
+ // Now compute the dependencies for all executable functions and queue the root | |
auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs); | |
- function_queue roots; | |
- for (auto entry : input_roots) { | |
- if (entry.first->is_executable) { | |
- graph_task.has_any_work = true; | |
- roots.push_back(graph_root.get()); | |
- ready_queue(-1).push_front(FunctionTask(&graph_task, graph_root, InputBuffer(0))); | |
- break; | |
- } | |
+ compute_dependencies(graph_root.get(), graph_task); | |
+ if (!outputs.empty()) { | |
+ graph_task.init_to_execute(*graph_root, outputs); | |
} | |
- | |
- // Search the graph and find all stochastic functions. Append them to the queue. | |
- find_stochastic_functions(roots, graph_root.get(), graph_task); | |
- | |
- if (!graph_task.has_any_work) { | |
- throw std::runtime_error( | |
- "there are no graph nodes that require computing gradients"); | |
- } | |
- | |
- // Now compute the dependencies for all executable functions | |
- compute_dependencies(std::move(roots), graph_task); | |
+ ready_queue(-1).push(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0))); | |
// Not a worker | |
if (worker_device == NO_DEVICE) { | |
@@ -413,6 +407,8 @@ auto Engine::execute(const function_list& input_roots, | |
final_callbacks[i](); | |
cb_lock.lock(); | |
} | |
+ | |
+ return graph_task.captured_vars; | |
} | |
void Engine::queue_callback(std::function<void()> callback) { | |
@@ -444,4 +440,69 @@ auto Engine::start_threads() -> void { | |
} | |
} | |
+void GraphTask::init_to_execute(Function& graph_root, const edge_list& outputs) { | |
+ exec_info[&graph_root].needed = true; | |
+ | |
+ int output_idx = 0; | |
+ for (auto & output_edge : outputs) { | |
+ Function *output = output_edge.function.get(); | |
+ auto & info = exec_info[output]; | |
+ if (!info.captures) | |
+ info.captures.reset(new std::vector<ExecInfo::Capture>()); | |
+ info.captures->emplace_back(output_edge.input_nr, output_idx++); | |
+ } | |
+ captured_vars.resize(output_idx); | |
+ | |
+ // NB: this is an uglier version (recursion replaced with iteration) of the following code: | |
+ // is_needed = {} | |
+ // def compute_is_needed(fn): | |
+ // if fn not in is_needed: | |
+ // is_needed[fn] = any(compute_is_needed(next_edge) | |
+ // for next_edge in fn.next_edges) | |
+ // return is_needed[fn] | |
+ struct Frame { | |
+ Frame (Function *fn) : fn(fn), next_next_fn(0) {} | |
+ Function *fn; | |
+ std::size_t next_next_fn; | |
+ | |
+ Function* get_next_fn() { | |
+ const auto & next = fn->next_edges(); | |
+ auto num_next = next.size(); | |
+ while (next_next_fn < num_next) { | |
+ auto fn = next[next_next_fn++].function.get(); | |
+ if (fn) return fn; | |
+ } | |
+ return nullptr; | |
+ } | |
+ }; | |
+ std::vector<Frame> stack; | |
+ std::unordered_set<Function*> seen; | |
+ for (const auto & input : graph_root.next_edges()) { | |
+ if (seen.count(input.function.get()) > 0) continue; | |
+ stack.emplace_back(input.function.get()); | |
+ while (!stack.empty()) { | |
+ auto &frame = stack.back(); | |
+ if (Function *next_fn = frame.get_next_fn()) { | |
+ if (/* bool unseen = */ seen.emplace(next_fn).second) { | |
+ stack.emplace_back(next_fn); | |
+ continue; // recurse | |
+ } | |
+ } else { | |
+ // NB: if we were using real recursion we could have saved some lookups | |
+ // using a return value from recursive call. It would make this manually unrolled | |
+ // version a lot more complicated, so I skipped that. | |
+ const auto & next_edges = frame.fn->next_edges(); | |
+ const bool needed = std::any_of( | |
+ next_edges.begin(), next_edges.end(), [&](const Edge& edge) { | |
+ auto it = exec_info.find(edge.function.get()); | |
+ return it != exec_info.end() && it->second.should_execute(); | |
+ }); | |
+ exec_info[frame.fn].needed = needed; | |
+ stack.pop_back(); | |
+ } | |
+ } | |
+ } | |
+} | |
+ | |
+ | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h | |
index ff1ae45c..eac677a3 100644 | |
--- a/torch/csrc/autograd/engine.h | |
+++ b/torch/csrc/autograd/engine.h | |
@@ -27,32 +27,21 @@ struct Engine { | |
virtual ~Engine(); | |
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>; | |
- using function_queue = std::vector<Function*>; | |
using dependencies_type = std::unordered_map<Function*, int>; | |
- using pre_callback_type = std::function<bool (Function*, variable_list&)>; | |
- using pre_callback_map = std::unordered_multimap<Function*, pre_callback_type>; | |
- using post_callback_type = std::function<bool (Function*, variable_list&, variable_list&)>; | |
- using post_callback_map = std::unordered_multimap<Function*, post_callback_type>; | |
- | |
// Given a list of (Function, input number) pairs computes the value of the graph | |
- // by following next_function references. | |
- virtual void execute( | |
- const function_list& roots, | |
+ // by following next_edge references. | |
+ virtual variable_list execute( | |
+ const edge_list& roots, | |
const variable_list& inputs, | |
bool keep_graph, | |
- const pre_callback_map& pre_callbacks = pre_callback_map(), | |
- const post_callback_map& post_callbacks = post_callback_map()); | |
+ bool create_graph, | |
+ const edge_list& outputs = {}); | |
void queue_callback(std::function<void()> callback); | |
protected: | |
- function_queue find_roots( | |
- const function_list& roots, | |
- variable_list& inputs, | |
- GraphTask& task); | |
- void find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task); | |
- void compute_dependencies(function_queue queue, GraphTask& task); | |
+ void compute_dependencies(Function* root, GraphTask& task); | |
void evaluate_function(FunctionTask& task); | |
ReadyQueue& ready_queue(int device); | |
void start_threads(); | |
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp | |
index b22c3d51..42af461f 100644 | |
--- a/torch/csrc/autograd/function.cpp | |
+++ b/torch/csrc/autograd/function.cpp | |
@@ -1,51 +1,24 @@ | |
-#include "function.h" | |
+#include <Python.h> | |
-#include <string> | |
+#include "torch/csrc/autograd/function.h" | |
-#include "variable.h" | |
-#include "torch/csrc/jit/ir.h" | |
#include "torch/csrc/autograd/functions/special.h" | |
+#include "torch/csrc/autograd/variable.h" | |
+#include "torch/csrc/jit/ir.h" | |
-namespace torch { namespace autograd { | |
+#include <ATen/ATen.h> | |
-template<typename T> | |
-auto makeFlags(const T &inputs) -> FunctionFlags { | |
- int num_inputs = inputs.size(); | |
- FunctionFlags f; | |
- f.is_executable = false; | |
- f.is_volatile = false; | |
- f.next_functions.resize(num_inputs); | |
- { | |
- int i = 0; | |
- for (auto it = inputs.begin(); it != inputs.end(); ++it, ++i) { | |
- auto& var = *it; | |
- if (var.defined()) { | |
- f.is_executable |= var.requires_grad(); | |
- f.is_volatile |= var.is_volatile(); | |
- if (var.grad_fn()) { | |
- f.next_functions[i] = std::make_pair<>(var.grad_fn(), var.output_nr()); | |
- } else { | |
- f.next_functions[i] = std::make_pair<>(var.grad_accumulator(), 0); | |
- } | |
- } | |
- } | |
- } | |
- f.is_executable &= !f.is_volatile; | |
- return f; | |
-} | |
- | |
-auto Function::flags(const variable_list& inputs) -> FunctionFlags { | |
- return makeFlags(inputs); | |
-} | |
+#include <algorithm> | |
+#include <cstdint> | |
+#include <memory> | |
+#include <stdexcept> | |
+#include <string> | |
+#include <utility> | |
+#include <vector> | |
-auto Function::flags(const std::initializer_list<Variable>& inputs) -> FunctionFlags { | |
- return makeFlags(inputs); | |
-} | |
+namespace torch { namespace autograd { | |
-auto Function::flags(at::TensorList inputs) -> FunctionFlags { | |
- // TODO: Eliminate the intermediate vector allocation | |
- return makeFlags(variable_list(inputs.begin(), inputs.end())); | |
-} | |
+thread_local uint64_t Function::next_sequence_nr_ = 0; | |
auto Function::name() -> std::string { | |
return std::string(typeid(*this).name()); | |
@@ -54,7 +27,7 @@ auto Function::name() -> std::string { | |
// This function is analogous to make_trace which operates on PythonOp, but this | |
// function instead works for C++ implemented autograd Functions, which don't | |
// actually have any backing Python class. We still need to trace them! | |
-variable_list Function::tracedApply(variable_list inputs) { | |
+variable_list Function::traced_apply(variable_list inputs) { | |
using namespace torch::jit; | |
// Traceable Functions are completely transparent to the JIT. | |
if (is_traceable()) { | |
@@ -65,7 +38,14 @@ variable_list Function::tracedApply(variable_list inputs) { | |
// Insert a CppOp in the trace. | |
auto& graph = state->graph; | |
- auto* this_node = graph->createCppOp(getSharedPtr()); | |
+ std::vector<VariableFlags> var_flags; | |
+ for(auto & input: inputs) { | |
+ var_flags.push_back(VariableFlags::of(input)); | |
+ } | |
+ auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags)); | |
+ this_node->setSourceLocation(std::make_shared<StringSourceLocation>( | |
+ jit::tracer::getPythonInterpreterStackTrace() | |
+ )); | |
for (auto& input: inputs) { | |
this_node->addInput(tracer::getValueTrace(state, input)); | |
} | |
@@ -80,7 +60,7 @@ variable_list Function::tracedApply(variable_list inputs) { | |
int num_outputs = outputs.size(); | |
for (int i = 0; i < num_outputs; ++i) { | |
auto& output = outputs[i]; | |
- Node* sel = graph->appendNode(graph->createSelect(this_node, i)); | |
+ auto sel = this_node->addOutput(); | |
// TODO: At the moment, C++ does not track shared storage. It | |
// should. Update this when that happens. | |
if (output.defined()) { | |
@@ -97,29 +77,30 @@ variable_list Function::tracedApply(variable_list inputs) { | |
// There's no point in wrapping functions in Eval, if we know they already are | |
// part of another Eval subgraph. This is both a small optimization, and | |
// it allows us to not implement saved_variables() in many functions. | |
- bool should_trace_backward = tracing_state->in_eval_subgraph; | |
+ const bool should_trace_backward = tracing_state_->in_eval_subgraph; | |
if (!should_trace_backward) { | |
auto saved_vars = saved_variables(); | |
if (!saved_vars) | |
- throw std::runtime_error(std::string("saved_variables() needed but not implemented in ") + name()); | |
+ throw std::runtime_error("saved_variables() needed but not implemented in " + name()); | |
variable_list bw_subgraph_inputs(inputs); | |
for (auto& saved_var : *saved_vars) { | |
- bw_subgraph_inputs.emplace_back(saved_var.unpack(getSharedPtr())); | |
+ bw_subgraph_inputs.emplace_back(saved_var.unpack(get_shared_ptr())); | |
} | |
tracer::nontraceableBackwardSubgraph(bw_subgraph_inputs, outputs); | |
} | |
bool has_backwards_eval = !should_trace_backward || this_eval; | |
if (has_backwards_eval) | |
- setUpContextEdge(this_node, num_outputs, inputs, outputs); | |
+ set_up_context_edge(this_node, inputs, outputs); | |
} | |
return outputs; | |
} | |
-void Function::setUpContextEdge(jit::Node* node, int ctx_output_nr, | |
- const variable_list& inputs, const variable_list& outputs) { | |
- jit::Graph* graph = node->owningGraph(); | |
- jit::Node* ctx_select = graph->appendNode(graph->createSelect(node, ctx_output_nr)); | |
- ctx_select->setType(std::make_shared<jit::HandleType>()); | |
+void Function::set_up_context_edge( | |
+ jit::Node* this_node, | |
+ const variable_list& inputs, | |
+ const variable_list& outputs) { | |
+ auto ctx_select = this_node->addOutput(); | |
+ ctx_select->setType(jit::HandleType::get()); | |
auto backward_eval = Eval::getBackwardEval(inputs, outputs); | |
if (backward_eval) | |
backward_eval->forward_ctx_select = ctx_select; | |
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h | |
index a7efd3eb..ed6955c5 100644 | |
--- a/torch/csrc/autograd/function.h | |
+++ b/torch/csrc/autograd/function.h | |
@@ -1,213 +1,374 @@ | |
#pragma once | |
-// Function is an abstract class that represents a single operation from one or | |
-// more variables to one more or variables. | |
-// | |
-// Subclasses may represent "forward" or "backward" operations (i.e functions | |
-// and their derivatives). Some functions may be used as both. | |
- | |
+#include "torch/csrc/assertions.h" | |
+#include "torch/csrc/autograd/edge.h" | |
+#include "torch/csrc/autograd/grad_mode.h" | |
+#include "torch/csrc/autograd/profiler.h" | |
#include "torch/csrc/autograd/saved_variable.h" | |
+#include "torch/csrc/autograd/variable.h" | |
+#include "torch/csrc/jit/tracer.h" | |
#include "torch/csrc/utils/auto_unique_ptr.h" | |
#include "torch/csrc/utils/python_stub.h" | |
-#include "torch/csrc/autograd/function_hook.h" | |
-#include "torch/csrc/autograd/profiler.h" | |
-#include "torch/csrc/jit/tracer.h" | |
+#include "torch/csrc/utils/variadic.h" | |
#include <ATen/ATen.h> | |
+#include <algorithm> | |
+#include <cstdint> | |
+#include <initializer_list> | |
#include <memory> | |
+#include <string> | |
+#include <utility> | |
#include <vector> | |
namespace torch { namespace autograd { | |
-struct Function; | |
-struct Variable; | |
+struct Edge; | |
+struct FunctionPostHook; | |
+struct FunctionPreHook; | |
using tensor_list = std::vector<at::Tensor>; | |
using variable_list = std::vector<Variable>; | |
-using edge_type = std::pair<std::shared_ptr<Function>, int>; | |
-using function_list = std::vector<edge_type>; | |
+using edge_list = std::vector<Edge>; | |
using saved_variable_list = std::vector<SavedVariable>; | |
- | |
-struct edge_hasher { | |
- std::size_t operator()(const edge_type& edge) const { | |
-#define HASH_IDX(idx) std::hash<std::tuple_element<idx, edge_type>::type>()(std::get<idx>(edge)) | |
- // TODO: that's probably a bad hash function, but whatever | |
- return HASH_IDX(0) ^ HASH_IDX(1); | |
- } | |
-}; | |
- | |
-// State used to create "backward" functions | |
-struct FunctionFlags { | |
- // Roughly speaking, is_executable corresponds to requires_grad. | |
- // See http://pytorch.org/docs/notes/autograd.html for more details: | |
- // both is_executable and is_volatile specify whether or not backwards | |
- // gradient computation will be performed for a function, but they differ in | |
- // their precedence. | |
- bool is_executable = false; | |
- bool is_volatile = false; | |
- // What functions take the output of this function as input. | |
- // There is one function per output of this function. | |
- function_list next_functions; | |
-}; | |
- | |
+using IndexRange = std::pair<size_t, size_t>; | |
+ | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// Function | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// A `Function` is an abstract class that represents an operation taking zero | |
+/// or more input `Variable`s and producing zero or more output `Variable`s. All | |
+/// functions in PyTorch's autograd machinery derive from this class and | |
+/// override its `apply` method. Instances of such subclasses will then be | |
+/// invokeable via the call operator. | |
+/// | |
+/// Functions in the Autograd Graph | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// When viewing the autograd system as a graph, `Function`s are the vertices or | |
+/// nodes, connected to each other via (directed) `Edge`s, which themselves are | |
+/// represented via (`Function`, input_nr) pairs. `Variable`s are the outputs to | |
+/// and inputs of `Function`s, and travel between these edges during execution | |
+/// of the graph. When two or more `Edge`s (from different sources) point at the | |
+/// same input to a `Function`, the values produced along all of these edges are | |
+/// implicitly summed prior to being forwarded to the target `Function`. | |
+/// | |
+/// Hierarchy | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// Subclasses usually represent differentiable functions as well as their | |
+/// gradient operators. Note, however, that due to the very general definition | |
+/// of a `Function` taking *zero* or more inputs and producing *zero* or more | |
+/// outputs, uses of `Function`s are flexible and extend beyond purely | |
+/// mathematical operations. For example, the `AccumulateGrad` function is a | |
+/// *sink*: it takes one input, but produces no outputs, instead accumulating | |
+/// the input as a side effect. At the other extreme, the `GraphRoot` function | |
+/// receives no inputs from other functions, but produces multiple outputs. | |
+/// | |
+/// Interface | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// The most important method on `Function` is the call operator, which takes in | |
+/// a list of variables and produces a list of variables. The precise size of | |
+/// these lists can be determined with `num_inputs()` and `num_outputs()`. | |
+/// `Function`s are stitched together via their `next_edge` interface, which let | |
+/// you manipulate the set of outgoing edges of a `Function`. You can add an | |
+/// edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and | |
+/// iterate over them via the `next_edges()` method. Other methods exist for | |
+/// integration with the JIT and other parts of PyTorch. Every `Function` has a | |
+/// *sequence number* that increases monotonically in the order of `Function` | |
+/// construction. It can be retrieved via the `sequence_nr()` method. Note that | |
+/// this sequence number is *thread local*. This means that when `Function`s | |
+/// `A`, `B` and `C` are created consecutively in the same thread, their | |
+/// sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B` | |
+/// are created in one thread and `C` is created in a new thread, there are *no | |
+/// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`. | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
struct Function : std::enable_shared_from_this<Function> { | |
- Function() | |
- : num_inputs(0) | |
- , next_functions() | |
- , is_executable(false) | |
- , is_stochastic(false) | |
- , pre_hooks() | |
- , post_hooks() | |
- , pyobj(nullptr) | |
- {} | |
- | |
- Function(FunctionFlags&& flags) | |
- : num_inputs(0) | |
- , next_functions(std::move(flags.next_functions)) | |
- , is_executable(flags.is_executable) | |
- , is_stochastic(false) | |
- , pre_hooks() | |
- , post_hooks() | |
- , pyobj(nullptr) | |
- {} | |
- | |
+ public: | |
+ /// Construct a new `Function` with `num_inputs` inputs and the given | |
+ /// `next_edges`. | |
+ explicit Function( | |
+ uint32_t num_inputs = 0, | |
+ edge_list&& next_edges = edge_list()) | |
+ : sequence_nr_(next_sequence_nr_++), | |
+ num_inputs_(num_inputs), | |
+ next_edges_(std::move(next_edges)) {} | |
+ | |
+ /// Functions are neither copyable nor moveable. | |
Function(const Function& other) = delete; | |
Function(Function&& other) = delete; | |
- virtual ~Function() {} | |
- | |
- // Implements the operation | |
- // NOTE: Don't call this function directly. Use apply_fn or operator() instead. | |
- virtual variable_list apply(const variable_list& inputs) = 0; | |
- variable_list tracedApply(variable_list inputs); | |
+ Function& operator=(const Function& other) = delete; | |
+ Function& operator=(Function&& other) = delete; | |
+ virtual ~Function() = default; | |
+ /// Evaluates the function on the given inputs and returns the result of the | |
+ /// function call. | |
variable_list operator()(const variable_list& inputs) { | |
profiler::RecordFunction rec(this); | |
- if (jit::tracer::isTracing(inputs)) { | |
- return tracedApply(inputs); | |
+ if (jit::tracer::isTracingVar(inputs)) { | |
+ return traced_apply(inputs); | |
} | |
return apply(inputs); | |
} | |
- // PyFunctions are not managed by shared_ptrs by default, but are bound to the | |
- // lifetime of their Python object instead. | |
- virtual std::shared_ptr<Function> getSharedPtr() { | |
- return shared_from_this(); | |
- }; | |
+ // Graph Connectivity API | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
- // Computes is_executable, is_volatile, and next_functions from a list | |
- // of input variables | |
- static FunctionFlags flags(const variable_list& inputs); | |
- static FunctionFlags flags(const std::initializer_list<Variable>& inputs); | |
- static FunctionFlags flags(at::TensorList inputs); | |
+ // Inputs | |
- // Releases saved variables if the operation won't be reused | |
- virtual inline void releaseVariables() {} | |
+ /// Increments the number of inputs of the function and returns the previous | |
+ /// value. | |
+ uint32_t bump_inputs() noexcept { | |
+ return num_inputs_++; | |
+ } | |
- // Function name for debugging | |
- virtual std::string name(); | |
+ void set_num_inputs(uint32_t num_inputs) noexcept { | |
+ num_inputs_ = num_inputs; | |
+ } | |
- inline bool should_compute_output(int i) const { | |
- auto& fn = next_functions[i].first; | |
- return fn && fn->is_executable; | |
+ uint32_t num_inputs() const noexcept { | |
+ return num_inputs_; | |
} | |
- inline bool should_compute_any_outputs() const { | |
- for (size_t i = 0; i < next_functions.size(); ++i) { | |
- if (should_compute_output((int)i)) { | |
- return true; | |
- } | |
- } | |
- return false; | |
+ // Outputs ("Next Edges") | |
+ | |
+ const Edge& next_edge(size_t index) const noexcept { | |
+ return next_edges_[index]; | |
+ } | |
+ | |
+ void set_next_edge(size_t index, Edge edge) { | |
+ next_edges_[index] = std::move(edge); | |
+ } | |
+ | |
+ void add_next_edge(Edge edge) { | |
+ next_edges_.push_back(std::move(edge)); | |
+ } | |
+ | |
+ void set_next_edges(edge_list&& next_edges) { | |
+ next_edges_ = std::move(next_edges); | |
+ } | |
+ | |
+ const edge_list& next_edges() const noexcept { | |
+ return next_edges_; | |
+ } | |
+ | |
+ edge_list& next_edges() noexcept { | |
+ return next_edges_; | |
+ } | |
+ | |
+ uint32_t num_outputs() const noexcept { | |
+ return next_edges_.size(); | |
+ } | |
+ | |
+ // Miscellaneous Methods | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ /// The sequence number of this `Function`. | |
+ uint64_t sequence_nr() const noexcept { | |
+ return sequence_nr_; | |
} | |
- inline bool should_compute_output(std::initializer_list<int> idxs) const { | |
- return std::any_of(idxs.begin(), idxs.end(), [this](int i) { | |
- return should_compute_output(i); | |
+ /// Returns a shared pointer to `this`. `PyFunction`s are not managed by | |
+ /// `shared_ptr`s by default, but are bound to the lifetime of their Python | |
+ /// object instead. | |
+ virtual std::shared_ptr<Function> get_shared_ptr() { | |
+ return shared_from_this(); | |
+ } | |
+ | |
+ /// Returns the name of the dynamic type of the function, for debugging. | |
+ virtual std::string name(); | |
+ | |
+ /// Returns true if the particular output edge is active, and that particular | |
+ /// output of this function should be computed. | |
+ bool should_compute_output(size_t output_edge_index) const { | |
+ TORCH_ASSERTM(output_edge_index < num_outputs(), "Index out of range"); | |
+ return next_edges_[output_edge_index].is_valid(); | |
+ } | |
+ | |
+ /// Returns true if any of the output edges in any of the ranges are active. | |
+ bool should_compute_output(std::initializer_list<IndexRange> idxs) const { | |
+ return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { | |
+ for (auto i = range.first; i < range.second; i++) { | |
+ if (should_compute_output(i)) | |
+ return true; | |
+ } | |
+ return false; | |
}); | |
} | |
- inline void set_flags(FunctionFlags&& flags) { | |
- is_executable = flags.is_executable; | |
- next_functions = std::move(flags.next_functions); | |
+ jit::tracer::FunctionTracingState& tracing_state() noexcept { | |
+ // Dereferencing will create the `TracingState` if the pointer is empty. | |
+ return *tracing_state_; | |
} | |
- // An op is traceable if all operations happening within apply() are performed | |
- // on autograd Variables (i.e. apply mostly instantiates and applies other functions). | |
- virtual inline bool is_traceable() { return false; }; | |
+ /// Returns the `PyObject` stored for this `Function` (for Python | |
+ /// interaction). | |
+ PyObject* pyobj() const noexcept { | |
+ return pyobj_; | |
+ } | |
- // An op is said to pass state transparently to backward, if the state consists | |
- // only of (Saved)Variables and only non-variable objects that parametrize the | |
- // operation in some way that defines the graph structure AND the backward function | |
- // is traceable. In particular, parametrization MUST NOT depend on the data | |
- // of any Variable. | |
- // TODO: it might be possible to handle cases where backward is non-traceable | |
- // but state passing could be considered transparent. This will probably depend | |
- // on saved_variable_list being mutable. | |
- // NOTE: this value matters only if is_traceable() returns false. | |
- virtual inline bool passes_state_transparently() { return false; }; | |
+ /// Sets the `PyObject` stored for this `Function` (for Python interaction). | |
+ void set_pyobj(PyObject* pyobj) noexcept { | |
+ pyobj_ = pyobj; | |
+ } | |
- // Let's the JIT find inputs to apply that are not present explicitly in arguments. | |
- // Required only for functions that are not traceable, don't pass state to | |
- // backward transparently, and are not backwards closures of functions that don't | |
- // pass the state transparently. Which means that hopefully they will hardly ever | |
- // need to be implemented :) | |
- virtual inline std::unique_ptr<saved_variable_list> saved_variables() { return nullptr; } | |
+ /// Create a context edge for the JIT. | |
+ static void set_up_context_edge( | |
+ jit::Node* this_node, | |
+ const variable_list& inputs, | |
+ const variable_list& outputs); | |
- static void setUpContextEdge(jit::Node* this_node, int ctx_output_nr, | |
- const variable_list& inputs, const variable_list& outputs); | |
+ // Hook API | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
- int num_inputs; | |
- function_list next_functions; | |
- bool is_executable; | |
- bool is_stochastic; | |
- std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks; | |
- std::vector<std::shared_ptr<FunctionPostHook>> post_hooks; | |
+ void add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) { | |
+ post_hooks_.push_back(std::move(post_hook)); | |
+ } | |
- PyObject *pyobj; // weak reference | |
+ const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const | |
+ noexcept { | |
+ return post_hooks_; | |
+ } | |
- auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state; | |
-}; | |
+ std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept { | |
+ return post_hooks_; | |
+ } | |
-// Actually what is a ForwardFunction here applies to all functions that are | |
-// applied only in forward OR are backward closures that don't save any Variables. | |
-// I chose this name, because the second situation is quite rare. | |
-template<bool transparent_state = false> | |
-struct ForwardFunction : public Function { | |
- using Function::Function; | |
+ void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) { | |
+ pre_hooks_.push_back(std::move(pre_hook)); | |
+ } | |
- virtual inline std::unique_ptr<saved_variable_list> saved_variables() final { | |
- return std::unique_ptr<saved_variable_list>(new saved_variable_list()); | |
+ const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() const | |
+ noexcept { | |
+ return pre_hooks_; | |
} | |
- virtual inline bool is_traceable() final { return false; }; | |
+ std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept { | |
+ return pre_hooks_; | |
+ } | |
- virtual inline bool passes_state_transparently() final { return transparent_state; }; | |
-}; | |
+ // Customization Points for Subclasses | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
-// See Function::is_traceable() for definition. | |
-struct TraceableFunction : public Function { | |
- using Function::Function; | |
+ /// Releases saved variables if the operation won't be reused. | |
+ virtual void release_variables() {} | |
- virtual inline bool is_traceable() final { return true; }; | |
-}; | |
+ /// Called before an apply if `release_variables()` is going to be called. | |
+ /// Allows larger ops like `InterpreterAutogradFunction` to incrementally | |
+ /// release variables as they run. | |
+ virtual void will_release_variables() {} | |
-template<typename T> | |
-struct apply_fn { | |
- template<typename... Args> | |
- apply_fn(Args&& ...args) | |
- : fn_(std::make_shared<T>(std::forward<Args>(args)...)) {} | |
+ /// Returns true if this function is traceable. An op is traceable if all | |
+ /// operations happening within `apply()` are performed on autograd | |
+ /// `Variables` (i.e. apply mostly instantiates and applies other functions). | |
+ virtual bool is_traceable() { | |
+ return false; | |
+ } | |
- Variable operator()(const variable_list& inputs) { | |
- return (*fn_)(inputs)[0]; | |
+ /// A `Function` is said to pass state transparently to backward, if the | |
+ /// state consists only of (Saved)Variables and only non-variable objects | |
+ /// that parameterize the operation in some way that defines the graph | |
+ /// structure AND the backward function is traceable. In particular, | |
+ /// parametrization MUST NOT depend on the data of any `Variable`. | |
+ /// TODO: it might be possible to handle cases where backward is | |
+ /// non-traceable but state passing could be considered transparent. This | |
+ /// will probably depend on saved_variable_list being mutable. | |
+ /// NOTE: this value matters only if is_traceable() returns false. | |
+ virtual bool passes_state_transparently() { | |
+ return false; | |
} | |
- template<typename... Args> | |
- Variable operator()(Args&& ...inputs) { | |
- return (*fn_)(variable_list{inputs...})[0]; | |
+ /// Returns `Variable`s saved by this `Function`. | |
+ /// This let's the JIT find inputs to apply that are not present explicitly | |
+ /// in arguments. Required only for functions that are not traceable, don't | |
+ /// pass state to backward transparently, and are not backwards closures of | |
+ /// functions that don't pass the state transparently. Which means that | |
+ /// hopefully they will hardly ever need to be implemented :) | |
+ virtual std::unique_ptr<saved_variable_list> saved_variables() { | |
+ return nullptr; | |
} | |
- std::shared_ptr<T> fn_; | |
+ protected: | |
+ /// Monotonically incrementing (thread local!) counter to supply sequence | |
+ /// numbers. | |
+ static thread_local uint64_t next_sequence_nr_; | |
+ | |
+ /// Performs the `Function`'s actual operation. | |
+ virtual variable_list apply(const variable_list& inputs) = 0; | |
+ | |
+ /// Calls `apply()`, but instruments it with tracing machinery. | |
+ variable_list traced_apply(variable_list inputs); | |
+ | |
+ // Since `Function`s are neither copyable nor moveable, we can have const | |
+ // fields. | |
+ const uint64_t sequence_nr_; | |
+ | |
+ uint32_t num_inputs_; | |
+ edge_list next_edges_; | |
+ PyObject* pyobj_ = nullptr; // weak reference | |
+ std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; | |
+ std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; | |
+ auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state_; | |
}; | |
+/// See Function::is_traceable() for definition. | |
+struct TraceableFunction : public Function { | |
+ using Function::Function; | |
+ bool is_traceable() final override { | |
+ return true; | |
+ } | |
+}; | |
+ | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+// Associated Free Functions | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+namespace detail { | |
+// Implementation of `collect_next_edges` (see below). | |
+struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> { | |
+ edge_list next_edges; | |
+ using IterArgs<MakeNextFunctionList>::operator(); | |
+ void operator()(const Variable& variable) { | |
+ if (variable.defined()) { | |
+ next_edges.push_back(variable.gradient_edge()); | |
+ } else { | |
+ next_edges.emplace_back(); | |
+ } | |
+ } | |
+}; | |
+} // namespace detail | |
+ | |
+/// Create an `Edge` between the given `variable` and the `function`, which is | |
+/// assumed to be the gradient function of this variable (i.e. the function | |
+/// through which this variable is backpropagated during the backward pass). | |
+/// This sets the `grad_fn` property of the `variable`. This function assumes | |
+/// that the `Variable` is a new input to the gradient function and its | |
+/// `input_nr` thus equal to `function->num_inputs()`. Additionally, it | |
+/// increments the `Function`'s number of inputs by one. Approximately | |
+/// equivalent to `variable.set_gradient_edge(function, | |
+/// function->bump_inputs())`. If you don't want the `Function`'s `num_inputs` | |
+/// to be incremented, use `set_gradient_edge` directly. | |
+inline void create_gradient_edge( | |
+ Variable& variable, | |
+ std::shared_ptr<Function> function) { | |
+ // Copy before move. | |
+ const auto input_nr = function->bump_inputs(); | |
+ variable.set_gradient_edge({std::move(function), input_nr}); | |
+} | |
+ | |
+/// Return true if any of the variables in the list require a gradient. | |
+inline bool any_variable_requires_grad(const variable_list& variables) { | |
+ return std::any_of( | |
+ variables.begin(), variables.end(), [](const Variable& variable) { | |
+ return variable.defined() && variable.requires_grad(); | |
+ }); | |
+} | |
+ | |
+/// Return the next edges of all the given variables, or tuples of variables. | |
+template <typename... Variables> | |
+edge_list collect_next_edges(Variables&&... variables) { | |
+ if (!GradMode::is_enabled()) | |
+ return {}; | |
+ detail::MakeNextFunctionList make; | |
+ make.apply(std::forward<Variables>(variables)...); | |
+ return std::move(make.next_edges); | |
+} | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h | |
index 4c195917..03c52fea 100644 | |
--- a/torch/csrc/autograd/function_hook.h | |
+++ b/torch/csrc/autograd/function_hook.h | |
@@ -1,6 +1,5 @@ | |
#pragma once | |
-#include <memory> | |
#include <vector> | |
// A hook that's called on gradients | |
diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp | |
index 9c1ff87e..d8c9457c 100644 | |
--- a/torch/csrc/autograd/functions/accumulate_grad.cpp | |
+++ b/torch/csrc/autograd/functions/accumulate_grad.cpp | |
@@ -1,5 +1,8 @@ | |
-#include "accumulate_grad.h" | |
+#include <Python.h> | |
+#include "torch/csrc/autograd/functions/accumulate_grad.h" | |
+ | |
+#include "torch/csrc/autograd/grad_mode.h" | |
#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/autograd/functions/basic_ops.h" | |
#include "torch/csrc/autograd/functions/tensor.h" | |
@@ -11,10 +14,7 @@ using at::Tensor; | |
namespace torch { namespace autograd { | |
AccumulateGrad::AccumulateGrad(Variable variable_) | |
- : variable(std::move(variable_)) { | |
- num_inputs = 1; | |
- is_executable = 1; | |
-} | |
+ : Function(/*num_inputs=*/1), variable(std::move(variable_)) {} | |
auto AccumulateGrad::apply(const variable_list& grads) -> variable_list { | |
// XXX: this method is not thread-safe! | |
@@ -24,8 +24,6 @@ auto AccumulateGrad::apply(const variable_list& grads) -> variable_list { | |
return {}; | |
if (variable.grad_fn()) | |
throw std::logic_error("leaf variable has been moved into the graph interior"); | |
- if (variable.current_version() != 0) | |
- throw std::runtime_error("leaf variable was used in an inplace operation"); | |
if (!variable.requires_grad()) | |
return {}; | |
@@ -34,31 +32,22 @@ auto AccumulateGrad::apply(const variable_list& grads) -> variable_list { | |
new_grad = (*hook)({new_grad})[0]; | |
} | |
- // TODO: Currently if var.grad is volatile and new-grad is non-volatile we | |
- // accumulate in-place. We should reconsider this and perhaps add the | |
- // gradients out-of-place. | |
- | |
auto& grad = variable.grad(); | |
if (!grad.defined()) { | |
- grad = apply_fn<Clone>()(new_grad); | |
- } else if (grad.is_volatile()) { | |
+ variable.grad() = new_grad.clone(); | |
+ } else if (!GradMode::is_enabled()) { | |
// This case is not strictly necessary, but it makes the first-order only case | |
// slightly more efficient and, what's more important, more predictable for | |
// the users. Thanks to this case we can avoid changing the grad tensor, | |
// a thing never promised and documented, but used in some hacks seen | |
// on the internet. | |
- AutoGPU guard(grad); | |
- if (grad.type().isSparse() && !new_grad.type().isSparse()) { | |
+ if (grad.type().is_sparse() && !new_grad.type().is_sparse()) { | |
grad.data() = new_grad.data() + grad.data(); | |
} else { | |
grad.data() += new_grad.data(); | |
} | |
} else { | |
- // If grad is non-volatile, it should stay like that | |
- if (new_grad.is_volatile()) { | |
- new_grad = make_variable(new_grad.data()); | |
- } | |
- variable.grad() = apply_fn<Add>()(grad, new_grad); | |
+ variable.grad() = grad + new_grad; | |
} | |
return variable_list(); | |
diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h | |
index ff6dd0b8..4a3e3bd2 100644 | |
--- a/torch/csrc/autograd/functions/accumulate_grad.h | |
+++ b/torch/csrc/autograd/functions/accumulate_grad.h | |
@@ -6,11 +6,11 @@ | |
namespace torch { namespace autograd { | |
struct AccumulateGrad : public Function { | |
- AccumulateGrad(Variable variable); | |
+ explicit AccumulateGrad(Variable variable); | |
virtual variable_list apply(const variable_list& inputs) override; | |
Variable variable; | |
}; | |
-}} | |
+}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp | |
index 2b2a7eb8..813706ad 100644 | |
--- a/torch/csrc/autograd/functions/basic_ops.cpp | |
+++ b/torch/csrc/autograd/functions/basic_ops.cpp | |
@@ -1,9 +1,15 @@ | |
-#include "basic_ops.h" | |
+#include "torch/csrc/autograd/functions/basic_ops.h" | |
+#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/autograd/functions/utils.h" | |
#include "torch/csrc/utils/auto_gpu.h" | |
+#include <ATen/ATen.h> | |
+ | |
+#include <memory> | |
+#include <utility> | |
+ | |
namespace torch { namespace autograd { | |
auto Error::apply(const variable_list& grad_outputs) -> variable_list { | |
@@ -15,51 +21,11 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list { | |
outputs.reserve(inputs.size()); | |
for (auto& var : inputs) { | |
// FIXME: share version counters | |
- outputs.emplace_back(var.defined() ? var.data() : Tensor()); | |
- } | |
- return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) { | |
- return std::make_shared<Error>(msg, std::move(f)); | |
- }); | |
-}; | |
- | |
-auto Add::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("Add", inputs, 2); | |
- auto& input1 = inputs[0].data(); | |
- auto& input2 = inputs[1].data(); | |
- AutoGPU guard(input1); | |
- | |
- at::Tensor output; | |
- if (input1.type().isSparse()) { | |
- output = input2 + input1; | |
- } else { | |
- output = input1 + input2; | |
+ outputs.emplace_back(var.defined() ? var.data() : at::Tensor()); | |
} | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<AddBackward_Deprecated>(std::move(f)); | |
- }); | |
-}; | |
- | |
-auto AddBackward_Deprecated::apply(const variable_list& grad_outputs) -> variable_list { | |
- check_input_variables("AddBackward_Deprecated", grad_outputs, 1); | |
- return {grad_outputs[0], grad_outputs[0]}; | |
-}; | |
- | |
-auto Mul::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("Mul", inputs, 2); | |
- AutoGPU guard(inputs[0]); | |
- auto& input1 = inputs[0].data(); | |
- auto& input2 = inputs[1].data(); | |
- | |
- auto output = input1 * input2; | |
- | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<MulBackward>(std::move(f)); | |
+ return wrap_outputs(inputs, std::move(outputs), [&](edge_list&& next_edges) { | |
+ return std::make_shared<Error>(msg, std::move(next_edges)); | |
}); | |
}; | |
-auto MulBackward::apply(const variable_list& grad_outputs) -> variable_list { | |
- check_input_variables("MulBackward", grad_outputs, 1); | |
- throw std::runtime_error("MulBackward::apply not implemented"); | |
-}; | |
- | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h | |
index 0c165d53..a37934cb 100644 | |
--- a/torch/csrc/autograd/functions/basic_ops.h | |
+++ b/torch/csrc/autograd/functions/basic_ops.h | |
@@ -1,18 +1,20 @@ | |
#pragma once | |
#include <Python.h> | |
-#include <memory> | |
-#include <string> | |
#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/autograd/symbolic.h" | |
+#include <memory> | |
+#include <string> | |
+#include <vector> | |
+ | |
namespace torch { namespace autograd { | |
struct Error : public Function { | |
- Error(std::string msg, FunctionFlags&& flags) | |
- : Function(std::move(flags)) | |
+ Error(std::string msg, edge_list&& next_edges) | |
+ : Function(/*num_inputs=*/0, std::move(next_edges)) | |
, msg(std::move(msg)) {} | |
Error(std::string msg) | |
@@ -34,11 +36,9 @@ struct DelayedError : public Function { | |
}; | |
struct GraphRoot : public Function { | |
- GraphRoot(function_list functions, variable_list inputs) | |
- : outputs(std::move(inputs)) { | |
- next_functions = std::move(functions); | |
- is_executable = true; | |
- }; | |
+ GraphRoot(edge_list functions, variable_list inputs) | |
+ : Function(/*num_inputs=*/0, std::move(functions)), | |
+ outputs(std::move(inputs)) {} | |
virtual variable_list apply(const variable_list& inputs) { | |
return outputs; | |
@@ -47,33 +47,4 @@ struct GraphRoot : public Function { | |
variable_list outputs; | |
}; | |
-struct Add : public ForwardFunction<true>, public HasSymbolic { | |
- Add() {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override; | |
-}; | |
- | |
- | |
-struct AddBackward_Deprecated : public Function { | |
- AddBackward_Deprecated(FunctionFlags&& flags) | |
- : Function(std::move(flags)) {} | |
- | |
- virtual variable_list apply(const variable_list& gradOutputs) override; | |
- virtual bool is_traceable() override { return true; } | |
-}; | |
- | |
-struct Mul : public ForwardFunction<> { | |
- Mul() {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
-}; | |
- | |
-struct MulBackward : public Function { | |
- MulBackward(FunctionFlags&& flags) | |
- : Function(std::move(flags)) {} | |
- | |
- virtual variable_list apply(const variable_list& gradOutputs) override; | |
-}; | |
- | |
}} | |
diff --git a/torch/csrc/autograd/functions/batch_normalization.cpp b/torch/csrc/autograd/functions/batch_normalization.cpp | |
deleted file mode 100644 | |
index 13d83881..00000000 | |
--- a/torch/csrc/autograd/functions/batch_normalization.cpp | |
+++ /dev/null | |
@@ -1,256 +0,0 @@ | |
-#include "batch_normalization.h" | |
- | |
-#include "torch/csrc/autograd/python_function.h" | |
-#include "torch/csrc/autograd/python_variable.h" | |
-#include "torch/csrc/autograd/variable.h" | |
-#include "torch/csrc/autograd/functions/utils.h" | |
-#include "torch/csrc/autograd/functions/basic_ops.h" | |
-#include "torch/csrc/utils/auto_gil.h" | |
-#include "torch/csrc/utils/auto_gpu.h" | |
-#include "torch/csrc/DynamicTypes.h" | |
-#include "torch/csrc/Exceptions.h" | |
-#include <sstream> | |
- | |
-#ifdef WITH_CUDNN | |
-#include "torch/csrc/cudnn/BatchNorm.h" | |
-#include "torch/csrc/cudnn/Handles.h" | |
-#include "torch/csrc/cudnn/Types.h" | |
-extern THCState* state; | |
-#endif | |
- | |
-namespace { | |
- void check_dims_match_num_input_features(const std::string& arg_name, long expected, long actual){ | |
- if (actual != expected){ | |
- std::stringstream ss; | |
- ss << arg_name << " should contain " << expected << " elements not " << actual ; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- } | |
-} | |
- | |
-namespace torch { namespace autograd { | |
- | |
-#ifndef CUDNN_BN_MIN_EPSILON | |
-#define CUDNN_BN_MIN_EPSILON 0 | |
-#endif | |
- | |
-auto BatchNormForward::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("BatchNorm", inputs, 3, 1); | |
- | |
- AutoGPU guard(inputs[0]); | |
- auto& input = inputs[0]; | |
- auto& weight = inputs[1]; | |
- auto& bias = inputs[2]; | |
- | |
- auto num_features = input.sizes()[1]; | |
- check_dims_match_num_input_features("running_mean", num_features, running_mean.numel()); | |
- check_dims_match_num_input_features("running_var", num_features, running_var.numel()); | |
- if (weight.defined()) { | |
- check_dims_match_num_input_features("weight", num_features, weight.numel()); | |
- } | |
- if (bias.defined()) { | |
- check_dims_match_num_input_features("bias", num_features, bias.numel()); | |
- } | |
- | |
- bool use_cudnn = false; | |
-#ifdef WITH_CUDNN | |
- use_cudnn = (input.type().isCuda() | |
- && (input.type().scalarType() != at::kHalf | |
- || weight.type().scalarType() == at::kFloat) | |
- && weight.defined() && bias.defined() | |
- && input.size(0) <= 131070 | |
- && cudnn_enabled && CUDNN_VERSION >= 5110L); | |
-#endif | |
- | |
- auto input_data = input.data(); | |
- Tensor output; | |
- auto save_mean = running_mean.type().tensor(running_mean.sizes()); | |
- auto save_std = running_var.type().tensor(running_var.sizes()); | |
- | |
- if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) { | |
-#ifdef WITH_CUDNN | |
- output = input_data.type().tensor(input.sizes()); | |
- torch::cudnn::cudnn_batch_norm_forward( | |
- state, | |
- torch::cudnn::getCudnnHandle(), | |
- torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)input.unsafeGetTH(false), | |
- (THVoidTensor*)output.unsafeGetTH(false), | |
- (THVoidTensor*)weight.unsafeGetTH(false), | |
- (THVoidTensor*)bias.unsafeGetTH(false), | |
- (THVoidTensor*)running_mean.unsafeGetTH(false), | |
- (THVoidTensor*)running_var.unsafeGetTH(false), | |
- (THVoidTensor*)save_mean.unsafeGetTH(false), | |
- (THVoidTensor*)save_std.unsafeGetTH(false), | |
- training, | |
- momentum, | |
- eps); | |
-#endif | |
- } else { | |
- output = at::batch_norm_forward( | |
- input_data, weight.opt_data(), bias.opt_data(), | |
- running_mean, running_var, training, momentum, eps, | |
- save_mean, save_std); | |
- } | |
- | |
- auto outputs = as_tensor_list(std::move(output)); | |
- return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) { | |
- return std::make_shared<BatchNormBackward>( | |
- f, *this, std::move(save_mean), std::move(save_std), | |
- input, weight, bias); | |
- }); | |
-}; | |
- | |
-auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_list { | |
- check_input_variables("BatchNormBackward", grad_outputs, 1); | |
- auto input_var = this->input.unpack(); | |
- auto weight_var = this->weight.unpack(); | |
- auto bias_var = this->bias.unpack(); | |
- | |
- auto input = input_var.data(); | |
- auto weight = weight_var.opt_data(); | |
- auto bias = bias_var.opt_data(); | |
- | |
- AutoGPU guard(input); | |
- | |
- bool use_cudnn = false; | |
-#ifdef WITH_CUDNN | |
- use_cudnn = (input.type().backend() == at::kCUDA | |
- && (input.type().scalarType() != at::kHalf | |
- || weight.type().scalarType() == at::kFloat) | |
- && weight.defined() && bias.defined() && training | |
- && input.size(0) <= 131070 | |
- && cudnn_enabled && CUDNN_VERSION >= 5110L); | |
-#endif | |
- | |
- at::Tensor grad_input; | |
- at::Tensor grad_weight; | |
- at::Tensor grad_bias; | |
- | |
- auto grad_output = grad_outputs[0].data().contiguous(); | |
- | |
- if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) { | |
-#ifdef WITH_CUDNN | |
- grad_input = input.type().tensor(input.sizes()); | |
- grad_weight = weight.type().tensor(weight.sizes()); | |
- grad_bias = bias.type().tensor(bias.sizes()); | |
- torch::cudnn::cudnn_batch_norm_backward( | |
- state, | |
- torch::cudnn::getCudnnHandle(), | |
- torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)input.unsafeGetTH(false), | |
- (THVoidTensor*)grad_output.unsafeGetTH(false), | |
- (THVoidTensor*)grad_input.unsafeGetTH(false), | |
- (THVoidTensor*)grad_weight.unsafeGetTH(false), | |
- (THVoidTensor*)grad_bias.unsafeGetTH(false), | |
- (THVoidTensor*)weight.unsafeGetTH(false), | |
- (THVoidTensor*)running_mean.unsafeGetTH(false), | |
- (THVoidTensor*)running_var.unsafeGetTH(false), | |
- (THVoidTensor*)save_mean.unsafeGetTH(false), | |
- (THVoidTensor*)save_std.unsafeGetTH(false), | |
- training, | |
- eps); | |
-#endif | |
- } else { | |
- std::array<bool, 3> mask = { | |
- should_compute_output(0), | |
- should_compute_output(1), | |
- should_compute_output(2), | |
- }; | |
- std::tie(grad_input, grad_weight, grad_bias) = at::batch_norm_backward( | |
- grad_output, input, weight, running_mean, running_var, | |
- training, eps, save_mean, save_std, | |
- mask); | |
- } | |
- | |
- // Add saved variables used out of the pure autograd to inputs | |
- variable_list all_inputs(grad_outputs); | |
- all_inputs.push_back(input_var); | |
- if (weight.defined()) { | |
- all_inputs.push_back(weight_var); | |
- } | |
- auto outputs = as_tensor_list(std::move(grad_input), | |
- std::move(grad_weight), | |
- std::move(grad_bias)); | |
- return wrap_outputs(all_inputs, std::move(outputs), [&](FunctionFlags f) { | |
- return std::make_shared<BatchNormBackwardBackward>( | |
- f, *this, save_mean, save_std, | |
- input_var, weight_var, | |
- grad_outputs[0]); | |
- }); | |
-}; | |
- | |
-auto BatchNormBackward::releaseVariables() -> void { | |
- input.data.reset(); | |
- weight.data.reset(); | |
- bias.data.reset(); | |
-} | |
- | |
-Variable getReturnTupleVar(PyObject *p, Py_ssize_t pos) { | |
- PyObject *item = PyTuple_GET_ITEM(p, pos); | |
- if (item != Py_None) { | |
- return ((THPVariable*)item)->cdata; | |
- } | |
- return Variable(); | |
-} | |
- | |
-auto BatchNormBackwardBackward::apply(const variable_list& grad_grad_inputs) -> variable_list { | |
- check_input_variables("BatchNormBackwardBackward", grad_grad_inputs, 3, 0); | |
- auto ggI = grad_grad_inputs[0]; | |
- auto ggW = grad_grad_inputs[1]; | |
- auto ggb = grad_grad_inputs[2]; | |
- | |
- auto input_var = input.unpack(); | |
- AutoGPU guard(input_var); | |
- | |
- auto weight_var = weight.unpack(); | |
- auto gO_var = grad_output.unpack(); | |
- | |
- auto input = input_var.data(); | |
- AutoGIL gil; | |
- | |
- THPObjectPtr input_pvar(THPVariable_Wrap(input_var)); | |
- THPObjectPtr weight_pvar(THPVariable_Wrap(weight_var)); | |
- | |
- THPObjectPtr ggi_pvar(THPVariable_Wrap(ggI)); | |
- THPObjectPtr ggW_pvar(THPVariable_Wrap(ggW)); | |
- THPObjectPtr ggb_pvar(THPVariable_Wrap(ggb)); | |
- THPObjectPtr gO_pvar(THPVariable_Wrap(gO_var)); | |
- THPObjectPtr eps_py(PyFloat_FromDouble(eps)); | |
- THPObjectPtr save_mean_py(createPyObject(save_mean)); | |
- THPObjectPtr save_std_py(createPyObject(save_std)); | |
- THPObjectPtr running_mean_py(createPyObject(running_mean)); | |
- THPObjectPtr running_var_py(createPyObject(running_var)); | |
- PyObject *training_pyo = training ? Py_True : Py_False; | |
- | |
- THPObjectPtr args(PyTuple_Pack(12, input_pvar.get(), weight_pvar.get(), | |
- ggi_pvar.get(), ggW_pvar.get(), ggb_pvar.get(), | |
- gO_pvar.get(), eps_py.get(), | |
- save_mean_py.get(), save_std_py.get(), | |
- running_mean_py.get(), running_var_py.get(), | |
- training_pyo)); | |
- THPObjectPtr r(PyObject_CallObject(THPBatchNormBackwardBackwardFunction, args.get())); | |
- if (!r) throw python_error(); | |
- if (!PyTuple_Check(r.get())) { | |
- throw std::runtime_error("expected PyTuple return from BatchNormBackwardBackward"); | |
- } | |
- | |
- auto gI_var = getReturnTupleVar(r, 0); | |
- auto gG_var = getReturnTupleVar(r, 1); | |
- auto ggO_var = getReturnTupleVar(r, 2); | |
- | |
- if (weight_var.defined()) { | |
- return {ggO_var, gI_var, gG_var}; | |
- } else { | |
- return {ggO_var, gI_var}; | |
- } | |
-}; | |
- | |
-auto BatchNormBackwardBackward::releaseVariables() -> void { | |
- input.data.reset(); | |
- weight.data.reset(); | |
- grad_output.data.reset(); | |
-} | |
- | |
- | |
-}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/batch_normalization.h b/torch/csrc/autograd/functions/batch_normalization.h | |
deleted file mode 100644 | |
index aafb5d68..00000000 | |
--- a/torch/csrc/autograd/functions/batch_normalization.h | |
+++ /dev/null | |
@@ -1,93 +0,0 @@ | |
-#pragma once | |
- | |
-#include <Python.h> | |
-#include <memory> | |
-#include <ATen/ATen.h> | |
- | |
-#include "torch/csrc/autograd/function.h" | |
-#include "torch/csrc/autograd/variable.h" | |
-#include "torch/csrc/autograd/symbolic.h" | |
-#include "torch/csrc/autograd/saved_variable.h" | |
- | |
-namespace torch { namespace autograd { | |
- | |
-struct BatchNormParams { | |
- at::Tensor running_mean; | |
- at::Tensor running_var; | |
- bool training; | |
- double momentum; | |
- double eps; | |
- bool cudnn_enabled; | |
-}; | |
- | |
-struct BatchNormForward : public ForwardFunction<>, public BatchNormParams, public HasSymbolic { | |
- BatchNormForward(BatchNormParams params) | |
- : BatchNormParams(std::move(params)) {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override; | |
-}; | |
- | |
-struct BatchNormBackward : public Function, public BatchNormParams { | |
- BatchNormBackward( | |
- FunctionFlags flags, | |
- BatchNormParams params, | |
- at::Tensor save_mean, | |
- at::Tensor save_std, | |
- Variable input, | |
- Variable weight, | |
- Variable bias) | |
- : Function(std::move(flags)) | |
- , BatchNormParams(std::move(params)) { | |
- if (is_executable) { | |
- this->save_mean = std::move(save_mean); | |
- this->save_std = std::move(save_std); | |
- this->input = SavedVariable(input, this); | |
- this->weight = SavedVariable(weight, this); | |
- this->bias = SavedVariable(bias, this); | |
- } | |
- } | |
- | |
- virtual variable_list apply(const variable_list& gradOutputs) override; | |
- | |
- virtual void releaseVariables() override; | |
- | |
- at::Tensor save_mean; | |
- at::Tensor save_std; | |
- SavedVariable input; | |
- SavedVariable weight; | |
- SavedVariable bias; | |
-}; | |
- | |
-struct BatchNormBackwardBackward : public Function, public BatchNormParams { | |
- BatchNormBackwardBackward( | |
- FunctionFlags flags, | |
- BatchNormParams params, | |
- at::Tensor save_mean, | |
- at::Tensor save_std, | |
- Variable input, | |
- Variable weight, | |
- Variable grad_output) | |
- : Function(std::move(flags)) | |
- , BatchNormParams(std::move(params)) { | |
- if (is_executable) { | |
- this->save_mean = std::move(save_mean); | |
- this->save_std = std::move(save_std); | |
- this->input = SavedVariable(input, this); | |
- this->weight = SavedVariable(weight, this); | |
- this->grad_output = SavedVariable(grad_output, this); | |
- } | |
- } | |
- | |
- virtual variable_list apply(const variable_list& grad_grad_inputs) override; | |
- | |
- virtual void releaseVariables() override; | |
- | |
- at::Tensor save_mean; | |
- at::Tensor save_std; | |
- SavedVariable input; | |
- SavedVariable weight; | |
- SavedVariable grad_output; | |
-}; | |
- | |
-}} | |
diff --git a/torch/csrc/autograd/functions/convolution.cpp b/torch/csrc/autograd/functions/convolution.cpp | |
deleted file mode 100644 | |
index 7c36242d..00000000 | |
--- a/torch/csrc/autograd/functions/convolution.cpp | |
+++ /dev/null | |
@@ -1,918 +0,0 @@ | |
-#include "convolution.h" | |
- | |
-#include <sstream> | |
- | |
-#include "torch/csrc/autograd/variable.h" | |
-#include "torch/csrc/autograd/functions/utils.h" | |
-#include "torch/csrc/autograd/functions/basic_ops.h" | |
-#include "torch/csrc/autograd/functions/tensor.h" | |
-#include "torch/csrc/utils/auto_gpu.h" | |
- | |
-#include <ATen/ATen.h> | |
- | |
-#ifdef WITH_CUDNN | |
-#include "torch/csrc/cudnn/Conv.h" | |
-#include "torch/csrc/cudnn/Handles.h" | |
-#include "torch/csrc/cudnn/Types.h" | |
-extern THCState* state; | |
-using namespace torch::cudnn; | |
-#endif | |
- | |
-#ifdef WITH_NNPACK | |
-#include "torch/csrc/nnpack/NNPACK.h" | |
-#endif | |
- | |
-using torch::cudnn::Convolution; | |
-using at::Tensor; | |
-using tensor_pair = std::pair<at::Tensor, at::Tensor>; | |
- | |
-namespace torch { namespace autograd { | |
- | |
-// Forward function definition and utility functions | |
- | |
-static at::Tensor compute_output( | |
- at::Tensor& input, at::Tensor& weight, at::Tensor& bias, at::Tensor& columns, at::Tensor& ones, | |
- const ConvForward& params); | |
- | |
-static std::tuple<Tensor, Tensor, Tensor> compute_backward( | |
- at::Tensor& input, at::Tensor& grad_output, at::Tensor& weight, at::Tensor& columns, at::Tensor& ones, | |
- const ConvBackward& params, std::array<bool, 3> output_mask); | |
- | |
-auto ConvParams::is_strided() const -> bool { | |
- bool is_strided = false; | |
- for (int s : stride) { | |
- is_strided |= (s != 1); | |
- } | |
- return is_strided; | |
-} | |
- | |
-auto ConvParams::is_dilated() const -> bool { | |
- bool is_dilated = false; | |
- for (int d : dilation) { | |
- is_dilated |= (d != 1); | |
- } | |
- return is_dilated; | |
-} | |
- | |
-auto ConvParams::is_padded() const -> bool { | |
- bool is_padded = false; | |
- for (int p : padding) { | |
- is_padded |= (p != 0); | |
- } | |
- return is_padded; | |
-} | |
- | |
-auto ConvParams::is_output_padding_neg() const -> bool { | |
- bool is_non_neg = false; | |
- for (int p : output_padding) { | |
- is_non_neg |= (p < 0); | |
- } | |
- return is_non_neg; | |
-} | |
- | |
-auto ConvParams::is_output_padding_big() const -> bool { | |
- bool is_big = false; | |
- for (size_t i = 0; i < output_padding.size(); i++) { | |
- is_big |= (output_padding[i] >= stride[i] || output_padding[i] >= dilation[i]); | |
- } | |
- return is_big; | |
-} | |
- | |
-auto ConvParams::is_padding_neg() const -> bool { | |
- bool is_non_neg = false; | |
- for (int p : padding) { | |
- is_non_neg |= (p < 0); | |
- } | |
- return is_non_neg; | |
-} | |
- | |
- | |
-auto ConvParams::view1d_as_2d() -> void { | |
- if (stride.size() == 1) { | |
- stride.insert(stride.begin(), 1); | |
- padding.insert(padding.begin(), 0); | |
- dilation.insert(dilation.begin(), 1); | |
- output_padding.insert(output_padding.begin(), 0); | |
- } | |
-} | |
- | |
-auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool { | |
-#ifdef WITH_CUDNN | |
- if (!input.type().isCuda() || !cudnn_enabled) { | |
- return false; | |
- } | |
- if (deterministic && is_dilated()) { | |
- // cudnn doesn't support deterministic dilated convolution fully yet | |
- return false; | |
- } | |
- if (is_dilated()) { | |
- cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state); | |
- // NOTE: extra parenthesis around numbers disable clang warnings about dead code | |
- return ((CUDNN_VERSION >= (6021)) || (CUDNN_VERSION >= (6000) && prop->major >= 5)) && !is_output_padding_big(); | |
- } | |
- return !is_output_padding_big(); | |
-#endif | |
- return false; | |
-} | |
- | |
-auto ConvParams::use_nnpack(const at::Tensor& input) const -> bool { | |
-#ifdef WITH_NNPACK | |
- return input.type().ID() == at::TypeID::CPUFloat && // only on CPU Float Tensors | |
- !is_strided() && // doesn't support strides | |
- !is_dilated() && // or dilation | |
- !transposed && // or transposed tensors | |
- input.ndimension() == 4 && // must be in NCHW format | |
- input.size(0) >= 16; // ensure large enough batch size to ensure perf, tuneable | |
-#endif | |
- return false; | |
-} | |
- | |
-// We currently only have depthwise support for the case where groups == | |
-// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of | |
-// a depthwise multiplier) | |
-auto ConvParams::is_depthwise( | |
- const at::Tensor& input, const at::Tensor& weight, int groups) const -> bool { | |
- return input.type().isCuda() && | |
- !transposed && | |
- input.ndimension() == 4 && | |
- input.size(1) == groups && | |
- groups > 1 && // no point if there is only a single group | |
- weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels | |
-} | |
- | |
-std::string ConvForward::name() { return "ConvForward"; } | |
- | |
-auto ConvForward::output_size(at::Tensor& input, at::Tensor& weight) const -> std::vector<int64_t> { | |
- auto in_size = input.sizes(); | |
- auto weight_size = weight.sizes(); | |
- auto dim = input.ndimension(); | |
- | |
- std::vector<int64_t> output_size(dim); | |
- output_size[0] = in_size[0]; | |
- output_size[1] = transposed ? weight_size[1] * groups : weight_size[0]; | |
- for (int d = 2; d < dim; ++d) { | |
- int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1; | |
- if (transposed) { | |
- output_size[d] = (in_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) + | |
- kernel + output_padding[d - 2]; | |
- } else { | |
- output_size[d] = (in_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1; | |
- } | |
- } | |
- return output_size; | |
-} | |
- | |
-static auto view4d(const at::Tensor& tensor) -> at::Tensor { | |
- if (tensor.ndimension() != 3) throw std::runtime_error("expected 3D tensor"); | |
- return tensor.unsqueeze(2); | |
-} | |
- | |
-static auto view3d(const at::Tensor& tensor) -> at::Tensor { | |
- if (tensor.ndimension() != 4) throw std::runtime_error("expected 4D tensor"); | |
- return tensor.squeeze(2); | |
-} | |
- | |
-static void check_input_shape_forward(const at::Tensor& input, | |
- const at::Tensor& weight, const at::Tensor& bias, | |
- int64_t groups, bool transposed) { | |
- int k = input.ndimension(); | |
- | |
- if (weight.ndimension() != k) { | |
- std::stringstream ss; | |
- ss << "Expected " << k << "-dimensional input for " << k | |
- << "-dimensional weight " << weight.sizes() << ", but got input of size " | |
- << input.sizes() << " instead"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- if (weight.size(0) < groups) { | |
- std::stringstream ss; | |
- ss << "Given groups=" << groups << ", expected weight to be at least " | |
- << groups << " at dimension 0, but got weight of size " << weight.sizes() | |
- << " instead"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- | |
- if (!transposed) { | |
- if (input.size(1) != (weight.size(1) * groups)) { | |
- std::stringstream ss; | |
- ss << "Given groups=" << groups << ", weight" << weight.sizes() | |
- << ", so expected input" << input.sizes() << " to have " | |
- << (weight.size(1) * groups) << " channels, but got " << input.size(1) | |
- << " channels instead"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- if (bias.defined() && (bias.ndimension() != 1 || bias.size(0) != weight.size(0))) { | |
- std::stringstream ss; | |
- ss << "Given weight of size " << weight.sizes() | |
- << ", expected bias to be 1-dimensional with " << weight.size(0) << " elements" | |
- << ", but got bias of size " << bias.sizes() << " instead"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- } else { // transposed | |
- if (input.size(1) != weight.size(0)) { | |
- std::stringstream ss; | |
- ss << "Given transposed=" << transposed << ", weight" << weight.sizes() | |
- << ", so expected input" << input.sizes() << " to have " | |
- << weight.size(0) << " channels, but got " << input.size(1) | |
- << " channels instead"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- if (bias.defined() && (bias.ndimension() != 1 || bias.size(0) != weight.size(1) * groups)) { | |
- std::stringstream ss; | |
- ss << "Given transposed=" << transposed << ", weight of size " << weight.sizes() | |
- << ", expected bias to be 1-dimensional with " << weight.size(1) * groups << " elements" | |
- << ", but got bias of size " << bias.sizes() << " instead"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- } | |
-} | |
- | |
-static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) { | |
- if (!tensor.defined()) { | |
- return at::Tensor(); | |
- } | |
- int64_t n = tensor.sizes()[dim] / groups; | |
- return tensor.narrow(dim, n * g, n).contiguous(); | |
-} | |
- | |
-static Variable subvariable(const Variable& var, int dim, int groups, int g) { | |
- int64_t n = var.sizes()[dim] / groups; | |
- auto result = apply_fn<Narrow>(dim, n * g, n)(var); | |
- return result; | |
-} | |
- | |
-static std::vector<int64_t> vecToInt64(const std::vector<int>& src) { | |
- std::vector<int64_t> res(src.size()); | |
- for (size_t i = 0; i < src.size(); i++) { | |
- res[i] = static_cast<int64_t>(src[i]); | |
- } | |
- return res; | |
-} | |
- | |
-static at::Tensor cat(const tensor_list& tensors, int dim) { | |
- int num_inputs = tensors.size(); | |
- if (num_inputs == 0) { | |
- return at::Tensor(); | |
- } | |
- | |
- auto output = tensors[0].type().tensor(); | |
- at::cat_out(output, tensors, dim); | |
- return output; | |
-} | |
- | |
-// ConvForward implementation | |
- | |
-auto ConvForward::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("ConvNd", inputs, 3, 2); | |
- if (is_padding_neg()) throw std::runtime_error("negative padding is not supported"); | |
- if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported"); | |
- | |
- AutoGPU guard(inputs[0]); | |
- | |
- auto input = inputs[0].data().contiguous(); | |
- auto weight = inputs[1].data(); | |
- auto bias = inputs[2].opt_data(); | |
- | |
- check_input_shape_forward(input, weight, bias, groups, transposed); | |
- | |
- int k = input.ndimension(); | |
- | |
- if (k == 3) { | |
- view1d_as_2d(); | |
- input = view4d(input); | |
- weight = view4d(weight); | |
- } | |
- | |
- auto output = input.type().tensor(); | |
- tensor_list columns(groups); | |
- tensor_list ones(groups); | |
- std::unique_ptr<Convolution> convolution; | |
- | |
- if (is_depthwise(input, weight, groups)) { | |
- /* output.resize_(output_size(input, weight)); */ | |
- | |
- auto kernel_size = weight.sizes().slice(2); | |
- auto stride = vecToInt64(this->stride); | |
- auto padding = vecToInt64(this->padding); | |
- auto dilation = vecToInt64(this->dilation); | |
- | |
- output = at::conv_depthwise2d_forward(input, weight, kernel_size, bias, stride, padding, dilation); | |
- } else if (use_cudnn(input)) { | |
-#ifdef WITH_CUDNN | |
- if (input.type().ID() != weight.type().ID()){ | |
- std::stringstream ss; | |
- ss << "Input type (" << input.toString() << ") and weight type (" << weight.toString() << ") should be the same"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- if (bias.defined() && input.type().ID() != bias.type().ID()){ | |
- std::stringstream ss; | |
- ss << "Input type (" << input.toString() << ") and bias type (" << bias.toString() << ") should be the same"; | |
- throw std::runtime_error(ss.str()); | |
- } | |
- | |
- output = input.type().tensor(); | |
- output.resize_(output_size(input, weight)); | |
- if (transposed) { | |
- convolution.reset(cudnn_convolution_transpose_full_forward( | |
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false), | |
- bias.defined() ? (THVoidTensor*)bias.unsafeGetTH(false) : nullptr, (THVoidTensor*)output.unsafeGetTH(false), | |
- padding, stride, dilation, groups, benchmark, deterministic)); | |
- } else { | |
- convolution.reset(cudnn_convolution_full_forward( | |
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false), | |
- bias.defined() ? (THVoidTensor*)bias.unsafeGetTH(false) : nullptr, (THVoidTensor*)output.unsafeGetTH(false), | |
- padding, stride, dilation, groups, benchmark, deterministic)); | |
- } | |
-#endif | |
- } else { | |
- for (int g = 0; g < groups; ++g) { | |
- columns[g] = input.type().tensor(); | |
- ones[g] = input.type().tensor(); | |
- } | |
- if (groups == 1) { | |
- output = compute_output( | |
- input, weight, bias, | |
- columns[0], ones[0], *this); | |
- } else { | |
- tensor_list outputs(groups); | |
- for (int g = 0; g < groups; ++g) { | |
- auto input_g = subtensor(input, 1, groups, g); | |
- auto weight_g = subtensor(weight, 0, groups, g); | |
- auto bias_g = subtensor(bias, 0, groups, g); | |
- outputs[g] = compute_output( | |
- input_g, weight_g, bias_g, | |
- columns[g], ones[g], *this); | |
- } | |
- output = cat(outputs, 1); | |
- } | |
- } | |
- | |
- if (k == 3) { | |
- output = view3d(output); | |
- } | |
- | |
- auto outputs = as_tensor_list(std::move(output)); | |
- return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) { | |
- return std::make_shared<ConvBackward>( | |
- f, *this, | |
- inputs[0], inputs[1], inputs[2], | |
- std::move(columns), std::move(ones), std::move(convolution)); | |
- }); | |
-}; | |
- | |
-// For Convolution strategies that don't implicitly handle grad_bias, we add a helper | |
-// function here to perform it using simple Tensor operators | |
-static at::Tensor compute_grad_bias(const at::Tensor& grad_output) { | |
- // grad_output is in N, C, H, W, we re-shape and reduce over spatial dims and batches | |
- return grad_output.contiguous().view({grad_output.size(0), grad_output.size(1), -1}).sum(0).sum(1); | |
-} | |
- | |
-// ConvBackward implementation | |
- | |
-auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list { | |
- check_input_variables("ConvNdBackward", grad_outputs, 1); | |
- if (is_padding_neg()) throw std::runtime_error("negative padding is not supported"); | |
- if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported"); | |
- | |
- auto input_var = input_.unpack(); | |
- auto weight_var = weight_.unpack(); | |
- auto bias_var = bias_.unpack(); | |
- | |
- auto input = input_var.data(); | |
- auto weight = weight_var.data(); | |
- | |
- AutoGPU guard(input); | |
- | |
- auto bias = bias_var.defined() ? bias_var.data() : Tensor(); | |
- | |
- input = input.contiguous(); | |
- auto grad_output = grad_outputs[0].data().contiguous(); | |
- | |
- int k = input.ndimension(); | |
- if (k == 3) { | |
- input = view4d(input); | |
- weight = view4d(weight); | |
- grad_output = view4d(grad_output); | |
- } | |
- | |
- | |
- bool use_depthwise = this->is_depthwise(input, weight, groups); | |
- bool use_cudnn = this->use_cudnn(input); | |
- | |
- at::Tensor grad_input; | |
- at::Tensor grad_weight; | |
- at::Tensor grad_bias; | |
- | |
- std::array<bool, 3> output_mask = { | |
- should_compute_output(0), | |
- should_compute_output(1), | |
- should_compute_output(2) && bias.defined(), | |
- }; | |
- | |
- if (use_depthwise) { | |
- if (output_mask[0] || output_mask[1]) { | |
- auto kernel_size = weight.sizes().slice(2); | |
- auto stride = vecToInt64(this->stride); | |
- auto padding = vecToInt64(this->padding); | |
- auto dilation = vecToInt64(this->dilation); | |
- | |
- std::tie(grad_input, grad_weight) = at::conv_depthwise2d_backward( | |
- grad_output, input, weight, kernel_size, stride, padding, dilation, | |
- {output_mask[0], output_mask[1]}); | |
- } | |
- | |
- // THCUNN implementation does not handle bias, so we do it ourselves | |
- if (output_mask[2]) { | |
- grad_bias = compute_grad_bias(grad_output); | |
- } | |
- } else if (use_cudnn) { | |
-#ifdef WITH_CUDNN | |
- if (output_mask[0]) { | |
- grad_input = input.type().tensor(); | |
- grad_input.resize_as_(input); | |
- if (transposed) { | |
- // ConvTranspose uses the same kernels as regular convolution | |
- // but swaps forward and backward calls | |
- cudnn_convolution_forward( | |
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false), (THVoidTensor*)grad_input.unsafeGetTH(false), | |
- convolution.get(), benchmark, deterministic); | |
- } else { | |
- cudnn_convolution_backward_data( | |
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)grad_input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false), | |
- convolution.get(), benchmark, deterministic); | |
- } | |
- } | |
- if (output_mask[1] || output_mask[2]) { | |
- grad_weight = weight.type().tensor(); | |
- grad_weight.resize_as_(weight); | |
- cudnn_convolution_backward_filter( | |
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)grad_weight.unsafeGetTH(false), | |
- convolution.get(), benchmark, deterministic); | |
- | |
- if (output_mask[2]) { | |
- grad_bias = bias.type().tensor(); | |
- grad_bias.resize_as_(bias); | |
- cudnn_convolution_backward_bias( | |
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input), | |
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)grad_bias.unsafeGetTH(false), | |
- convolution.get()); | |
- } | |
- } | |
-#endif | |
- } else if (groups == 1) { | |
- std::tie(grad_input, grad_weight, grad_bias) = compute_backward( | |
- input, grad_output, weight, columns[0], ones[0], | |
- *this, output_mask); | |
- } else { | |
- tensor_list grad_inputs(groups); | |
- tensor_list grad_weights(groups); | |
- tensor_list grad_biases(groups); | |
- for (int g = 0; g < groups; ++g) { | |
- auto input_g = subtensor(input, 1, groups, g); | |
- auto grad_output_g = subtensor(grad_output, 1, groups, g); | |
- auto weight_g = subtensor(weight, 0, groups, g); | |
- std::tie(grad_inputs[g], grad_weights[g], grad_biases[g]) = compute_backward( | |
- input_g, grad_output_g, weight_g, columns[g], ones[g], | |
- *this, output_mask); | |
- } | |
- if (output_mask[0]) { | |
- grad_input = cat(grad_inputs, 1); | |
- } | |
- if (output_mask[1]) { | |
- grad_weight = cat(grad_weights, 0); | |
- } | |
- if (output_mask[2]) { | |
- grad_bias = cat(grad_biases, 0); | |
- } | |
- } | |
- | |
- if (k == 3) { | |
- if (grad_input.defined()) { | |
- grad_input = view3d(grad_input); | |
- } | |
- if (grad_weight.defined()) { | |
- grad_weight = view3d(grad_weight); | |
- } | |
- } | |
- | |
- // Add saved variables used out of the pure autograd to inputs | |
- variable_list all_inputs(grad_outputs); | |
- all_inputs.push_back(input_var); | |
- all_inputs.push_back(weight_var); | |
- | |
- auto outputs = as_tensor_list(std::move(grad_input), | |
- std::move(grad_weight), | |
- std::move(grad_bias)); | |
- return wrap_outputs(all_inputs, std::move(outputs), [&](FunctionFlags f) { | |
- return std::make_shared<ConvBackwardBackward>( | |
- f, *this, | |
- input_var, weight_var, | |
- bias_var, grad_outputs[0]); | |
- }); | |
-}; | |
- | |
-auto ConvBackward::releaseVariables() -> void { | |
- input_.data.reset(); | |
- weight_.data.reset(); | |
- bias_.data.reset(); | |
-} | |
- | |
- | |
-// ConvBackwardBackward implementation | |
- | |
-auto ConvBackwardBackward::apply(const variable_list& grad_grad_inputs) -> variable_list { | |
- check_input_variables("ConvNdBackwardBackward", grad_grad_inputs, 3, 0); | |
- | |
- auto ggI = grad_grad_inputs[0]; | |
- auto ggW = grad_grad_inputs[1]; | |
- auto ggb = grad_grad_inputs[2]; | |
- | |
- auto gO = grad_output_.unpack(); | |
- auto weight = weight_.unpack(); | |
- auto input = input_.unpack(); | |
- | |
- AutoGPU guard(input.data()); | |
- | |
- // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb | |
- Variable ggO; | |
- if (ggI.defined()) { | |
- if (weight.type().isCuda()) { | |
- weight = apply_fn<Contiguous>()(weight); | |
- } | |
- ggO = apply_fn<ConvForward>(*this)(ggI, weight, Variable()); | |
- } | |
- | |
- if (ggW.defined()) { | |
- if (ggW.type().isCuda()) { | |
- ggW = apply_fn<Contiguous>()(ggW); | |
- } | |
- auto ggW_term = apply_fn<ConvForward>(*this)(input, ggW, Variable()); | |
- if (ggO.defined()) { | |
- ggO = apply_fn<Add>()(ggO, ggW_term); | |
- } else { | |
- ggO = ggW_term; | |
- } | |
- } | |
- | |
- if (ggb.defined()) { | |
- // View as (1, ggb.size(0), 1, 1...) | |
- | |
- // Expand | |
- std::vector<int64_t> new_size(gO.ndimension(), 1); | |
- new_size[1] = ggb.sizes()[0]; | |
- auto ggb_contiguous = apply_fn<Contiguous>()(ggb); | |
- auto ggb_view = apply_fn<View>(new_size)(ggb_contiguous); | |
- | |
- // Expand | |
- auto ggb_expanded = apply_fn<Expand>(gO.sizes())(ggb_view); | |
- | |
- if (ggO.defined()) { | |
- ggO = apply_fn<Add>()(ggO, ggb_expanded); | |
- } else { | |
- ggO = ggb_expanded; | |
- } | |
- } | |
- | |
- // Compute gW = conv(ggI, gO) | |
- Variable gW; | |
- if (ggI.defined()) { | |
- // Modified params with correct padding | |
- ConvParams gw_conv_params(*this); | |
- | |
- // Disable groups as they are handled separately | |
- auto groups = gw_conv_params.groups; | |
- gw_conv_params.groups = 1; | |
- std::swap(gw_conv_params.dilation, gw_conv_params.stride); | |
- | |
- // Transpose gO and ggI to accumulate over batch | |
- auto gOt = apply_fn<Transpose>(0, 1)(gO); | |
- auto ggIt = apply_fn<Transpose>(0, 1)(ggI); | |
- | |
- Variable gWt; | |
- // Compute conv | |
- if (groups == 1) { | |
- if (gOt.type().isCuda()) { | |
- gOt = apply_fn<Contiguous>()(gOt); | |
- } | |
- | |
- // Compute conv | |
- if (transposed) { | |
- gw_conv_params.transposed = false; | |
- gWt = apply_fn<ConvForward>(gw_conv_params)(gOt, ggIt, Variable()); | |
- } else { | |
- gWt = apply_fn<ConvForward>(gw_conv_params)(ggIt, gOt, Variable()); | |
- } | |
- } else { | |
- variable_list gWt_list(groups); | |
- for (int g = 0; g < groups; ++g) { | |
- auto ggIt_g = subvariable(ggIt, 0, groups, g); | |
- auto gOt_g = subvariable(gOt, 0, groups, g); | |
- if (gOt_g.type().isCuda()) { | |
- gOt_g = apply_fn<Contiguous>()(gOt_g); | |
- } | |
- | |
- // Compute conv | |
- if (transposed) { | |
- gw_conv_params.transposed = false; | |
- gWt_list[g] = apply_fn<ConvForward>(gw_conv_params)(gOt_g, ggIt_g, Variable()); | |
- } else { | |
- gWt_list[g] = apply_fn<ConvForward>(gw_conv_params)(ggIt_g, gOt_g, Variable()); | |
- } | |
- } | |
- | |
- gWt = apply_fn<Cat>(1)(gWt_list); | |
- } | |
- | |
- // Transpose gW to match chan_in and chan_out | |
- gW = apply_fn<Transpose>(0, 1)(gWt); | |
- | |
- // narrow gW to only relevant portion | |
- // we do it this way instead of narrowing the input itself because | |
- // the ConvForward kernels don't support asymmetric padding. | |
- auto gW_size = gW.sizes(); | |
- auto w_size = weight.sizes(); | |
- for (size_t i = 2; i < gW_size.size(); ++i) { | |
- if (gW_size[i] > w_size[i]) { | |
- gW = apply_fn<Narrow>(i, 0, w_size[i])(gW); | |
- gW_size = gW.sizes(); | |
- } | |
- } | |
- } | |
- | |
- // Compute gI = convT(ggW, gO.t()) if !transposed | |
- // gI = conv(go, ggw) if transposed | |
- Variable gI; | |
- if (ggW.defined()) { | |
- ConvParams gi_conv_params(*this); | |
- gi_conv_params.transposed = !transposed; | |
- | |
- if (transposed) { | |
- if (gO.type().isCuda()) { | |
- gO = apply_fn<Contiguous>()(gO); | |
- } | |
- gI = apply_fn<ConvForward>(gi_conv_params)(gO, ggW, Variable()); | |
- | |
- // narrow gI to only relevant portion | |
- // we do it this way because negative output_padding is not supported | |
- // TODO: figure out if we can narrow gO and save some compute, | |
- // rather than narrowing the computed gI | |
- auto gI_size = gI.sizes(); | |
- auto i_size = input.sizes(); | |
- for (size_t i = 2; i < gI_size.size(); ++i) { | |
- if (gI_size[i] > i_size[i]) { | |
- gI = apply_fn<Narrow>(i, 0, i_size[i])(gI); | |
- gI_size = gI.sizes(); | |
- } | |
- } | |
- } else { | |
- auto groups = gi_conv_params.groups; | |
- gi_conv_params.groups = 1; | |
- // swap stride and dilation | |
- std::swap(gi_conv_params.dilation, gi_conv_params.stride); | |
- | |
- auto ggWt = apply_fn<Transpose>(0, 1)(ggW); | |
- auto gOt = apply_fn<Transpose>(0, 1)(gO); | |
- | |
- // calculate output_padding | |
- auto kernel_size = weight.sizes().slice(2); | |
- auto input_shape = input.sizes().slice(2); | |
- auto grad_output_shape = gO.sizes().slice(2); | |
- | |
- if (kernel_size.size() == 1) { | |
- auto expected_input_shape = (kernel_size[0] - 1) * gi_conv_params.stride[1] | |
- - 2 * gi_conv_params.padding[1] | |
- + (gi_conv_params.dilation[1] * (grad_output_shape[0] - 1) + 1); | |
- if (expected_input_shape != input_shape[0]) { | |
- gi_conv_params.output_padding[1] = input_shape[0] - expected_input_shape; | |
- } | |
- } else { | |
- for(size_t i = 0; i < kernel_size.size(); ++i) { | |
- // Check if whole input has been used or not | |
- auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.stride[i] | |
- - 2 * gi_conv_params.padding[i] | |
- + (gi_conv_params.dilation[i] * (grad_output_shape[i] - 1) + 1); | |
- if (expected_input_shape != input_shape[i]) { | |
- gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape; | |
- } | |
- } | |
- } | |
- | |
- Variable gIt; | |
- if (groups == 1) { | |
- if (gOt.type().isCuda()) { | |
- gOt = apply_fn<Contiguous>()(gOt); | |
- } | |
- | |
- gIt = apply_fn<ConvForward>(gi_conv_params)(ggWt, gOt, Variable()); | |
- } else { | |
- variable_list gIt_list(groups); | |
- for (int g = 0; g < groups; ++g) { | |
- auto ggWt_g = subvariable(ggWt, 1, groups, g); | |
- auto gOt_g = subvariable(gOt, 0, groups, g); | |
- if (gOt_g.type().isCuda()) { | |
- gOt_g = apply_fn<Contiguous>()(gOt_g); | |
- } | |
- | |
- gIt_list[g] = apply_fn<ConvForward>(gi_conv_params)(ggWt_g, gOt_g, Variable()); | |
- } | |
- | |
- gIt = apply_fn<Cat>(0)(gIt_list); | |
- } | |
- | |
- gI = apply_fn<Transpose>(0, 1)(gIt); | |
- } | |
- } | |
- | |
- if (should_compute_output(0) && !ggO.defined()) ggO = at::zeros_like(gO); | |
- if (should_compute_output(1) && !gI.defined()) gI = at::zeros_like(input); | |
- if (should_compute_output(2) && !gW.defined()) gW = at::zeros_like(weight); | |
- bool is_volatile = std::any_of(grad_grad_inputs.begin(), grad_grad_inputs.end(), [](const Variable& v){ | |
- return v.defined() && v.is_volatile(); | |
- }); | |
- auto results = variable_list({ggO, gI, gW}); | |
- for (auto& result : results) { | |
- result.is_volatile() |= is_volatile; | |
- } | |
- return results; | |
-} | |
- | |
-auto ConvBackwardBackward::releaseVariables() -> void { | |
- input_.data.reset(); | |
- weight_.data.reset(); | |
- bias_.data.reset(); | |
- grad_output_.data.reset(); | |
-} | |
- | |
-// Forward and backward functions for Tensor | |
- | |
-static at::Tensor compute_output( | |
- at::Tensor& input, at::Tensor& weight, at::Tensor& bias, | |
- at::Tensor& columns, at::Tensor& ones, | |
- const ConvForward& params) { | |
- | |
- auto dim = input.ndimension(); | |
- auto dilated = params.is_dilated(); | |
- auto kernel_size = weight.sizes().slice(2); | |
- auto stride = vecToInt64(params.stride); | |
- auto padding = vecToInt64(params.padding); | |
- auto dilation = vecToInt64(params.dilation); | |
- auto output_padding = vecToInt64(params.output_padding); | |
- | |
- if (params.transposed) { | |
- if (dim == 4) { | |
- return at::conv_transpose2d_forward( | |
- input, weight, kernel_size, bias, | |
- stride, padding, output_padding, dilation, | |
- columns, ones); | |
- } else if (dim == 5) { | |
- return at::conv_transpose3d_forward( | |
- input, weight, bias, | |
- stride, padding, output_padding, dilation, | |
- columns, ones); | |
- } | |
- } else { /* Not transposed */ | |
- if (dim == 4) { | |
- if (dilated) { | |
- return at::conv_dilated2d_forward( | |
- input, weight, kernel_size, bias, | |
- stride, padding, dilation, | |
- columns, ones); | |
- } else { /* dim == 4, non-dilated */ | |
- if (params.use_nnpack(input)) { | |
-#ifdef WITH_NNPACK | |
- // THNN functions handle resizing the output Tensor themselves, | |
- // but NNPACK expects the Tensors to be in the appropriate shape | |
- // already, so we resize here | |
- auto output = input.type().tensor(params.output_size(input, weight)); | |
- nnpack::SpatialConvolution_updateOutput( | |
- input, output, weight, bias, | |
- kernel_size[1], kernel_size[0], | |
- params.padding[1], params.padding[0]); | |
- return output; | |
-#endif | |
- } else { | |
- /* CPU implementation has specialized MM kernels | |
- for non-dilated case here */ | |
- return at::conv2d_forward( | |
- input, weight, kernel_size, bias, | |
- stride, padding, | |
- columns, ones); | |
- } | |
- } | |
- } else if (dim == 5 && (input.type().isCuda() || dilated)) { | |
- return at::conv_dilated3d_forward( | |
- input, weight, kernel_size, bias, | |
- stride, padding, dilation, | |
- columns, ones); | |
- } else if (dim == 5) { /* dim == 5, CPU, non-dilated */ | |
- /* CPU implementation has specialized MM kernels | |
- for non-dilated case here */ | |
- return at::conv3d_forward( | |
- input, weight, kernel_size, bias, | |
- stride, padding, | |
- columns); | |
- } | |
- } | |
- | |
- throw std::runtime_error("unsupported ConvNd parameters"); | |
-} | |
- | |
-static std::tuple<Tensor, Tensor, Tensor> compute_backward( | |
- at::Tensor& input, at::Tensor& grad_output, at::Tensor& weight, | |
- at::Tensor& columns, at::Tensor& ones, | |
- const ConvBackward& params, | |
- std::array<bool, 3> output_mask) { | |
- | |
- auto kernel_size = weight.sizes().slice(2); | |
- auto stride = vecToInt64(params.stride); | |
- auto padding = vecToInt64(params.padding); | |
- auto dilation = vecToInt64(params.dilation); | |
- auto output_padding = vecToInt64(params.output_padding); | |
- | |
- auto dim = input.ndimension(); | |
- auto dilated = params.is_dilated(); | |
- | |
- if (params.transposed) { | |
- if (dim == 4) { | |
- return at::conv_transpose2d_backward( | |
- grad_output, input, weight, kernel_size, | |
- stride, padding, output_padding, dilation, | |
- columns, ones, output_mask); | |
- } else if (dim == 5) { | |
- return at::conv_transpose3d_backward( | |
- grad_output, input, weight, | |
- stride, padding, output_padding, dilation, | |
- columns, ones, output_mask); | |
- } | |
- } else { /* Not transposed */ | |
- if (dim == 4) { | |
- if (dilated) { | |
- return at::conv_dilated2d_backward( | |
- grad_output, input, weight, kernel_size, | |
- stride, padding, dilation, | |
- columns, ones, output_mask); | |
- } else { | |
- if (params.use_nnpack(input)) { | |
-#ifdef WITH_NNPACK | |
- Tensor grad_input; | |
- Tensor grad_weight; | |
- Tensor grad_bias; | |
- | |
- if (output_mask[0]) { | |
- grad_input = input.type().tensor(input.sizes()); | |
- nnpack::SpatialConvolution_updateGradInput( | |
- input, grad_output, grad_input, weight, | |
- kernel_size[1], kernel_size[0], | |
- params.padding[1], params.padding[0]); | |
- } | |
- | |
- // NNPACK does not have a bias gradient calculation, so we split | |
- // into two calls here if necessary | |
- if (output_mask[1]) { | |
- grad_weight = weight.type().tensor(weight.sizes()); | |
- grad_weight.zero_(); | |
- nnpack::SpatialConvolution_accGradWeight( | |
- input, grad_output, grad_weight, | |
- kernel_size[1], kernel_size[0], | |
- params.padding[1], params.padding[0]); | |
- } | |
- | |
- if (output_mask[2]) { | |
- grad_bias = compute_grad_bias(grad_output); | |
- } | |
- | |
- return std::make_tuple(grad_input, grad_weight, grad_bias); | |
-#endif | |
- } else { | |
- /* CPU implementation has specialized MM kernels | |
- for non-dilated case here */ | |
- return at::conv2d_backward( | |
- grad_output, input, weight, kernel_size, | |
- stride, padding, | |
- columns, ones, output_mask); | |
- } | |
- } | |
- } else if (dim == 5 && (input.type().isCuda() || dilated)) { | |
- return at::conv_dilated3d_backward( | |
- grad_output, input, weight, kernel_size, | |
- stride, padding, dilation, | |
- columns, ones, output_mask); | |
- } else if (dim == 5) { /* dim == 5, CPU, non-dilated */ | |
- /* CPU implementation has specialized MM kernels | |
- for non-dilated case here */ | |
- return at::conv3d_backward( | |
- grad_output, input, weight, kernel_size, | |
- stride, padding, | |
- columns, ones, output_mask); | |
- } | |
- } | |
- | |
- throw std::runtime_error("unsupported ConvNdBackward parameters"); | |
-} | |
- | |
-}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/convolution.h b/torch/csrc/autograd/functions/convolution.h | |
deleted file mode 100644 | |
index f8ba1bcb..00000000 | |
--- a/torch/csrc/autograd/functions/convolution.h | |
+++ /dev/null | |
@@ -1,123 +0,0 @@ | |
-#pragma once | |
- | |
-#include <Python.h> | |
-#include <ATen/ATen.h> | |
-#include <memory> | |
-#include <vector> | |
-#include <iostream> | |
- | |
-#include "torch/csrc/autograd/function.h" | |
-#include "torch/csrc/autograd/variable.h" | |
-#include "torch/csrc/autograd/symbolic.h" | |
-#include "torch/csrc/autograd/saved_variable.h" | |
- | |
-#ifdef WITH_CUDNN | |
-#include "torch/csrc/cudnn/Conv.h" | |
-#else | |
-namespace torch { namespace cudnn { | |
-struct Convolution {}; | |
-}} | |
-#endif | |
- | |
-namespace torch { namespace autograd { | |
- | |
-struct ConvParams { | |
- std::vector<int> stride; | |
- std::vector<int> padding; | |
- std::vector<int> dilation; | |
- bool transposed; | |
- std::vector<int> output_padding; | |
- int groups; | |
- bool benchmark; | |
- bool deterministic; | |
- bool cudnn_enabled; | |
- | |
- bool is_strided() const; | |
- bool is_dilated() const; | |
- bool is_padded() const; | |
- bool is_output_padding_neg() const; | |
- bool is_output_padding_big() const; | |
- bool is_padding_neg() const; | |
- void view1d_as_2d(); | |
- bool use_cudnn(const at::Tensor& input) const; | |
- bool use_nnpack(const at::Tensor& input) const; | |
- bool is_depthwise(const at::Tensor& input, const at::Tensor& weight, int groups) const; | |
-}; | |
- | |
-struct ConvForward : public ForwardFunction<>, public ConvParams, public HasSymbolic { | |
- explicit ConvForward(ConvParams params) : ConvParams(std::move(params)) {} | |
- | |
- virtual std::string name() override; | |
- virtual variable_list apply(const variable_list& inputs) override; | |
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override; | |
- | |
- std::vector<int64_t> output_size(at::Tensor& input, at::Tensor& weight) const; | |
-}; | |
- | |
-struct ConvBackward : public Function, public ConvParams { | |
- ConvBackward( | |
- FunctionFlags flags, | |
- ConvParams params, | |
- Variable input, | |
- Variable weight, | |
- Variable bias, | |
- tensor_list columns, | |
- tensor_list ones, | |
- std::unique_ptr<torch::cudnn::Convolution> convolution) | |
- : Function(std::move(flags)) | |
- , ConvParams(std::move(params)) | |
- , convolution(std::move(convolution)) { | |
- if (is_executable) { | |
- this->input_ = SavedVariable(input, this); | |
- this->weight_ = SavedVariable(weight, this); | |
- if (bias.defined()) { | |
- this->bias_ = SavedVariable(bias, this); | |
- } | |
- this->columns = std::move(columns); | |
- this->ones = std::move(ones); | |
- } | |
- } | |
- | |
- virtual variable_list apply(const variable_list& gradOutputs) override; | |
- | |
- virtual void releaseVariables() override; | |
- | |
- SavedVariable input_; | |
- SavedVariable weight_; | |
- SavedVariable bias_; | |
- tensor_list columns; | |
- tensor_list ones; | |
- std::unique_ptr<torch::cudnn::Convolution> convolution; | |
-}; | |
- | |
-struct ConvBackwardBackward : public Function, public ConvParams { | |
- ConvBackwardBackward( | |
- FunctionFlags flags, | |
- ConvParams params, | |
- Variable input, | |
- Variable weight, | |
- Variable bias, | |
- Variable grad_output) | |
- : Function(std::move(flags)) | |
- , ConvParams(std::move(params)) { | |
- if (is_executable) { | |
- this->input_ = SavedVariable(input, this); | |
- this->weight_ = SavedVariable(weight, this); | |
- if (bias.defined()) { | |
- this->bias_ = SavedVariable(bias, this); | |
- } | |
- this->grad_output_ = SavedVariable(grad_output, this); | |
- } | |
- } | |
- | |
- virtual variable_list apply(const variable_list& grad_grad_inputs) override; | |
- | |
- virtual void releaseVariables() override; | |
- | |
- SavedVariable input_; | |
- SavedVariable weight_; | |
- SavedVariable bias_; | |
- SavedVariable grad_output_; | |
-}; | |
- | |
-}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp | |
index 3f5da160..fe5c52c0 100644 | |
--- a/torch/csrc/autograd/functions/init.cpp | |
+++ b/torch/csrc/autograd/functions/init.cpp | |
@@ -1,10 +1,9 @@ | |
-#include "batch_normalization.h" | |
-#include "convolution.h" | |
+#include "Python.h" | |
#include "accumulate_grad.h" | |
#include "basic_ops.h" | |
#include "tensor.h" | |
#include "special.h" | |
-#include "jit_closure.h" | |
+#include "torch/csrc/jit/interpreter_autograd_function.h" | |
#include "torch/csrc/autograd/functions/pybind.h" | |
#include "torch/csrc/autograd/python_cpp_function.h" | |
#include "torch/csrc/autograd/generated/python_functions.h" | |
@@ -15,41 +14,6 @@ | |
using namespace torch::autograd; | |
using torch::TupleParser; | |
-struct BatchNormCtor { | |
- BatchNormForward* operator()(PyObject* args) { | |
- BatchNormParams params; | |
- | |
- TupleParser parser(args, 6); | |
- parser.parse(params.running_mean, "running_mean"); | |
- parser.parse(params.running_var, "running_var"); | |
- parser.parse(params.training, "training"); | |
- parser.parse(params.momentum, "momentum"); | |
- parser.parse(params.eps, "eps"); | |
- parser.parse(params.cudnn_enabled, "cudnn_enabled"); | |
- | |
- return new BatchNormForward(std::move(params)); | |
- } | |
-}; | |
- | |
-struct ConvCtor { | |
- ConvForward* operator()(PyObject* args) { | |
- ConvParams params; | |
- | |
- TupleParser parser(args, 9); | |
- parser.parse(params.stride, "stride"); | |
- parser.parse(params.padding, "padding"); | |
- parser.parse(params.dilation, "dilation"); | |
- parser.parse(params.transposed, "transposed"); | |
- parser.parse(params.output_padding, "output_padding"); | |
- parser.parse(params.groups, "groups"); | |
- parser.parse(params.benchmark, "benchmark"); | |
- parser.parse(params.deterministic, "deterministic"); | |
- parser.parse(params.cudnn_enabled, "cudnn_enabled"); | |
- | |
- return new ConvForward(std::move(params)); | |
- } | |
-}; | |
- | |
struct DelayedErrorCtor { | |
DelayedError* operator()(PyObject* args) { | |
std::string msg; | |
@@ -69,7 +33,7 @@ struct NoCtor { | |
template<typename C, typename T> | |
static void addClass(PyObject* module, PyTypeObject& type, const char* name, | |
- PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL) | |
+ PyGetSetDef* function_properties=nullptr, PyMethodDef* function_methods=nullptr) | |
{ | |
createForwardFunctionPyTypeObject<T>(type, name, function_properties, function_methods); | |
Py_INCREF(&type); | |
@@ -86,7 +50,7 @@ PyObject* getTupleAttr(PyObject* obj, void* _unused) | |
auto& arr = ((T*)(self->cdata.get()))->*ptr; | |
auto num_elems = arr.size(); | |
THPObjectPtr py_tuple(PyTuple_New(num_elems)); | |
- if (!py_tuple) return NULL; | |
+ if (!py_tuple) return nullptr; | |
for (size_t i = 0; i < num_elems; ++i) { | |
PyTuple_SET_ITEM(py_tuple.get(), i, Convert(arr[i])); | |
} | |
@@ -105,125 +69,6 @@ PyObject* getValueAttr(PyObject* obj, void* _unused) | |
END_HANDLE_TH_ERRORS | |
} | |
-template<typename T, typename ParamsT, at::Tensor ParamsT::*ptr> | |
-PyObject* getTensorAttr(PyObject* obj, void* _unused) | |
-{ | |
- HANDLE_TH_ERRORS | |
- THPCppFunction* self = (THPCppFunction*)obj; | |
- auto& val = ((T*)(self->cdata.get()))->*ptr; | |
- THPObjectPtr py_tensor; | |
- if (!val.defined()) { | |
- Py_INCREF(Py_None); | |
- py_tensor = Py_None; | |
- } else { | |
- py_tensor = torch::createPyObject(val); | |
- } | |
- return py_tensor.release(); | |
- END_HANDLE_TH_ERRORS | |
-} | |
- | |
-static struct PyGetSetDef batch_norm_forward_properties[] = { | |
- THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {(char*)"running_mean", (getter)getTensorAttr<BatchNormForward, BatchNormParams, | |
- &BatchNormParams::running_mean>, NULL, NULL, NULL}, | |
- {(char*)"running_var", (getter)getTensorAttr<BatchNormForward, BatchNormParams, | |
- &BatchNormParams::running_var>, NULL, NULL, NULL}, | |
- {(char*)"training", (getter)getValueAttr<BatchNormForward, bool, BatchNormParams, | |
- &BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"momentum", (getter)getValueAttr<BatchNormForward, double, BatchNormParams, | |
- &BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL}, | |
- {(char*)"eps", (getter)getValueAttr<BatchNormForward, double, BatchNormParams, | |
- &BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL}, | |
- {(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormForward, bool, BatchNormParams, | |
- &BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {NULL} | |
-}; | |
- | |
-static struct PyGetSetDef batch_norm_backward_properties[] = { | |
- THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {(char*)"running_mean", (getter)getTensorAttr<BatchNormBackward, BatchNormParams, | |
- &BatchNormParams::running_mean>, NULL, NULL, NULL}, | |
- {(char*)"running_var", (getter)getTensorAttr<BatchNormBackward, BatchNormParams, | |
- &BatchNormParams::running_var>, NULL, NULL, NULL}, | |
- {(char*)"training", (getter)getValueAttr<BatchNormBackward, bool, BatchNormParams, | |
- &BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"momentum", (getter)getValueAttr<BatchNormBackward, double, BatchNormParams, | |
- &BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL}, | |
- {(char*)"eps", (getter)getValueAttr<BatchNormBackward, double, BatchNormParams, | |
- &BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL}, | |
- {(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormBackward, bool, BatchNormParams, | |
- &BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {NULL} | |
-}; | |
- | |
-static struct PyGetSetDef batch_norm_backward_backward_properties[] = { | |
- THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {(char*)"running_mean", (getter)getTensorAttr<BatchNormBackwardBackward, BatchNormParams, | |
- &BatchNormParams::running_mean>, NULL, NULL, NULL}, | |
- {(char*)"running_var", (getter)getTensorAttr<BatchNormBackwardBackward, BatchNormParams, | |
- &BatchNormParams::running_var>, NULL, NULL, NULL}, | |
- {(char*)"training", (getter)getValueAttr<BatchNormBackwardBackward, bool, BatchNormParams, | |
- &BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"momentum", (getter)getValueAttr<BatchNormBackwardBackward, double, BatchNormParams, | |
- &BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL}, | |
- {(char*)"eps", (getter)getValueAttr<BatchNormBackwardBackward, double, BatchNormParams, | |
- &BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL}, | |
- {(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormBackwardBackward, bool, BatchNormParams, | |
- &BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {NULL} | |
-}; | |
- | |
-static struct PyGetSetDef conv_forward_properties[] = { | |
- THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {(char*)"stride", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams, | |
- &ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"padding", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams, | |
- &ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"dilation", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams, | |
- &ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"transposed", (getter)getValueAttr<ConvForward, bool, ConvParams, | |
- &ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"output_padding", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams, | |
- &ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"groups", (getter)getValueAttr<ConvForward, int, ConvParams, | |
- &ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {NULL} | |
-}; | |
- | |
-static struct PyGetSetDef conv_backward_properties[] = { | |
- THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {(char*)"stride", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams, | |
- &ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"padding", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams, | |
- &ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"dilation", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams, | |
- &ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"transposed", (getter)getValueAttr<ConvBackward, bool, ConvParams, | |
- &ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"output_padding", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams, | |
- &ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"groups", (getter)getValueAttr<ConvBackward, int, ConvParams, | |
- &ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {NULL} | |
-}; | |
- | |
-static struct PyGetSetDef conv_backward_backward_properties[] = { | |
- THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {(char*)"stride", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams, | |
- &ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"padding", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams, | |
- &ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"dilation", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams, | |
- &ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"transposed", (getter)getValueAttr<ConvBackwardBackward, bool, ConvParams, | |
- &ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"output_padding", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams, | |
- &ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {(char*)"groups", (getter)getValueAttr<ConvBackwardBackward, int, ConvParams, | |
- &ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL}, | |
- {NULL} | |
-}; | |
- | |
static PyObject* accumulateGradVar(PyObject *_self, void* _unused) | |
{ | |
THPCppFunction* self = (THPCppFunction*)_self; | |
@@ -233,79 +78,57 @@ static PyObject* accumulateGradVar(PyObject *_self, void* _unused) | |
static struct PyGetSetDef accumulate_grad_properties[] = { | |
THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {(char*)"variable", accumulateGradVar, NULL, NULL, NULL}, | |
- {NULL} | |
+ {(char*)"variable", accumulateGradVar, nullptr, nullptr, nullptr}, | |
+ {nullptr} | |
}; | |
-bool THPAutograd_initFunctions(PyObject* _unused) | |
+void THPAutograd_initFunctions() | |
{ | |
THPObjectPtr module(PyModule_New("torch._C._functions")); | |
- if (!module) return false; | |
- | |
- static PyTypeObject BatchNormClass, BatchNormBackwardClass, BatchNormBackwardBackwardClass; | |
- addClass<BatchNormForward, BatchNormCtor>(module, BatchNormClass, "BatchNorm", batch_norm_forward_properties); | |
- addClass<BatchNormBackward, NoCtor>(module, BatchNormBackwardClass, "BatchNormBackward", batch_norm_backward_properties); | |
- addClass<BatchNormBackwardBackward, NoCtor>(module, BatchNormBackwardBackwardClass, "BatchNormBackwardBackward", batch_norm_backward_backward_properties); | |
- | |
- static PyTypeObject ConvClass, ConvBackwardClass, ConvBackwardBackwardClass; | |
- addClass<ConvForward, ConvCtor>(module, ConvClass, "ConvNd", conv_forward_properties); | |
- addClass<ConvBackward, NoCtor>(module, ConvBackwardClass, "ConvNdBackward", conv_backward_properties); | |
- addClass<ConvBackwardBackward, NoCtor>(module, ConvBackwardBackwardClass, "ConvNdBackwardBackward", conv_backward_backward_properties); | |
+ if (!module) throw python_error(); | |
static PyTypeObject AccumulateGradClass; | |
addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties); | |
- static PyTypeObject AddClass, AddBackwardClass; | |
- addClass<Add, NoCtor>(module, AddClass, "Add"); | |
- addClass<AddBackward_Deprecated, NoCtor>(module, AddBackwardClass, "AddBackward_Deprecated"); | |
- | |
static PyTypeObject ErrorClass; | |
addClass<Error, NoCtor>(module, ErrorClass, "Error"); | |
static PyTypeObject DelayedErrorClass; | |
addClass<DelayedError, DelayedErrorCtor>(module, DelayedErrorClass, "DelayedError"); | |
- static PyTypeObject CloneClass; | |
- addClass<Clone, NoCtor>(module, CloneClass, "Clone"); | |
- static PyTypeObject ContiguousClass; | |
- addClass<Contiguous, NoCtor>(module, ContiguousClass, "Contiguous"); | |
- static PyTypeObject IdentityClass; | |
- addClass<Identity, NoCtor>(module, IdentityClass, "Identity"); | |
- static PyTypeObject TransposeClass; | |
- addClass<Transpose, NoCtor>(module, TransposeClass, "Transpose"); | |
- static PyTypeObject ViewClass; | |
- addClass<View, NoCtor>(module, ViewClass, "View"); | |
- static PyTypeObject ExpandClass; | |
- addClass<Expand, NoCtor>(module, ExpandClass, "Expand"); | |
- static PyTypeObject NarrowClass; | |
- addClass<Narrow, NoCtor>(module, NarrowClass, "Narrow"); | |
- static PyTypeObject CatClass; | |
- addClass<Cat, NoCtor>(module, CatClass, "Cat"); | |
- | |
static PyTypeObject EvalClass; | |
addClass<Eval, NoCtor>(module, EvalClass, "Eval"); | |
- static PyTypeObject AutogradClosureClass; | |
- addClass<AutogradClosure, NoCtor>(module, AutogradClosureClass, "AutogradClosure"); | |
+ static PyTypeObject InterpreterAutogradClass; | |
+ addClass<torch::jit::InterpreterAutogradFunction, NoCtor>(module, InterpreterAutogradClass, "InterpreterAutogradFunction"); | |
+ | |
+ static PyTypeObject CopyBackwardsClass; | |
+ addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards"); | |
+ | |
+ static PyTypeObject CopySlicesClass; | |
+ addClass<CopySlices, NoCtor>(module, CopySlicesClass, "CopySlices"); | |
generated::initialize_autogenerated_functions(); | |
- THPObjectPtr parent(PyImport_ImportModule("torch._C")); | |
- if (!parent) return false; | |
- PyModule_AddObject(parent.get(), "_functions", module.release()); | |
- return true; | |
+ auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C")); | |
+ if (!c_module) throw python_error(); | |
+ | |
+ Py_INCREF(module); | |
+ if (PyModule_AddObject(c_module, "_functions", module) < 0) { | |
+ throw python_error(); | |
+ } | |
} | |
namespace torch { namespace autograd { | |
void initAutogradClosureBindings(PyObject* module) { | |
auto m = py::handle(module).cast<py::module>(); | |
- py::class_<AutogradClosureFactory,std::shared_ptr<AutogradClosureFactory>>(m, "AutogradClosureFactory") | |
- .def("__call__", &AutogradClosureFactory::construct) | |
+ py::class_<jit::InterpreterFunctionFactory,std::shared_ptr<jit::InterpreterFunctionFactory>>(m, "InterpreterFunctionFactory") | |
+ .def("__call__", &jit::InterpreterFunctionFactory::construct_function) | |
; | |
- m.def("_jit_createAutogradClosure", [](jit::tracer::TracingState* tracing_state) { | |
- return std::make_shared<AutogradClosureFactory>(tracing_state); | |
+ m.def("_jit_createInterpreterFactory", [](jit::tracer::TracingState* tracing_state) { | |
+ return std::make_shared<jit::InterpreterFunctionFactory>(tracing_state); | |
}); | |
} | |
diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp | |
deleted file mode 100644 | |
index 33cff95a..00000000 | |
--- a/torch/csrc/autograd/functions/jit_closure.cpp | |
+++ /dev/null | |
@@ -1,835 +0,0 @@ | |
-#include "torch/csrc/autograd/functions/jit_closure.h" | |
- | |
-#include "torch/csrc/Exceptions.h" | |
-#include "torch/csrc/utils/auto_gil.h" | |
-#include "torch/csrc/utils/functional.h" | |
-#include "torch/csrc/autograd/engine.h" | |
-#include "torch/csrc/autograd/functions/special.h" | |
-#include "torch/csrc/autograd/functions/basic_ops.h" | |
-#include "torch/csrc/autograd/functions/tensor.h" | |
-#include "torch/csrc/autograd/functions/utils.h" | |
-#include "torch/csrc/autograd/python_engine.h" | |
-#include "torch/csrc/autograd/python_variable.h" | |
-#include "torch/csrc/autograd/python_function.h" | |
-#include "torch/csrc/jit/generated/aten_dispatch.h" | |
-#ifdef WITH_CUDA | |
-#include "torch/csrc/jit/fusion_compiler.h" | |
-#endif | |
-namespace torch { namespace autograd { | |
- | |
-using namespace torch::jit; | |
-using namespace torch::jit::tracer; | |
- | |
-// Used when an output has multiple uses (there's only one entry | |
-// in next_functions per output). | |
-struct Replicate : public Function { | |
- Replicate() { | |
- is_executable = true; | |
- num_inputs = 1; | |
- } | |
- | |
- virtual variable_list apply(const variable_list& inputs) { | |
- return variable_list(next_functions.size(), inputs[0]); | |
- } | |
-}; | |
- | |
-// This class is never put in the autograd graph: see InputPlaceholder | |
-// and EvalPlaceholder. | |
-struct Placeholder : public Function { | |
- virtual variable_list apply(const variable_list& inputs) { | |
- return inputs; | |
- } | |
-}; | |
- | |
-// Used for inputs of previous previous stages | |
-struct PrevStageInput : public Replicate {}; | |
-// Used for inputs to the closure (execution roots) | |
-struct InputPlaceholder : public Placeholder {}; | |
-// Used to mark places that will have to apply Evals from previous stages. | |
-// | |
-// Why do we need this? Let us recall the raison d'etre of Eval nodes: they | |
-// exist so that when we compute an backwards autograd closure on the fly | |
-// while executing forwards, we can use exactly that closure when backwards | |
-// executes. Morally, this closure is simply an input to the backwards | |
-// computation, and in the compiler IR representation, it's represented | |
-// precisely this way (with opaque Handle nodes.) | |
-// | |
-// However, the *autograd* execution model only accounts for Variable | |
-// input/outputs, which a Handle is not! "So why not add non-Variable inputs | |
-// to autograd"? Perhaps this could be made to work, but it is a bit awkward: | |
-// it would involve totally adding a new type of input to the execution model. | |
-// Autograd is not intended to be a general purpose programming language | |
-// runtime, so on balance, we decided to consider solutions which don't involve | |
-// adding new types of inputs to autograd, instead passing the closures "out | |
-// of band". | |
-// | |
-// By what mechanism, then, can we actually pass the closure? Here is the idea. | |
-// Instead of actually inserting an "Eval" node, we instead insert an | |
-// EvalPlaceholder, which doesn't know anything about evaluating a closure. | |
-// Then, at the time when we want to partially apply the actual closure | |
-// (computed from the forwards pass), we stick a pre-callback on the EvalPlaceholder | |
-// that takes the inputs, does the actual Eval, and then passes on the outputs | |
-// (which the EvalPlaceholder subsequently passes through.) | |
-// | |
-// Remember that callbacks are NOT stored on a Function object itself: they are | |
-// registered on a per AutogradClosure (for which there may be multiple per | |
-// graph). So we can't do something like mutate a | |
-// Eval Function to give it the autograd closure to run inside its main body: | |
-// that violates the invariant that autograd graphs are immutable! (In other | |
-// words, the same EvalPlaceholder may be participating in multiple engine | |
-// executions) You truly must somehow associate these closures with the graph as | |
-// a whole, and there must be a mechanism to ferry this data to the Function | |
-// itself. A callback is just the ticket. | |
-struct EvalPlaceholder : public Placeholder {}; | |
- | |
-// Used for the graph output. Execution should be stopped by a callback before apply(). | |
-struct Output : public Function { | |
- Output(int ninputs) { | |
- is_executable = true; | |
- num_inputs = ninputs; | |
- } | |
- | |
- virtual variable_list apply(const variable_list& inputs) { | |
- throw std::runtime_error("Output::apply called"); | |
- } | |
-}; | |
- | |
-struct SimpleEval : public Function { | |
- SimpleEval(const std::shared_ptr<Function>& fn) | |
- : fn(fn) {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override { | |
- return fn->apply(inputs); | |
- } | |
- | |
- std::shared_ptr<Function> fn; | |
-}; | |
- | |
-struct EmitNull : public Function { | |
- EmitNull() { | |
- is_executable = true; | |
- num_inputs = 0; | |
- } | |
- | |
- virtual variable_list apply(const variable_list& inputs) { | |
- return {Variable()}; | |
- }; | |
-}; | |
- | |
-struct LambdaFunction : public Function { | |
- LambdaFunction(const jit::TensorOp& op) | |
- : LambdaFunction(op.num_inputs, op.op) { | |
- this->name_ = op.name; | |
- } | |
- | |
- LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn) | |
- : fn_(fn) { | |
- this->is_executable = true; | |
- this->num_inputs = num_inputs; | |
- } | |
- | |
- virtual std::string name() override { | |
- return name_.size() == 0 ? "LambdaFunction" : name_; | |
- } | |
- | |
- virtual variable_list apply(const variable_list& inputs) override { | |
- return fn_(inputs); | |
- } | |
- | |
- std::string name_; | |
- std::function<variable_list(const variable_list&)> fn_; | |
-}; | |
- | |
-// Wraps a PythonOp and dispatches calls to Functions implemented in Python | |
-struct PythonCall : public Function { | |
- PythonCall(PythonOp *op) | |
- : cconv(op->cconv) | |
- , scalar_args() { | |
- | |
- Py_INCREF(op->pyobj.get()); | |
- pyobj = op->pyobj.get(); | |
- | |
- scalar_args.reserve(op->scalar_args.size()); | |
- for (auto& arg_ptr : op->scalar_args) { | |
- Py_INCREF(arg_ptr.get()); | |
- scalar_args.emplace_back(arg_ptr.get()); | |
- } | |
- } | |
- | |
- virtual variable_list apply(const variable_list& inputs) { | |
- AutoGIL gil; | |
- | |
- THPObjectPtr apply_fn {PyObject_GetAttrString(pyobj, "apply")}; | |
- if (!apply_fn) throw python_error(); | |
- | |
- THPObjectPtr py_inputs { packInputs(inputs) }; | |
- THPObjectPtr result { PyObject_Call(apply_fn.get(), py_inputs.get(), NULL) }; | |
- if (!result) throw python_error(); | |
- return unpackOutputs(result); | |
- } | |
- | |
- THPObjectPtr packInputs(const variable_list& inputs) { | |
- THPObjectPtr py_inputs { PyTuple_New(cconv.size()) }; | |
- if (!py_inputs) throw python_error(); | |
- | |
- auto var_it = inputs.begin(); | |
- auto scalar_it = scalar_args.begin(); | |
- int input_nr = 0; | |
- | |
- for (auto arg_type : cconv) { | |
- PyObject *obj = nullptr; | |
- if (arg_type == 's') { | |
- if (scalar_it == scalar_args.end()) | |
- throw std::runtime_error("expected too many scalar args"); | |
- obj = (scalar_it++)->get(); | |
- Py_INCREF(obj); | |
- } else if (arg_type == 't') { | |
- if (var_it == inputs.end()) | |
- throw std::runtime_error("expected too many inputs"); | |
- obj = THPVariable_Wrap(*(var_it++)); | |
- } else { | |
- throw std::runtime_error("unexpected calling convention"); | |
- } | |
- PyTuple_SET_ITEM(py_inputs.get(), input_nr++, obj); | |
- } | |
- | |
- return py_inputs; | |
- } | |
- | |
- variable_list unpackOutputs(THPObjectPtr& result) { | |
- variable_list var_result; | |
- | |
- ensure_tuple(result); | |
- auto num_outputs = PyTuple_GET_SIZE(result.get()); | |
- for (int i = 0; i < num_outputs; ++i) { | |
- PyObject *output = PyTuple_GET_ITEM(result.get(), i); | |
- if (!THPVariable_Check(output)) | |
- throw std::runtime_error("Function.apply returned a non-Variable output"); | |
- THPVariable *var = (THPVariable*)output; | |
- var_result.emplace_back(var->cdata); | |
- } | |
- | |
- return var_result; | |
- } | |
- | |
- THPObjectPtr pyobj; | |
- std::string cconv; | |
- std::vector<THPObjectPtr> scalar_args; | |
-}; | |
- | |
-// Note [Handling nullary functions in the autograd engine] | |
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
-// Today, the autograd engine cannot handle nullary functions, because | |
-// it assumes that every non-input function has at least one input. | |
-// This fits nicely with the scheduling strategy, which schedules a | |
-// function for execution when all of its inputs are ready. Unfortunately, | |
-// constants are nullary. | |
-// | |
-// Instead, we use a little hack. Rather than creating an extra root | |
-// for every constant, we add a single new root, ConstantFactory, which | |
-// when run triggers all of the actual constant functions, WrapConstant, | |
-// which actually contribute a constant. Furthermore, we use a single | |
-// null input to ensure that the next_function index has a valid offset. | |
-// | |
-// One possible alternative to represent this might be to special case constants | |
-// in the execution engine, as a separate vector of roots. But the current | |
-// strategy seems to work fine and isn't too difficult to construct a trace | |
-// for. | |
- | |
-struct WrapConstant : public Function { | |
- WrapConstant(at::Tensor value) | |
- : value(std::move(value)) { | |
- is_executable = true; | |
- num_inputs = 1; | |
- } | |
- | |
- virtual variable_list apply(const variable_list& inputs) { | |
- if (inputs.size() != 1 || inputs[0].defined()) | |
- throw std::logic_error("WrapConstant nodes should only receive a single NULL input"); | |
- AutoGPU guard(value); | |
- return {make_variable(value.clone())}; | |
- } | |
- | |
- at::Tensor value; | |
-}; | |
- | |
-// See Note [Handling nullary functions in the autograd engine] | |
-struct ConstantFactory : public Function { | |
- ConstantFactory() { | |
- is_executable = true; | |
- num_inputs = 1; | |
- } | |
- | |
- virtual variable_list apply(const variable_list& inputs) { | |
- if (inputs.size() != 1 || inputs[0].defined()) | |
- throw std::logic_error("ConstantFactory nodes should only receive a single NULL input"); | |
- return variable_list(next_functions.size()); | |
- } | |
-}; | |
- | |
-#ifdef WITH_CUDA | |
-struct FusionGroupFunction : public Function { | |
- FusionGroupFunction(const std::shared_ptr<CompiledFusionFunction> & function) | |
- : function(function) {} | |
- virtual variable_list apply(const variable_list& inputs) { | |
- //TODO: handle the case where inputs do not match the device function was | |
- // compiled for | |
- std::vector<at::Tensor> data; | |
- for(auto & input : inputs) | |
- data.push_back(input.data()); | |
- AutoGPU guard(data.back()); | |
- std::vector<at::Tensor> outputs; | |
- outputs.reserve(function->outputDescriptors().size()); | |
- for(auto & od : function->outputDescriptors()) { | |
- outputs.push_back(at::CUDA(od.scalar_type).tensor()); | |
- } | |
- function->launch(data, outputs); | |
- return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) { | |
- return std::make_shared<torch::autograd::Error>("FusionGroupFunction is not differentiable", std::move(f)); | |
- }); | |
- } | |
-private: | |
- std::shared_ptr<CompiledFusionFunction> function; | |
-}; | |
-#endif | |
- | |
-//////////////////////////////////////////////////////////////////////////////// | |
-//////////////////////////////////////////////////////////////////////////////// | |
- | |
-// A helper struct that precomputes and caches information regarding cross-stage | |
-// dependencies and state passing. | |
-// | |
-// Example: | |
-// graph (%1, | |
-// %2, | |
-// ------ stage 1 ------ | |
-// %9, | |
-// ------ stage 2 ------ | |
-// %31, | |
-// %32) { | |
-// %3.0, %3.1 = MulConstant(2)(%2) | |
-// %6.0, %6.1 = Mul()(%3.0, %1) | |
-// ---------------- stage 1 ---------------- | |
-// %10.0, %10.1, %10.2 = Eval(%9, %6.1) | |
-// %23.0, %23.1 = Eval(%10.0, %3.1) | |
-// ---------------- stage 2 ---------------- | |
-// %33.0, %33.1 = Eval(%32, %23.1) | |
-// %44.0, %44.1, %44.2, %44.3 = Eval(%33.0, %31, %10.2) | |
-// %78.0, %78.1 = Eval(%44.1, %3.1) | |
-// return (%6.0, %10.1, %23.0, %44.0, %44.2, %78.0); | |
-// } | |
-// | |
-// Then: | |
-// | |
-// graph->stage() = 2 | |
-// stage_begins = [%3, %10, %33, %0] (0 = return node) | |
-// stage_inputs = [ | |
-// [%1, %2], | |
-// [%9], | |
-// [%31, %32] | |
-// ] | |
-// stage_outputs = [ | |
-// [%6.0], | |
-// [%10.1, %23], | |
-// [%44.0, %44.2, %78.0] | |
-// ] | |
-// prev_stage_inputs = [ | |
-// [], # Always empty! | |
-// [%6.1, %3.1], | |
-// [%23.1, %10.2, %3.1] | |
-// ] | |
-// cur_stage_captures = [ | |
-// [%6.1, %3.1], | |
-// [%23.1, %10.2], | |
-// [] # Always empty! | |
-// ] | |
-struct CrossStageStateDesc { | |
- CrossStageStateDesc(Graph* graph) | |
- // FYI: graph->stage() == the last stage we have traced | |
- // (e.g., forwards+backwards = 1) | |
- : stage_inputs(graph->stage() + 1) | |
- , stage_outputs(graph->stage() + 1) | |
- , prev_stage_inputs(graph->stage() + 1) | |
- , cur_stage_captures(graph->stage() + 1) { | |
- | |
- std::size_t current_stage = -1; | |
- for (auto node : graph->nodes()) { | |
- // Look for stage boundaries | |
- if (node->stage() != current_stage) { | |
- JIT_ASSERT(node->stage() == current_stage + 1); | |
- current_stage = node->stage(); | |
- stage_begins.push_back(node); | |
- } | |
- // Look for things we need to save | |
- for (auto input : node->inputs()) { | |
- if (input->stage() != current_stage) { | |
- JIT_ASSERT(input->stage() < current_stage); | |
- // We need to save it in all intermediate stages too | |
- for (auto i = current_stage; i > input->stage(); --i) { | |
- prev_stage_inputs[i].insert(input); | |
- } | |
- cur_stage_captures[input->stage()].insert(input); | |
- } | |
- } | |
- } | |
- | |
- // It's convenient to pretend output is one more stage - we can always | |
- // take an iterator for stage i and i+1 as loop boundaries | |
- stage_begins.push_back(graph->return_node()); | |
- | |
- // Scatter inputs and outputs across stage buckets | |
- for (auto input : graph->inputs()) | |
- stage_inputs[input->stage()].push_back(input); | |
- for (auto output : graph->outputs()) | |
- stage_outputs[output->stage()].push_back(output); | |
- | |
- JIT_ASSERT(prev_stage_inputs.front().empty()); | |
- JIT_ASSERT(cur_stage_captures.back().empty()); | |
- } | |
- | |
- // For each stage, the first Node in Graph's topological sort which | |
- // is a member of this stage. In general, the stages of nodes in | |
- // a graph will look like this: | |
- // | |
- // 000000011111112222222E (E is the Return node) | |
- // ^ ^ ^ ^ | |
- // | |
- // We have pointers to the caret'ed nodes. | |
- std::vector<Node*> stage_begins; | |
- std::vector<std::vector<Node*>> stage_inputs; | |
- std::vector<std::vector<Node*>> stage_outputs; | |
- // A set of all Nodes from previous stage that pass anything (both Variables | |
- // and handles) to current stage. | |
- std::vector<std::unordered_set<Node*>> prev_stage_inputs; | |
- // A set of all Nodes from this stage, that need their values to be captured | |
- // for future stages (applies to both Variables and handles). | |
- std::vector<std::unordered_set<Node*>> cur_stage_captures; | |
-}; | |
- | |
-// Creates a graph for a given stage and stores information necessary to construct | |
-// an AutogradClosure with it | |
-struct StageClosure { | |
- using node_fn_map_type = std::unordered_map<Node*, std::shared_ptr<Function>>; | |
- | |
- StageClosure(TracingState *state, const CrossStageStateDesc& xstate, std::size_t stage) | |
- : var_flags(state->var_flags.at(stage)) | |
- , const_factory(std::make_shared<ConstantFactory>()) { | |
- auto graph = state->graph.get(); | |
- node_fn_map_type node_map; | |
- // This map caches PrevStageInputs for a given node, so that you don't | |
- // create multiple PrevStageInput for the same node. | |
- node_fn_map_type prev_stage_input_map; | |
- | |
- // Prepare output node and compute an offset within return node inputs where | |
- // nodes from this stage apear. | |
- output = std::make_shared<Output>(xstate.stage_outputs[stage].size()); | |
- node_map[graph->return_node()] = output; | |
- std::size_t output_offset = 0; | |
- for (std::size_t i = 0; i < stage; ++i) | |
- output_offset += xstate.stage_outputs[i].size(); | |
- | |
- // Builds up a closure for node. It assumes that it has been called | |
- // for all nodes that use outputs of node, which is why we iterate | |
- // in reverse topological order. | |
- auto add_node = [&](Node *node) { | |
- JIT_ASSERT(node->stage() == stage); | |
- | |
- // Get function object | |
- auto fn = getFunction(node); | |
- if (!fn) return; // This node is a no-op | |
- | |
- // Initialize function fields | |
- fn->is_executable = true; | |
- if (fn->num_inputs == 0) { | |
- fn->num_inputs = node->inputs().size(); | |
- } | |
- fillNextFunctions(node, fn, node_map, output_offset, stage); | |
- | |
- registerPrevStageInputs(node, fn, prev_stage_input_map); | |
- node_map[node] = fn; | |
- }; | |
- | |
- for (auto it = std::next(xstate.stage_begins[stage+1]->reverseIterator()), | |
- end = std::next(xstate.stage_begins[stage]->reverseIterator()); it != end; ++it) { | |
- add_node(*it); | |
- } | |
- for (auto node : xstate.stage_inputs[stage]) { | |
- add_node(node); | |
- } | |
- | |
- // Prepare inputs. | |
- for (Node *input : xstate.stage_inputs[stage]) { | |
- roots.emplace_back(node_map.at(input), 0); | |
- } | |
- for (auto & entry : prev_stage_input_map) { | |
- roots.emplace_back(entry.second, 0); | |
- prev_stage_variables.emplace_back(entry.first->unique()); | |
- } | |
- // NOTE: prev_stage_handles have been already filled in by add_node | |
- JIT_ASSERT(prev_stage_variables.size() + prev_stage_handles.size() == xstate.prev_stage_inputs[stage].size()); | |
- | |
- // Prepare a list of values / handles to capture | |
- for (auto captured_node : xstate.cur_stage_captures[stage]) { | |
- if (captured_node->kind() == kSelect) { | |
- auto & fn = node_map.at(captured_node->input()); | |
- if (captured_node->type()->kind() == TypeKind::TensorType) { | |
- captured_variables.emplace_back(fn.get(), captured_node->i(kOffset), captured_node->unique()); | |
- } else { | |
- JIT_ASSERT(captured_node->type()->kind() == TypeKind::HandleType); | |
- captured_handles.emplace(fn.get(), captured_node->unique()); | |
- } | |
- } else { | |
- JIT_ASSERT(captured_node->type()->kind() == TypeKind::TensorType); | |
- auto & fn = node_map.at(captured_node); | |
- captured_variables.emplace_back(fn.get(), 0, captured_node->unique()); | |
- } | |
- } | |
- | |
- roots.emplace_back(const_factory, 0); | |
- | |
- findCopiedNextFunctions(state, stage); | |
- } | |
- | |
- // Returns a function implementing functionality of a given node, | |
- // or nullptr if it's a no-op for autograd. | |
- std::shared_ptr<Function> getFunction(Node *node) { | |
- IR_IFM(node, PythonOp) | |
- return std::make_shared<PythonCall>(value); | |
- IR_ELSEIFM(CppOp) | |
- if (dynamic_cast<Eval*>(value->fn.get())) { | |
- auto fn = std::make_shared<EvalPlaceholder>(); | |
- | |
- // All Eval nodes take context edges as an input, and we need to register | |
- // all such places | |
- auto inputs = value->inputs(); | |
- JIT_ASSERT(inputs.size() > 0); | |
- auto handle_input = inputs[inputs.size() - 1]; | |
- JIT_ASSERT(handle_input->type()->kind() == TypeKind::HandleType); | |
- prev_stage_handles.emplace_back(fn.get(), handle_input->unique()); | |
- | |
- fn->num_inputs = node->inputs().size() - 1; | |
- return fn; | |
- } else { | |
- return std::make_shared<SimpleEval>(value->fn); | |
- } | |
- IR_ELSEIF(Select) | |
- // No-op. Selects are handled by their inputs. | |
- return nullptr; | |
- IR_ELSEIF(FusionGroup) | |
-#ifdef WITH_CUDA | |
- // TODO: make this more robust - handle device and contiguity changes! | |
- auto fusion_fn = sharedFusionCompiler().getOrCompile(*value->g(kSubgraph)); | |
- return std::make_shared<FusionGroupFunction>(std::move(fusion_fn)); | |
-#else | |
- throw std::runtime_error("don't know how to execute FusionGroups without CUDA"); | |
-#endif | |
- IR_ELSEIF(Param) | |
- auto fn = std::make_shared<InputPlaceholder>(); | |
- fn->num_inputs = 1; | |
- return fn; | |
- IR_ELSEIF(Constant) | |
- auto fn = std::make_shared<torch::autograd::WrapConstant>(value->t(kvalue)); | |
- const_factory->next_functions.emplace_back(fn, 0); | |
- fn->num_inputs = 1; | |
- return fn; | |
- IR_ELSEIF(Undefined) | |
- return std::make_shared<EmitNull>(); | |
- IR_ELSEIF(Transpose) // NOTE: Transpose in ONNX is Permute in Torch | |
- auto permutation = value->is(kperm); | |
- if (permutation != std::vector<int64_t>({1, 0})) | |
- throw std::runtime_error("Transpose isn't fully supported in closure compiler"); | |
- return std::make_shared<LambdaFunction>(1, [](const variable_list& vars) -> variable_list { | |
- return {make_variable(vars[0].data().transpose(1, 0), vars[0].requires_grad())}; | |
- }); | |
- IR_ELSEIF(Reshape) | |
- auto shape = value->is(kshape); | |
- return std::make_shared<LambdaFunction>(1, [shape](const variable_list& vars) -> variable_list { | |
- return {make_variable(vars[0].data().view(shape), vars[0].requires_grad())}; | |
- }); | |
- IR_ELSEIF(Gemm) | |
- auto beta = value->f(kbeta); | |
- auto alpha = value->f(kalpha); | |
- return std::make_shared<LambdaFunction>(3, [beta, alpha](const variable_list& vars) -> variable_list { | |
- return {vars[2].addmm(vars[0], vars[1], beta, alpha)}; | |
- }); | |
- IR_ELSE() | |
- return std::make_shared<LambdaFunction>(getTensorOp(node)); | |
- IR_END() | |
- } | |
- | |
- // Fill in the next_functions of the Function we just allocated | |
- void fillNextFunctions(Node *node, const std::shared_ptr<Function>& fn, node_fn_map_type& node_map, int output_offset, std::size_t stage) { | |
- auto graph = node->owningGraph(); | |
- // Gather uses of each output | |
- std::vector<std::reference_wrapper<const use_list>> output_uses_refs; | |
- if (node->hasMultipleOutputs()) { | |
- // Each use is a single Select node corresponding to an output | |
- for (auto& use : node->uses()) { | |
- if (use.user->isHandle()) continue; | |
- auto& select_uses = use.user->uses(); | |
- output_uses_refs.emplace_back(select_uses); | |
- } | |
- } else { | |
- output_uses_refs.emplace_back(node->uses()); | |
- } | |
- | |
- // Fill next_functions accordingly to uses of each output | |
- // There's some fiddling required for fixing the offset of uses for return node, so it's | |
- // better to keep this logic in a lambda. | |
- auto append_use = [&node_map, graph, output_offset](const std::shared_ptr<Function>& fn, Use& use) { | |
- int offset = use.offset; | |
- if (use.user == graph->return_node()) offset -= output_offset; | |
- fn->next_functions.emplace_back(node_map.at(use.user), offset); | |
- }; | |
- for (auto& output_uses_ref : output_uses_refs) { | |
- // Filter out uses from future stages (except for output!) | |
- auto output_uses = filter(output_uses_ref.get(), [stage, graph](const Use& use) { | |
- return use.user->stage() == stage || use.user == graph->return_node(); | |
- }); | |
- // Optimize out unnecessary Replicate nodes | |
- if (output_uses.size() == 1) { | |
- append_use(fn, output_uses[0]); | |
- // If an output was used more than once, we need to insert a Replicate node | |
- // because there's only a single entry for an output in next_functions | |
- } else { | |
- auto replicate = std::make_shared<Replicate>(); | |
- for (auto& use : output_uses) append_use(replicate, use); | |
- fn->next_functions.emplace_back(std::move(replicate), 0); | |
- } | |
- } | |
- } | |
- | |
- // Possibly create PrevStageInputs for any uses of nodes from previous | |
- // stages, and fill in their next_functions with our use. | |
- void registerPrevStageInputs(Node *node, const std::shared_ptr<Function>& fn, | |
- node_fn_map_type& prev_stage_input_map) { | |
- const auto& inputs = node->inputs(); | |
- for (std::size_t i = 0; i < inputs.size(); ++i) { | |
- auto input_node = inputs[i]; | |
- if (input_node->type()->kind() == TypeKind::HandleType) continue; | |
- JIT_ASSERT(input_node->type()->kind() == TypeKind::TensorType); | |
- if (input_node->stage() < node->stage()) { | |
- auto info = prev_stage_input_map.emplace(input_node, nullptr); | |
- auto & input_fn_ptr = info.first->second; | |
- // Create a node if insertion took place | |
- if (info.second) input_fn_ptr.reset(new PrevStageInput()); | |
- input_fn_ptr->next_functions.emplace_back(fn, i); | |
- } | |
- } | |
- } | |
- | |
- // If this stage produces gradients of any of previous stage inputs, | |
- // it needs to include them in its next_functions. However, we do not | |
- // necessarily keep them as SavedVariables, so it's not straightforward | |
- // to use wrap_outputs for this purpose. Here, we find all next_functions | |
- // from the previous stage that will need to be copied as next_functions | |
- // of this stage (the copy is done explicitly in lambda constructor given to | |
- // wrap_outputs). | |
- // NOTE: we depend on the Eval input ordering here (i.e. inherited/prev stage | |
- // outputs come after this stage inputs and remain sorted). | |
- void findCopiedNextFunctions(TracingState *state, std::size_t stage) { | |
- if (stage == 0) return; | |
- auto & current_outputs = state->output_edges[stage]; | |
- auto & prev_outputs = state->output_edges[stage - 1]; | |
- for (auto & output : current_outputs) { | |
- auto prev_it = std::find(prev_outputs.begin(), prev_outputs.end(), output); | |
- if (prev_it == prev_outputs.end()) continue; | |
- copied_next_fns.push_back(std::distance(prev_outputs.begin(), prev_it)); | |
- } | |
- } | |
- | |
- // Roots for a call to the engine. The list contains function in this order: | |
- // [ apply input roots | prev stage input roots | constant factory ] | |
- function_list roots; | |
- std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>> var_flags; | |
- | |
- // Output node | |
- std::shared_ptr<Function> output; | |
- std::shared_ptr<ConstantFactory> const_factory; | |
- | |
- std::vector<int> copied_next_fns; | |
- | |
- // These will be used by each instantiation of AutogradClosure to register hooks. | |
- std::vector<int> prev_stage_variables; // unique | |
- std::vector<std::pair<Function*, int>> prev_stage_handles; // (placeholder, unique) | |
- // After the function is run, take either a Variable or a backward handle, and | |
- // put it in the environment under 'unique'. | |
- std::vector<std::tuple<Function*, int, int>> captured_variables; // (function, output_nr, unique) | |
- std::unordered_map<Function*, int> captured_handles; // (function, unique) | |
-}; | |
- | |
-// Computes and stores an array of StageClosures for each stage in the graph | |
-struct MultiStageClosure { | |
- MultiStageClosure(TracingState* state) { | |
- auto graph = state->graph.get(); | |
- CrossStageStateDesc xstate {graph}; | |
- auto num_stages = graph->stage() + 1; | |
- for (std::size_t i = 0; i < num_stages; ++i) { | |
- stages.emplace_back(state, xstate, i); | |
- } | |
- } | |
- | |
- std::vector<StageClosure> stages; | |
-}; | |
- | |
-AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc) | |
- : AutogradClosure(desc, 0) {} | |
- | |
-// TODO: there's a lot processing involved in creating a new AutogradClosure instance, | |
-// so it might be worth to keep a pool of unused instances (or at least their attrs) | |
-// for all stages. We can't save saved_vars and saved_handles, but all callbacks | |
-// can be made reusable. | |
-AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage) | |
- : desc(desc) | |
- , stage(stage) { | |
- auto & stage_desc = desc->stages[stage]; | |
- | |
- // Callbacks to capture Variables for backward closure | |
- for (auto & entry : stage_desc.captured_variables) { | |
- auto & fn = std::get<0>(entry); | |
- auto output_offset = std::get<1>(entry); | |
- auto saved_idx = std::get<2>(entry); | |
- post_callbacks.emplace(fn, [this, saved_idx, output_offset](Function* fn, variable_list& inputs, variable_list& outputs) { | |
- std::lock_guard<std::mutex> lock(this->capture_mutex); | |
- this->captured_vars[saved_idx] = outputs[output_offset].data(); | |
- return true; | |
- }); | |
- } | |
- | |
- // Callbacks to capture handles (backward subgraphs) for backward closure | |
- for (auto & entry : stage_desc.captured_handles) { | |
- auto & fn = entry.first; | |
- auto saved_idx = entry.second; | |
- // Evals already wrap their backwards and they will be handled in the | |
- // next loop if needed | |
- if (dynamic_cast<EvalPlaceholder*>(fn)) continue; | |
- // Otherwise we have to wrap the backwards in a handle ourselves | |
- post_callbacks.emplace(fn, [this, saved_idx](Function* fn, variable_list& inputs, variable_list& outputs) { | |
- auto eval_fn = std::make_shared<Eval>(); | |
- eval_fn->replaceSubgraph(inputs, outputs); | |
- std::lock_guard<std::mutex> lock(this->capture_mutex); | |
- this->captured_handles[saved_idx] = std::move(eval_fn); | |
- return true; | |
- }); | |
- } | |
- | |
- // Callbacks that run closures received from forward and optionally capture | |
- // them for next stages | |
- for (auto & entry : stage_desc.prev_stage_handles) { | |
- int unique = entry.second; | |
- // Check if we need to capture the handle for next stage too | |
- auto it = stage_desc.captured_handles.find(entry.first); | |
- int saved_idx = it == stage_desc.captured_handles.end() ? -1 : it->second; | |
- post_callbacks.emplace(entry.first, [this, unique, saved_idx](Function* fn, variable_list& inputs, variable_list& outputs) { | |
- outputs = (*this->saved_handles.at(unique))(inputs); | |
- if (saved_idx != -1) { | |
- auto eval_fn = Eval::getBackwardEval(inputs, outputs); | |
- std::lock_guard<std::mutex> lock(this->capture_mutex); | |
- this->captured_handles[saved_idx] = std::move(eval_fn); | |
- } | |
- return true; | |
- }); | |
- } | |
- | |
- // A callback to capture the output | |
- pre_callbacks.emplace(stage_desc.output.get(), [this](Function*, variable_list& inputs) { | |
- std::lock_guard<std::mutex> lock(this->capture_mutex); | |
- this->outputs.reserve(inputs.size()); | |
- for (auto & input : inputs) | |
- this->outputs.emplace_back(input.opt_data()); | |
- return false; // Stop execution | |
- }); | |
-} | |
- | |
-variable_list AutogradClosure::apply(const variable_list& inputs) { | |
- auto& stage_closure = desc->stages[stage]; | |
- | |
- // Validate inputs | |
- auto num_inputs = inputs.size(); | |
- if (num_inputs != stage_closure.var_flags.first.size()) | |
- throw std::runtime_error("AutogradClosure received an incorrect number of inputs"); | |
- for (std::size_t i = 0; i < num_inputs; ++i) { | |
- auto & flags = stage_closure.var_flags.first[i]; | |
- if (!flags.verify(inputs[i])) | |
- throw std::runtime_error("AutogradClosure received inputs with different flags"); | |
- } | |
- | |
- // TODO: we could run all this with volatile variables, but we need to somehow capture handles | |
- // we should enable requires_grad only for the parts that need it | |
- auto input_leaves = fmap(inputs, [](const Variable& v) { | |
- return v.defined() ? make_variable(v.data(), v.requires_grad(), v.is_volatile()) : Variable(); | |
- }); | |
- for (auto unique : desc->stages[stage].prev_stage_variables) | |
- input_leaves.emplace_back(make_variable(saved_vars.at(unique), true, false)); | |
- input_leaves.emplace_back(Variable()); // for ConstantFactory | |
- | |
- auto& engine = python::PythonEngine::getDefaultEngine(); | |
- engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks); | |
- | |
- // Create the backward function lazily | |
- auto make_grad_fn = [this]() -> std::shared_ptr<Function> { | |
- if (this->stage == this->desc->stages.size() - 1) { | |
- std::string msg = "JIT closure compiled only for "; | |
- msg += std::to_string(this->stage); | |
- msg += " backwards"; | |
- return std::make_shared<Error>(std::move(msg)); | |
- } | |
- auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1)); | |
- // TODO: don't make a full copy of saved_* - copy only the things that bw needs | |
- bw_fn->saved_vars = this->saved_vars; | |
- bw_fn->saved_vars.insert(std::make_move_iterator(this->captured_vars.begin()), | |
- std::make_move_iterator(this->captured_vars.end())); | |
- bw_fn->saved_handles = this->saved_handles; | |
- bw_fn->saved_handles.insert(std::make_move_iterator(this->captured_handles.begin()), | |
- std::make_move_iterator(this->captured_handles.end())); | |
- // Patch next_functions to include prevous stage next_functions | |
- for (auto copied_idx : this->desc->stages[this->stage + 1].copied_next_fns) { | |
- bw_fn->next_functions.push_back(this->next_functions[copied_idx]); | |
- } | |
- // This is needed because of copied functions (even if all inputs of this stage | |
- // didn't require grad, copied function can), and is always valid because we assert | |
- // that flags are the same as when we compiled the closure (and the tracing Eval | |
- // was run, so it must have been executable). | |
- bw_fn->is_executable = true; | |
- return bw_fn; | |
- }; | |
- | |
- // See Note [Null-edge pruning] | |
- variable_list result; | |
- auto num_outputs = outputs.size(); | |
- std::shared_ptr<Function> grad_fn; | |
- JIT_ASSERT(outputs.size() == stage_closure.var_flags.second.size()); | |
- for (std::size_t i = 0; i < num_outputs; ++i) { | |
- auto & flags = stage_closure.var_flags.second[i]; | |
- if (flags.requires_grad) { | |
- if (!grad_fn) grad_fn = make_grad_fn(); | |
- result.push_back(make_variable(outputs[i], grad_fn)); | |
- } else { | |
- result.push_back(make_variable(outputs[i], flags.requires_grad, flags.is_volatile)); | |
- } | |
- } | |
- | |
- // If we created grad_fn for any of the outputs, we also need to fill in next_functions | |
- if (grad_fn) { | |
- for (auto & input : inputs) { | |
- if (!input.requires_grad()) continue; | |
- grad_fn->next_functions.emplace_back( | |
- input.grad_fn() ? input.grad_fn() : input.grad_accumulator(), | |
- input.output_nr()); | |
- } | |
- } | |
- | |
- captured_vars.clear(); | |
- captured_handles.clear(); | |
- outputs.clear(); | |
- return result; | |
-} | |
- | |
-AutogradClosureFactory::AutogradClosureFactory(TracingState *state) | |
- : desc(std::make_shared<MultiStageClosure>(state)) {} | |
- | |
-std::shared_ptr<Function> AutogradClosureFactory::construct() { | |
- return std::make_shared<AutogradClosure>(desc); | |
-} | |
- | |
-}} | |
diff --git a/torch/csrc/autograd/functions/jit_closure.h b/torch/csrc/autograd/functions/jit_closure.h | |
deleted file mode 100644 | |
index 6d905e63..00000000 | |
--- a/torch/csrc/autograd/functions/jit_closure.h | |
+++ /dev/null | |
@@ -1,50 +0,0 @@ | |
-#pragma once | |
- | |
-#include <Python.h> | |
-#include <memory> | |
-#include <unordered_map> | |
- | |
-#include "torch/csrc/jit/ir.h" | |
-#include "torch/csrc/jit/tracer_state.h" | |
-#include "torch/csrc/autograd/engine.h" | |
-#include "torch/csrc/autograd/function.h" | |
-#include "torch/csrc/autograd/variable.h" | |
- | |
-namespace torch { namespace autograd { | |
- | |
-struct MultiStageClosure; | |
- | |
-struct AutogradClosureFactory { | |
- AutogradClosureFactory(torch::jit::tracer::TracingState *graph); | |
- | |
- std::shared_ptr<Function> construct(); | |
- | |
- std::shared_ptr<MultiStageClosure> desc; | |
-}; | |
- | |
-struct AutogradClosure : public Function { | |
- AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc); | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
- | |
-private: | |
- AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage); | |
- | |
- variable_list rewrapInputs(const variable_list& inputs); | |
- | |
- std::shared_ptr<MultiStageClosure> desc; | |
- std::size_t stage; | |
- | |
- std::unordered_map<int, at::Tensor> saved_vars; | |
- std::unordered_map<int, std::shared_ptr<Function>> saved_handles; | |
- | |
- Engine::pre_callback_map pre_callbacks; | |
- Engine::post_callback_map post_callbacks; | |
- | |
- std::unordered_map<int, at::Tensor> captured_vars; | |
- std::unordered_map<int, std::shared_ptr<Function>> captured_handles; | |
- tensor_list outputs; | |
- std::mutex capture_mutex; | |
-}; | |
- | |
-}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/onnx/basic_ops.cpp b/torch/csrc/autograd/functions/onnx/basic_ops.cpp | |
deleted file mode 100644 | |
index 5a89a1fa..00000000 | |
--- a/torch/csrc/autograd/functions/onnx/basic_ops.cpp | |
+++ /dev/null | |
@@ -1,12 +0,0 @@ | |
-#include "torch/csrc/autograd/functions/basic_ops.h" | |
- | |
-namespace torch { namespace autograd { | |
- | |
-jit::node_list Add::symbolic(SymbolicContext* ctx, jit::node_list inputs) { | |
- auto & g = ctx->graph; | |
- auto node = g->create(jit::kAdd, inputs); | |
- g->appendNode(node); | |
- return {node}; | |
-} | |
- | |
-}} | |
diff --git a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp b/torch/csrc/autograd/functions/onnx/batch_normalization.cpp | |
deleted file mode 100644 | |
index 08f675c4..00000000 | |
--- a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp | |
+++ /dev/null | |
@@ -1,38 +0,0 @@ | |
-#include "torch/csrc/autograd/functions/batch_normalization.h" | |
-#include <sstream> | |
- | |
-#include "torch/csrc/jit/ir.h" | |
- | |
-namespace torch { | |
-namespace autograd { | |
- | |
-jit::node_list BatchNormForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) { | |
- auto & g = ctx->graph; | |
- // X, Scale, Bias | |
- auto bn = g->appendNode(g->create(jit::kBatchNormalization, {inputs.at(0),inputs.at(1),inputs.at(2)})); | |
- bn->addInput(jit::tracer::getBufferTrace(*ctx->buffer_map, running_mean)); | |
- bn->addInput(jit::tracer::getBufferTrace(*ctx->buffer_map, running_var)); | |
- bn->i_(jit::kis_test, !this->training); | |
- bn->f_(jit::kepsilon, eps); | |
- //bn->s_(jit::korder, "NCHW"); | |
- bn->f_(jit::kmomentum, 1 - momentum); | |
- | |
- auto orig_output = g->appendNode(g->createSelect(bn, 0)); | |
- | |
- if(this->training) { | |
- g->appendNode(g->createSelect(bn, 1)->setType(bn->input(3)->type())); | |
- g->appendNode(g->createSelect(bn, 2)->setType(bn->input(4)->type())); | |
- // dummy output | |
- for(int i = 3; i < 5; i++) { | |
- g->appendNode(g->createSelect(bn, i)->setDebugName("batch_norm_dead_output")); | |
- } | |
- } | |
- bn->is_(jit::kconsumed_inputs,{0,0,0,1,1}); | |
- | |
- ctx->batch_norm_count++; | |
- return {orig_output}; | |
-} | |
- | |
- | |
-} // torch::autograd | |
-} // torch | |
diff --git a/torch/csrc/autograd/functions/onnx/convolution.cpp b/torch/csrc/autograd/functions/onnx/convolution.cpp | |
deleted file mode 100644 | |
index a8270fc8..00000000 | |
--- a/torch/csrc/autograd/functions/onnx/convolution.cpp | |
+++ /dev/null | |
@@ -1,78 +0,0 @@ | |
-#include "torch/csrc/autograd/functions/convolution.h" | |
- | |
-namespace torch { namespace autograd { | |
- | |
-// Note [Caffe2ConvTranspose] | |
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
-// ConvTranspose in Caffe2 is a bit silly: bias is mandatory. But ONNX | |
-// has removed bias input from official ConvTranspose. How can the Caffe2 | |
-// backend do the translation? It can't! It's impossible! So as a temporary | |
-// hack while we wait for Caffe2 to make bias optional, we are using a | |
-// Caffe2ConvTranspose experimental ONNX op which has a mandatory bias. | |
-// PyTorch has no trouble making the zero-filled tensor. | |
-// | |
-// For code simplicity, even if PyTorch was given a bias tensor, it is NOT | |
-// passed here; it's done as an external addition. This is less efficient | |
-// but this code should be temporary anyway. | |
- | |
-jit::node_list ConvForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) { | |
- auto & g = ctx->graph; | |
- // See Note [Caffe2ConvTranspose] | |
- auto n = g->create(!transposed ? jit::kConv : jit::kConvTranspose, | |
- {inputs.at(0), inputs.at(1)}); | |
- | |
- // Irritatingly, Caffe2 requires us to specify kernels, | |
- // but we don't actually have that information directly | |
- // recorded in ConvForward. So we have to reverse | |
- // engineer it from the input types... | |
- // TODO: dynamic_cast ew | |
- auto weight_type = inputs.at(1)->type()->cast<jit::TensorType>(); | |
- JIT_ASSERT(weight_type); | |
- auto weight_size = weight_type->sizes(); | |
- | |
- // See Note [Caffe2ConvTranspose] | |
- if(transposed) { | |
- n->addInput(g->appendNode(g->createConstant(at::CPU(at::kFloat).zeros({weight_size[1]})))); | |
- } | |
- | |
- g->appendNode(n); | |
- | |
- std::vector<int64_t> kernel_size(weight_size.begin() + 2, weight_size.end()); | |
- n->is_(jit::kkernel_shape, std::move(kernel_size)); | |
- std::vector<int64_t> kernel_stride(stride.begin(),stride.end()); | |
- n->is_(jit::kstrides, std::move(kernel_stride)); | |
- | |
- std::vector<int64_t> kernel_pads(padding.begin(),padding.end()); | |
- // NB: Caffe2 let's specifying top and bottom pads separately; | |
- // PyTorch assumes it's symmetric | |
- for (int p : padding) { | |
- kernel_pads.push_back(p); | |
- } | |
- n->is_(jit::kpads,std::move(kernel_pads)); | |
- | |
- std::vector<int64_t> kernel_dilations(dilation.begin(),dilation.end()); | |
- n->is_(jit::kdilations,std::move(kernel_dilations)); | |
- n->i_(jit::kgroup,groups); | |
- | |
- // Not in ONNX? | |
- // TODO: implement it once ConvTranspose in ONNX gets `adj` argument instead | |
- // of providing `output_shape` | |
- for (int p : output_padding) { | |
- JIT_EXPECTM(p == 0, "output padding is not supported."); | |
- } | |
- | |
- // ignore benchmark/cudnn_enabled | |
- | |
- if (inputs.at(2)->kind() != jit::kUndefined) { | |
- // TODO: Set type here based on RETURN type (not available atm) | |
- auto a_n = g->create(jit::kAdd, {g->appendNode(g->createSelect(n, 0)), inputs.at(2)}); | |
- a_n->i_(jit::kbroadcast, 1); | |
- a_n->i_(jit::kaxis, 1); | |
- g->appendNode(a_n); | |
- return {a_n}; | |
- } else { | |
- return {n}; | |
- } | |
-} | |
- | |
-}} | |
diff --git a/torch/csrc/autograd/functions/special.cpp b/torch/csrc/autograd/functions/special.cpp | |
index c0e75f69..f750c39f 100644 | |
--- a/torch/csrc/autograd/functions/special.cpp | |
+++ b/torch/csrc/autograd/functions/special.cpp | |
@@ -1,15 +1,37 @@ | |
#include "torch/csrc/autograd/functions/special.h" | |
+#include "torch/csrc/assertions.h" | |
#include "torch/csrc/autograd/python_engine.h" | |
+#include "torch/csrc/autograd/edge.h" | |
+#include "torch/csrc/autograd/function.h" | |
+#include "torch/csrc/autograd/edge.h" | |
+ | |
+#include <cstdint> | |
+#include <memory> | |
+#include <unordered_map> | |
+#include <unordered_set> | |
+#include <vector> | |
+#include <utility> // for swap | |
namespace torch { namespace autograd { | |
+// Used when an output has multiple uses (there's only one entry | |
+// in next_edges per output). | |
+struct Replicate : public Function { | |
+ Replicate() : Function(/*num_inputs=*/1) {} | |
+ | |
+ virtual variable_list apply(const variable_list& inputs) { | |
+ TORCH_ASSERT(inputs.size() == 1); | |
+ return variable_list(num_outputs(), inputs[0]); | |
+ } | |
+}; | |
+ | |
// Note [Null-edge pruning] | |
// Evals have a problem with null edges appearing in the graph, because there's | |
// no way to tell the identity of the input (i.e. each nullptr might have been | |
// a different input, all of them might have been a single input, etc.). | |
// However, null edges are generally quite useless, so we can safely prune them, | |
-// by removing them from next_functions of Eval node and never allocating | |
+// by removing them from next_edges of Eval node and never allocating | |
// placeholders for them. This is a bit annoying because backward subgraphs may | |
// have many less outputs than forward graph had inputs, but I don't think there's | |
// a way around it. It's a tiny perf optimization too :) | |
@@ -38,11 +60,11 @@ auto Eval::getSubgraph(const variable_list& inputs, const variable_list& outputs | |
input_edges.reserve(inputs.size()); | |
for (auto & input : inputs) { | |
if (!input.defined()) continue; | |
- input_edges.emplace(input.grad_fn() ? input.grad_fn() : input.grad_accumulator(), input.output_nr()); | |
+ input_edges.emplace(input.gradient_edge()); | |
} | |
// This is used to stop the search in situation 2 and find the corresponding placeholders. | |
- std::unordered_map<edge_type, std::shared_ptr<EvalOutput>, edge_hasher> inherited_edges; | |
+ std::unordered_map<Edge, std::shared_ptr<EvalOutput>> inherited_edges; | |
inherited_edges.reserve(inherited_placeholders.size()); | |
for (auto & placeholder : inherited_placeholders) { | |
input_edges.emplace(placeholder->next_edge); | |
@@ -62,15 +84,14 @@ auto Eval::getSubgraph(const variable_list& inputs, const variable_list& outputs | |
while (!queue.empty()) { | |
auto fn = queue.back(); queue.pop_back(); | |
JIT_ASSERT(fn); | |
- fn->tracing_state->in_eval_subgraph = true; | |
- int num_edges = fn->next_functions.size(); | |
- for (int i = 0; i < num_edges; ++i) { | |
- auto & edge = fn->next_functions[i]; | |
- auto & next_fn = edge.first; | |
- if (!next_fn) continue; // See Note [Null-edge pruning] | |
+ fn->tracing_state().in_eval_subgraph = true; | |
+ const auto num_outputs = fn->num_outputs(); | |
+ for (size_t i = 0; i < num_outputs; ++i) { | |
+ const auto& edge = fn->next_edge(i); | |
+ if (!edge.function) continue; // See Note [Null-edge pruning] | |
// Edge belongs to subgraph boundary. Register that and don't search along it. | |
if (input_edges.count(edge) > 0) { | |
- subgraph.boundary.begins.emplace(fn->getSharedPtr(), i); | |
+ subgraph.boundary.begins.emplace(fn->get_shared_ptr(), i); | |
subgraph.boundary.ends.emplace(edge); | |
auto it = inherited_edges.find(edge); | |
// Situation 2. If that edge is actually pointing to an earlier stage subgraph, | |
@@ -81,14 +102,14 @@ auto Eval::getSubgraph(const variable_list& inputs, const variable_list& outputs | |
continue; | |
} | |
// Situation 1. If we end up in a placeholder, we need to inherit it. | |
- if (auto placeholder = std::dynamic_pointer_cast<EvalOutput>(next_fn)) { | |
+ if (auto placeholder = std::dynamic_pointer_cast<EvalOutput>(edge.function)) { | |
extra_placeholders.emplace(placeholder); | |
subgraph.boundary.ends.emplace(placeholder->next_edge); | |
continue; | |
} | |
- bool unseen = seen.emplace(next_fn.get()).second; | |
+ bool unseen = seen.emplace(edge.function.get()).second; | |
if (unseen) | |
- queue.emplace_back(next_fn.get()); | |
+ queue.emplace_back(edge.function.get()); | |
} | |
} | |
@@ -106,47 +127,46 @@ bool Eval::trySimpleEval(const variable_list& inputs, const variable_list& outpu | |
if (inherited_placeholders.size() != 0) return false; | |
auto& grad_fn = outputs[0].grad_fn(); | |
- if (static_cast<std::size_t>(grad_fn->num_inputs) >= max_outputs) return false; | |
- if (static_cast<std::size_t>(grad_fn->num_inputs) != outputs.size()) return false; | |
+ if (static_cast<std::size_t>(grad_fn->num_inputs()) >= max_outputs) return false; | |
+ if (static_cast<std::size_t>(grad_fn->num_inputs()) != outputs.size()) return false; | |
// Check that all outputs have the same grad_fn and cover all its inputs | |
bitset_type output_nrs = 0; | |
- bitset_type expected_bitset = ((1 << grad_fn->num_inputs) - 1); | |
+ bitset_type expected_bitset = ((1 << grad_fn->num_inputs()) - 1); | |
for (auto & output : outputs) { | |
if (output.grad_fn() != grad_fn) return false; | |
output_nrs |= (1 << output.output_nr()); | |
} | |
if (output_nrs != expected_bitset) return false; | |
- // Check that grad_fn->next_functions matches the inputs exactly | |
+ // Check that grad_fn's next_edges match the inputs exactly. | |
auto num_inputs = inputs.size(); | |
- auto& grad_next_fns = grad_fn->next_functions; | |
- if (num_inputs != grad_next_fns.size()) return false; | |
+ if (num_inputs != grad_fn->num_outputs()) return false; | |
for (std::size_t i = 0; i < num_inputs; ++i) { | |
- // Unfortunately, null edge pruning (see Note [Null-edge pruning]) applies to | |
- // autograd functions which would otherwise be eligible for the SimpleEval | |
- // optimization. This makes everything more complicated, so for now we just don't | |
- // attempt the optimization in this case. To fix it properly, | |
- // we'd need to filter grad_next_fns and outputs of apply in Eval::apply. | |
- // The check below tests if null edge pruning occurred. | |
- if (!inputs[i].defined() || !grad_next_fns[i].first) return false; | |
- const auto& input_grad = inputs[i].grad_fn() ? inputs[i].grad_fn() : inputs[i].grad_accumulator(); | |
- if (grad_next_fns[i].first != input_grad || grad_next_fns[i].second != inputs[i].output_nr()) return false; | |
+ const auto& next_grad_edge = grad_fn->next_edge(i); | |
+ // Unfortunately, null edge pruning (see Note [Null-edge pruning]) applies | |
+ // to autograd functions which would otherwise be eligible for the | |
+ // SimpleEval optimization. This makes everything more complicated, so for | |
+ // now we just don't attempt the optimization in this case. To fix it | |
+ // properly, we'd need to filter grad_fn's output edges and outputs of | |
+ // apply in Eval::apply. The check below tests if null edge pruning | |
+ // occurred. | |
+ if (!inputs[i].defined() || !next_grad_edge.is_valid()) return false; | |
+ if (next_grad_edge != inputs[i].gradient_edge()) return false; | |
} | |
// Success! We still need to set up placeholders for next stages and to drop | |
// references to the graph. | |
- std::swap(next_functions, grad_next_fns); | |
- grad_next_fns.reserve(num_inputs); | |
+ std::swap(next_edges_, grad_fn->next_edges()); | |
+ grad_fn->next_edges().reserve(num_inputs); | |
placeholders.reserve(num_inputs); | |
- for (std::size_t i = 0; i < num_inputs; ++i) { | |
- auto placeholder = std::make_shared<EvalOutput>(next_functions[i]); | |
- grad_next_fns.emplace_back(placeholder, 0); | |
+ for (const auto& input : next_edges_) { | |
+ auto placeholder = std::make_shared<EvalOutput>(input); | |
+ grad_fn->add_next_edge({placeholder, 0}); | |
placeholders.emplace_back(std::move(placeholder)); | |
} | |
- is_executable = grad_fn->is_executable; | |
simple_graph = grad_fn; | |
- grad_fn->tracing_state->in_eval_subgraph = true; | |
+ grad_fn->tracing_state().in_eval_subgraph = true; | |
return true; | |
} | |
@@ -160,12 +180,12 @@ variable_list Eval::filterRelevantOutputs(const variable_list& inputs, const var | |
ignored_grad_fns.reserve(inputs.size()); | |
for (auto& input : inputs) { | |
if (!input.defined()) continue; | |
- ignored_grad_fns.emplace(input.grad_fn(), input.output_nr()); | |
+ ignored_grad_fns.insert(input.gradient_edge()); | |
} | |
for (auto& output : outputs) { | |
if (!output.defined()) continue; | |
if (!output.grad_fn()) continue; | |
- if (ignored_grad_fns.count(std::make_pair(output.grad_fn(), output.output_nr())) > 0) continue; | |
+ if (ignored_grad_fns.count(output.gradient_edge()) > 0) continue; | |
relevant_outputs.emplace_back(output); | |
} | |
return relevant_outputs; | |
@@ -176,10 +196,7 @@ auto Eval::computeInputOrder(const variable_list& inputs, const placeholder_list | |
int idx = 0; | |
for (auto & input : inputs) { | |
if (!input.defined()) continue; | |
- input_order.emplace( | |
- std::make_pair(input.grad_fn() ? input.grad_fn() : input.grad_accumulator(), input.output_nr()), | |
- idx++ | |
- ); | |
+ input_order.emplace(input.gradient_edge(), idx++); | |
} | |
for (auto & placeholder : inherited_placeholders) | |
input_order.emplace(placeholder->next_edge, idx++); | |
@@ -200,12 +217,12 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou | |
if (!trySimpleEval(inputs, relevant_outputs, inherited_placeholders)) { | |
roots.reserve(relevant_outputs.size()); | |
for (auto & output : relevant_outputs) | |
- roots.emplace_back(output.grad_fn(), output.output_nr()); | |
+ roots.push_back(output.gradient_edge()); | |
auto subgraph = getSubgraph(inputs, relevant_outputs, inherited_placeholders); | |
// Prepare output placeholder nodes for each end. | |
- std::unordered_map<edge_type, std::shared_ptr<EvalOutput>, edge_hasher> ends_to_outputs; | |
+ std::unordered_map<Edge, std::shared_ptr<EvalOutput>> ends_to_outputs; | |
for (auto & placeholder : placeholders) { | |
ends_to_outputs[placeholder->next_edge] = placeholder; | |
} | |
@@ -218,20 +235,18 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou | |
// Replace begins with pointers to output nodes. | |
// This detaches the subgraph from the full backward graph. | |
- for (auto & begin : subgraph.boundary.begins) { | |
- auto & fn = begin.first; | |
- auto offset = begin.second; | |
- fn->next_functions[offset] = std::make_pair(ends_to_outputs.at(fn->next_functions[offset]), 0); | |
+ for (auto& begin : subgraph.boundary.begins) { | |
+ const auto& edge = begin.function->next_edge(begin.input_nr); | |
+ begin.function->set_next_edge( | |
+ begin.input_nr, Edge(ends_to_outputs.at(edge), 0)); | |
} | |
// Replace subgraph with this node. | |
- next_functions.insert(next_functions.begin(), subgraph.boundary.ends.begin(), subgraph.boundary.ends.end()); | |
- is_executable = std::any_of(relevant_outputs.begin(), relevant_outputs.end(), | |
- [](const Variable& var) { return var.requires_grad(); }); | |
+ next_edges_.insert(next_edges_.begin(), subgraph.boundary.ends.begin(), subgraph.boundary.ends.end()); | |
// Ensure placeholders and inputs are sorted in the same way. | |
edge_order input_order = computeInputOrder(inputs, inherited_placeholders); | |
- std::sort(next_functions.begin(), next_functions.end(), [&input_order](const edge_type &a, const edge_type &b) { | |
+ std::sort(next_edges_.begin(), next_edges_.end(), [&input_order](const Edge &a, const Edge &b) { | |
return input_order.at(a) < input_order.at(b); | |
}); | |
std::sort(placeholders.begin(), placeholders.end(), [&input_order](const std::shared_ptr<EvalOutput> &a, const std::shared_ptr<EvalOutput> &b) { | |
@@ -240,9 +255,30 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou | |
} | |
// Rebase outputs. | |
+ auto this_shared = shared_from_this(); | |
+ std::unordered_set<Variable*> repeated_outputs; | |
+ // NB: every output can be in 3 states: | |
+ // - unique so far - only the else of second if is taken | |
+ // - repeated first time - first if + first branch of second if | |
+ // - repeated many times - first branch of second if only | |
for (auto & output : relevant_outputs) { | |
- output.grad_fn() = shared_from_this(); | |
- output.get()->output_nr = num_inputs++; | |
+ // This output is already rebased. This happens when there | |
+ // the same Variable has been returned multiple times, and | |
+ // is repeated in this list. | |
+ if (output.grad_fn_unsafe() == this) { | |
+ auto replicate = std::make_shared<Replicate>(); | |
+ replicate->add_next_edge({this_shared, output.output_nr()}); | |
+ output.set_gradient_edge({std::move(replicate), 0}); | |
+ repeated_outputs.emplace(&output); | |
+ } | |
+ // NOTE: this check should be fairly cheap, and the set shouldn't | |
+ // perform any allocations until we actually see repeated outputs. | |
+ if (repeated_outputs.count(&output) > 0) { | |
+ auto & replicate = output.grad_fn(); | |
+ replicate->add_next_edge({this_shared, num_inputs_++}); | |
+ } else { | |
+ autograd::create_gradient_edge(output, this_shared); | |
+ } | |
} | |
return true; | |
@@ -253,11 +289,12 @@ variable_list Eval::apply(const variable_list& inputs) { | |
if (simple_graph) { | |
outputs = (*simple_graph)(inputs); | |
} else { | |
- std::mutex outputs_mutex; | |
- outputs.resize(placeholders.size()); | |
auto& engine = python::PythonEngine::getDefaultEngine(); | |
auto exec_data = filterRoots(inputs); | |
- engine.execute(exec_data.first, exec_data.second, true, getCallbacks(outputs, outputs_mutex)); | |
+ auto next_edges = fmap( | |
+ placeholders, | |
+ [](const std::shared_ptr<EvalOutput>& o) { return Edge(o, 0); }); | |
+ outputs = engine.execute(exec_data.first, exec_data.second, true, true, next_edges); | |
} | |
auto bw_eval = newEval(); | |
@@ -267,7 +304,7 @@ variable_list Eval::apply(const variable_list& inputs) { | |
// This node already does it (backward of non-traceable backward is implicitly non-traceable), | |
// and it passes more information (backward Eval may inherit placeholders) than | |
// Function::traced_apply has available. | |
- tracing_state->in_eval_subgraph = true; | |
+ tracing_state_->in_eval_subgraph = true; | |
return outputs; | |
} | |
@@ -275,9 +312,9 @@ variable_list Eval::apply(const variable_list& inputs) { | |
// TODO: once we clean up the stochastic function mess it should be possible to ignore | |
// nullptr inputs in the Engine (it implies that the Variables is 0, so the jacobian vector | |
// product will be all zero too). | |
-std::pair<function_list, variable_list> Eval::filterRoots(const variable_list& inputs) { | |
+std::pair<edge_list, variable_list> Eval::filterRoots(const variable_list& inputs) { | |
variable_list filtered_inputs; | |
- function_list filtered_roots; | |
+ edge_list filtered_roots; | |
auto num_inputs = inputs.size(); | |
if (roots.size() != num_inputs) | |
throw std::logic_error("inputs.size() != roots.size()"); | |
@@ -305,20 +342,4 @@ std::pair<function_list, variable_list> Eval::filterRoots(const variable_list& i | |
return std::make_pair(std::move(filtered_roots), std::move(filtered_inputs)); | |
} | |
-Engine::pre_callback_map Eval::getCallbacks(variable_list& outputs, std::mutex& outputs_mutex) { | |
- Engine::pre_callback_map callbacks; | |
- int num_outputs = placeholders.size(); | |
- for (int i = 0; i < num_outputs; ++i) { | |
- auto& output_fn = placeholders[i]; | |
- callbacks.emplace(output_fn.get(), [&outputs, &outputs_mutex, i](Function* _unused, variable_list& inputs) -> bool { | |
- if (inputs.size() != 1) | |
- throw std::logic_error("placeholder callback received too many inputs"); | |
- std::lock_guard<std::mutex> lock(outputs_mutex); | |
- outputs[i] = inputs[0]; | |
- return false; // Stop at output nodes | |
- }); | |
- } | |
- return callbacks; | |
-} | |
- | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/special.h b/torch/csrc/autograd/functions/special.h | |
index 4af9f251..7676083b 100644 | |
--- a/torch/csrc/autograd/functions/special.h | |
+++ b/torch/csrc/autograd/functions/special.h | |
@@ -1,37 +1,34 @@ | |
#pragma once | |
#include <Python.h> | |
-#include <memory> | |
-#include <string> | |
-#include <mutex> | |
#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/autograd/engine.h" | |
+#include <memory> | |
+#include <mutex> | |
+#include <string> | |
+#include <unordered_map> | |
+#include <unordered_set> | |
+#include <vector> | |
+ | |
namespace torch { namespace autograd { | |
struct EvalOutput : Function { | |
- EvalOutput(const edge_type& next_edge) | |
- : next_edge(next_edge) { | |
- num_inputs = 1; | |
- // It would be nice if we could inherit this from the function of next_edge, | |
- // but we want to always run this node to capture the output. This might | |
- // confuse some of the functions causing them to do unnecessary work. | |
- // TODO: it should be possible to improve this once we get rid of NULL Variables | |
- is_executable = true; | |
- } | |
+ explicit EvalOutput(const Edge& next_edge_) | |
+ : Function(/*num_inputs=*/1), next_edge(next_edge_) {} | |
virtual variable_list apply(const variable_list& inputs) override { | |
throw std::logic_error("EvalOutput::apply() called"); | |
} | |
- edge_type next_edge; | |
+ Edge next_edge; | |
}; | |
struct Eval : Function { | |
- using edge_set = std::unordered_set<edge_type, edge_hasher>; | |
- using edge_order = std::unordered_map<edge_type, int, edge_hasher>; | |
+ using edge_set = std::unordered_set<Edge>; | |
+ using edge_order = std::unordered_map<Edge, int>; | |
using placeholder_list = std::vector<std::shared_ptr<EvalOutput>>; | |
// This struct has only one member, but it's useful to e.g. add a set of all | |
@@ -40,12 +37,12 @@ struct Eval : Function { | |
struct Boundary { | |
// All nodes from within the subgraph that connect to the outside. | |
// These are the places that will need to be patched to point to placeholders. | |
- // Contains pairs of (fn, offset into next_functions). | |
+ // Contains pairs of (fn, offset into next_edges). | |
edge_set begins; | |
// All nodes that are not in the subgraph, but are in the union of | |
- // next_functions of all nodes from the subgraph. These are the places that | |
+ // next_edges of all nodes from the subgraph. These are the places that | |
// will be modeled by placeholders. | |
- // Contains pairs of (fn, input_nr) and is equivalent to next_functions | |
+ // Contains pairs of (fn, input_nr) and is equivalent to next_edges | |
// of an Eval node that will replace the subgraph. | |
edge_set ends; | |
}; | |
@@ -78,20 +75,19 @@ struct Eval : Function { | |
return std::make_shared<Eval>(); | |
} | |
- // Roots are empty if simple_graph is not NULL. | |
+ // Roots are empty if simple_graph is not nullptr. | |
// simple_graph is an optimization of first backward stage - in this case | |
// all Eval subgraphs contain only a single gradient function, and the | |
// graph search on creation + call to the engine in apply can be elided | |
- function_list roots; | |
+ edge_list roots; | |
std::shared_ptr<Function> simple_graph; | |
placeholder_list placeholders; | |
- jit::Node* forward_ctx_select = nullptr; | |
+ jit::Value* forward_ctx_select = nullptr; | |
bool traceable = false; | |
private: | |
- std::pair<function_list, variable_list> filterRoots(const variable_list& inputs); | |
- Engine::pre_callback_map getCallbacks(variable_list& outputs, std::mutex& outputs_mutex); | |
+ std::pair<edge_list, variable_list> filterRoots(const variable_list& inputs); | |
Subgraph getSubgraph( | |
const variable_list& inputs, | |
diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp | |
index 00086cd5..dfd8baf3 100644 | |
--- a/torch/csrc/autograd/functions/tensor.cpp | |
+++ b/torch/csrc/autograd/functions/tensor.cpp | |
@@ -1,117 +1,92 @@ | |
-#include "tensor.h" | |
+#include "torch/csrc/autograd/functions/tensor.h" | |
+#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/autograd/functions/basic_ops.h" | |
#include "torch/csrc/autograd/functions/utils.h" | |
+#include "torch/csrc/autograd/generated/Functions.h" | |
+#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/utils/auto_gpu.h" | |
-namespace torch { namespace autograd { | |
- | |
-auto Identity::apply(const variable_list& inputs) -> variable_list { | |
- return inputs; | |
-}; | |
- | |
-auto Clone::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("Clone", inputs, 1); | |
- auto& input = inputs[0].data(); | |
- AutoGPU guard(input); | |
- | |
- at::Tensor output = input.clone(); | |
- | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<Identity>(std::move(f)); | |
- }); | |
-}; | |
- | |
-auto Contiguous::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("Contiguous", inputs, 1); | |
- auto& input = inputs[0].data(); | |
- AutoGPU guard(input); | |
+#include <cstdint> | |
+#include <memory> | |
+#include <utility> | |
- at::Tensor output = input.contiguous(); | |
+namespace torch { namespace autograd { | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<Identity>(std::move(f)); | |
- }); | |
+auto CopyBackwards::apply(const variable_list& grads) -> variable_list { | |
+ check_input_variables("CopyBackwards", grads, 1); | |
+ auto& grad = grads[0]; | |
+ variable_list grad_inputs(2); | |
+ if (should_compute_output(0)) { | |
+ grad_inputs[0] = at::zeros_like(grad); | |
+ } | |
+ if (should_compute_output(1)) { | |
+ AutoGPU autoGPU(src_device); | |
+ if (grad.is_cuda() && grad.get_device() != src_device) { | |
+ grad_inputs[1] = src_type->copy(grad); | |
+ } else { | |
+ grad_inputs[1] = grad.toType(*src_type); | |
+ } | |
+ } | |
+ return grad_inputs; | |
}; | |
-auto Transpose::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("Transpose", inputs, 1); | |
- | |
- auto& input = inputs[0].data(); | |
- AutoGPU guard(input); | |
- | |
- at::Tensor output = input.transpose(dim1, dim2); | |
- | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<Transpose>(dim1, dim2); | |
- }); | |
-} | |
- | |
-auto View::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("View", inputs, 1); | |
- | |
- auto& input = inputs[0].data(); | |
- AutoGPU guard(input); | |
- | |
- at::Tensor output = input.view(size); | |
- | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<View>(input.sizes()); | |
- }); | |
-} | |
- | |
-auto Expand::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("Expand", inputs, 1); | |
- | |
- auto& input = inputs[0].data(); | |
- AutoGPU guard(input); | |
- | |
- at::Tensor output = input.expand(size); | |
- | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<Error>("Expand is not differentiable", std::move(f)); | |
- }); | |
+CopySlices::CopySlices( | |
+ const Variable& base_var, | |
+ at::TensorGeometry view_, | |
+ std::shared_ptr<Function> fn_) | |
+ : Function(/*num_inputs=*/1), | |
+ base(base_var), | |
+ view(std::move(view_)), | |
+ fn(std::move(fn_)) { | |
+ // Take the next_edges of fn as our own, except for index 0 which goes | |
+ // to base instead of the view. | |
+ const auto num_outputs = fn->num_outputs(); | |
+ next_edges_.reserve(num_outputs); | |
+ add_next_edge(base_var.gradient_edge()); | |
+ for (size_t i = 1; i < num_outputs; i++) { | |
+ add_next_edge(fn->next_edge(i)); | |
+ } | |
} | |
-auto Narrow::apply(const variable_list& inputs) -> variable_list { | |
- check_input_variables("Narrow", inputs, 1); | |
- | |
- auto& input = inputs[0].data(); | |
- AutoGPU guard(input); | |
- | |
- at::Tensor output = input.narrow(dim, start, size); | |
+auto CopySlices::apply(const variable_list& inputs) -> variable_list { | |
+ check_input_variables("CopySlices", inputs, 1); | |
+ auto& grad = inputs[0]; | |
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) { | |
- return std::make_shared<Error>("Narrow is not differentiable", std::move(f)); | |
- }); | |
-} | |
- | |
-auto Cat::apply(const variable_list& inputs) -> variable_list { | |
- int num_inputs = inputs.size(); | |
- if (num_inputs == 0) { | |
- throw std::runtime_error("Cat operation expect at least one argument."); | |
+ if (!fn) { | |
+ throw std::runtime_error(ERR_BACKWARD_TWICE); | |
} | |
- auto& input = inputs[0].data(); | |
- AutoGPU guard(input); | |
- | |
- std::vector<at::Tensor> tensors(num_inputs); | |
- for (int i = 0; i < num_inputs; ++i) { | |
- tensors[i] = inputs[i].data(); | |
+ auto result = grad.type().tensor(base.sizes(), base.strides()); | |
+ result.copy_(grad); | |
+ | |
+ auto offset = view.storage_offset() - base.storage_offset(); | |
+ auto grad_slice = result.as_strided(view.sizes(), view.strides(), offset); | |
+ | |
+ // TODO: We clone grad_slice because we modify it below and "fn" might save | |
+ // it for the backward of res. We might be able to avoid the clone() if | |
+ // double-backprop is disabled. | |
+ auto res = (*fn)({ grad_slice.clone() }); | |
+ | |
+ variable_list grad_inputs(num_outputs()); | |
+ for (size_t i = 0; i < res.size(); i++) { | |
+ if (should_compute_output(i)) { | |
+ TORCH_ASSERT(res[i].defined()); | |
+ if (i == 0) { | |
+ grad_slice.copy_(res[i]); | |
+ grad_inputs[i] = std::move(result); | |
+ } else { | |
+ grad_inputs[i] = std::move(res[i]); | |
+ } | |
+ } | |
} | |
- auto output = input.type().cat(tensors, dim); | |
- return wrap_outputs(inputs, as_tensor_list(output), [&](FunctionFlags f) { | |
- return std::make_shared<Error>("Cat is not differentiable", std::move(f)); | |
- }); | |
+ return grad_inputs; | |
} | |
-auto Chunk::apply(const variable_list& inputs) -> variable_list { | |
- auto outputs = chunk(inputs[0].data(), chunks, dim); | |
- return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) { | |
- return std::make_shared<Error>("Chunk is not differentiable", std::move(f)); | |
- }); | |
+void CopySlices::release_variables() { | |
+ fn = nullptr; | |
} | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h | |
index 50208182..0c1463f9 100644 | |
--- a/torch/csrc/autograd/functions/tensor.h | |
+++ b/torch/csrc/autograd/functions/tensor.h | |
@@ -1,91 +1,37 @@ | |
#pragma once | |
#include <Python.h> | |
-#include <memory> | |
#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/variable.h" | |
-namespace torch { namespace autograd { | |
- | |
-struct Identity : public TraceableFunction { | |
- using TraceableFunction::TraceableFunction; | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
-}; | |
- | |
-struct Clone : public ForwardFunction<> { | |
- Clone() {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
-}; | |
- | |
-struct Contiguous : public ForwardFunction<> { | |
- Contiguous() {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
-}; | |
- | |
-struct Transpose : public ForwardFunction<> { | |
- Transpose(int64_t dim1, int64_t dim2) | |
- : dim1(dim1) | |
- , dim2(dim2) {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
- | |
- int64_t dim1; | |
- int64_t dim2; | |
-}; | |
- | |
-struct View : public ForwardFunction<> { | |
- View(std::vector<int64_t> size) | |
- : size(size) {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
- | |
- std::vector<int64_t> size; | |
-}; | |
- | |
-struct Expand : public ForwardFunction<> { | |
- Expand(std::vector<int64_t> size) | |
- : size(size) {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
+#include "ATen/Type.h" | |
+#include <ATen/TensorGeometry.h> | |
- std::vector<int64_t> size; | |
-}; | |
- | |
-struct Narrow : public ForwardFunction<> { | |
- Narrow(int64_t dim, int64_t start, int64_t size) | |
- : dim(dim) | |
- , start(start) | |
- , size(size) {} | |
- | |
- virtual variable_list apply(const variable_list& inputs) override; | |
- | |
- int64_t dim; | |
- int64_t start; | |
- int64_t size; | |
-}; | |
+#include <cstdint> | |
+#include <memory> | |
-struct Cat : public ForwardFunction<> { | |
- Cat(int64_t dim) | |
- : dim(dim) {} | |
+namespace torch { namespace autograd { | |
+struct CopyBackwards : public Function { | |
virtual variable_list apply(const variable_list& inputs) override; | |
- int64_t dim; | |
+ at::Type *src_type; | |
+ int64_t src_device; | |
}; | |
-struct Chunk : public Function { | |
- Chunk(int64_t chunks, int64_t dim) | |
- : chunks(chunks), dim(dim) {} | |
+// Performs grad[idx] = fn(grad[idx]), but out-of-place. The slicing operation | |
+// grad[idx] is defined by the relative sizes, strides, and offset of base and | |
+// view. | |
+struct CopySlices : public Function { | |
+ CopySlices(const Variable& base, at::TensorGeometry view, std::shared_ptr<Function> fn); | |
- virtual variable_list apply(const variable_list& inputs) override; | |
+ virtual variable_list apply(const variable_list& grads) override; | |
+ virtual void release_variables() override; | |
-private: | |
- int64_t chunks; | |
- int64_t dim; | |
+ at::TensorGeometry base; | |
+ at::TensorGeometry view; | |
+ std::shared_ptr<Function> fn; | |
}; | |
}} | |
diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp | |
index 6b5c81ac..09939a10 100644 | |
--- a/torch/csrc/autograd/functions/utils.cpp | |
+++ b/torch/csrc/autograd/functions/utils.cpp | |
@@ -1,33 +1,35 @@ | |
#include "torch/csrc/autograd/functions/utils.h" | |
-#include "torch/csrc/utils/functional.h" | |
-#include "torch/csrc/jit/tracer.h" | |
+#include "torch/csrc/autograd/edge.h" | |
+#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/variable.h" | |
#include <sstream> | |
+#include <vector> | |
namespace torch { namespace autograd { | |
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, | |
function_constructor ctr) { | |
- auto flags = Function::flags(inputs); | |
variable_list result; | |
result.reserve(outputs.size()); | |
- if (flags.is_volatile) { | |
+ if (!any_variable_requires_grad(inputs)) { | |
for (auto& output : outputs) { | |
if (output.defined()) { | |
- result.emplace_back(make_variable(output, false, true)); | |
+ result.push_back(make_variable(output, /*requires_grad=*/false)); | |
} else { | |
result.emplace_back(); | |
} | |
} | |
} else { | |
- auto grad_fn = ctr(std::move(flags)); | |
+ auto grad_fn = ctr(collect_next_edges(inputs)); | |
for (auto& output : outputs) { | |
if (output.defined()) { | |
- result.emplace_back(make_variable(output, grad_fn)); | |
+ auto variable = autograd::make_variable(output, /*requires_grad=*/false); | |
+ autograd::create_gradient_edge(variable, grad_fn); | |
+ result.push_back(std::move(variable)); | |
} else { | |
- ++grad_fn->num_inputs; | |
+ grad_fn->bump_inputs(); | |
result.emplace_back(); | |
} | |
} | |
@@ -53,5 +55,4 @@ void check_input_variables(const char* name, const variable_list& inputs, int ar | |
} | |
} | |
} | |
- | |
-}} | |
+}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h | |
index 622652ca..18fc8628 100644 | |
--- a/torch/csrc/autograd/functions/utils.h | |
+++ b/torch/csrc/autograd/functions/utils.h | |
@@ -1,44 +1,28 @@ | |
#pragma once | |
#include <Python.h> | |
-#include <functional> | |
-#include <memory> | |
-#include <array> | |
#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/variable.h" | |
-namespace torch { namespace autograd { | |
- | |
-using function_constructor = std::function<std::shared_ptr<Function>(FunctionFlags)>; | |
- | |
-template<typename ...Args> | |
-inline variable_list as_variable_list(Args&& ... args) { | |
- std::array<variable_list::value_type, sizeof...(args)> arr = { {std::move(args)...} }; | |
- return variable_list(std::make_move_iterator(arr.begin()), | |
- std::make_move_iterator(arr.end())); | |
-} | |
+#include <functional> | |
+#include <memory> | |
+#include <vector> | |
-template<typename ...Args> | |
-inline tensor_list as_tensor_list(Args&& ... args) { | |
- std::array<tensor_list::value_type, sizeof...(args)> arr = { {std::move(args)...} }; | |
- return tensor_list(std::make_move_iterator(arr.begin()), | |
- std::make_move_iterator(arr.end())); | |
-} | |
+namespace torch { namespace autograd { | |
+using function_constructor = std::function<std::shared_ptr<Function>(edge_list&&)>; | |
/** | |
- * Wraps the tensor outputs in variables, and if necessary (i.e., none of the | |
- * inputs are volatile), uses the function ctr and inputs to create a grad_fn | |
- * for each of them. | |
+ * Wraps the tensor outputs in variables and creates the grad_fn and sets the | |
+ * grad_fn if necessary. | |
*/ | |
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, | |
function_constructor ctr); | |
/** | |
* Checks that inputs contains exactly `args` items and that the first `required_args` | |
- * items are not NULL. If not specified, `required_args` defaults to `args`. | |
+ * items are not nullptr. If not specified, `required_args` defaults to `args`. | |
*/ | |
void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1); | |
- | |
}} | |
diff --git a/torch/csrc/autograd/grad_mode.cpp b/torch/csrc/autograd/grad_mode.cpp | |
new file mode 100644 | |
index 00000000..6409c697 | |
--- /dev/null | |
+++ b/torch/csrc/autograd/grad_mode.cpp | |
@@ -0,0 +1,7 @@ | |
+#include "grad_mode.h" | |
+ | |
+namespace torch { namespace autograd { | |
+ | |
+thread_local bool GradMode::_enabled = 1; | |
+ | |
+}} | |
diff --git a/torch/csrc/autograd/grad_mode.h b/torch/csrc/autograd/grad_mode.h | |
new file mode 100644 | |
index 00000000..98b04d7a | |
--- /dev/null | |
+++ b/torch/csrc/autograd/grad_mode.h | |
@@ -0,0 +1,26 @@ | |
+#pragma once | |
+ | |
+namespace torch { namespace autograd { | |
+ | |
+struct GradMode { | |
+ static bool is_enabled() { | |
+ return _enabled; | |
+ } | |
+ static void set_enabled(bool enabled) { | |
+ _enabled = enabled; | |
+ } | |
+private: | |
+ static thread_local bool _enabled; | |
+}; | |
+ | |
+struct AutoGradMode { | |
+ AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { | |
+ GradMode::set_enabled(enabled); | |
+ } | |
+ ~AutoGradMode() { | |
+ GradMode::set_enabled(prev_mode); | |
+ } | |
+ bool prev_mode; | |
+}; | |
+ | |
+}} | |
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp | |
index d0fedc4b..81b5958b 100644 | |
--- a/torch/csrc/autograd/init.cpp | |
+++ b/torch/csrc/autograd/init.cpp | |
@@ -1,46 +1,15 @@ | |
#include <Python.h> | |
#include "torch/csrc/utils/pybind.h" | |
+#include "torch/csrc/autograd/grad_mode.h" | |
#include "torch/csrc/autograd/profiler.h" | |
#include "THP.h" | |
-namespace pybind11 { namespace detail { | |
- | |
-template <> struct type_caster<torch::autograd::profiler::EventKind> { | |
-public: | |
- PYBIND11_TYPE_CASTER(torch::autograd::profiler::EventKind, _("torch::autograd::profiler::EventKind")); | |
- | |
- bool load(handle src, bool) { | |
- try { | |
- auto str = py::cast<std::string>(src); | |
- if (str == "push") { | |
- value = torch::autograd::profiler::EventKind::PushRange; | |
- } else if (str == "pop") { | |
- value = torch::autograd::profiler::EventKind::PopRange; | |
- } else if (str == "mark") { | |
- value = torch::autograd::profiler::EventKind::Mark; | |
- } else { | |
- return false; | |
- } | |
- } catch (std::exception& e) { | |
- return false; | |
- } | |
- return true; | |
- } | |
- static handle cast(torch::autograd::profiler::EventKind src, return_value_policy /* policy */, handle /* parent */) { | |
- switch (src) { | |
- case torch::autograd::profiler::EventKind::PushRange: | |
- return py::cast("push").release(); | |
- case torch::autograd::profiler::EventKind::PopRange: | |
- return py::cast("pop").release(); | |
- case torch::autograd::profiler::EventKind::Mark: | |
- return py::cast("mark").release(); | |
- } | |
- __builtin_unreachable(); | |
- } | |
-}; | |
- | |
-}} // namespace pybind11::detail | |
+#ifdef _MSC_VER | |
+#define ENSURE_UNREACHABLE __assume(0); | |
+#else | |
+#define ENSURE_UNREACHABLE __builtin_unreachable(); | |
+#endif | |
PyObject * THPAutograd_initExtension(PyObject *_unused) | |
{ | |
@@ -50,20 +19,75 @@ PyObject * THPAutograd_initExtension(PyObject *_unused) | |
THPVariableClass = PyMapping_GetItemString(autograd_dict,(char*)"Variable"); | |
THPFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"Function"); | |
- THPUtils_assert_PyImport("torch.nn._functions.thnn", thnn_functions); | |
- THPBatchNormBackwardBackwardFunction = PyObject_GetAttrString(thnn_functions,(char*)"batchnorm_double_backwards_fn"); | |
- | |
- THPStochasticFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"StochasticFunction"); | |
THPUtils_assert(THPVariableClass, "couldn't find Variable class in " | |
"torch.autograd module"); | |
THPUtils_assert(THPFunctionClass, "couldn't find Function class in " | |
"torch.autograd module"); | |
- THPUtils_assert(THPStochasticFunctionClass, "couldn't find " | |
- "StochasticFunction class in torch.autograd module"); | |
auto m = py::handle(autograd_module).cast<py::module>(); | |
+ | |
+ py::class_<torch::autograd::profiler::Event>(m,"ProfilerEvent") | |
+ .def("kind",&torch::autograd::profiler::Event::kind) | |
+ .def("name",&torch::autograd::profiler::Event::name) | |
+ .def("thread_id",&torch::autograd::profiler::Event::thread_id) | |
+ .def("device",&torch::autograd::profiler::Event::device) | |
+ .def("cpu_elapsed_us",&torch::autograd::profiler::Event::cpu_elapsed_us) | |
+ .def("cuda_elapsed_us",&torch::autograd::profiler::Event::cuda_elapsed_us) | |
+ .def("has_cuda",&torch::autograd::profiler::Event::has_cuda); | |
+ py::enum_<torch::autograd::profiler::ProfilerState>(m,"ProfilerState") | |
+ .value("Disabled", torch::autograd::profiler::ProfilerState::Disabled) | |
+ .value("CPU", torch::autograd::profiler::ProfilerState::CPU) | |
+ .value("CUDA", torch::autograd::profiler::ProfilerState::CUDA) | |
+ .value("NVTX", torch::autograd::profiler::ProfilerState::NVTX); | |
+ | |
m.def("_enable_profiler", torch::autograd::profiler::enableProfiler); | |
m.def("_disable_profiler", torch::autograd::profiler::disableProfiler); | |
+ m.def("_push_range", [](const char *name) { | |
+ using namespace torch::autograd::profiler; | |
+ if (state == ProfilerState::Disabled) return; | |
+ pushRange(name); | |
+ }); | |
+ m.def("_pop_range", []() { | |
+ using namespace torch::autograd::profiler; | |
+ if (state == ProfilerState::Disabled) return; | |
+ popRange(); | |
+ }); | |
+ | |
Py_RETURN_TRUE; | |
} | |
+ | |
+namespace torch { namespace autograd { | |
+ | |
+static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) { | |
+ HANDLE_TH_ERRORS | |
+ if (!PyBool_Check(arg)) { | |
+ at::runtime_error("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); | |
+ } | |
+ GradMode::set_enabled(arg == Py_True); | |
+ Py_RETURN_NONE; | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
+static PyObject * is_grad_enabled(PyObject* _unused, PyObject *arg) { | |
+ HANDLE_TH_ERRORS | |
+ if (GradMode::is_enabled()) { | |
+ Py_RETURN_TRUE; | |
+ } else { | |
+ Py_RETURN_FALSE; | |
+ } | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
+// autograd methods on torch._C | |
+static PyMethodDef methods[] = { | |
+ {"set_grad_enabled", (PyCFunction)set_grad_enabled, METH_O, nullptr}, | |
+ {"is_grad_enabled", (PyCFunction)is_grad_enabled, METH_NOARGS, nullptr}, | |
+ {nullptr, nullptr, 0, nullptr} | |
+}; | |
+ | |
+PyMethodDef* python_functions() { | |
+ return methods; | |
+} | |
+ | |
+}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp | |
index 4b6d40e0..77946214 100644 | |
--- a/torch/csrc/autograd/input_buffer.cpp | |
+++ b/torch/csrc/autograd/input_buffer.cpp | |
@@ -6,43 +6,36 @@ | |
namespace torch { namespace autograd { | |
-InputBuffer::InputBuffer(size_t size) | |
- : buffer(size) | |
- {} | |
void InputBuffer::add(size_t pos, Variable var) { | |
TORCH_ASSERT(pos >= 0 && pos < buffer.size()); | |
if (!var.defined()) { | |
return; | |
} | |
- auto& item = buffer[pos]; | |
- if (!item.first.defined()) { | |
- auto current_version = var.current_version(); | |
- buffer[pos] = std::make_pair<>(std::move(var), current_version); | |
+ auto& old_var = buffer[pos]; | |
+ if (!old_var.defined()) { | |
+ buffer[pos] = std::move(var); | |
} else { | |
- auto result = apply_fn<Add>()(item.first, std::move(var)); | |
- buffer[pos] = std::make_pair<>(std::move(result), 0); | |
+ // ATen doesn't route sparse additions correctly... | |
+ if (old_var.type().is_sparse()) { | |
+ buffer[pos] = var + old_var; | |
+ } else { | |
+ buffer[pos] = old_var + var; | |
+ } | |
} | |
} | |
auto InputBuffer::device() const -> int { | |
- for (auto& pair : buffer) { | |
- if (pair.first.defined() && pair.first.type().isCuda()) { | |
- return pair.first.get_device(); | |
+ for (auto& var : buffer) { | |
+ if (var.defined() && var.type().is_cuda()) { | |
+ return var.get_device(); | |
} | |
} | |
return -1; | |
} | |
auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> { | |
- InputBuffer _buffer = std::move(g); | |
- auto& buffer = _buffer.buffer; | |
- int size = buffer.size(); | |
- std::vector<Variable> result; | |
- result.reserve(size); | |
- for (int i = 0; i != size; ++i) { | |
- result.emplace_back(buffer[i].first); | |
- } | |
+ std::vector<Variable> result = std::move(g.buffer); | |
return result; | |
} | |
diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h | |
index bf3eaef9..2f9aba4f 100644 | |
--- a/torch/csrc/autograd/input_buffer.h | |
+++ b/torch/csrc/autograd/input_buffer.h | |
@@ -16,21 +16,24 @@ | |
namespace torch { namespace autograd { | |
struct InputBuffer { | |
- explicit InputBuffer(size_t size); | |
+ explicit InputBuffer(size_t size) | |
+ : buffer(size) {} | |
InputBuffer(const InputBuffer& other) = delete; | |
InputBuffer(InputBuffer&& other) = default; | |
+ InputBuffer& operator=(InputBuffer&& other) = default; | |
// Accumulates the variable at a specified index. | |
void add(size_t idx, Variable var); | |
int device() const; | |
+ Variable operator[](std::size_t pos) { return buffer[pos]; } | |
+ | |
// Returns the inputs as a list of variables. Destroys given InputBuffer. | |
static std::vector<Variable> variables(InputBuffer&& buffer); | |
private: | |
- // (Variable, version at save) | |
- std::vector<std::pair<Variable, int>> buffer; | |
+ std::vector<Variable> buffer; | |
}; | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp | |
index 4b6f5429..13c2e21f 100644 | |
--- a/torch/csrc/autograd/profiler.cpp | |
+++ b/torch/csrc/autograd/profiler.cpp | |
@@ -1,40 +1,73 @@ | |
+#include "Python.h" | |
#include "torch/csrc/autograd/profiler.h" | |
#include "torch/csrc/autograd/function.h" | |
namespace torch { namespace autograd { namespace profiler { | |
-bool profiling = false; | |
-bool using_cuda; | |
+ProfilerState state = ProfilerState::Disabled; | |
+uint32_t next_thread_id = 0; | |
std::mutex all_event_lists_mutex; | |
std::list<std::shared_ptr<RangeEventList>> all_event_lists; | |
thread_local std::shared_ptr<RangeEventList> event_list; | |
+thread_local int32_t thread_id; | |
void RecordFunction::pushFunctionRange(Function* fn) { | |
pushRange(fn->name()); | |
} | |
-void enableProfiler(bool use_cuda) { | |
+#ifdef WITH_CUDA | |
+static void onEachDevice(std::function<void(int)> op) { | |
+ AutoGPU gpu_guard; | |
+ int count; | |
+ TORCH_CUDA_CHECK(cudaGetDeviceCount(&count)); | |
+ for(int i = 0; i < count; i++) { | |
+ gpu_guard.setDevice(i); | |
+ op(i); | |
+ } | |
+} | |
+#endif | |
+ | |
+void enableProfiler(ProfilerState new_state) { | |
+ TORCH_ASSERT(new_state != ProfilerState::Disabled); | |
#ifndef WITH_CUDA | |
- if (use_cuda) | |
- throw std::runtime_error("Can't use CUDA profiler - PyTorch was compiled without CUDA"); | |
+ if (new_state == ProfilerState::NVTX) | |
+ throw std::runtime_error("Can't use NVTX profiler - PyTorch was compiled without CUDA"); | |
#endif | |
- if (profiling) { | |
- if (use_cuda != using_cuda) | |
- throw std::runtime_error("can't change use_cuda flag while profiler is running"); | |
- return; | |
+ if (state != ProfilerState::Disabled && new_state != state) { | |
+ throw std::runtime_error("can't change kind of profiling (e.g. NVTX to CPU) while profiler is running"); | |
} | |
- profiling = true; | |
- using_cuda = use_cuda; | |
- mark("__start_profile"); | |
+ state = new_state; | |
+ | |
+#ifdef WITH_CUDA | |
+ if(state == ProfilerState::CUDA) { | |
+ // event recording appears to have some startup overhead, so we need to | |
+ // to generate some dummy events first before recording syncrhonization events | |
+ for(int i = 0; i < 5; i++) { | |
+ onEachDevice([](int d) { | |
+ mark("__cuda_startup"); | |
+ cudaDeviceSynchronize(); | |
+ }); | |
+ } | |
+ | |
+ // cuda events must be on the same device, so we need a start event recorded | |
+ // for each gpu. we then use this event to synchronize time on the GPU | |
+ // with the CPU clock. | |
+ onEachDevice([](int d) { | |
+ mark("__cuda_start_event"); | |
+ }); | |
+ } | |
+#endif | |
+ mark("__start_profile", false); | |
} | |
thread_event_lists disableProfiler() { | |
- if (!profiling) { | |
+ if (state == ProfilerState::Disabled) { | |
throw std::runtime_error("can't disable profiler when it's not running"); | |
} | |
+ ProfilerState old_state = state; | |
mark("__stop_profile"); | |
- profiling = false; | |
- if (using_cuda) { | |
+ state = ProfilerState::Disabled; | |
+ if (old_state == ProfilerState::NVTX) { | |
return thread_event_lists(); | |
} else { | |
thread_event_lists result; | |
diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h | |
index 2f2be145..d87e891a 100644 | |
--- a/torch/csrc/autograd/profiler.h | |
+++ b/torch/csrc/autograd/profiler.h | |
@@ -11,8 +11,14 @@ | |
#include <cstdint> | |
#include <string> | |
#include <list> | |
+#include <sstream> | |
#include <forward_list> | |
#include <tuple> | |
+#include "ATen/ATen.h" | |
+#include "torch/csrc/cuda/cuda_check.h" | |
+#ifdef WITH_CUDA | |
+#include <cuda_runtime.h> | |
+#endif | |
namespace torch { namespace autograd { | |
@@ -24,16 +30,95 @@ constexpr inline std::size_t ceilToMultiple(std::size_t a, std::size_t b) { | |
return ((a + b - 1) / b) * b; | |
} | |
+inline uint64_t getTime() { | |
+ using namespace std::chrono; | |
+ using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type; | |
+ return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count(); | |
+} | |
+ | |
enum class EventKind { | |
Mark, | |
PushRange, | |
PopRange | |
}; | |
-// NOTE: we don't need a flag saying if an event is a kernel, because it's | |
-// used only for the CPU-side perf recording. | |
-using Event = std::tuple<std::string, uint64_t, EventKind>; // (name, time, kind) | |
+struct Event { | |
+ Event(EventKind kind, std::string name, uint32_t thread_id, bool record_cuda) | |
+ : kind_(kind) | |
+ , name_(std::move(name)) | |
+ , thread_id_(thread_id) { | |
+#ifdef WITH_CUDA | |
+ if(record_cuda) { | |
+ TORCH_CUDA_CHECK(cudaGetDevice(&device_)); | |
+ TORCH_CUDA_CHECK(cudaEventCreate(&event)); | |
+ auto stream = at::globalContext().getCurrentCUDAStream(); | |
+ cpu_ns_ = getTime(); | |
+ TORCH_CUDA_CHECK(cudaEventRecord(event, stream)); | |
+ } else { | |
+ cpu_ns_ = getTime(); | |
+ } | |
+#else | |
+ cpu_ns_ = getTime(); | |
+#endif | |
+ } | |
+ std::string kind() const { | |
+ switch(kind_) { | |
+ case EventKind::Mark: return "mark"; | |
+ case EventKind::PushRange: return "push"; | |
+ case EventKind::PopRange: return "pop"; | |
+ } | |
+ throw std::runtime_error("unknown EventKind"); | |
+ } | |
+ const std::string & name() const { | |
+ return name_; | |
+ } | |
+ uint32_t thread_id() const { | |
+ return thread_id_; | |
+ } | |
+ double cpu_elapsed_us(const Event & e) { | |
+ return (e.cpu_ns_ - cpu_ns_)/(1000.0); | |
+ } | |
+ double cuda_elapsed_us(const Event & e) { | |
+#ifdef WITH_CUDA | |
+ if(!e.has_cuda() || !has_cuda()) { | |
+ throw std::logic_error("Events were not recorded for CUDA"); | |
+ } | |
+ if(e.device() != device()) { | |
+ throw std::logic_error("Events are not on the same device"); | |
+ } | |
+ TORCH_CUDA_CHECK(cudaEventSynchronize(event)); | |
+ TORCH_CUDA_CHECK(cudaEventSynchronize(e.event)); | |
+ float ms; | |
+ TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event, e.event)); | |
+ return ms*1000.0; | |
+#else | |
+ throw std::logic_error("CUDA not enabled"); | |
+#endif | |
+ } | |
+ bool has_cuda() const { | |
+#ifdef WITH_CUDA | |
+ return event != nullptr; | |
+#else | |
+ return false; | |
+#endif | |
+ } | |
+ int device() const { | |
+ return device_; | |
+ } | |
+private: | |
+ EventKind kind_; | |
+ std::string name_; | |
+ uint32_t thread_id_; | |
+ int64_t cpu_ns_; // signed to allow for negative intervals | |
+#ifdef WITH_CUDA | |
+ cudaEvent_t event = nullptr; | |
+#endif | |
+ int device_ = -1; | |
+}; | |
+// a linked-list of fixed sized vectors, to avoid | |
+// a std::vector resize from taking a large amount of time inside | |
+// a profiling event | |
struct RangeEventList { | |
constexpr static std::size_t MB = 1024 * 1024; | |
constexpr static std::size_t event_block_size = 16 * MB; | |
@@ -70,81 +155,85 @@ struct RangeEventList { | |
std::forward_list<block_type> blocks; | |
}; | |
-extern bool profiling; | |
-extern bool using_cuda; | |
+enum class ProfilerState { | |
+ Disabled, | |
+ CPU, // CPU-only profiling | |
+ CUDA, // CPU + CUDA events | |
+ NVTX, // only emit NVTX markers | |
+}; | |
+ | |
+extern ProfilerState state; | |
+extern uint32_t next_thread_id; | |
extern std::mutex all_event_lists_mutex; | |
extern std::list<std::shared_ptr<RangeEventList>> all_event_lists; | |
+ | |
extern thread_local std::shared_ptr<RangeEventList> event_list; | |
+extern thread_local int32_t thread_id; | |
inline RangeEventList& getEventList() { | |
if (!event_list) { | |
std::lock_guard<std::mutex> guard(all_event_lists_mutex); | |
event_list = std::make_shared<RangeEventList>(); | |
+ thread_id = next_thread_id++; | |
all_event_lists.emplace_front(event_list); | |
} | |
return *event_list; | |
} | |
-inline uint64_t getTime() { | |
- using namespace std::chrono; | |
- using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type; | |
- return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count(); | |
-} | |
- | |
-inline void mark(std::string name) { | |
- if (using_cuda) { | |
+inline void mark(std::string name, bool include_cuda = true) { | |
+ if (state == ProfilerState::NVTX) { | |
#ifdef WITH_CUDA | |
nvtxMarkA(name.c_str()); | |
#else | |
- throw std::logic_error("mark called with use_cuda=True, but compiled without CUDA"); | |
+ throw std::logic_error("mark called with NVTX tracing, but compiled without CUDA"); | |
#endif | |
} else { | |
- getEventList().record(std::move(name), getTime(), EventKind::Mark); | |
+ getEventList().record(EventKind::Mark, std::move(name), thread_id, include_cuda && state == ProfilerState::CUDA); | |
} | |
} | |
inline void pushRange(std::string name) { | |
- if (using_cuda) { | |
+ if (state == ProfilerState::NVTX) { | |
#ifdef WITH_CUDA | |
nvtxRangePushA(name.c_str()); | |
#else | |
- throw std::logic_error("pushRange called with use_cuda=True, but compiled without CUDA"); | |
+ throw std::logic_error("pushRange called with NVTX tracing, but compiled without CUDA"); | |
#endif | |
} else { | |
- getEventList().record(std::move(name), getTime(), EventKind::PushRange); | |
+ getEventList().record(EventKind::PushRange, std::move(name), thread_id, state == ProfilerState::CUDA); | |
} | |
} | |
inline void popRange() { | |
- if (using_cuda) { | |
+ if (state == ProfilerState::NVTX) { | |
#ifdef WITH_CUDA | |
nvtxRangePop(); | |
#else | |
- throw std::logic_error("popRange called with use_cuda=True, but compiled without CUDA"); | |
+ throw std::logic_error("popRange called with NVTX tracing, but compiled without CUDA"); | |
#endif | |
} else { | |
- getEventList().record(std::string(), getTime(), EventKind::PopRange); | |
+ getEventList().record(EventKind::PopRange, std::string(), thread_id, state == ProfilerState::CUDA); | |
} | |
} | |
struct RecordFunction { | |
explicit RecordFunction(Function *fn) { | |
- if (!profiling) return; | |
+ if (state == ProfilerState::Disabled) return; | |
pushFunctionRange(fn); | |
} | |
explicit RecordFunction(std::string name) { | |
- if (!profiling) return; | |
+ if (state == ProfilerState::Disabled) return; | |
pushRange(std::move(name)); | |
} | |
explicit RecordFunction(const char *name) { | |
- if (!profiling) return; | |
+ if (state == ProfilerState::Disabled) return; | |
pushRange(name); | |
} | |
~RecordFunction() { | |
- if (!profiling) return; | |
+ if (state == ProfilerState::Disabled) return; | |
popRange(); | |
} | |
@@ -155,7 +244,7 @@ struct RecordFunction { | |
using thread_event_lists = std::vector<std::vector<Event>>; | |
// NOTE: changing profiler modes is **NOT THREAD SAFE**. You should ensure that | |
// there no autograd functions are being executed when these function are used. | |
-void enableProfiler(bool use_cuda); | |
+void enableProfiler(ProfilerState state); | |
thread_event_lists disableProfiler(); | |
} // namespace profiler | |
diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp | |
index bc88c92d..7cc06082 100644 | |
--- a/torch/csrc/autograd/python_cpp_function.cpp | |
+++ b/torch/csrc/autograd/python_cpp_function.cpp | |
@@ -62,12 +62,12 @@ PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs) | |
int THPCppFunction_traverse(PyObject* self, visitproc visit, void *arg) | |
{ | |
auto& fn = *((THPCppFunction*)self)->cdata; | |
- for (auto& hook : fn.pre_hooks) { | |
+ for (const auto& hook : fn.pre_hooks()) { | |
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) { | |
Py_VISIT(pyhook->dict); | |
} | |
} | |
- for (auto& hook : fn.post_hooks) { | |
+ for (const auto& hook : fn.post_hooks()) { | |
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) { | |
Py_VISIT(pyhook->dict); | |
} | |
@@ -80,7 +80,7 @@ int THPCppFunction_clear(PyObject* self) | |
auto f = (THPCppFunction*)self; | |
// Remove the weak ref of the c++ object if it exist | |
if (f->cdata) { | |
- f->cdata->pyobj = nullptr; | |
+ f->cdata->set_pyobj(nullptr); | |
} | |
f->cdata.reset(); | |
return 0; | |
@@ -97,19 +97,18 @@ void THPCppFunction_dealloc(PyObject* self) | |
PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook) | |
{ | |
- auto& next_functions = self->cdata->next_functions; | |
- auto num_next = next_functions.size(); | |
+ const auto num_next = self->cdata->num_outputs(); | |
THPObjectPtr py_functions(PyTuple_New(num_next)); | |
- if (!py_functions) return NULL; | |
+ if (!py_functions) return nullptr; | |
for (size_t i = 0; i < num_next; ++i) { | |
- auto& c_tuple = next_functions[i]; | |
+ auto& c_tuple = self->cdata->next_edge(i); | |
THPObjectPtr tuple(PyTuple_New(2)); | |
- if (!tuple) return NULL; | |
- PyObject *py_fn = functionToPyObject(c_tuple.first); | |
- if (!py_fn) return NULL; | |
+ if (!tuple) return nullptr; | |
+ PyObject *py_fn = functionToPyObject(c_tuple.function); | |
+ if (!py_fn) return nullptr; | |
PyTuple_SET_ITEM(tuple.get(), 0, py_fn); | |
- PyObject *py_idx = PyLong_FromLong(c_tuple.second); | |
- if (!py_idx) return NULL; | |
+ PyObject *py_idx = PyLong_FromLong(c_tuple.input_nr); | |
+ if (!py_idx) return nullptr; | |
PyTuple_SET_ITEM(tuple.get(), 1, py_idx); | |
PyTuple_SET_ITEM(py_functions.get(), i, tuple.release()); | |
} | |
@@ -117,7 +116,7 @@ PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook) | |
} | |
PyObject* THPCppFunction_requires_grad(THPCppFunction* self) { | |
- return PyBool_FromLong(self->cdata->is_executable); | |
+ Py_RETURN_TRUE; | |
} | |
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) | |
@@ -127,8 +126,9 @@ PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) | |
} | |
auto var = (THPVariable*)_var; | |
auto& fn = *((THPCppFunction*)self)->cdata; | |
- fn.pre_hooks.push_back(std::make_shared<PyFunctionPreHook>( | |
- var->backward_hooks, var->cdata.output_nr())); | |
+ std::unique_ptr<FunctionPreHook> hook( | |
+ new PyFunctionPreHook(var->backward_hooks, var->cdata.output_nr())); | |
+ fn.add_pre_hook(std::move(hook)); | |
Py_RETURN_NONE; | |
} | |
@@ -141,12 +141,12 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) | |
static struct PyMethodDef default_methods[] = { | |
THP_FUNCTION_DEFAULT_METHODS, | |
- {NULL} | |
+ {nullptr} | |
}; | |
static struct PyGetSetDef default_properties[] = { | |
THP_FUNCTION_DEFAULT_PROPERTIES, | |
- {NULL} | |
+ {nullptr} | |
}; | |
PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name, | |
@@ -182,8 +182,8 @@ PyObject* functionToPyObject(std::shared_ptr<Function> cdata) | |
return obj; | |
} | |
- if (cdata->pyobj) { | |
- Py_INCREF(cdata->pyobj); | |
+ if (cdata->pyobj()) { | |
+ Py_INCREF(cdata->pyobj()); | |
} else { | |
auto& fn = *cdata; | |
auto it = cpp_function_types.find(std::type_index(typeid(fn))); | |
@@ -194,15 +194,15 @@ PyObject* functionToPyObject(std::shared_ptr<Function> cdata) | |
PyTypeObject* type = (PyTypeObject*)it->second.get(); | |
THPObjectPtr obj(type->tp_alloc(type, 0)); | |
- if (!obj) return NULL; | |
+ if (!obj) return nullptr; | |
THPCppFunction* f = (THPCppFunction*)obj.get(); | |
new (&f->cdata) std::shared_ptr<Function>(cdata); | |
// No INCREF here as we only have a weak reference | |
- cdata->pyobj = obj.release(); | |
+ cdata->set_pyobj(obj.release()); | |
} | |
- return cdata->pyobj; | |
+ return cdata->pyobj(); | |
} | |
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) | |
@@ -214,7 +214,7 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) | |
PyObject* registerFunctionHook(Function& fn, PyObject* hook) | |
{ | |
PyObject* dict = Py_None; | |
- for (auto& hook : fn.post_hooks) { | |
+ for (const auto& hook : fn.post_hooks()) { | |
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) { | |
dict = pyhook->dict; | |
break; | |
@@ -222,13 +222,14 @@ PyObject* registerFunctionHook(Function& fn, PyObject* hook) | |
} | |
THPObjectPtr register_fn(PyObject_GetAttrString(THPFunctionClass, "_register_hook")); | |
- if (!register_fn) return NULL; | |
- THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, NULL)); | |
- if (!res) return NULL; | |
+ if (!register_fn) return nullptr; | |
+ THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr)); | |
+ if (!res) return nullptr; | |
if (dict == Py_None) { | |
dict = PyTuple_GET_ITEM(res.get(), 0); | |
- fn.post_hooks.push_back(std::make_shared<PyFunctionPostHook>(dict)); | |
+ std::unique_ptr<FunctionPostHook> hook(new PyFunctionPostHook(dict)); | |
+ fn.add_post_hook(std::move(hook)); | |
} | |
PyObject* handle = PyTuple_GET_ITEM(res.get(), 1); | |
diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h | |
index 35b9d252..aa518c1c 100644 | |
--- a/torch/csrc/autograd/python_cpp_function.h | |
+++ b/torch/csrc/autograd/python_cpp_function.h | |
@@ -19,24 +19,24 @@ template<typename Ctor> | |
PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) | |
{ | |
THPObjectPtr obj(type->tp_alloc(type, 0)); | |
- if (!obj) return NULL; | |
+ if (!obj) return nullptr; | |
THPCppFunction* f = (THPCppFunction*)obj.get(); | |
HANDLE_TH_ERRORS | |
new (&f->cdata) std::shared_ptr<Function>(Ctor()(args)); | |
END_HANDLE_TH_ERRORS | |
if (!f->cdata) { | |
- return NULL; | |
+ return nullptr; | |
} | |
return obj.release(); | |
} | |
#define THP_FUNCTION_DEFAULT_METHODS \ | |
- {(char*)"_register_hook_dict", (PyCFunction)THPCppFunction_register_hook_dict, METH_O, NULL}, \ | |
- {(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, NULL} | |
+ {(char*)"_register_hook_dict", (PyCFunction)THPCppFunction_register_hook_dict, METH_O, nullptr}, \ | |
+ {(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, nullptr} | |
#define THP_FUNCTION_DEFAULT_PROPERTIES \ | |
- {(char*)"next_functions", (getter)THPCppFunction_next_functions, NULL, NULL, NULL}, \ | |
- {(char*)"requires_grad", (getter)THPCppFunction_requires_grad, NULL, NULL, NULL} | |
+ {(char*)"next_functions", (getter)THPCppFunction_next_functions, nullptr, nullptr, nullptr}, \ | |
+ {(char*)"requires_grad", (getter)THPCppFunction_requires_grad, nullptr, nullptr, nullptr} | |
PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook); | |
PyObject* THPCppFunction_requires_grad(THPCppFunction* self); | |
@@ -50,7 +50,7 @@ PyObject* registerFunctionHook(Function& fn, PyObject* hook); | |
template<typename Ctor> | |
PyTypeObject* createForwardFunctionPyTypeObject(PyTypeObject& type, const char* name, | |
- PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL) | |
+ PyGetSetDef* function_properties=nullptr, PyMethodDef* function_methods=nullptr) | |
{ | |
type.tp_new = &CppFunction_pynew<Ctor>; | |
return _initFunctionPyTypeObject(type, name, function_properties, function_methods); | |
diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp | |
index e47eed4c..f5cce00b 100644 | |
--- a/torch/csrc/autograd/python_engine.cpp | |
+++ b/torch/csrc/autograd/python_engine.cpp | |
@@ -1,12 +1,18 @@ | |
#include "torch/csrc/autograd/python_engine.h" | |
-#include "torch/csrc/autograd/engine.h" | |
-#include "torch/csrc/autograd/python_function.h" | |
-#include "torch/csrc/THP.h" | |
#include "torch/csrc/DynamicTypes.h" | |
#include "torch/csrc/PtrWrapper.h" | |
+#include "torch/csrc/THP.h" | |
+#include "torch/csrc/autograd/engine.h" | |
+#include "torch/csrc/autograd/function.h" | |
+#include "torch/csrc/autograd/edge.h" | |
+#include "torch/csrc/autograd/python_function.h" | |
#include "torch/csrc/utils/auto_gil.h" | |
+#ifndef _WIN32 | |
+#include <pthread.h> | |
+#endif | |
+ | |
#include <unordered_set> | |
using namespace torch::autograd; | |
@@ -36,14 +42,14 @@ void PythonEngine::thread_on_exception(FunctionTask& task, std::exception& e) { | |
Engine::thread_on_exception(task, e); | |
} | |
-void PythonEngine::execute( | |
- const function_list& roots, | |
+variable_list PythonEngine::execute( | |
+ const edge_list& roots, | |
const variable_list& inputs, | |
bool keep_graph, | |
- const pre_callback_map& pre_callbacks, | |
- const post_callback_map& post_callbacks) { | |
+ bool create_graph, | |
+ const edge_list& outputs) { | |
try { | |
- Engine::execute(roots, inputs, keep_graph, pre_callbacks, post_callbacks); | |
+ return Engine::execute(roots, inputs, keep_graph, create_graph, outputs); | |
} catch (python_error& e) { | |
e.restore(); | |
throw; | |
@@ -56,77 +62,21 @@ PythonEngine& PythonEngine::getDefaultEngine() { | |
}}} // namespace torch::autograd::python | |
-PyObject *THPEngineClass = NULL; | |
- | |
-struct CallbackContext { | |
- std::string error; | |
- THPObjectPtr outputs; | |
- // Used to determine which callback arguments should be used to | |
- // fill outputs. | |
- // Function -> ([grad_nr, outputs_idx], is_leaf) | |
- std::unordered_map< | |
- std::shared_ptr<Function>, | |
- std::pair<std::vector<std::pair<int, int>>, bool>> output_map; | |
-}; | |
- | |
-void compute_partial_exec_callbacks(const function_list& roots, | |
- const CallbackContext& ctx, | |
- Engine::pre_callback_map& map, | |
- bool allow_unreachable) { | |
- // This callback is used to suppress the computation of a node | |
- // if it is not necessary. | |
- static Engine::pre_callback_type abort_callback( | |
- [](Function* fn, variable_list &vars) { return false; }); | |
- | |
- std::vector<Function*> queue; | |
- std::unordered_set<Function*> seen; // for the initial DFS | |
- std::unordered_set<Function*> needed; // functions to compute | |
- std::unordered_map<Function*, std::vector<Function*>> rev_graph; | |
- | |
- // Reverse the next_fn edges | |
- queue.reserve(roots.size()); | |
- for (auto& root : roots) { | |
- auto ptr = root.first.get(); | |
- bool unseen; | |
- std::tie(std::ignore, unseen) = seen.insert(ptr); | |
- if (unseen) queue.emplace_back(ptr); | |
- } | |
- while (!queue.empty()) { | |
- auto fn = queue.back(); queue.pop_back(); | |
- for (auto& next_fn_pair : fn->next_functions) { | |
- auto next_fn = next_fn_pair.first.get(); | |
- if (!next_fn) continue; | |
- rev_graph[next_fn].push_back(fn); | |
- if (seen.insert(next_fn).second) { | |
- queue.push_back(next_fn); | |
- } | |
- } | |
- } | |
- auto all_functions = std::move(seen); // this is cheap and improves readability | |
- | |
- // Find all functions we need to compute | |
- queue.clear(); | |
- for (auto input_info: ctx.output_map) { | |
- auto input = input_info.first.get(); | |
- auto rev_edges_it = rev_graph.find(input); | |
- if (!allow_unreachable && rev_edges_it == rev_graph.end()) | |
- throw std::runtime_error("differentiated input is unreachable"); | |
- queue.emplace_back(input); | |
- needed.insert(input); | |
- } | |
- while (!queue.empty()) { | |
- auto fn = queue.back(); queue.pop_back(); | |
- for (auto rev_next_fn : rev_graph[fn]) { | |
- if (needed.insert(rev_next_fn).second) { | |
- queue.push_back(rev_next_fn); | |
- } | |
- } | |
- } | |
- | |
- // Prevent expansion for functions in {all_vertices} \ {needed} | |
- for (auto fn : all_functions) { | |
- if (needed.count(fn) > 0) continue; | |
- map.emplace(fn, abort_callback); | |
+PyObject *THPEngineClass = nullptr; | |
+ | |
+static bool _reinitialize_engine = false; | |
+ | |
+static void _maybe_reinitialize_engine_after_fork() { | |
+ // This is "probably" thread-safe because the flag is set in a fork handler | |
+ // before any threads are created, and this function is only called with the | |
+ // GIL held. However, using fork + threads is playing with fire so this is | |
+ // more of a "best effort" thing. For example, if the fork occurs while the | |
+ // backwards threads hold a lock, we'll probably deadlock in the engine | |
+ // destructor. | |
+ if (_reinitialize_engine) { | |
+ engine.~PythonEngine(); | |
+ new (&engine) torch::autograd::python::PythonEngine(); | |
+ _reinitialize_engine = false; | |
} | |
} | |
@@ -134,17 +84,19 @@ void compute_partial_exec_callbacks(const function_list& roots, | |
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs) | |
{ | |
HANDLE_TH_ERRORS | |
- PyObject *variables = NULL; | |
- PyObject *grad_variables = NULL; | |
+ _maybe_reinitialize_engine_after_fork(); | |
+ PyObject *variables = nullptr; | |
+ PyObject *grad_variables = nullptr; | |
unsigned char keep_graph = 0; | |
- PyObject *inputs = NULL; | |
- unsigned char only_inputs = 0; | |
- unsigned char allow_unreachable = 0; | |
- const char *accepted_kwargs[] = {"variables", "grad_variables", | |
- "keep_graph", "inputs", "only_inputs", "allow_unreachable", NULL}; | |
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOb|Obb", (char**)accepted_kwargs, | |
- &variables, &grad_variables, &keep_graph, &inputs, &only_inputs, &allow_unreachable)) | |
- return NULL; | |
+ unsigned char create_graph = 0; | |
+ PyObject *inputs = nullptr; | |
+ const char *accepted_kwargs[] = { | |
+ "variables", "grad_variables", "keep_graph", "create_graph", "inputs", | |
+ nullptr | |
+ }; | |
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|O", (char**)accepted_kwargs, | |
+ &variables, &grad_variables, &keep_graph, &create_graph, &inputs)) | |
+ return nullptr; | |
THPUtils_assert(PyTuple_Check(variables), "variables argument is expected to " | |
"be a tuple, but got %s", THPUtils_typename(variables)); | |
@@ -156,32 +108,23 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar | |
THPUtils_assert(num_variables == num_gradients, "got %ld variables and %ld " | |
"gradients", num_variables, num_gradients); | |
- function_list roots(num_variables); | |
- variable_list grads(num_variables); | |
+ edge_list roots; | |
+ roots.reserve(num_variables); | |
+ variable_list grads; | |
+ grads.reserve(num_variables); | |
for (int i = 0; i < num_variables; i++) { | |
PyObject *_variable = PyTuple_GET_ITEM(variables, i); | |
THPUtils_assert(THPVariable_Check(_variable), "element %d of variables " | |
"tuple is not a Variable", i); | |
auto& variable = ((THPVariable*)_variable)->cdata; | |
- THPUtils_assert(!variable.is_volatile(), | |
- "element %d of variables tuple is volatile", i); | |
- // If grad_fn is NULL (as is the case for a leaf node), we instead | |
- // interpret the gradient function to be a grad accumulator, | |
- // which will accumulate its inputs into the grad property of the | |
- // variable. These nodes get suppressed in some situations, | |
- // see "suppress grad accumulation" below. Note that only variables which | |
- // have requires_grad=True can have grad accumulators. | |
- auto grad_fn = variable.grad_fn() ? variable.grad_fn() : variable.grad_accumulator(); | |
- int output_nr = variable.grad_fn() ? variable.output_nr() : 0; | |
- THPUtils_assert(!variable.is_volatile(), | |
- "element %d of variables tuple is volatile", i); | |
- THPUtils_assert(grad_fn, | |
+ auto gradient_edge = variable.gradient_edge(); | |
+ THPUtils_assert(gradient_edge.function, | |
"element %d of variables does not require grad and does not have a grad_fn", i); | |
- roots[i] = std::make_pair<>(std::move(grad_fn), output_nr); | |
+ roots.push_back(std::move(gradient_edge)); | |
PyObject *grad = PyTuple_GET_ITEM(grad_variables, i); | |
if (THPVariable_Check(grad)) { | |
- grads[i] = ((THPVariable*)grad)->cdata; | |
+ grads.push_back(((THPVariable*)grad)->cdata); | |
} else { | |
THPUtils_assert(grad == Py_None, | |
"element %d of gradients tuple is not a Variable or None", i); | |
@@ -190,74 +133,46 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar | |
} | |
} | |
- Engine::pre_callback_map callbacks; | |
- CallbackContext ctx; | |
- if (inputs != NULL) { | |
- THPUtils_assert(PyTuple_Check(inputs), "inputs argument has to be a tuple"); | |
+ edge_list output_edges; | |
+ if (inputs != nullptr) { | |
int num_inputs = PyTuple_GET_SIZE(inputs); | |
- ctx.outputs = PyTuple_New(num_inputs); | |
- if (!ctx.outputs) return NULL; | |
- // First, find all relevant functions and fill ctx.output_map | |
+ output_edges.reserve(num_inputs); | |
for (int i = 0; i < num_inputs; ++i) { | |
PyObject *input = PyTuple_GET_ITEM(inputs, i); | |
THPUtils_assert(THPVariable_Check(input), | |
"all inputs have to be Variables, but got %s", THPUtils_typename(input)); | |
THPVariable *input_var = (THPVariable*)input; | |
+ const auto output_nr = input_var->cdata.output_nr(); | |
auto grad_fn = input_var->cdata.grad_fn(); | |
- int output_nr = input_var->cdata.output_nr(); | |
- bool is_leaf = !grad_fn; | |
- if (is_leaf) { | |
- grad_fn = input_var->cdata.get()->grad_accumulator.lock(); | |
+ if (!grad_fn) { | |
+ grad_fn = input_var->cdata.try_get_grad_accumulator(); | |
} | |
THPUtils_assert(input_var->cdata.requires_grad(), | |
"One of the differentiated Variables does not require grad"); | |
- if (allow_unreachable && !grad_fn) continue; | |
- THPUtils_assert(grad_fn, | |
- "One of the differentiated Variables appears to not have been used in the graph"); | |
- THPUtils_assert(grad_fn->is_executable, | |
- "One of the differentiated Variables has a non-executable grad_fn. Submit a bug report."); | |
- auto& fn_info = ctx.output_map[grad_fn]; | |
- fn_info.first.emplace_back(output_nr, i); | |
- fn_info.second = is_leaf; | |
- } | |
- // Register callbacks that will gather the outputs | |
- for (auto& entry : ctx.output_map) { | |
- auto& fn_info = entry.second; | |
- callbacks.emplace(entry.first.get(), [&ctx, &fn_info](Function* _unused, variable_list& grads) { | |
- auto& saved_outputs = fn_info.first; | |
- bool is_leaf = fn_info.second; | |
- AutoGIL gil; | |
- for (auto& saved_out : saved_outputs) { | |
- PyTuple_SET_ITEM(ctx.outputs.get(), saved_out.second, | |
- THPVariable_Wrap(grads[saved_out.first])); | |
- } | |
- // Suppress grad accumulation. | |
- // If the variable is a leaf, the next function to execute | |
- // is a grad_accumulator. But when inputs != NULL, we should | |
- // NOT accumulate, so terminate execution. | |
- return !is_leaf; | |
- }); | |
- } | |
- // Disable execution for all unneeded functions | |
- if (only_inputs) { | |
- compute_partial_exec_callbacks(roots, ctx, callbacks, allow_unreachable); | |
+ if (!grad_fn) { | |
+ output_edges.emplace_back(); | |
+ } else { | |
+ THPUtils_assert(grad_fn, | |
+ "One of the differentiated Variables appears to not have been used in the graph"); | |
+ output_edges.emplace_back(grad_fn, output_nr); | |
+ } | |
} | |
} | |
+ variable_list outputs; | |
{ | |
AutoNoGIL no_gil; | |
- engine.execute(roots, grads, keep_graph, callbacks); | |
+ outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges); | |
} | |
- if (ctx.outputs) { | |
- for (int i = 0; i < PyTuple_GET_SIZE(inputs); i++) { | |
- // XXX: initializing tuples with NULL pointers might be a CPython | |
- // implementation detail | |
- if (PyTuple_GET_ITEM(ctx.outputs.get(), i)) continue; | |
- Py_INCREF(Py_None); | |
- PyTuple_SET_ITEM(ctx.outputs.get(), i, Py_None); | |
+ if (inputs != nullptr) { | |
+ int num_inputs = PyTuple_GET_SIZE(inputs); | |
+ THPObjectPtr py_outputs {PyTuple_New(num_inputs)}; | |
+ if (!py_outputs) return nullptr; | |
+ for (int i = 0; i < num_inputs; i++) { | |
+ PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i])); | |
} | |
- return ctx.outputs.release(); | |
+ return py_outputs.release(); | |
} else { | |
Py_RETURN_NONE; | |
} | |
@@ -265,14 +180,17 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar | |
} | |
PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) { | |
+ HANDLE_TH_ERRORS | |
+ _maybe_reinitialize_engine_after_fork(); | |
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); }); | |
Py_INCREF(_callback); | |
engine.queue_callback([callback]() { | |
AutoGIL gil; | |
- THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), NULL)}; | |
+ THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), nullptr)}; | |
if (!result) throw python_error(); | |
}); | |
Py_RETURN_NONE; | |
+ END_HANDLE_TH_ERRORS | |
} | |
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) | |
@@ -281,14 +199,14 @@ PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) | |
} | |
static struct PyMethodDef THPEngine_methods[] = { | |
- {(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, NULL}, | |
- {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, NULL}, | |
- {NULL} | |
+ {(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr}, | |
+ {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr}, | |
+ {nullptr} | |
}; | |
PyTypeObject THPEngineType = { | |
- PyVarObject_HEAD_INIT(NULL, 0) | |
+ PyVarObject_HEAD_INIT(nullptr, 0) | |
"torch._C._EngineBase", /* tp_name */ | |
sizeof(THPEngine), /* tp_basicsize */ | |
0, /* tp_itemsize */ | |
@@ -308,7 +226,7 @@ PyTypeObject THPEngineType = { | |
0, /* tp_setattro */ | |
0, /* tp_as_buffer */ | |
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ | |
- NULL, /* tp_doc */ | |
+ nullptr, /* tp_doc */ | |
0, /* tp_traverse */ | |
0, /* tp_clear */ | |
0, /* tp_richcompare */ | |
@@ -328,8 +246,17 @@ PyTypeObject THPEngineType = { | |
THPEngine_new /* tp_new */ | |
}; | |
+static void child_atfork() { | |
+ _reinitialize_engine = true; | |
+} | |
+ | |
bool THPEngine_initModule(PyObject *module) | |
{ | |
+#ifndef _WIN32 | |
+ if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) { | |
+ throw std::runtime_error("unable to set pthread_atfork handler"); | |
+ } | |
+#endif | |
if (PyType_Ready(&THPEngineType) < 0) | |
return false; | |
Py_INCREF(&THPEngineType); | |
diff --git a/torch/csrc/autograd/python_engine.h b/torch/csrc/autograd/python_engine.h | |
index 14627cd6..dd9300c8 100644 | |
--- a/torch/csrc/autograd/python_engine.h | |
+++ b/torch/csrc/autograd/python_engine.h | |
@@ -1,6 +1,8 @@ | |
#pragma once | |
#include <Python.h> | |
+ | |
+#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/engine.h" | |
bool THPEngine_initModule(PyObject *module); | |
@@ -10,12 +12,12 @@ namespace torch { namespace autograd { namespace python { | |
struct PythonEngine : public Engine { | |
virtual void thread_init(int device) override; | |
virtual void thread_on_exception(FunctionTask& task, std::exception& e) override; | |
- virtual void execute( | |
- const function_list& roots, | |
+ virtual variable_list execute( | |
+ const edge_list& roots, | |
const variable_list& inputs, | |
bool keep_graph, | |
- const pre_callback_map& pre_callbacks = pre_callback_map(), | |
- const post_callback_map& post_callbacks = post_callback_map()) override; | |
+ bool create_graph, | |
+ const edge_list& outputs = {}) override; | |
static PythonEngine& getDefaultEngine(); | |
}; | |
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp | |
index 68a2ae2f..7d84562f 100644 | |
--- a/torch/csrc/autograd/python_function.cpp | |
+++ b/torch/csrc/autograd/python_function.cpp | |
@@ -8,30 +8,25 @@ | |
#include <ATen/ATen.h> | |
#include "THP.h" | |
+#include "torch/csrc/autograd/grad_mode.h" | |
#include "torch/csrc/autograd/functions/accumulate_grad.h" | |
#include "torch/csrc/autograd/functions/basic_ops.h" | |
#include "torch/csrc/autograd/functions/utils.h" | |
#include "torch/csrc/autograd/python_cpp_function.h" | |
#include "torch/csrc/autograd/python_hook.h" | |
-#include "torch/csrc/jit/tracer.h" | |
#include "torch/csrc/autograd/saved_variable.h" | |
+#include "torch/csrc/jit/tracer.h" | |
#include "torch/csrc/DynamicTypes.h" | |
#include "torch/csrc/utils/auto_gil.h" | |
#include "torch/csrc/utils/auto_gpu.h" | |
#include "torch/csrc/Exceptions.h" | |
-#ifdef WITH_CUDA | |
-#include "cuda/AutoGPU.h" | |
-#endif | |
- | |
using namespace torch; | |
using namespace torch::autograd; | |
using namespace torch::jit; | |
using at::Tensor; | |
-PyObject *THPFunctionClass = NULL; | |
-PyObject *THPStochasticFunctionClass = NULL; | |
-PyObject *THPBatchNormBackwardBackwardFunction = NULL; | |
+PyObject *THPFunctionClass = nullptr; | |
#define THPFunction_assert(condition, ...) \ | |
if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } | |
@@ -43,7 +38,7 @@ VariableInfo::VariableInfo(const Variable& var) | |
, device(-1) | |
, size(var.sizes()) | |
, requires_grad(var.requires_grad()) { | |
- if (var.type().isCuda()) { | |
+ if (var.type().is_cuda()) { | |
device = var.get_device(); | |
} | |
} | |
@@ -60,15 +55,7 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list { | |
if (!pyInputs) throw python_error(); | |
for (size_t i = 0; i != inputs.size(); ++i) { | |
- PyObject* input; | |
- if (inputs[i].defined()) { | |
- input = createPyObject(inputs[i].data()); | |
- if (!input) throw python_error(); | |
- } else { | |
- input = Py_None; | |
- Py_INCREF(input); | |
- } | |
- PyTuple_SET_ITEM(pyInputs.get(), i, input); | |
+ PyTuple_SET_ITEM(pyInputs.get(), i, THPVariable_Wrap(inputs[i])); | |
} | |
THPObjectPtr r(PyObject_CallMethod( | |
@@ -80,13 +67,13 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list { | |
for (int i = 0; i != num_outputs; ++i) { | |
PyObject* obj = PyTuple_GET_ITEM(r.get(), i); | |
if (obj != Py_None) { | |
- if (!THPModule_isTensor(obj)) { | |
- std::string msg("expected Tensor (got '"); | |
+ if (!THPVariable_Check(obj)) { | |
+ std::string msg("expected Variable (got '"); | |
msg += THPUtils_typename(obj); | |
msg += "')'"; | |
throw std::runtime_error(msg); | |
} | |
- tensor_results[i] = createTensor(obj); | |
+ tensor_results[i] = ((THPVariable*)obj)->cdata.data(); | |
} | |
} | |
@@ -96,9 +83,13 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list { | |
// leads to unexpected error messages ("no nodes require computing gradients"), | |
// but I don't have a better idea. These functions would raise an error | |
// in backward anyway. | |
- return wrap_outputs(inputs, std::move(tensor_results), [this](FunctionFlags &&f) { | |
- return std::make_shared<Error>(name() + " is not differentiable twice", std::move(f)); | |
- }); | |
+ return wrap_outputs( | |
+ inputs, | |
+ std::move(tensor_results), | |
+ [this](edge_list&& next_edges) { | |
+ return std::make_shared<Error>( | |
+ name() + " is not differentiable twice", std::move(next_edges)); | |
+ }); | |
} | |
// NOTE: this function is written in a way that assumes it's only called for backward; | |
@@ -208,7 +199,7 @@ auto PyFunction::is_traceable() -> bool { | |
return traceable_py_bool == Py_True; | |
} | |
-auto PyFunction::releaseVariables() -> void { | |
+auto PyFunction::release_variables() -> void { | |
AutoGIL gil; | |
auto f = (THPFunction*) obj; | |
f->saved_variables.clear(); | |
@@ -226,7 +217,7 @@ auto PyFunction::name() -> std::string { | |
return name; | |
} | |
-auto PyFunction::getSharedPtr() -> std::shared_ptr<Function> { | |
+auto PyFunction::get_shared_ptr() -> std::shared_ptr<Function> { | |
return THPFunction_asFunction((THPFunction*)obj); | |
} | |
@@ -235,18 +226,17 @@ auto PyFunction::getSharedPtr() -> std::shared_ptr<Function> { | |
// Traverse and clear are required for supporting Python's GC cycle handling. | |
static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) | |
{ | |
- for (auto& hook : self->cdata.pre_hooks) { | |
+ for (const auto& hook : self->cdata.pre_hooks()) { | |
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) { | |
Py_VISIT(pyhook->dict); | |
} | |
} | |
- for (auto& hook : self->cdata.post_hooks) { | |
+ for (const auto& hook : self->cdata.post_hooks()) { | |
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) { | |
Py_VISIT(pyhook->dict); | |
} | |
} | |
Py_VISIT(self->to_save); | |
- Py_VISIT(self->shared_pairs); | |
Py_VISIT(self->non_differentiable); | |
Py_VISIT(self->dirty_tensors); | |
return 0; | |
@@ -254,12 +244,11 @@ static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) | |
static int THPFunction_clear(THPFunction *self) | |
{ | |
- self->cdata.num_inputs = 0; | |
+ self->cdata.set_num_inputs(0); | |
Py_CLEAR(self->needs_input_grad); | |
Py_CLEAR(self->to_save); | |
- Py_CLEAR(self->shared_pairs); | |
Py_CLEAR(self->non_differentiable); | |
Py_CLEAR(self->dirty_tensors); | |
@@ -268,10 +257,13 @@ static int THPFunction_clear(THPFunction *self) | |
self->saved_variables.clear(); | |
self->is_variable_input.clear(); | |
- // XXX: this will clear all hooks (not only Python ones) | |
- // I guess it's ok to leave it as is for now. | |
- auto pre_hooks = std::move(self->cdata.pre_hooks); | |
- auto post_hooks = std::move(self->cdata.post_hooks); | |
+ // Moving the hooks out makes sure to first disassociate them from the | |
+ // function, but without destroying any of them. They will get deleted when | |
+ // exiting this scope. This is important, because deleting Python objects can | |
+ // trigger deletion of other objects, and they can reference this function, | |
+ // seeing it in a half-deleted state. | |
+ auto pre_hooks = std::move(self->cdata.pre_hooks()); | |
+ auto post_hooks = std::move(self->cdata.post_hooks()); | |
return 0; | |
} | |
@@ -291,7 +283,7 @@ static void THPFunction_dealloc(THPFunction* self) | |
PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) | |
{ | |
PyObject* obj = type->tp_alloc(type, 0); | |
- if (!obj) return NULL; | |
+ if (!obj) return nullptr; | |
// Python zero-initializes the object memory, so there's no need to initialize | |
// most fields | |
THPFunction* self = (THPFunction*)obj; | |
@@ -300,8 +292,7 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) | |
new (&self->input_info) std::vector<VariableInfo>(); | |
new (&self->saved_variables) std::vector<SavedVariable>(); | |
new (&self->is_variable_input) std::vector<bool>(); | |
- self->cdata.num_inputs = -1; | |
- self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass); | |
+ self->cdata.set_num_inputs(0); | |
return obj; | |
} | |
@@ -313,67 +304,37 @@ using t2var_type = std::unordered_map<PyObject *, THPVariable *>; | |
// Bump the counters of all recorded dirty input tensors, adding each of them | |
// into dirty_inputs. Also does some sanity checking. | |
-static void _mark_dirty(THPFunction *self, t2var_type &t2var, | |
- std::unordered_set<PyObject *> &dirty_inputs) | |
+static std::vector<PyObject*> _mark_dirty(THPFunction *self) | |
{ | |
// Increase versions of modified tensors | |
- if (!self->dirty_tensors) return; | |
+ std::vector<PyObject*> dirty_inputs; | |
+ if (!self->dirty_tensors) return dirty_inputs; | |
THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd " | |
"internal error: dirty_tensors attribute is expected to be a tuple " | |
"but is %s", THPUtils_typename(self->dirty_tensors)); | |
Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors); | |
for (int i = 0; i < num_dirty; i++) { | |
- PyObject *tensor = PyTuple_GET_ITEM(self->dirty_tensors, i); | |
- dirty_inputs.insert(tensor); | |
- THPVariable *variable; | |
- try { | |
- variable = t2var.at(tensor); | |
- } catch (std::out_of_range &e) { | |
- THPFunction_assert(THPModule_isTensor(tensor), "mark_dirty can " | |
- "only accept tensors, but argument %d is of type %s", i, | |
- THPUtils_typename(tensor)); | |
- THPFunction_assert(false, "mark_dirty only accepts input tensors, but " | |
- "argument %d isn't one", i); | |
- } | |
- auto& version_counter = variable->cdata.version_counter(); | |
- THPFunction_assert(version_counter.live_refs() == 1, | |
- "in-place operations can be only used on variables that don't share " | |
- "storage with any other variables, but detected that there are %d " | |
- "objects sharing it", | |
- version_counter.live_refs()); | |
- version_counter.increment(); | |
+ PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i); | |
+ THPFunction_assert(THPVariable_Check(obj), "mark_dirty can " | |
+ "only accept variables, but argument %d is of type %s", i, | |
+ THPUtils_typename(obj)); | |
+ | |
+ dirty_inputs.push_back(obj); | |
+ auto variable = (THPVariable*)obj; | |
+ variable->cdata.bump_version(); | |
} | |
// We're not going to ever need this so let's remove references now | |
- Py_DECREF(self->dirty_tensors); | |
- self->dirty_tensors = NULL; | |
+ Py_CLEAR(self->dirty_tensors); | |
+ return dirty_inputs; | |
} | |
-static void _transplant_var(VariableImpl& var, const std::shared_ptr<Function>& fn, int output_nr, bool is_volatile) | |
-{ | |
- if (is_volatile) { | |
- var.grad_fn = nullptr; | |
- var.requires_grad = false; | |
- var.is_volatile = true; | |
- var.output_nr = 0; | |
- } else { | |
- var.grad_fn = fn; | |
- var.requires_grad = fn->is_executable; | |
- var.is_volatile = is_volatile; | |
- var.output_nr = output_nr; | |
- } | |
- var.grad.reset(); | |
- var.hooks.clear(); | |
- if (auto grad_acc_fn = var.grad_accumulator.lock()) { | |
- auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get()); | |
- grad_acc->variable.reset(); | |
- } | |
-} | |
+static std::unordered_set<PyObject*> _parse_non_differentiable(THPFunction *self); | |
// Given a Python tuple of raw output tensors (raw_output), set each of | |
// the corresponding entries in a different Python tuple (outputs) with | |
// these tensors wrapped with variables. We save the gradient function (self) | |
-// to the variable if the output is not volatile (is_volatile). | |
+// to the variable if the output requires grad. | |
// | |
// There is a considerable amount of complexity to handle if the operation | |
// that produced these output tensors is inplace. A mapping of *input* | |
@@ -381,92 +342,93 @@ static void _transplant_var(VariableImpl& var, const std::shared_ptr<Function>& | |
// the set of dirty tensors (dirty_inputs) is used to figure out what to | |
// do in this case. After this method is run, t2var is extended with | |
// mappings for output tensors as well. | |
-static void _wrap_outputs(THPFunction *self, t2var_type &t2var, | |
- std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output, | |
- PyObject *outputs, bool is_volatile) | |
+static void _wrap_outputs(THPFunction *self, | |
+ PyObject* inputs_tuple, PyObject *raw_output, PyObject *outputs, bool is_executable) | |
{ | |
- auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self); | |
+ auto cdata = is_executable ? THPFunction_asFunction(self) : nullptr; | |
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output); | |
- if (self->cdata.is_executable) { | |
+ if (is_executable) { | |
self->output_info.clear(); | |
self->output_info.reserve(num_outputs); | |
} | |
- for (int i = 0; i < num_outputs; i++) { | |
- PyObject *output = PyTuple_GET_ITEM(raw_output, i); | |
- THPVariable *output_var; | |
- auto it = t2var.find(output); | |
- if (it == t2var.end()) { | |
- // A completely new tensor - just wrap it and continue | |
- if (is_volatile) { | |
- output_var = (THPVariable*)THPVariable_NewVolatile(output); | |
- } else { | |
- output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); | |
+ | |
+ std::unordered_set<PyObject*> inputs; | |
+ int num_inputs = PyTuple_GET_SIZE(inputs_tuple); | |
+ for (int i = 0; i < num_inputs; i++) { | |
+ inputs.emplace(PyTuple_GET_ITEM(inputs_tuple, i)); | |
+ } | |
+ | |
+ auto non_differentiable = _parse_non_differentiable(self); | |
+ auto dirty_inputs = _mark_dirty(self); | |
+ | |
+ auto as_variable = [&](PyObject* obj, int i) -> Variable { | |
+ if (THPVariable_Check(obj)) { | |
+ return ((THPVariable*)obj)->cdata; | |
+ } | |
+ throw TypeError("%s.forward: expected Variable (got %s) for return value %d", | |
+ Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i); | |
+ }; | |
+ | |
+ // Sets the grad_fn and output_nr of an output Variable. | |
+ auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified, | |
+ bool is_differentiable) { | |
+ if (!is_differentiable) { | |
+ if (!var.requires_grad()) return; | |
+ // NB: we don't support returning non-differentiable views that could require grad | |
+ // (this could happen if someone were to return an input to the function). | |
+ if (var.is_view()) { | |
+ throw std::runtime_error("Returning Variables sharing storage with other Variables " | |
+ "that require grad is not supported in Python functions. " | |
+ "Please submit a feature request if you hit this error."); | |
} | |
- } else { | |
- // If one of the outputs was also an input tensor it's a bit more complicated. | |
- THPVariable *input_var = it->second; | |
- auto& input_var_ = input_var->cdata; | |
- if (input_var_.grad_fn()) { | |
- Py_INCREF(input_var); | |
- output_var = input_var; | |
- // If it's not a leaf we want to move it in the graph so backprop | |
- // will be computed correctly, but only if it was modified. Otherwise | |
- // it's better to minimize the number of operations that mutate the graph. | |
- // grad_fn <- variable <- self ==> grad_fn <- self <- variable | |
- if (dirty_inputs.count(output) > 0) { | |
- _transplant_var(*input_var_.get(), cdata, i, is_volatile); | |
- } | |
- } else { | |
- // If the leaf Variable has been returned, we have to move it after the | |
- // current function to ensure the gradient is computed correctly. | |
- // There are two cases now: | |
- // 1. It has been modified in-place. If it didn't require_grad it's ok, | |
- // but if it does, then it's a clear error. | |
- // 2. It hasn't been modified. This means that it must have been | |
- // returned unchanged, and we can simply return a new Variable | |
- // referencing the same storage. | |
- if (dirty_inputs.count(output) > 0) { | |
- if (!input_var_.requires_grad()) { | |
- Py_INCREF(input_var); | |
- output_var = input_var; | |
- _transplant_var(*input_var_.get(), cdata, i, is_volatile); | |
- } else { // input_var_.requires_grad | |
- throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation."); | |
- } | |
- } else { | |
- // An input has been returned, but it wasn't modified. It's better | |
- // not to move the Variable, because there are some legitimate cases | |
- // where making it non-leaf would break stuff (e.g. broadcast). Also, | |
- // returning the input Variable is not a good option either, | |
- // because if someone registers hooks on it, they will fire with grads | |
- // from all usages, not only from usages of this output. This is why | |
- // we'll return a copy and join their version counters. This has | |
- // a side-effect of making in-place ops on any of these Variables an | |
- // immediate error, but it would be raised anyway once someone | |
- // calls backward. | |
- if (is_volatile) { | |
- output_var = (THPVariable*)THPVariable_NewVolatile(output); | |
- } else { | |
- output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); | |
- } | |
- if (!output_var) throw python_error(); | |
- output_var->cdata.version_counter() = input_var->cdata.version_counter(); | |
- } | |
+ var.detach_(); | |
+ } else if (is_modified) { | |
+ if (var.is_leaf() && var.requires_grad()) { | |
+ throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation."); | |
} | |
+ // If the input was modified, transplant the grad_fn in the graph: | |
+ // grad_fn <- variable <- self ==> grad_fn <- self <- variable | |
+ var.reset_grad(); | |
+ var.clear_hooks(); | |
+ if (auto grad_acc_fn = var.try_get_grad_accumulator()) { | |
+ auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get()); | |
+ grad_acc->variable.reset(); | |
+ } | |
+ if (cdata) { | |
+ var.rebase_history({cdata, output_nr}); | |
+ } | |
+ } else if (is_input) { | |
+ // An input has been returned, but it wasn't modified. Return it as a view | |
+ // so that we can attach a new grad_fn to the Variable. | |
+ var = var.slice(); | |
+ var.set_gradient_edge({cdata, output_nr}); | |
+ } else if (cdata) { | |
+ var.set_gradient_edge({cdata, output_nr}); | |
} | |
- if (!output_var) throw python_error(); | |
+ }; | |
+ | |
+ for (int i = 0; i < num_outputs; i++) { | |
+ PyObject* obj = PyTuple_GET_ITEM(raw_output, i); | |
+ | |
+ bool is_input = inputs.count(obj) > 0; | |
+ bool is_modified = std::find(dirty_inputs.begin(), dirty_inputs.end(), obj) != dirty_inputs.end(); | |
+ bool is_differentiable = is_executable && non_differentiable.count(obj) == 0; | |
- if (self->cdata.is_executable) { | |
- self->output_info.emplace_back(output_var->cdata); | |
+ // Note that output Variables may be repeated. In that case, the last call | |
+ // to set_history wins. | |
+ auto var = as_variable(obj, i); | |
+ set_history(var, i, is_input, is_modified, is_differentiable); | |
+ | |
+ if (is_executable) { | |
+ self->output_info.emplace_back(var); | |
} | |
- t2var[output] = output_var; | |
- output_var->cdata.get()->output_nr = i; | |
- PyTuple_SET_ITEM(outputs, i, (PyObject*)output_var); | |
+ | |
+ PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(var)); | |
} | |
} | |
// Save any variables that requested by to_save | |
-static void _save_variables(THPFunction* self, t2var_type &t2var) | |
+static void _save_variables(THPFunction* self) | |
{ | |
if (!self->to_save) return; | |
@@ -478,110 +440,54 @@ static void _save_variables(THPFunction* self, t2var_type &t2var) | |
self->saved_variables.reserve(num_saved); | |
auto cdata_ptr = &self->cdata; | |
for (int i = 0; i < num_saved; i++) { | |
- PyObject *tensor = PyTuple_GET_ITEM(self->to_save, i); | |
- if (tensor == Py_None) { | |
+ PyObject *obj = PyTuple_GET_ITEM(self->to_save, i); | |
+ if (obj == Py_None) { | |
self->saved_variables.emplace_back(); | |
continue; | |
+ } else if (THPVariable_Check(obj)) { | |
+ auto variable = (THPVariable*)obj; | |
+ bool is_output = variable->cdata.grad_fn().get() == cdata_ptr; | |
+ self->saved_variables.emplace_back(variable->cdata, is_output); | |
+ } else { | |
+ throw TypeError( | |
+ "save_for_backward can only save variables, but argument %d is of " | |
+ "type %s", i, Py_TYPE(obj)->tp_name); | |
} | |
- | |
- THPVariable *variable; | |
- try { | |
- variable = t2var.at(tensor); | |
- } catch(std::out_of_range &e) { | |
- THPFunction_assert(THPModule_isTensor(tensor), | |
- "save_for_backward can only save tensors, but argument %d is of " | |
- "type %s", i, THPUtils_typename(tensor)); | |
- THPFunction_assert(false, "save_for_backward can only save input or output " | |
- "tensors, but argument %d doesn't satisfy this condition", i); | |
- } | |
- | |
- self->saved_variables.emplace_back(variable->cdata, cdata_ptr); | |
} | |
// Free .to_save | |
- Py_DECREF(self->to_save); | |
- self->to_save = NULL; | |
-} | |
- | |
-// t2var maps input and output tensors to variables | |
-static void _join_version_counters(THPFunction *self, t2var_type &t2var) | |
-{ | |
- if (!self->shared_pairs) return; | |
- THPFunction_assert(PyTuple_Check(self->shared_pairs), "autograd internal " | |
- "error: shared_pairs attribute is expected to be a tuple but is %s", | |
- THPUtils_typename(self->shared_pairs)); | |
- Py_ssize_t num_shared = PyTuple_GET_SIZE(self->shared_pairs); | |
- for (int i = 0; i < num_shared; i++) { | |
- PyObject *shared_tuple = PyTuple_GET_ITEM(self->shared_pairs, i); | |
- THPFunction_assert(PyTuple_Check(shared_tuple), "mark_shared_storages " | |
- "accepts a number of pairs, but one of the arguments is of type %s", | |
- THPUtils_typename(shared_tuple)); | |
- THPFunction_assert(PyTuple_GET_SIZE(shared_tuple) == 2, | |
- "mark_shared_storages accepts pairs, but argument %d is a tuple of " | |
- "%d elements", i, PyTuple_GET_SIZE(shared_tuple)); | |
- | |
- // Now we're sure it's really a pair! | |
- THPVariable *v1, *v2; | |
- try { | |
- // NB: According to the documentation, v1 is an input tensor, and v2 | |
- // is an output tensor, but we don't actually check this | |
- v1 = t2var.at(PyTuple_GET_ITEM(shared_tuple, 0)); | |
- v2 = t2var.at(PyTuple_GET_ITEM(shared_tuple, 1)); | |
- } catch(std::out_of_range &e) { | |
- // One tuple items wasn't present in t2var, so there are two cases: | |
- // 1. it's not a tensor | |
- // 2. it's not an input nor an output | |
- PyObject *t1 = PyTuple_GET_ITEM(shared_tuple, 0); | |
- PyObject *t2 = PyTuple_GET_ITEM(shared_tuple, 1); | |
- THPFunction_assert(THPModule_isTensor(t1) && THPModule_isTensor(t2), | |
- "mark_shared_storages accepts pairs of tensors, but one of them " | |
- "contains %s and %s", THPUtils_typename(t1), THPUtils_typename(t2)); | |
- THPFunction_assert(false, "mark_shared_storages only accepts pairs of input " | |
- "and output tensors, but argument %d doesn't satify this " | |
- "condition", i); | |
- } | |
- v2->cdata.version_counter() = v1->cdata.version_counter(); | |
- } | |
- // Free .shared_pairs | |
- Py_DECREF(self->shared_pairs); | |
- self->shared_pairs = NULL; | |
+ Py_CLEAR(self->to_save); | |
} | |
// Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable) | |
-static void _mark_non_differentiable(THPFunction *self, t2var_type &t2var) | |
+static std::unordered_set<PyObject*> | |
+_parse_non_differentiable(THPFunction *self) | |
{ | |
- if (!self->non_differentiable) return; | |
+ std::unordered_set<PyObject*> set; | |
+ if (!self->non_differentiable) return set; | |
THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd " | |
"internal error: non_differentiable attribute is expected to be a " | |
"tuple but is %s", THPUtils_typename(self->non_differentiable)); | |
Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable); | |
+ set.reserve(num_nondiff); | |
for (int i = 0; i < num_nondiff; i++) { | |
PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i); | |
- THPVariable *var; | |
- try { | |
- var = t2var.at(t); | |
- THPFunction_assert(var->cdata.grad_fn().get() == &self->cdata, | |
- "mark_non_differentiable only accepts output tensors, but " | |
- "argument %d isn't an output", i); | |
- } catch (std::out_of_range &e) { | |
- THPFunction_assert(THPModule_isTensor(t), "mark_non_differentiable " | |
- "only accepts tensor arguments, but got %s", THPUtils_typename(t)); | |
- THPFunction_assert(false, "mark_non_differentiable only accepts function " | |
- "outputs"); | |
- } | |
- var->cdata.requires_grad() = false; | |
+ THPFunction_assert(THPVariable_Check(t), "mark_non_differentiable " | |
+ "only accepts variable arguments, but got %s", THPUtils_typename(t)); | |
+ set.insert(t); | |
} | |
- Py_DECREF(self->non_differentiable); | |
- self->non_differentiable = NULL; | |
+ Py_CLEAR(self->non_differentiable); | |
+ return set; | |
} | |
struct UnpackedInput { | |
- THPObjectPtr tensor_input; | |
+ THPObjectPtr input_tuple; | |
variable_list input_vars; | |
}; | |
struct InputFlags { | |
- FunctionFlags flags; | |
+ bool is_executable = false; | |
+ edge_list next_edges; | |
THPObjectPtr needs_input_grad; | |
std::vector<bool> is_variable_input; | |
}; | |
@@ -592,60 +498,53 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) { | |
InputFlags flags; | |
auto num_args = PyTuple_GET_SIZE(args); | |
- unpacked.tensor_input = PyTuple_New(num_args); | |
+ unpacked.input_tuple = PyTuple_New(num_args); | |
flags.needs_input_grad = PyTuple_New(num_args); | |
for (int i = 0; i < num_args; i++) { | |
PyObject *arg = PyTuple_GET_ITEM(args, i); | |
- PyObject *new_arg; | |
bool is_variable = THPVariable_Check(arg); | |
flags.is_variable_input.push_back(is_variable); | |
if (!is_variable) { | |
+ // TODO: remove this code path once Variable and Tensor are merged in Python | |
if (enforce_variables) { | |
THPUtils_setError("expected a Variable argument, but got %s", | |
THPUtils_typename(arg)); | |
throw python_error(); | |
} | |
- Py_INCREF(arg); | |
- new_arg = arg; | |
Py_INCREF(Py_False); | |
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False); | |
} else { | |
THPVariable* variable = (THPVariable*)arg; | |
- new_arg = THPVariable_get_data(variable); | |
unpacked.input_vars.push_back(variable->cdata); | |
PyObject* needs_grad = variable->cdata.requires_grad() ? Py_True : Py_False; | |
Py_INCREF(needs_grad); | |
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad); | |
} | |
- PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg); | |
+ Py_INCREF(arg); | |
+ PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg); | |
} | |
- flags.flags = Function::flags(unpacked.input_vars); | |
+ flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars); | |
+ flags.next_edges = collect_next_edges(unpacked.input_vars); | |
return std::make_pair(std::move(unpacked), std::move(flags)); | |
} | |
-static void _trace_create(PyObject* op_obj, THPFunction* bw_obj, | |
- PyObject *input_objects, PyObject *output_objects, | |
- const variable_list& input_vars, bool is_inplace) { | |
- if (!tracer::isTracing(input_vars)) | |
- return; | |
- | |
- if (!op_obj) { | |
+static void _assert_not_tracing(const char* name, const variable_list& input_vars) { | |
+ if (tracer::isTracingVar(input_vars)) { | |
std::ostringstream oss; | |
- oss << "Attempted to trace " << Py_TYPE(bw_obj)->tp_name; | |
+ oss << "Attempted to trace " << name; | |
oss << ", but tracing of legacy functions is not supported"; | |
throw std::runtime_error(oss.str()); | |
} | |
+} | |
- auto tracing_state = tracer::getTracingState(input_vars); | |
- bw_obj->is_traced = true; | |
- | |
- // Isolate C variable ptrs in a vector | |
- variable_list output_vars; | |
- for (int i = 0; i < PyTuple_GET_SIZE(output_objects); ++i) { | |
- THPVariable *var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i); | |
- output_vars.emplace_back(var->cdata); | |
+static jit::tracer::PreTraceInfo _trace_pre_record( | |
+ PyObject* op_obj, | |
+ PyObject *input_objects, | |
+ const variable_list& input_vars) { | |
+ if (!tracer::isTracingVar(input_vars)) { | |
+ return jit::tracer::PreTraceInfo(); | |
} | |
// Save scalar args and the calling convention | |
@@ -665,49 +564,37 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj, | |
} | |
} | |
- auto state_lock = tracing_state->lock(); | |
- | |
- // Note [getValueTrace can allocate nodes] | |
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
- // When an input variable is not traced, we create a constant instruction | |
- // to represent it. This means that you must invoke getValueTrace() BEFORE | |
- // actually constructing the function that takes these variables as inputs. | |
- // If we do it the other order, the graph will be in the wrong topological | |
- // order. | |
- | |
- // See Note [getValueTrace can allocate nodes] | |
- std::vector<Node*> value_traces; | |
- value_traces.reserve(input_vars.size()); | |
- for (auto& i : input_vars) | |
- value_traces.emplace_back(tracer::getValueTrace(tracing_state, i)); | |
+ Py_INCREF(op_obj); | |
+ auto pyobj = THPObjectPtr(op_obj); | |
+ return jit::tracer::preRecordPythonTrace( | |
+ std::move(pyobj), | |
+ std::move(arg_types), | |
+ input_vars, | |
+ std::move(scalar_args)); | |
+} | |
- // NB: this function is called only from THPFunction_apply, which is used only | |
- // when computing forward. All these functions are non-traceable by definition, | |
- // because they are implemented in terms of tensor operations. Hence, there's no | |
- // need for any conditionals in here and we can always create the node. | |
+static void _trace_post_record( | |
+ const jit::tracer::PreTraceInfo& trace_info, | |
+ PyObject* op_obj, | |
+ const variable_list& input_vars, | |
+ PyObject *output_objects, | |
+ bool is_inplace) { | |
+ if (!trace_info.state) { | |
+ return; | |
+ } | |
- // Construct the IR Node and its Selects | |
- Py_INCREF(op_obj); | |
- auto& graph = tracing_state->graph; | |
- auto this_expr = graph->appendNode(graph->createPythonOp( | |
- THPObjectPtr(op_obj), | |
- arg_types, | |
- false, // TODO: remove is_legacy | |
- std::move(scalar_args))); | |
- for (auto t : value_traces) | |
- this_expr->addInput(t); | |
- | |
- int num_outputs = output_vars.size(); | |
+ // Isolate C variable ptrs in a vector | |
+ int num_outputs = PyTuple_GET_SIZE(output_objects); | |
+ variable_list output_vars(num_outputs); | |
for (int i = 0; i < num_outputs; ++i) { | |
- auto& output = output_vars[i]; | |
- // NOTE: normally we don't add Select nodes when there's only a single | |
- // output, but Python nodes can't be optimized away, so we simplify the | |
- // code here. | |
- Node* sel = graph->appendNode(graph->createSelect(this_expr, i)); | |
- sel->inferTypeFrom(output.data()); | |
- tracer::setValueTrace(tracing_state, output, sel); | |
+ auto var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i); | |
+ output_vars[i] = var->cdata; | |
} | |
- this_expr->i_(kinplace, is_inplace); | |
+ | |
+ jit::tracer::postRecordTrace(trace_info, output_vars); | |
+ | |
+ auto state_lock = trace_info.state->lock(); | |
+ trace_info.n->i_(kinplace, is_inplace); | |
// See definition in function.cpp. | |
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")}; | |
@@ -717,12 +604,15 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj, | |
// tracing_state->in_eval_subgraph (it's always false, because they are never part of backward | |
// subgraphs AND we don't even materialize the forward function). | |
if (!passes_state_transparently) { | |
+ // TODO: sgross and ezyang don't know if this is right | |
tracer::nontraceableBackwardSubgraph(input_vars, output_vars); | |
- Function::setUpContextEdge(this_expr, num_outputs, input_vars, output_vars); | |
+ Function::set_up_context_edge(trace_info.n, input_vars, output_vars); | |
} | |
} | |
-PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked, PyObject *inputs, THPObjectPtr&& raw_output, bool is_volatile) { | |
+PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked, | |
+ PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable, | |
+ const jit::tracer::PreTraceInfo& trace_info) { | |
bool unpack_output = ensure_tuple(raw_output); | |
auto num_outputs = PyTuple_GET_SIZE(raw_output.get()); | |
@@ -730,10 +620,10 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked | |
THPObjectPtr outputs(PyTuple_New(num_outputs)); | |
if (!outputs) throw python_error(); | |
- grad_fn->cdata.num_inputs = num_outputs; | |
+ grad_fn->cdata.set_num_inputs(num_outputs); | |
// Record type, device, and size information about inputs | |
- if (grad_fn->cdata.is_executable) { | |
+ if (is_executable) { | |
grad_fn->input_info.clear(); | |
grad_fn->input_info.reserve(unpacked.input_vars.size()); | |
for (auto& var : unpacked.input_vars) { | |
@@ -741,36 +631,22 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked | |
} | |
} | |
- // Initialize t2var map with input tensors | |
- t2var_type t2var; | |
- for (auto& c_var : unpacked.input_vars) { | |
- THPVariable* py_var = (THPVariable*)c_var.get()->pyobj; | |
- t2var.emplace(py_var->data, py_var); | |
- } | |
- | |
- std::unordered_set<PyObject *> dirty_inputs; | |
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors); | |
- _mark_dirty(grad_fn, t2var, dirty_inputs); | |
- _wrap_outputs(grad_fn, t2var, dirty_inputs, raw_output, outputs, is_volatile); | |
- // At this point, t2var contains output tensors as well | |
- _join_version_counters(grad_fn, t2var); | |
- if (grad_fn->cdata.is_executable) { | |
- _mark_non_differentiable(grad_fn, t2var); | |
- } | |
- // NOTE: _trace_create has to run before _save_variables, because we need | |
+ _wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable); | |
+ // NOTE: _trace_post_record has to run before _save_variables, because we need | |
// to assign traces to outputs before we convert them to SavedVariables. | |
// On the other hand, it needs to go after _mark_non_differentiable, because | |
// it might be wraping backwards in Evals, and _mark_non_differentiable uses | |
// grad_fn pointer equality for error checking. | |
- _trace_create(op_obj, grad_fn, inputs, outputs, unpacked.input_vars, is_inplace); | |
- if (grad_fn->cdata.is_executable) { | |
- _save_variables(grad_fn, t2var); | |
+ _trace_post_record(trace_info, op_obj, unpacked.input_vars, outputs, is_inplace); | |
+ if (is_executable) { | |
+ _save_variables(grad_fn); | |
} else { | |
// Remove unnecessary attributes | |
Py_XDECREF(grad_fn->to_save); | |
- grad_fn->to_save = NULL; | |
+ grad_fn->to_save = nullptr; | |
Py_XDECREF(grad_fn->non_differentiable); | |
- grad_fn->non_differentiable = NULL; | |
+ grad_fn->non_differentiable = nullptr; | |
} | |
// Unpack the output, unless .forward() returned a tuple | |
@@ -792,17 +668,25 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs) | |
auto info_pair = unpack_input<true>(_inputs); | |
auto& unpacked_input = info_pair.first; | |
auto& input_info = info_pair.second; | |
- bool is_volatile = input_info.flags.is_volatile; | |
- self->cdata.set_flags(std::move(input_info.flags)); | |
+ bool is_executable = input_info.is_executable; | |
+ self->cdata.set_next_edges(std::move(input_info.next_edges)); | |
self->needs_input_grad = input_info.needs_input_grad.release(); | |
+ // We don't support tracing in the legacy code path | |
+ _assert_not_tracing(Py_TYPE(self)->tp_name, unpacked_input.input_vars); | |
+ | |
// Now we're ready to call a forward (implemented in Python) | |
- THPObjectPtr forward_fn(PyObject_GetAttrString((PyObject*)self, "forward")); | |
- if (!forward_fn) return NULL; | |
- THPObjectPtr raw_output(PyObject_CallObject(forward_fn, unpacked_input.tensor_input)); | |
- if (!raw_output) return NULL; | |
+ THPObjectPtr raw_output; | |
+ { | |
+ AutoGradMode grad_mode(false); | |
+ THPObjectPtr forward_fn(PyObject_GetAttrString((PyObject*)self, "forward")); | |
+ if (!forward_fn) return nullptr; | |
+ raw_output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple); | |
+ if (!raw_output) return nullptr; | |
+ } | |
- return process_outputs(nullptr, self, unpacked_input, _inputs, std::move(raw_output), is_volatile); | |
+ return process_outputs(nullptr, self, unpacked_input, _inputs, std::move(raw_output), | |
+ is_executable, jit::tracer::PreTraceInfo()); | |
END_HANDLE_TH_ERRORS | |
} | |
@@ -812,9 +696,9 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) | |
torch::autograd::profiler::RecordFunction record(((PyTypeObject*)cls)->tp_name); | |
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); | |
- if (!backward_cls) return NULL; | |
- THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, NULL)); | |
- if (!ctx_obj) return NULL; | |
+ if (!backward_cls) return nullptr; | |
+ THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr)); | |
+ if (!ctx_obj) return nullptr; | |
THPFunction* ctx = (THPFunction*)ctx_obj.get(); | |
// Prepare inputs and allocate context (grad fn) | |
@@ -822,32 +706,41 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) | |
UnpackedInput& unpacked_input = info_pair.first; | |
InputFlags& input_info = info_pair.second; | |
+ // Record input nodes if tracing | |
+ auto trace_info = _trace_pre_record(cls, inputs, unpacked_input.input_vars); | |
+ if (trace_info.state) { | |
+ // TODO: ezyang suggests this is unused and can be removed | |
+ ctx->is_traced = true; | |
+ } | |
+ | |
// Initialize backward function (and ctx) | |
- bool is_volatile = input_info.flags.is_volatile; | |
- ctx->cdata.set_flags(std::move(input_info.flags)); | |
+ bool is_executable = input_info.is_executable; | |
+ ctx->cdata.set_next_edges(std::move(input_info.next_edges)); | |
ctx->needs_input_grad = input_info.needs_input_grad.release(); | |
ctx->is_variable_input = std::move(input_info.is_variable_input); | |
- // Prepend ctx to tensor_input, in preparation for static method call | |
+ // Prepend ctx to input_tuple, in preparation for static method call | |
auto num_args = PyTuple_GET_SIZE(inputs); | |
- THPObjectPtr ctx_tensor_input(PyTuple_New(num_args + 1)); | |
- PyTuple_SET_ITEM(ctx_tensor_input.get(), 0, ctx_obj.release()); | |
+ THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1)); | |
+ PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, ctx_obj.release()); | |
for (int i = 0; i < num_args; ++i) { | |
- PyObject *arg = PyTuple_GET_ITEM(unpacked_input.tensor_input.get(), i); | |
+ PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i); | |
Py_INCREF(arg); | |
- PyTuple_SET_ITEM(ctx_tensor_input.get(), i + 1, arg); | |
+ PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg); | |
} | |
// Call forward | |
- THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); | |
- if (!forward_fn) return NULL; | |
- THPObjectPtr tensor_outputs(PyObject_CallObject(forward_fn, ctx_tensor_input)); | |
- if (!tensor_outputs) return NULL; | |
- | |
- THPObjectPtr outputs {process_outputs(cls, ctx, unpacked_input, inputs, | |
- std::move(tensor_outputs), is_volatile)}; | |
+ THPObjectPtr tensor_outputs; | |
+ { | |
+ AutoGradMode grad_mode(false); | |
+ THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); | |
+ if (!forward_fn) return nullptr; | |
+ tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple); | |
+ if (!tensor_outputs) return nullptr; | |
+ } | |
- return outputs.release(); | |
+ return process_outputs(cls, ctx, unpacked_input, inputs, std::move(tensor_outputs), | |
+ is_executable, trace_info); | |
END_HANDLE_TH_ERRORS | |
} | |
@@ -856,51 +749,52 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) | |
// Backward | |
//////////////////////////////////////////////////////////////////////////////// | |
-static void _prepare_grad_output(THPFunction *self, THPObjectPtr& raw_grad_output) | |
+static void _prepare_grads(THPFunction *self, THPObjectPtr& raw_grads, bool is_grad_output) | |
{ | |
AutoGPU gpu_guard(-1); | |
- int num_grad_output = PyTuple_GET_SIZE(raw_grad_output.get()); | |
- // First, check if any of grad_outputs is None. If not, there's nothing to do | |
+ int num_grads = PyTuple_GET_SIZE(raw_grads.get()); | |
+ // First, check if any of grads is None. If not, there's nothing to do | |
bool has_none = false; | |
- for (int i = 0; i < num_grad_output; i++) { | |
- has_none |= PyTuple_GET_ITEM(raw_grad_output.get(), i) == Py_None; | |
+ for (int i = 0; i < num_grads; i++) { | |
+ has_none |= PyTuple_GET_ITEM(raw_grads.get(), i) == Py_None; | |
} | |
if (!has_none) | |
return; | |
- THPObjectPtr grad_output; | |
- grad_output = PyTuple_New(num_grad_output); | |
- if (!grad_output) throw python_error(); | |
+ THPObjectPtr grads; | |
+ grads = PyTuple_New(num_grads); | |
+ if (!grads) throw python_error(); | |
// Look for Nones and replace them with new buffers | |
- auto& output_info = self->output_info; | |
- for (int i = 0; i < num_grad_output; i++) { | |
- PyObject *grad = PyTuple_GET_ITEM(raw_grad_output.get(), i); | |
+ auto& grads_info = is_grad_output ? self->output_info : self->input_info; | |
+ TORCH_ASSERT(grads_info.size() == (size_t)num_grads); | |
+ for (int i = 0; i < num_grads; i++) { | |
+ PyObject *grad = PyTuple_GET_ITEM(raw_grads.get(), i); | |
if (grad == Py_None) { | |
- grad = createPyObject(output_info[i].zeros(gpu_guard).data()); | |
+ grad = THPVariable_Wrap(grads_info[i].zeros(gpu_guard)); | |
if (!grad) throw python_error(); | |
} else { | |
Py_INCREF(grad); | |
} | |
- PyTuple_SET_ITEM(grad_output.get(), i, grad); | |
+ PyTuple_SET_ITEM(grads.get(), i, grad); | |
} | |
- raw_grad_output = grad_output.release(); | |
+ raw_grads = grads.release(); | |
} | |
static void _trim_grad_input(THPFunction *self, THPObjectPtr& grad_input) | |
{ | |
int num_grads = PyTuple_GET_SIZE(grad_input.get()); | |
- int num_next_fns = self->cdata.next_functions.size(); | |
- if (num_grads > num_next_fns) { | |
+ const int num_outputs = self->cdata.num_outputs(); | |
+ if (num_grads > num_outputs) { | |
// Check that all extra grads are none | |
bool all_none = true; | |
- for (int i = num_next_fns; i < num_grads; i++) { | |
+ for (int i = num_outputs; i < num_grads; i++) { | |
all_none = (PyTuple_GET_ITEM(grad_input.get(), i) == Py_None); | |
if (!all_none) break; | |
} | |
// If yes, slice the tuple | |
if (all_none) { | |
- num_grads = num_next_fns; | |
+ num_grads = num_outputs; | |
grad_input = PyTuple_GetSlice(grad_input.get(), 0, num_grads); | |
if (!grad_input) throw python_error(); | |
} | |
@@ -915,44 +809,46 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args) | |
PyObject *raw_grad_output = PyTuple_GET_ITEM(args, 0); | |
PyObject *retain_variables = PyTuple_GET_ITEM(args, 1); | |
if (!PyTuple_Check(raw_grad_output) || !PyBool_Check(retain_variables)) { | |
- THPUtils_invalidArguments(args, NULL, "_do_backward", 1, "(tuple, bool)"); | |
- return NULL; | |
+ THPUtils_invalidArguments(args, nullptr, "_do_backward", 1, "(tuple, bool)"); | |
+ return nullptr; | |
} | |
- THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == self->cdata.num_inputs, | |
+ THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == self->cdata.num_inputs(), | |
"%s got an invalid number of gradients (expected %d got %d)", | |
- THPUtils_typename(self), self->cdata.num_inputs, | |
+ THPUtils_typename(self), self->cdata.num_inputs(), | |
PyTuple_GET_SIZE(raw_grad_output)); | |
// Some of the output might have been unused, so we have to allocate | |
// zero-filled buffers instead | |
Py_INCREF(raw_grad_output); | |
THPObjectPtr grad_output(raw_grad_output); | |
- _prepare_grad_output(self, grad_output); | |
+ _prepare_grads(self, grad_output, true); | |
// self.backward(*grad_output) | |
THPObjectPtr backward_fn(PyObject_GetAttrString((PyObject*)self, "backward")); | |
THPUtils_assert(backward_fn.get(), "function %s doesn't implement a required " | |
"'backward' method", THPUtils_typename((PyObject*)self)); | |
THPObjectPtr grad_input(PyObject_CallObject(backward_fn, grad_output.get())); | |
- if (!grad_input) return NULL; | |
+ if (!grad_input) return nullptr; | |
ensure_tuple(grad_input); | |
// We allow functions to return more gradients, than there were outputs, | |
// if and only if the additional ones are all None | |
_trim_grad_input(self, grad_input); | |
int num_grads = PyTuple_GET_SIZE(grad_input.get()); | |
- int num_next_fns = self->cdata.next_functions.size(); | |
- THPUtils_assert(num_grads == num_next_fns, "%s returned an invalid number of " | |
+ int num_outputs = self->cdata.num_outputs(); | |
+ THPUtils_assert(num_grads == num_outputs, "%s returned an invalid number of " | |
"gradient tensors (expected %d, but got %d)", THPUtils_typename(self), | |
- num_next_fns, num_grads); | |
+ num_outputs, num_grads); | |
+ // If any of the remaining grad_inputs are None, zero them. | |
+ _prepare_grads(self, grad_input, false); | |
return grad_input.release(); | |
} catch (python_error& e) { | |
- return NULL; | |
+ return nullptr; | |
} catch (std::exception& e) { | |
THPUtils_setError(e.what()); | |
- return NULL; | |
+ return nullptr; | |
} | |
} | |
@@ -964,7 +860,9 @@ PyObject* THPFunction__register_hook_dict(THPFunction *self, PyObject *_var) | |
{ | |
THPUtils_assert(THPVariable_Check(_var), "_register_hook_dict expected a variable"); | |
THPVariable *var = (THPVariable*)_var; | |
- self->cdata.pre_hooks.emplace_back(new PyFunctionPreHook(var->backward_hooks, var->cdata.output_nr())); | |
+ std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook( | |
+ var->backward_hooks, var->cdata.output_nr())); | |
+ self->cdata.add_pre_hook(std::move(hook)); | |
Py_RETURN_NONE; | |
} | |
@@ -985,7 +883,7 @@ static PyObject *unpack_saved_variables( | |
int num_saved = saved_variables.size(); | |
THPObjectPtr saved(PyTuple_New(num_saved)); | |
if (!saved) | |
- return NULL; | |
+ return nullptr; | |
auto saved_for = THPFunction_asFunction(self); | |
for (int i = 0; i < num_saved; i++) { | |
auto unpacked_var = saved_variables[i].unpack(saved_for); | |
@@ -1003,32 +901,36 @@ static PyObject *unpack_saved_variables( | |
PyObject *THPFunction_saved_tensors(THPFunction *self, void *_unused) | |
{ | |
+ HANDLE_TH_ERRORS | |
return unpack_saved_variables(self, [](const Variable& var) { | |
- return createPyObject(var.data()); | |
+ return THPVariable_Wrap(var); | |
}); | |
+ END_HANDLE_TH_ERRORS | |
} | |
PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused) | |
{ | |
+ HANDLE_TH_ERRORS | |
return unpack_saved_variables(self, [](const Variable& var) { | |
return THPVariable_Wrap(var); | |
}); | |
+ END_HANDLE_TH_ERRORS | |
} | |
PyObject *THPFunction_next_functions(THPFunction *self, void *_unused) | |
{ | |
- auto& next_fns = self->cdata.next_functions; | |
- int size = next_fns.size(); | |
- THPObjectPtr result(PyTuple_New(size)); | |
+ const auto num_outputs = self->cdata.num_outputs(); | |
+ THPObjectPtr result(PyTuple_New(num_outputs)); | |
if (!result) | |
- return NULL; | |
- for (int i = 0; i < size; i++) { | |
+ return nullptr; | |
+ for (uint32_t i = 0; i < num_outputs; i++) { | |
THPObjectPtr fn_tuple(PyTuple_New(2)); | |
- if (!fn_tuple) return NULL; | |
- PyObject* fn = functionToPyObject(next_fns[i].first); | |
- if (!fn) return NULL; | |
+ if (!fn_tuple) return nullptr; | |
+ const auto& edge = self->cdata.next_edge(i); | |
+ PyObject* fn = functionToPyObject(edge.function); | |
+ if (!fn) return nullptr; | |
PyTuple_SET_ITEM(fn_tuple.get(), 0, fn); | |
- PyTuple_SET_ITEM(fn_tuple.get(), 1, PyInt_FromLong(next_fns[i].second)); | |
+ PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr)); | |
PyTuple_SET_ITEM(result.get(), i, fn_tuple.release()); | |
} | |
return result.release(); | |
@@ -1075,43 +977,36 @@ PyObject* getImplMember(PyObject* obj, void* _unused) { | |
return Convert(self->cdata.*ptr); | |
} | |
-int setRequiresGrad(PyObject* obj, PyObject* value, void* _unused) { | |
- auto self = (THPFunction*)obj; | |
- if (!PyBool_Check(value)) { | |
- PyErr_Format(PyExc_TypeError, "'is_executable' must be a bool"); | |
- return -1; | |
- } | |
- self->cdata.is_executable = (value == Py_True); | |
- return 0; | |
+PyObject* getRequiresGrad(PyObject* obj, void* _unused) { | |
+ Py_RETURN_TRUE; | |
} | |
} | |
static struct PyGetSetDef THPFunction_properties[] = { | |
- {"saved_tensors", (getter)THPFunction_saved_tensors, NULL, NULL, NULL}, | |
- {"saved_variables", (getter)THPFunction_saved_variables, NULL, NULL, NULL}, | |
- {"next_functions", (getter)THPFunction_next_functions, NULL, NULL, NULL}, | |
- {"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>, NULL, NULL}, | |
- {"shared_pairs", &getObject<&THPFunction::shared_pairs>, &setObject<&THPFunction::shared_pairs>, NULL, NULL}, | |
- {"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>, NULL, NULL}, | |
- {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, NULL, NULL}, | |
- {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, NULL, NULL, NULL}, | |
- {"requires_grad", &getImplMember<bool, &Function::is_executable, PyBool_FromLong>, &setRequiresGrad, NULL, NULL}, | |
- {"_is_tracing", &getMember<char, &THPFunction::is_traced, PyBool_FromLong>, NULL, NULL, NULL}, | |
- {NULL} | |
+ {"saved_tensors", (getter)THPFunction_saved_tensors, nullptr, nullptr, nullptr}, | |
+ {"saved_variables", (getter)THPFunction_saved_variables, nullptr, nullptr, nullptr}, | |
+ {"next_functions", (getter)THPFunction_next_functions, nullptr, nullptr, nullptr}, | |
+ {"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>, nullptr, nullptr}, | |
+ {"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>, nullptr, nullptr}, | |
+ {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, nullptr, nullptr}, | |
+ {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, nullptr, nullptr, nullptr}, | |
+ {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr}, | |
+ {"_is_tracing", &getMember<char, &THPFunction::is_traced, PyBool_FromLong>, nullptr, nullptr, nullptr}, | |
+ {nullptr} | |
}; | |
static struct PyMethodDef THPFunction_methods[] = { | |
- {(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, NULL}, | |
- {(char*)"_do_forward", (PyCFunction)THPFunction_do_forward, METH_VARARGS, NULL}, | |
- {(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, NULL}, | |
- {(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, NULL}, | |
- {(char*)"register_hook", (PyCFunction)THPFunction_register_hook, METH_O, NULL}, | |
- {NULL} | |
+ {(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr}, | |
+ {(char*)"_do_forward", (PyCFunction)THPFunction_do_forward, METH_VARARGS, nullptr}, | |
+ {(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, nullptr}, | |
+ {(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, nullptr}, | |
+ {(char*)"register_hook", (PyCFunction)THPFunction_register_hook, METH_O, nullptr}, | |
+ {nullptr} | |
}; | |
PyTypeObject THPFunctionType = { | |
- PyVarObject_HEAD_INIT(NULL, 0) | |
+ PyVarObject_HEAD_INIT(nullptr, 0) | |
"torch._C._FunctionBase", /* tp_name */ | |
sizeof(THPFunction), /* tp_basicsize */ | |
0, /* tp_itemsize */ | |
@@ -1131,7 +1026,7 @@ PyTypeObject THPFunctionType = { | |
0, /* tp_setattro */ | |
0, /* tp_as_buffer */ | |
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ | |
- NULL, /* tp_doc */ | |
+ nullptr, /* tp_doc */ | |
(traverseproc)THPFunction_traverse, /* tp_traverse */ | |
(inquiry)THPFunction_clear, /* tp_clear */ | |
0, /* tp_richcompare */ | |
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h | |
index 837b0add..72e894a3 100644 | |
--- a/torch/csrc/autograd/python_function.h | |
+++ b/torch/csrc/autograd/python_function.h | |
@@ -34,9 +34,9 @@ struct PyFunction : public Function { | |
virtual variable_list apply(const variable_list& inputs) override; | |
variable_list legacy_apply(const variable_list& inputs); | |
- virtual void releaseVariables() override; | |
+ virtual void release_variables() override; | |
virtual std::string name() override; | |
- virtual std::shared_ptr<Function> getSharedPtr() override; | |
+ virtual std::shared_ptr<Function> get_shared_ptr() override; | |
virtual bool is_traceable() override; | |
// THPFunction this Function is wrapping. | |
@@ -66,19 +66,15 @@ struct THPFunction { | |
PyObject *needs_input_grad; | |
// Python tuple of tensors whose variables we should save. Set | |
- // by Python with 'save_for_backward'. If NULL, no tensors were | |
+ // by Python with 'save_for_backward'. If nullptr, no tensors were | |
// saved. | |
PyObject *to_save; | |
- // Python pairs of distinct tensors which share storage. Set by | |
- // Python with 'mark_shared_storage'. If NULL, no tensors share | |
- // storage. | |
- PyObject *shared_pairs; | |
// Python tuple of tensors which are not differentiable. Set by | |
- // Python with 'mark_non_differentiable'. If NULL, no tensors were | |
+ // Python with 'mark_non_differentiable'. If nullptr, no tensors were | |
// non-differentiable. | |
PyObject *non_differentiable; | |
// Python tuple of tensors which had inplace updates in the forward() | |
- // pass. Set by Python with 'mark_dirty'. If NULL, no tensors were | |
+ // pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were | |
// modified inplace. | |
PyObject *dirty_tensors; | |
@@ -98,8 +94,6 @@ struct THPFunction { | |
bool THPFunction_initModule(PyObject *module); | |
extern PyTypeObject THPFunctionType; | |
extern PyObject *THPFunctionClass; | |
-extern PyObject *THPStochasticFunctionClass; | |
-extern PyObject *THPBatchNormBackwardBackwardFunction; // Temporarily here until we move it to C++ | |
// XXX: this function requires the GIL (it can have side effects). | |
std::shared_ptr<torch::autograd::PyFunction> THPFunction_asFunction(THPFunction* self); | |
diff --git a/torch/csrc/autograd/python_hook.cpp b/torch/csrc/autograd/python_hook.cpp | |
index c87669aa..3ceb1f4a 100644 | |
--- a/torch/csrc/autograd/python_hook.cpp | |
+++ b/torch/csrc/autograd/python_hook.cpp | |
@@ -169,11 +169,11 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject | |
throw std::runtime_error(ss.str()); | |
} | |
- if (original.type().isCuda() != result.type().isCuda()) { | |
+ if (original.type().is_cuda() != result.type().is_cuda()) { | |
std::stringstream ss; | |
auto name = hook_name(hook); | |
ss << "hook '" << name << "' has changed the type of value"; | |
- if (original.type().isCuda()) { | |
+ if (original.type().is_cuda()) { | |
ss << " (was CUDA tensor got CPU tensor)"; | |
} else { | |
ss << " (was CPU tensor got CUDA tensor)"; | |
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp | |
index 40321519..ecb59008 100644 | |
--- a/torch/csrc/autograd/python_variable.cpp | |
+++ b/torch/csrc/autograd/python_variable.cpp | |
@@ -1,23 +1,38 @@ | |
#include "torch/csrc/autograd/python_variable.h" | |
-#include <structmember.h> | |
- | |
#include "THP.h" | |
#include "torch/csrc/DynamicTypes.h" | |
+#include "torch/csrc/Exceptions.h" | |
+#include "torch/csrc/Size.h" | |
#include "torch/csrc/Types.h" | |
+#include "torch/csrc/autograd/edge.h" | |
#include "torch/csrc/autograd/python_cpp_function.h" | |
#include "torch/csrc/autograd/python_hook.h" | |
+#include "torch/csrc/autograd/python_variable_indexing.h" | |
+#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/autograd/functions/accumulate_grad.h" | |
-#include "torch/csrc/cuda/AutoGPU.h" | |
+#include "torch/csrc/autograd/function.h" | |
+#include "torch/csrc/autograd/generated/VariableType.h" | |
+#include "torch/csrc/autograd/utils/wrap_outputs.h" | |
+#include "torch/csrc/jit/tracer_state.h" | |
+#include "torch/csrc/tensor/python_tensor.h" | |
#include "torch/csrc/utils/auto_gil.h" | |
-#include "torch/csrc/Exceptions.h" | |
-#include "torch/csrc/autograd/variable.h" | |
+#include "torch/csrc/utils/python_strings.h" | |
+ | |
+#include <ATen/ATen.h> | |
+#include <list> | |
+#include <memory> | |
+#include <structmember.h> | |
using namespace at; | |
using namespace torch::autograd; | |
-PyObject *THPVariableClass = NULL; | |
+PyObject *THPVariableClass = nullptr; | |
+ | |
+static const char* VOLATILE_WARNING = | |
+ "volatile was removed and now has no effect. Use " | |
+ "`with torch.no_grad():` instead."; | |
// Creates a new Python object for a Variable. The Variable must not already | |
// have a PyObject* associated with it. | |
@@ -27,11 +42,13 @@ static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var) | |
if (obj) { | |
auto v = (THPVariable*) obj; | |
new (&v->cdata) Variable(std::move(var)); | |
- v->cdata.get()->pyobj = obj; | |
- if (auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn().get())) { | |
+ v->cdata.set_pyobj(obj); | |
+ if (auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn_unsafe())) { | |
// Create a new reference to the THPFunction. This ensures that ref count | |
// of the THPFunction is at least the number of referring THPVariables. | |
- v->cdata.grad_fn() = THPFunction_asFunction((THPFunction*)fn->obj); | |
+ const auto output_nr = v->cdata.output_nr(); | |
+ auto grad_fn = THPFunction_asFunction((THPFunction*)fn->obj); | |
+ v->cdata.set_gradient_edge({std::move(grad_fn), output_nr}); | |
} | |
} | |
return obj; | |
@@ -43,11 +60,7 @@ PyObject * THPVariable_Wrap(Variable var) | |
Py_RETURN_NONE; | |
} | |
- if (var.dim() == 0) { | |
- throw std::runtime_error("Variable API does not support Scalars"); | |
- } | |
- | |
- if (auto obj = var.get()->pyobj) { | |
+ if (auto obj = var.pyobj()) { | |
Py_INCREF(obj); | |
return obj; | |
} | |
@@ -55,50 +68,8 @@ PyObject * THPVariable_Wrap(Variable var) | |
return THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, std::move(var)); | |
} | |
-// This function DOES NOT steal a reference to data | |
-PyObject * THPVariable_NewWithFunction(PyObject *data, const std::shared_ptr<torch::autograd::Function>& grad_fn) | |
-{ | |
- THPUtils_assert(THPModule_isTensor(data), "data must be a Tensor"); | |
- | |
- Variable v = make_variable(torch::createTensor(data)); | |
- v.requires_grad() = grad_fn->is_executable; | |
- v.grad_fn() = grad_fn; | |
- | |
- PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v)); | |
- if (obj) { | |
- ((THPVariable*)obj)->data = data; | |
- Py_INCREF(data); | |
- } | |
- return obj; | |
-} | |
- | |
-// This function DOES NOT steal a reference to data | |
-PyObject * THPVariable_NewVolatile(PyObject *data) | |
-{ | |
- Variable v = make_variable(torch::createTensor(data), false, true); | |
- PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v)); | |
- if (obj) { | |
- ((THPVariable*)obj)->data = data; | |
- Py_INCREF(data); | |
- } | |
- return obj; | |
-} | |
- | |
-// This function DOES NOT steal a reference to data | |
-PyObject * THPVariable_NewLeaf(PyObject *data) | |
-{ | |
- Variable v = make_variable(torch::createTensor(data)); | |
- PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v)); | |
- if (obj) { | |
- ((THPVariable*)obj)->data = data; | |
- Py_INCREF(data); | |
- } | |
- return obj; | |
-} | |
- | |
static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg) | |
{ | |
- Py_VISIT(self->data); | |
Py_VISIT(self->backward_hooks); | |
// We don't want to traverse the grad_fn, even if the Variable owns it and the | |
// shared pointer's use count is 1. This is because we would need to treat | |
@@ -115,7 +86,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg) | |
// for more details about the race condition involving traversing the grad_fn | |
// and the python GC. | |
if (self->cdata.defined()) { | |
- for (auto& hook : self->cdata.hooks()) { | |
+ for (const auto& hook : self->cdata.hooks()) { | |
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) { | |
Py_VISIT(pyhook->dict); | |
} | |
@@ -126,13 +97,12 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg) | |
static int THPVariable_clear(THPVariable *self) | |
{ | |
- Py_CLEAR(self->data); | |
Py_CLEAR(self->backward_hooks); | |
if (self->cdata.defined()) { | |
- if (auto grad_acc = self->cdata.get()->grad_accumulator.lock()) { | |
- grad_acc->pre_hooks.clear(); | |
+ if (auto grad_acc = self->cdata.try_get_grad_accumulator()) { | |
+ grad_acc->pre_hooks().clear(); | |
} | |
- self->cdata.get()->pyobj = nullptr; | |
+ self->cdata.set_pyobj(nullptr); | |
} | |
self->cdata.reset(); | |
return 0; | |
@@ -148,25 +118,24 @@ static void THPVariable_dealloc(THPVariable* self) | |
PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) | |
{ | |
+ HANDLE_TH_ERRORS | |
THPObjectPtr _data; | |
- PyObject *data = NULL; | |
- PyObject *grad_fn = NULL; | |
+ PyObject *data = nullptr; | |
+ PyObject *grad_fn = nullptr; | |
char is_volatile = 0; | |
char requires_grad = 0; | |
+ const char* name = nullptr; | |
- const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", NULL}; | |
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbO", (char**)accepted_args, | |
- &data, &requires_grad, &is_volatile, &grad_fn)) | |
- return NULL; | |
+ const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr}; | |
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args, | |
+ &data, &requires_grad, &is_volatile, &grad_fn, &name)) | |
+ return nullptr; | |
if (grad_fn == Py_None) | |
- grad_fn = NULL; | |
+ grad_fn = nullptr; | |
- if (data == NULL || data == Py_None) { | |
- // For legacy serialization code, create an empty tensor temporarily. | |
- at::Tensor tensor = at::CPU(at::kFloat).tensor(); | |
- _data = torch::createPyObject(tensor); | |
- data = _data.get(); | |
+ if (is_volatile) { | |
+ PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1); | |
} | |
THPUtils_assert(!(is_volatile && requires_grad), | |
@@ -174,23 +143,34 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) | |
THPUtils_assert(!grad_fn || THPFunction_Check(grad_fn), | |
"Variable _grad_fn has to be a Function object or None, but got %s", | |
THPUtils_typename(grad_fn)); | |
- THPUtils_assert(THPModule_isTensor(data), "Variable data has to " | |
- "be a tensor, but got %s", THPUtils_typename(data)); | |
+ Tensor tensor; | |
+ if (!data || data == Py_None) { | |
+ // For legacy serialization code, create an empty tensor. This is also used | |
+ // by nn.Parameter() with no arguments. | |
+ auto var = torch::tensor::get_default_tensor_type().tensor(); | |
+ tensor = static_cast<Variable&>(var).data(); | |
+ } else if (THPVariable_Check(data)) { | |
+ tensor = ((THPVariable*)data)->cdata.data(); | |
+ } else { | |
+ throw torch::TypeError("Variable data has to be a tensor, but got %s", | |
+ THPUtils_typename(data)); | |
+ } | |
Variable var; | |
if (grad_fn) { | |
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn); | |
- var = make_variable(torch::createTensor(data), grad_fn_); | |
+ Edge edge(grad_fn_, grad_fn_->bump_inputs()); | |
+ var = make_variable(std::move(tensor), std::move(edge)); | |
} else { | |
- var = make_variable(torch::createTensor(data), requires_grad, is_volatile); | |
+ var = make_variable(std::move(tensor), requires_grad); | |
} | |
- PyObject* self = THPVariable_NewWithVar(type, std::move(var)); | |
- if (self) { | |
- ((THPVariable*)self)->data = data; | |
- Py_INCREF(data); | |
+ if (name) { | |
+ var.set_name(name); | |
} | |
- return self; | |
+ | |
+ return THPVariable_NewWithVar(type, std::move(var)); | |
+ END_HANDLE_TH_ERRORS | |
} | |
int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds) | |
@@ -199,13 +179,14 @@ int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds) | |
// The 'data' argument is optional in __new__ to handle legacy serialized | |
// Variables. | |
PyObject *data; | |
- PyObject *grad_fn = NULL; | |
+ PyObject *grad_fn = nullptr; | |
char is_volatile = 0; | |
char requires_grad = 0; | |
+ const char* name = nullptr; | |
- const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", NULL}; | |
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbO", (char**)accepted_args, | |
- &data, &requires_grad, &is_volatile, &grad_fn)) | |
+ const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr}; | |
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args, | |
+ &data, &requires_grad, &is_volatile, &grad_fn, &name)) | |
return -1; | |
return 0; | |
@@ -214,151 +195,144 @@ int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds) | |
typedef PyObject *(*getter)(PyObject *, void *); | |
typedef int (*setter)(PyObject *, PyObject *, void *); | |
+PyObject *THPVariable_get_cdata(THPVariable *self) | |
+{ | |
+ HANDLE_TH_ERRORS | |
+ auto& var = self->cdata; | |
+ return PyLong_FromVoidPtr(var.unsafeGetTH(false)); | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
PyObject *THPVariable_get_version(THPVariable *self) | |
{ | |
+ HANDLE_TH_ERRORS | |
auto& var = self->cdata; | |
return PyInt_FromLong(var.current_version()); | |
+ END_HANDLE_TH_ERRORS | |
} | |
PyObject *THPVariable_get_grad_fn(THPVariable *self) | |
{ | |
+ HANDLE_TH_ERRORS | |
auto& var = self->cdata; | |
if (!var.grad_fn()) { | |
Py_RETURN_NONE; | |
} | |
return functionToPyObject(var.grad_fn()); | |
+ END_HANDLE_TH_ERRORS | |
} | |
-int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj) | |
+static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj) | |
{ | |
+ HANDLE_TH_ERRORS | |
THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None"); | |
- self->cdata.grad_fn() = nullptr; | |
+ self->cdata.detach_(); | |
return 0; | |
+ END_HANDLE_TH_ERRORS_RET(-1) | |
} | |
-PyObject *THPVariable_is_leaf(THPVariable *self) | |
+static PyObject *THPVariable_is_leaf(THPVariable *self) | |
{ | |
+ HANDLE_TH_ERRORS | |
return PyBool_FromLong(!self->cdata.grad_fn()); | |
+ END_HANDLE_TH_ERRORS | |
} | |
-PyObject * THPVariable_get_data(THPVariable *self) | |
+static PyObject * THPVariable_get_data(THPVariable *self) | |
{ | |
- if (!self->data) { | |
- self->data = torch::createPyObject(self->cdata.data()); | |
- } | |
- Py_XINCREF(self->data); | |
- return self->data; | |
-} | |
- | |
-namespace { | |
- | |
-// XXX: This is a hack to access private TensorImpl::type_ | |
-// http://bloglitb.blogspot.com/2011/12/access-to-private-members-safer.html | |
-// This is currently needed because module.float() changes the type of the | |
-// data field of each variable. We should fix this and not allow changing the | |
-// type of var.data. | |
- | |
-template<typename Tag, typename Tag::type M> | |
-struct Rob { | |
- friend typename Tag::type get(Tag) { | |
- return M; | |
- } | |
-}; | |
- | |
-struct TensorImpl_Type { | |
- typedef Type* TensorImpl::*type; | |
- friend type get(TensorImpl_Type); | |
-}; | |
- | |
-template struct Rob<TensorImpl_Type, &TensorImpl::type_>; | |
- | |
+ HANDLE_TH_ERRORS | |
+ return THPVariable_Wrap(make_variable(self->cdata.data(), false)); | |
+ END_HANDLE_TH_ERRORS | |
} | |
int THPVariable_set_data(THPVariable *self, PyObject *data) | |
{ | |
- THPUtils_assertRet(-1, THPModule_isTensor(data), "Variable data has to " | |
- "be a tensor, but got %s", THPUtils_typename(data)); | |
- Py_INCREF(data); | |
- Py_XDECREF(self->data); | |
- self->data = data; | |
- Tensor tensor = torch::createTensor(data); | |
- if (&self->cdata.data().type() != &tensor.type()) { | |
+ HANDLE_TH_ERRORS | |
+ if (!THPVariable_Check(data)) { | |
+ throw torch::TypeError("Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name); | |
+ } | |
+ Tensor tensor = THPVariable_UnpackData(data); | |
+ if (self->cdata.data().type() != tensor.type()) { | |
// we change the type of var.data so we must change the type of var | |
- auto newType = VariableImpl::getType(tensor); | |
- self->cdata.get()->*get(TensorImpl_Type()) = newType; | |
+ auto newType = VariableType::getType(tensor); | |
+ self->cdata.temporary_hack_set_type(newType); | |
} | |
- self->cdata.data() = tensor; | |
+ self->cdata.data() = std::move(tensor); | |
return 0; | |
+ END_HANDLE_TH_ERRORS_RET(-1) | |
} | |
PyObject *THPVariable_get_grad(THPVariable *self) | |
{ | |
+ HANDLE_TH_ERRORS | |
return THPVariable_Wrap(self->cdata.grad()); | |
+ END_HANDLE_TH_ERRORS | |
} | |
-int THPVariable_set_grad(THPVariable *self, PyObject *other) | |
+int THPVariable_set_grad(THPVariable *self, PyObject *py_grad) | |
{ | |
+ HANDLE_TH_ERRORS | |
auto& var = self->cdata; | |
- if (other == Py_None) { | |
- var.grad().reset(); | |
+ if (py_grad == Py_None) { | |
+ var.reset_grad(); | |
return 0; | |
} | |
- THPUtils_assertRet(-1, THPVariable_Check(other), | |
- "expected Variable or None (got %s)", THPUtils_typename(other)); | |
- THPUtils_assertRet(-1, self != (THPVariable*)other, | |
+ THPUtils_assertRet(-1, THPVariable_Check(py_grad), | |
+ "expected Variable or None (got %s)", THPUtils_typename(py_grad)); | |
+ THPUtils_assertRet(-1, self != (THPVariable*)py_grad, | |
"can't assign Variable as its own grad"); | |
- auto& data = var.data(); | |
- auto& other_var = ((THPVariable*)other)->cdata; | |
- auto& other_data = other_var.data(); | |
+ auto& grad = ((THPVariable*)py_grad)->cdata; | |
+ auto& sparseType = var.type().toBackend(var.is_cuda() ? kSparseCUDA : kSparseCPU); | |
- // Make sure the data is ok | |
- THPUtils_assertRet(-1, other_data.type().ID() == data.type().ID(), | |
+ THPUtils_assertRet(-1, grad.type() == var.type() || grad.type() == sparseType, | |
"assigned grad has data of a different type"); | |
- THPUtils_assertRet(-1, other_data.type().isCuda() == data.type().isCuda(), | |
- "assigned grad has data located on a different device"); | |
- if (data.type().isCuda()) { | |
- THPUtils_assertRet(-1, other_data.get_device() == data.get_device(), | |
+ if (var.type().is_cuda()) { | |
+ THPUtils_assertRet(-1, grad.get_device() == var.get_device(), | |
"assigned grad has data located on a different device"); | |
} | |
- THPUtils_assertRet(-1, other_data.sizes().vec() == data.sizes().vec(), | |
+ THPUtils_assertRet(-1, grad.sizes().equals(var.sizes()), | |
"assigned grad has data of a different size"); | |
- var.grad() = other_var; | |
+ var.grad() = grad; | |
return 0; | |
+ END_HANDLE_TH_ERRORS_RET(-1) | |
} | |
PyObject *THPVariable_get_volatile(THPVariable *self) | |
{ | |
- auto& var = self->cdata; | |
- return PyBool_FromLong(var.is_volatile()); | |
+ const char* msg = "volatile was removed (Variable.volatile is always False)"; | |
+ PyErr_WarnEx(PyExc_UserWarning, msg, 1); | |
+ Py_RETURN_FALSE; | |
} | |
int THPVariable_set_volatile(THPVariable *self, PyObject *obj) | |
{ | |
- THPUtils_assertRet(-1, PyBool_Check(obj), "volatile must be a bool"); | |
- THPUtils_assertRet(-1, !self->cdata.grad_fn(), | |
- "volatile can only be set on leaf variables"); | |
- self->cdata.is_volatile() = (obj == Py_True); | |
- return 0; | |
+ return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1); | |
} | |
PyObject *THPVariable_get_output_nr(THPVariable *self) | |
{ | |
- return PyInt_FromLong(self->cdata.output_nr()); | |
+ HANDLE_TH_ERRORS | |
+ const auto output_nr = static_cast<long>(self->cdata.output_nr()); | |
+ return PyInt_FromLong(output_nr); | |
+ END_HANDLE_TH_ERRORS | |
} | |
PyObject *THPVariable_get_requires_grad(THPVariable *self) | |
{ | |
+ HANDLE_TH_ERRORS | |
return PyBool_FromLong(self->cdata.requires_grad()); | |
+ END_HANDLE_TH_ERRORS | |
} | |
int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj) | |
{ | |
+ HANDLE_TH_ERRORS | |
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool"); | |
auto& var = self->cdata; | |
- if (var.grad_fn()) { | |
+ if (!var.is_leaf()) { | |
const char *hint = ""; | |
if (obj == Py_False) { | |
hint = " If you want to use a computed variable in a subgraph " | |
@@ -368,54 +342,119 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj) | |
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint); | |
return -1; | |
} | |
- var.requires_grad() = (obj == Py_True); | |
- if (auto grad_accumulator = var.get()->grad_accumulator.lock()) { | |
- grad_accumulator->is_executable = var.requires_grad(); | |
- } | |
+ var.set_requires_grad(obj == Py_True); | |
return 0; | |
+ END_HANDLE_TH_ERRORS_RET(-1) | |
+} | |
+ | |
+PyObject *THPVariable_get_name(THPVariable* self) | |
+{ | |
+ if (self->cdata.name() == "") | |
+ Py_RETURN_NONE; | |
+ return THPUtils_packString(self->cdata.name().c_str()); | |
} | |
PyObject *THPVariable_get_backwards_hooks(THPVariable *self) | |
{ | |
+ HANDLE_TH_ERRORS | |
if (self->backward_hooks) { | |
Py_INCREF(self->backward_hooks); | |
return self->backward_hooks; | |
} | |
Py_RETURN_NONE; | |
+ END_HANDLE_TH_ERRORS | |
} | |
int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj) | |
{ | |
+ HANDLE_TH_ERRORS | |
if (obj == Py_None) { | |
obj = nullptr; | |
} | |
Py_XINCREF(obj); | |
Py_XDECREF(self->backward_hooks); | |
self->backward_hooks = obj; | |
- self->cdata.hooks().clear(); | |
+ self->cdata.clear_hooks(); | |
if (obj) { | |
- self->cdata.hooks().emplace_back(new PyFunctionPreHook(obj, 0)); | |
+ self->cdata.add_hook(std::make_shared<PyFunctionPreHook>(obj, 0)); | |
} | |
return 0; | |
+ END_HANDLE_TH_ERRORS_RET(-1) | |
+} | |
+ | |
+PyObject *THPVariable_get_base(THPVariable *self) | |
+{ | |
+ HANDLE_TH_ERRORS | |
+ if (self->cdata.is_view()) { | |
+ return THPVariable_Wrap(self->cdata.base()); | |
+ } | |
+ Py_RETURN_NONE; | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
+PyObject *THPVariable_get_shape(THPVariable *self) | |
+{ | |
+ HANDLE_TH_ERRORS | |
+ auto& self_ = self->cdata; | |
+ auto sizes = self_.sizes(); | |
+ return THPSize_New(sizes.size(), (int64_t *)sizes.data()); | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
+PyObject *THPVariable_is_cuda(THPVariable *self) | |
+{ | |
+ HANDLE_TH_ERRORS | |
+ auto& self_ = self->cdata; | |
+ return torch::autograd::utils::wrap(self_.is_cuda()); | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
+PyObject *THPVariable_is_sparse(THPVariable *self) | |
+{ | |
+ HANDLE_TH_ERRORS | |
+ auto& self_ = self->cdata; | |
+ return torch::autograd::utils::wrap(self_.is_sparse()); | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
+PyObject *THPVariable_dtype(THPVariable *self) | |
+{ | |
+ HANDLE_TH_ERRORS | |
+ auto& self_ = self->cdata; | |
+ return torch::autograd::utils::wrap(torch::getDtype(self_.type())); | |
+ END_HANDLE_TH_ERRORS | |
} | |
static struct PyGetSetDef THPVariable_properties[] = { | |
- {"_version", (getter)THPVariable_get_version, NULL, NULL, NULL}, | |
- {"grad_fn", (getter)THPVariable_get_grad_fn, NULL, NULL, NULL}, | |
- {"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, NULL, NULL}, | |
- {"is_leaf", (getter)THPVariable_is_leaf, NULL, NULL, NULL}, | |
- {"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, NULL, NULL}, | |
- {"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL}, // only for legacy reasons | |
- {"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL}, | |
- {"volatile", (getter)THPVariable_get_volatile, (setter)THPVariable_set_volatile, NULL, NULL}, | |
- {"output_nr", (getter)THPVariable_get_output_nr, NULL, NULL, NULL}, | |
- {"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad, NULL, NULL}, | |
- {"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, NULL, NULL}, | |
- {NULL} | |
+ {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr}, | |
+ {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr}, | |
+ {"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr}, | |
+ {"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, nullptr, nullptr}, | |
+ {"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr}, | |
+ {"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, nullptr, nullptr}, | |
+ {"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr}, // only for legacy reasons | |
+ {"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr}, | |
+ {"_base", (getter)THPVariable_get_base, nullptr, nullptr, nullptr}, | |
+ {"volatile", (getter)THPVariable_get_volatile, (setter)THPVariable_set_volatile, nullptr, nullptr}, | |
+ {"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr}, | |
+ {"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad, nullptr, nullptr}, | |
+ {"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, nullptr, nullptr}, | |
+ {"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr}, | |
+ {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr}, | |
+ {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr}, | |
+ {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr}, | |
+ {"dtype", (getter)THPVariable_dtype, NULL, NULL, NULL}, | |
+ {nullptr} | |
+}; | |
+ | |
+static PyMappingMethods THPVariable_as_mapping = { | |
+ THPVariable_length, | |
+ THPVariable_getitem, | |
+ THPVariable_setitem, | |
}; | |
PyTypeObject THPVariableType = { | |
- PyVarObject_HEAD_INIT(NULL, 0) | |
+ PyVarObject_HEAD_INIT(nullptr, 0) | |
"torch._C._VariableBase", /* tp_name */ | |
sizeof(THPVariable), /* tp_basicsize */ | |
0, /* tp_itemsize */ | |
@@ -427,7 +466,7 @@ PyTypeObject THPVariableType = { | |
0, /* tp_repr */ | |
0, /* tp_as_number */ | |
0, /* tp_as_sequence */ | |
- 0, /* tp_as_mapping */ | |
+ &THPVariable_as_mapping, /* tp_as_mapping */ | |
0, /* tp_hash */ | |
0, /* tp_call */ | |
0, /* tp_str */ | |
@@ -435,7 +474,7 @@ PyTypeObject THPVariableType = { | |
0, /* tp_setattro */ | |
0, /* tp_as_buffer */ | |
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ | |
- NULL, /* tp_doc */ | |
+ nullptr, /* tp_doc */ | |
(traverseproc)THPVariable_traverse, /* tp_traverse */ | |
(inquiry)THPVariable_clear, /* tp_clear */ | |
0, /* tp_richcompare */ | |
@@ -458,6 +497,7 @@ PyTypeObject THPVariableType = { | |
namespace torch { namespace autograd { | |
extern PyMethodDef variable_methods[]; | |
+extern void initTorchFunctions(PyObject *module); | |
}} | |
@@ -470,5 +510,6 @@ bool THPVariable_initModule(PyObject *module) | |
return false; | |
Py_INCREF(&THPVariableType); | |
PyModule_AddObject(module, "_VariableBase", (PyObject *)&THPVariableType); | |
+ torch::autograd::initTorchFunctions(module); | |
return true; | |
} | |
diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h | |
index 10b64fc3..21ab0596 100644 | |
--- a/torch/csrc/autograd/python_variable.h | |
+++ b/torch/csrc/autograd/python_variable.h | |
@@ -5,32 +5,34 @@ | |
#include <ATen/ATen.h> | |
#include "torch/csrc/autograd/variable.h" | |
+#include "torch/csrc/THP_export.h" | |
// Python object that backs torch.autograd.Variable | |
struct THPVariable { | |
PyObject_HEAD | |
// Payload | |
torch::autograd::Variable cdata; | |
- // Tensor this wraps (corresponds to Python attr 'data'). | |
- // It assumed that a THPVariable is *uniquely* identified by the | |
- // tensor it wraps. | |
- // Invariant: v->data == v->cdata->data | |
- PyObject* data; | |
// Hooks to be run on backwards pass (corresponds to Python attr | |
// '_backwards_hooks', set by 'register_hook') | |
PyObject* backward_hooks; | |
}; | |
-extern PyObject *THPVariableClass; | |
+THP_API PyObject *THPVariableClass; | |
bool THPVariable_initModule(PyObject *module); | |
-PyObject * THPVariable_NewVolatile(PyObject *data); | |
-PyObject * THPVariable_NewLeaf(PyObject *data); | |
-PyObject * THPVariable_NewWithFunction(PyObject *data, const std::shared_ptr<torch::autograd::Function>& var); | |
PyObject * THPVariable_Wrap(torch::autograd::Variable var); | |
-PyObject * THPVariable_get_data(THPVariable *self); | |
inline bool THPVariable_Check(PyObject *obj) | |
{ | |
return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass); | |
} | |
+ | |
+inline torch::autograd::Variable& THPVariable_Unpack(PyObject* obj) { | |
+ auto var = (THPVariable*)obj; | |
+ return var->cdata; | |
+} | |
+ | |
+inline at::Tensor& THPVariable_UnpackData(PyObject* obj) { | |
+ auto var = (THPVariable*)obj; | |
+ return var->cdata.data(); | |
+} | |
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp | |
new file mode 100644 | |
index 00000000..320c5cab | |
--- /dev/null | |
+++ b/torch/csrc/autograd/python_variable_indexing.cpp | |
@@ -0,0 +1,350 @@ | |
+#include "torch/csrc/autograd/python_variable_indexing.h" | |
+ | |
+#include "torch/csrc/DynamicTypes.h" | |
+#include "torch/csrc/Exceptions.h" | |
+#include "torch/csrc/THP_export.h" | |
+#include "torch/csrc/autograd/function.h" | |
+#include "torch/csrc/autograd/python_variable.h" | |
+#include "torch/csrc/autograd/utils/wrap_outputs.h" | |
+#include "torch/csrc/autograd/variable.h" | |
+#include "torch/csrc/utils/python_compat.h" | |
+#include "torch/csrc/utils/python_numbers.h" | |
+#include "torch/csrc/utils/tensor_new.h" | |
+ | |
+#include <ATen/ExpandUtils.h> | |
+#include <vector> | |
+ | |
+using namespace at; | |
+using namespace torch::autograd::utils; | |
+ | |
+namespace torch { namespace autograd { | |
+ | |
+Py_ssize_t THPVariable_length(PyObject* self) { | |
+ HANDLE_TH_ERRORS | |
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata; | |
+ if (self_.dim() == 0) { | |
+ return 0; | |
+ } | |
+ return (Py_ssize_t)self_.size(0); | |
+ END_HANDLE_TH_ERRORS_RET(-1) | |
+} | |
+ | |
+ | |
+// We allow indexing by integers, slices, ellipsis, None, Variables, | |
+// and tuples of those types. We also handle bools as if they were a | |
+// Variable[ByteTensor]. | |
+ | |
+static int64_t count_specified_dimensions(PyObject* index) { | |
+ // Count the number of indexed dimensions (everything but ellipsis and None) | |
+ int64_t count = 0; | |
+ auto size = PyTuple_GET_SIZE(index); | |
+ for (Py_ssize_t i = 0; i < size; i++) { | |
+ PyObject* obj = PyTuple_GET_ITEM(index, i); | |
+ if (THPVariable_Check(obj)) { | |
+ auto& var = reinterpret_cast<THPVariable*>(obj)->cdata; | |
+ if (var.type().scalarType() == kByte) { | |
+ count += var.dim(); | |
+ } else { | |
+ count++; | |
+ } | |
+ } else if (obj != Py_None && obj != Py_Ellipsis) { | |
+ count++; | |
+ } | |
+ } | |
+ return count; | |
+} | |
+ | |
+[[noreturn]] | |
+static void invalid_index(PyObject* obj) { | |
+ throw IndexError( | |
+ "only integers, slices (`:`), ellipsis (`...`), None and long or byte " | |
+ "Variables are valid indices (got %s)", Py_TYPE(obj)->tp_name); | |
+} | |
+ | |
+static Variable applySlice(const Variable& self, int64_t dim, PyObject* slice, bool ensure_view=false) { | |
+ Py_ssize_t start, stop, step, slicelength; | |
+ auto length = self.size(dim); | |
+ if (!THPUtils_parseSlice(slice, length, &start, &stop, &step, &slicelength)) { | |
+ throw python_error(); | |
+ } | |
+ if (step == 0) { | |
+ throw ValueError("step cannot be zero"); | |
+ } | |
+ if (step < 0) { | |
+ // TODO: implement negative step | |
+ throw ValueError("negative step not yet supported"); | |
+ } | |
+ if (!ensure_view && start == 0 && stop == length && step == 1) { | |
+ return self; | |
+ } | |
+ return self.slice(dim, start, stop, step); | |
+} | |
+ | |
+static Variable applySelect(const Variable& self, int64_t dim, int64_t index) { | |
+ if (index == 0 && dim == 0 && self.dim() == 0) { | |
+ // Deprecated support for indexing 0-dim tensors as if they were 1-dim. | |
+ PyErr_WarnEx(PyExc_UserWarning, | |
+ "invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. " | |
+ "Use tensor.item() to convert a 0-dim tensor to a Python number", 1); | |
+ return at::alias(self); | |
+ } | |
+ int64_t size = self.size(dim); | |
+ if (index < -size || index >= size) { | |
+ throw IndexError("index %lld is out of bounds for dimension %lld with size %lld", | |
+ index, dim, size); | |
+ } | |
+ if (index < 0) { | |
+ index += size; | |
+ } | |
+ return self.select(dim, index); | |
+} | |
+ | |
+static Variable sequenceToVariable(const Type& type, PyObject* seq) { | |
+ auto& idx_type = type.toScalarType(kLong); | |
+ return torch::utils::new_from_data(idx_type, -1, seq); | |
+} | |
+ | |
+static Variable valueToTensor(const Type & type, PyObject* value) { | |
+ if (THPVariable_Check(value)) { | |
+ return reinterpret_cast<THPVariable*>(value)->cdata; | |
+ } | |
+ if (THPUtils_checkLong(value)) { | |
+ return type.scalarTensor(Scalar(THPUtils_unpackLong(value))); | |
+ } | |
+ if (PyFloat_Check(value)) { | |
+ return type.scalarTensor(Scalar(THPUtils_unpackDouble(value))); | |
+ } | |
+ throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString()); | |
+} | |
+ | |
+static Variable applySlicing(const Variable& self, PyObject* index, variable_list& outIndices) { | |
+ int64_t size = PyTuple_GET_SIZE(index); | |
+ int64_t dim = 0; | |
+ int64_t specified_dims = count_specified_dimensions(index); | |
+ | |
+ auto handle_var = [&](const Variable& var) { | |
+ // TODO: check scalarType | |
+ outIndices.resize(dim + 1); | |
+ outIndices[dim] = var; | |
+ dim++; | |
+ }; | |
+ | |
+ if (specified_dims > self.dim()) { | |
+ throw IndexError("too many indices for tensor of dimension %d", (int)self.dim()); | |
+ } | |
+ | |
+ Variable result = self; | |
+ for (int64_t i = 0; i < size; i++) { | |
+ PyObject* obj = PyTuple_GET_ITEM(index, i); | |
+ if (THPUtils_checkLong(obj)) { | |
+ result = applySelect(result, dim, THPUtils_unpackLong(obj)); | |
+ } else if (PySlice_Check(obj)) { | |
+ result = applySlice(result, dim, obj); | |
+ dim++; | |
+ } else if (obj == Py_Ellipsis) { | |
+ dim += self.dim() - specified_dims; | |
+ } else if (obj == Py_None) { | |
+ result = result.unsqueeze(dim); | |
+ dim++; | |
+ } else if (THPVariable_Check(obj)) { | |
+ handle_var(reinterpret_cast<THPVariable*>(obj)->cdata); | |
+ } else if (PySequence_Check(obj)) { | |
+ handle_var(sequenceToVariable(self.type(), obj)); | |
+ } else { | |
+ auto index = THPObjectPtr(PyNumber_Index(obj)); | |
+ if (!index) { | |
+ PyErr_Clear(); | |
+ invalid_index(obj); | |
+ } | |
+ result = applySelect(result, dim, THPUtils_unpackLong(index)); | |
+ } | |
+ } | |
+ return result; | |
+} | |
+ | |
+static std::vector<Tensor> asTensorList(const variable_list& v) { | |
+ return std::vector<Tensor>(v.begin(), v.end()); | |
+} | |
+ | |
+static Variable dispatch_index(const Variable& self, const variable_list& indices) { | |
+ AutoNoGIL no_gil; | |
+ AutoGPU auto_gpu(self); | |
+ return self.index(asTensorList(indices)); | |
+} | |
+ | |
+static Variable dispatch_index_put_(Variable& self, const variable_list& indices, const Variable& value) { | |
+ AutoNoGIL no_gil; | |
+ AutoGPU auto_gpu(self); | |
+ return self.index_put_(asTensorList(indices), value); | |
+} | |
+ | |
+static bool treatSequenceAsTuple(PyObject* index) { | |
+ if (PyTuple_Check(index)) { | |
+ return true; | |
+ } | |
+ if (!PySequence_Check(index)) { | |
+ return false; | |
+ } | |
+ // This uses a heuristics from NumPy for determining whether to treat | |
+ // non-tuple sequences as if they were a tuple. From the NumPy code comments: | |
+ // | |
+ // "At this point, we're left with a non-tuple, non-array, sequence: | |
+ // typically, a list. We use some somewhat-arbitrary heuristics from here | |
+ // onwards to decided whether to treat that list as a single index, or a | |
+ // list of indices. Backwards compatibility only takes effect for short | |
+ // sequences - otherwise we treat it like any other scalar." | |
+ auto n = PySequence_Size(index); | |
+ if (n < 0) { | |
+ // Negative size indicates a Python error in the PySequence_Size call. | |
+ PyErr_Clear(); | |
+ return false; | |
+ } | |
+ if (n >= 32) { | |
+ return false; | |
+ } | |
+ for (Py_ssize_t i = 0; i < n; i++) { | |
+ auto obj = THPObjectPtr{PySequence_GetItem(index, i)}; | |
+ if (!obj.get()) { | |
+ PyErr_Clear(); | |
+ return false; | |
+ } | |
+ if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) || PySlice_Check(obj.get())) { | |
+ return true; | |
+ } | |
+ if (obj.get() == Py_Ellipsis || obj.get() == Py_None) { | |
+ return true; | |
+ } | |
+ } | |
+ return false; | |
+} | |
+ | |
+static THPObjectPtr wrapTuple(PyObject* index) { | |
+ THPObjectPtr res; | |
+ if (treatSequenceAsTuple(index)) { | |
+ res = PySequence_Tuple(index); | |
+ } else { | |
+ res = PyTuple_Pack(1, index); | |
+ } | |
+ if (!res) throw python_error(); | |
+ return res; | |
+} | |
+ | |
+static bool isSingleBoolScalar(const variable_list& vars) { | |
+ return vars.size() == 1 && vars[0].type().scalarType() == ScalarType::Byte && vars[0].dim() == 0; | |
+} | |
+ | |
+static PyObject* applyBoolGetitem(const Variable& self, bool index) { | |
+ if (index) { | |
+ return wrap(self.type().copy(self.unsqueeze(0))); | |
+ } else { | |
+ return wrap(self.type().tensor({0})); | |
+ } | |
+} | |
+ | |
+PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { | |
+ HANDLE_TH_ERRORS | |
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata; | |
+ AutoGPU auto_gpu(self_); | |
+ | |
+ // handle simple types: integers, slices, ellipsis | |
+ if (index == Py_None) { | |
+ return wrap(self_.unsqueeze(0)); | |
+ } else if (index == Py_Ellipsis) { | |
+ return wrap(at::alias(self_)); | |
+ } else if (THPUtils_checkLong(index)) { | |
+ return wrap(applySelect(self_, 0, THPUtils_unpackLong(index))); | |
+ } else if (PyBool_Check(index)) { | |
+ return applyBoolGetitem(self_, index == Py_True); | |
+ } else if (PySlice_Check(index)) { | |
+ return wrap(applySlice(self_, 0, index, true)); | |
+ } | |
+ | |
+ // wrap index in a tuple if it's not already one | |
+ THPObjectPtr holder = wrapTuple(index); | |
+ | |
+ variable_list variableIndices; | |
+ Variable sliced = applySlicing(self_, holder.get(), variableIndices); | |
+ if (variableIndices.empty()) { | |
+ if (sliced.is_same(self_)) { | |
+ // ensure we return a shallow copy for things like x[...] | |
+ sliced = at::alias(sliced); | |
+ } | |
+ return wrap(sliced); | |
+ } | |
+ if (isSingleBoolScalar(variableIndices)) { | |
+ return applyBoolGetitem(self_, variableIndices[0].toCByte()); | |
+ } | |
+ | |
+ // indexing by tensors ("advanced" indexing) | |
+ return wrap(dispatch_index(sliced, variableIndices)); | |
+ Py_RETURN_NONE; | |
+ END_HANDLE_TH_ERRORS | |
+} | |
+ | |
+static void copy_to(Variable dst, const Variable& src) { | |
+ Tensor b_src; | |
+ // To match numpy semantics: | |
+ // As a special case for backwards compatibility, | |
+ // strip away unit dimensions from the left of 'src' | |
+ auto src_sizes = src.sizes(); | |
+ size_t first_nonzero_src = src_sizes.size(); | |
+ for (size_t i = 0; i < src_sizes.size(); ++i) { | |
+ if (src_sizes[i] != 1) { | |
+ first_nonzero_src = i; | |
+ break; | |
+ } | |
+ } | |
+ | |
+ src_sizes = src_sizes.slice(first_nonzero_src); | |
+ std::tie(b_src) = expand_inplace(dst, src.view(src_sizes), "setitem"); | |
+ dst.copy_(b_src); | |
+} | |
+ | |
+int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { | |
+ HANDLE_TH_ERRORS | |
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata; | |
+ AutoGPU auto_gpu(self_); | |
+ auto value = valueToTensor(self_.type(), py_value); | |
+ | |
+ // handle simple types: integers, slices, ellipsis, bool | |
+ if (index == Py_False) { | |
+ // do nothing for false (technically we should check the size, but we don't have | |
+ // real 0-sized shapes. | |
+ return 0; | |
+ } else if (index == Py_Ellipsis) { | |
+ copy_to(self_, value); | |
+ return 0; | |
+ } else if (index == Py_None || index == Py_True) { | |
+ copy_to(self_.unsqueeze(0), value); | |
+ return 0; | |
+ } else if (THPUtils_checkLong(index)) { | |
+ copy_to(applySelect(self_, 0, THPUtils_unpackLong(index)), value); | |
+ return 0; | |
+ } else if (PySlice_Check(index)) { | |
+ copy_to(applySlice(self_, 0, index), value); | |
+ return 0; | |
+ } | |
+ | |
+ // wrap index in a tuple if it's not already one | |
+ THPObjectPtr holder = wrapTuple(index); | |
+ | |
+ variable_list variableIndices; | |
+ Variable sliced = applySlicing(self_, holder.get(), variableIndices); | |
+ if (variableIndices.empty()) { | |
+ copy_to(sliced, value); | |
+ return 0; | |
+ } | |
+ if (isSingleBoolScalar(variableIndices)) { | |
+ if (variableIndices[0].toCByte()) { | |
+ copy_to(self_.unsqueeze(0), value); | |
+ } | |
+ return 0; | |
+ } | |
+ | |
+ // indexing by tensors ("advanced" indexing) | |
+ dispatch_index_put_(sliced, variableIndices, value); | |
+ return 0; | |
+ END_HANDLE_TH_ERRORS_RET(-1) | |
+} | |
+ | |
+}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/python_variable_indexing.h b/torch/csrc/autograd/python_variable_indexing.h | |
new file mode 100644 | |
index 00000000..e961ed5c | |
--- /dev/null | |
+++ b/torch/csrc/autograd/python_variable_indexing.h | |
@@ -0,0 +1,11 @@ | |
+#pragma once | |
+ | |
+#include <Python.h> | |
+ | |
+namespace torch { namespace autograd { | |
+ | |
+Py_ssize_t THPVariable_length(PyObject* self); | |
+PyObject* THPVariable_getitem(PyObject* self, PyObject* index); | |
+int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value); | |
+ | |
+}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp | |
index e112e515..0fbbaf97 100644 | |
--- a/torch/csrc/autograd/saved_variable.cpp | |
+++ b/torch/csrc/autograd/saved_variable.cpp | |
@@ -1,83 +1,91 @@ | |
+#include "Python.h" | |
#include "torch/csrc/autograd/saved_variable.h" | |
+#include "torch/csrc/autograd/edge.h" | |
#include "torch/csrc/autograd/function.h" | |
+#include "torch/csrc/autograd/variable.h" | |
+#include "torch/csrc/jit/tracer_state.h" | |
-using namespace at; | |
+#include <ATen/Tensor.h> | |
+ | |
+#include <cstdint> | |
+#include <list> | |
+#include <memory> | |
namespace torch { namespace autograd { | |
-SavedVariable::SavedVariable(const Variable& variable, Function* saved_for) | |
- : SavedVariable() { | |
- if (!variable.defined()) { | |
- return; | |
- } | |
- data = variable.data(); | |
- requires_grad = variable.requires_grad(); | |
- is_volatile = variable.is_volatile(); | |
- expected_version = variable.current_version(); | |
- version = variable.get()->version_counter.save(); | |
- has_grad_fn = variable.grad_fn() != nullptr; | |
- output_nr = variable.output_nr(); | |
- if (!has_grad_fn) { | |
- grad_accumulator = variable.grad_accumulator(); | |
- } | |
- if (variable.grad_fn().get() != saved_for) { | |
- grad_fn = variable.grad_fn(); | |
- } | |
- if (variable.tracing_state()) { | |
- tracing_state.reset(new jit::tracer::ValueTracingState(*variable.tracing_state())); | |
+SavedVariable::SavedVariable(const Variable& variable, bool is_output) { | |
+ if (variable.defined()) { | |
+ was_default_constructed_ = false; | |
+ output_nr_ = variable.output_nr(); | |
+ requires_grad_ = variable.requires_grad(); | |
+ has_grad_fn_ = !variable.is_leaf(); | |
+ // These copies are all shared_ptr copies, so slightly more expensive. | |
+ // Do them here instead of in the init list in case data is undefined. | |
+ data_ = variable.data(); | |
+ if (variable.is_leaf()) { | |
+ grad_accumulator_ = variable.grad_accumulator(); | |
+ } else if (!is_output) { | |
+ grad_fn_ = variable.grad_fn(); | |
+ } | |
+ version_counter_ = variable.version_counter(); | |
+ saved_version_ = version_counter_.current_version(); | |
+ if (variable.has_tracing_state()) { | |
+ tracing_state_.reset( | |
+ new jit::tracer::ValueTracingState(variable.tracing_state())); | |
+ } | |
} | |
} | |
-auto SavedVariable::unpack(std::shared_ptr<Function> saved_for) const -> Variable { | |
- if (!data.defined()) { | |
- if (version.defined()) { | |
+Variable SavedVariable::unpack(std::shared_ptr<Function> saved_for) const { | |
+ if (!data_.defined()) { | |
+ if (!was_default_constructed_) { | |
throw std::runtime_error(ERR_BACKWARD_TWICE); | |
} | |
return Variable(); | |
} | |
- if (version.is_modified()) { | |
+ if (saved_version_ != version_counter_.current_version()) { | |
throw std::runtime_error( | |
"one of the variables needed for gradient computation has been " | |
"modified by an inplace operation"); | |
} | |
- Variable var = make_variable(data, requires_grad, is_volatile); | |
- if (has_grad_fn && !grad_fn) { | |
+ auto grad_fn = grad_fn_; | |
+ if (has_grad_fn_ && !grad_fn) { | |
if (!saved_for) { | |
// If saving the grad_fn would create a circular reference, then it must | |
// be passed in to the unpack function. | |
throw std::runtime_error("No grad_fn for non-leaf saved variable"); | |
} | |
- var.grad_fn() = saved_for; | |
+ grad_fn = std::move(saved_for); | |
+ } | |
+ | |
+ // NB: saved views are unpacked as normal Variables (not views) even though | |
+ // they still share the same storage. This works only because we never call | |
+ // in-place functions on unpacked variables. | |
+ Variable var; | |
+ if (grad_fn) { | |
+ var = make_variable(data_, Edge(std::move(grad_fn), output_nr_)); | |
} else { | |
- var.grad_fn() = grad_fn; | |
+ var = make_variable(data_, requires_grad_); | |
} | |
- var.output_nr() = output_nr; | |
- var.version_counter() = version; | |
+ var.set_version_counter(saved_version_); | |
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we | |
// should have saved the grad accumulator. Even if the Variable no longer | |
- // alive, the accumulator should be kept alive by the references in the graph). | |
- if (requires_grad && !var.grad_fn() && grad_accumulator.expired()) | |
+ // alive, the accumulator should be kept alive by the references in the | |
+ // graph). | |
+ if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired()) | |
throw std::logic_error("No grad accumulator for a saved leaf!"); | |
- var.get()->grad_accumulator = grad_accumulator; | |
- if (tracing_state) | |
- var.tracing_state().reset(new jit::tracer::ValueTracingState(*tracing_state)); | |
+ var.set_grad_accumulator(grad_accumulator_); | |
+ if (tracing_state_) { | |
+ var.set_tracing_state(new jit::tracer::ValueTracingState(*tracing_state_)); | |
+ } | |
return var; | |
} | |
-auto SavedVariable::unpack_data(std::shared_ptr<Function> saved_for) const -> Tensor { | |
- auto var = unpack(saved_for); | |
- if (var.defined()) { | |
- return var.data(); | |
- } | |
- return Tensor(); | |
-} | |
- | |
- | |
const char* ERR_BACKWARD_TWICE = | |
"Trying to backward through the graph a second time, but the buffers have " | |
"already been freed. Specify retain_graph=True when calling backward " | |
diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h | |
index 6ec741da..7372d10c 100644 | |
--- a/torch/csrc/autograd/saved_variable.h | |
+++ b/torch/csrc/autograd/saved_variable.h | |
@@ -1,50 +1,55 @@ | |
#pragma once | |
-#include <mutex> | |
-#include <memory> | |
-#include <functional> | |
+#include "torch/csrc/autograd/variable_version.h" | |
+#include "torch/csrc/jit/tracer_state.h" | |
+ | |
#include <ATen/ATen.h> | |
-#include "torch/csrc/jit/tracer_state.h" | |
-#include "torch/csrc/autograd/variable.h" | |
-#include "torch/csrc/autograd/variable_version.h" | |
-#include "torch/csrc/Types.h" | |
+#include <cstdint> | |
+#include <list> | |
+#include <memory> | |
namespace torch { namespace autograd { | |
+struct Variable; | |
struct Function; | |
extern const char* ERR_BACKWARD_TWICE; | |
-struct SavedVariable { | |
- SavedVariable() | |
- : data() | |
- , has_grad_fn(false) | |
- , version() | |
- , requires_grad(false) | |
- , is_volatile(false) | |
- , expected_version(-1) {} | |
+/// A snapshot of a variable at a certain version. A `SavedVariable` stores | |
+/// enough information to reconstruct a variable from a certain point in time. | |
+class SavedVariable { | |
+ public: | |
+ SavedVariable() = default; | |
+ SavedVariable(const Variable& variable, bool is_output); | |
+ SavedVariable(SavedVariable&&) = default; | |
+ SavedVariable& operator=(SavedVariable&&) = default; | |
- SavedVariable(const Variable& variable, Function* saved_for); | |
+ /// Reconstructs the saved variable. Pass `saved_for` as the gradient | |
+ /// function if constructing the `SavedVariable` with it would have caused a | |
+ /// circular reference. | |
+ Variable unpack(std::shared_ptr<Function> saved_for = nullptr) const; | |
+ void reset_data() { | |
+ return data_.reset(); | |
+ } | |
+ | |
+ private: | |
+ at::Tensor data_; | |
- at::Tensor data; | |
// The gradient function associated with this node. If has_grad_fn | |
// is false, then this is a leaf node. Note that the grad_fn is not saved if | |
// it would create a circular reference. In that case, the grad_fn must be | |
// passed in to the unpack function when reconstructing the Variable. | |
- bool has_grad_fn; | |
- std::shared_ptr<Function> grad_fn; | |
- std::weak_ptr<Function> grad_accumulator; | |
- SavedVersion version; | |
- bool requires_grad; | |
- bool is_volatile; | |
- int expected_version; | |
- int output_nr; | |
- std::unique_ptr<jit::tracer::ValueTracingState> tracing_state; | |
- | |
- Variable unpack(std::shared_ptr<Function> saved_for=nullptr) const; | |
- at::Tensor unpack_data(std::shared_ptr<Function> saved_for=nullptr) const; | |
+ std::shared_ptr<Function> grad_fn_; | |
+ std::weak_ptr<Function> grad_accumulator_; | |
+ std::unique_ptr<jit::tracer::ValueTracingState> tracing_state_; | |
+ VariableVersion version_counter_; | |
+ | |
+ uint32_t saved_version_; | |
+ uint32_t output_nr_; | |
+ bool was_default_constructed_ = true; | |
+ bool requires_grad_; | |
+ bool has_grad_fn_; | |
}; | |
- | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/symbolic.h b/torch/csrc/autograd/symbolic.h | |
index a81c22c9..b0fd5ccc 100644 | |
--- a/torch/csrc/autograd/symbolic.h | |
+++ b/torch/csrc/autograd/symbolic.h | |
@@ -8,20 +8,10 @@ namespace torch { namespace autograd { | |
struct SymbolicContext { | |
jit::Graph* graph; | |
- const std::unordered_map<void*, jit::Node*>* buffer_map; | |
- int batch_norm_count = 0; | |
}; | |
struct symbolic_unconvertible : public std::runtime_error { | |
using std::runtime_error::runtime_error; | |
}; | |
- | |
-struct HasSymbolic { | |
- // Add some nodes to the ONNX protobuf, under the assumption that this node | |
- // as a whole has the represented inputs and outputs. Raises a | |
- // symbolic_unconvertible exception if conversion is not supported. | |
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) = 0; | |
-}; | |
- | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h | |
index 0b83675d..5f087205 100644 | |
--- a/torch/csrc/autograd/utils/wrap_outputs.h | |
+++ b/torch/csrc/autograd/utils/wrap_outputs.h | |
@@ -6,6 +6,7 @@ | |
#include <Python.h> | |
#include <tuple> | |
+#include "torch/csrc/Dtype.h" | |
#include "torch/csrc/autograd/python_variable.h" | |
#include "torch/csrc/autograd/variable.h" | |
#include "torch/csrc/utils/python_numbers.h" | |
@@ -13,10 +14,6 @@ | |
namespace torch { namespace autograd { namespace utils { | |
inline PyObject* wrap(at::Tensor tensor) { | |
- if (tensor.defined() && tensor.dim() == 0) { | |
- // don't expose 0-dim tensors to Variable API. | |
- Variable(tensor).data().as_strided_({1}, {1}); | |
- } | |
return THPVariable_Wrap(Variable(std::move(tensor))); | |
} | |
@@ -37,6 +34,27 @@ inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor> tensors) { | |
return r.release(); | |
} | |
+inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> tensors) { | |
+ auto r = THPObjectPtr{PyTuple_New(4)}; | |
+ if (!r) throw python_error(); | |
+ PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); | |
+ PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); | |
+ PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); | |
+ PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors)))); | |
+ return r.release(); | |
+} | |
+ | |
+inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> tensors) { | |
+ auto r = THPObjectPtr{PyTuple_New(5)}; | |
+ if (!r) throw python_error(); | |
+ PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); | |
+ PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); | |
+ PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); | |
+ PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors)))); | |
+ PyTuple_SET_ITEM(r.get(), 4, wrap(std::move(std::get<4>(tensors)))); | |
+ return r.release(); | |
+} | |
+ | |
inline PyObject* wrap(at::TensorList tl) { | |
auto r = THPObjectPtr{PyTuple_New(tl.size())}; | |
if (!r) throw python_error(); | |
@@ -58,6 +76,10 @@ inline PyObject* wrap(int64_t value) { | |
return THPUtils_packInt64(value); | |
} | |
+inline PyObject* wrap(double value) { | |
+ return PyFloat_FromDouble(value); | |
+} | |
+ | |
inline PyObject* wrap(void* value) { | |
return THPUtils_packInt64(reinterpret_cast<intptr_t>(value)); | |
} | |
@@ -66,5 +88,9 @@ inline PyObject* wrap(at::Scalar scalar) { | |
return wrap(scalar.toTensor()); | |
} | |
+inline PyObject* wrap(THPDtype *dtype) { | |
+ Py_INCREF(dtype); | |
+ return (PyObject*)dtype; | |
+} | |
}}} // namespace torch::autograd::utils | |
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp | |
index 7ea75c80..ca4b65ec 100644 | |
--- a/torch/csrc/autograd/variable.cpp | |
+++ b/torch/csrc/autograd/variable.cpp | |
@@ -1,83 +1,86 @@ | |
#include "torch/csrc/autograd/variable.h" | |
-#include "torch/csrc/autograd/generated/VariableType.h" | |
+#include "torch/csrc/assertions.h" | |
+#include "torch/csrc/autograd/edge.h" | |
+#include "torch/csrc/autograd/function.h" | |
#include "torch/csrc/autograd/functions/accumulate_grad.h" | |
+#include "torch/csrc/autograd/functions/tensor.h" | |
+#include "torch/csrc/autograd/generated/Functions.h" | |
+#include "torch/csrc/autograd/generated/VariableType.h" | |
+#include "torch/csrc/autograd/variable_version.h" | |
+#include "torch/csrc/jit/tracer_state.h" | |
+#include "torch/csrc/utils/auto_unique_ptr.h" | |
-using namespace at; | |
+#include <ATen/ATen.h> | |
-namespace torch { namespace autograd { | |
+#include <list> | |
+#include <memory> | |
+#include <mutex> | |
+#include <stdexcept> | |
+#include <string> | |
+#include <vector> | |
-VariableImpl::VariableImpl(Tensor data_, bool requires_grad, bool is_volatile) | |
- : TensorImpl(getType(data_)) | |
- , data(std::move(data_)) | |
- , grad() | |
- , version_counter() | |
- , requires_grad(requires_grad) | |
- , is_volatile(is_volatile) | |
- , output_nr(0) | |
- , pyobj(nullptr) { | |
+namespace torch { namespace autograd { | |
+Variable::Impl::Impl(at::Tensor data_, bool requires_grad_, Edge gradient_edge_) | |
+ : TensorImpl(VariableType::getType(data_)), | |
+ data(std::move(data_)), | |
+ grad_fn(std::move(gradient_edge_.function)), | |
+ requires_grad(requires_grad_), | |
+ is_view(false), | |
+ output_nr(gradient_edge_.input_nr), | |
+ pyobj(nullptr) { | |
+ TORCH_ASSERTM( | |
+ !grad_fn || !requires_grad, | |
+ "_requires_grad should be false if grad_fn is set"); | |
if (!data.defined()) { | |
throw std::runtime_error("data is undefined"); | |
} | |
} | |
-VariableImpl::VariableImpl(Tensor data, std::shared_ptr<Function> grad_fn) | |
- : VariableImpl(std::move(data)) | |
-{ | |
- this->grad_fn = grad_fn; | |
- requires_grad = grad_fn->is_executable; | |
- output_nr = grad_fn->num_inputs++; | |
-} | |
+Variable::Impl::~Impl() = default; | |
-VariableImpl::VariableImpl(Tensor data) | |
- : VariableImpl(std::move(data), false, false) | |
-{ | |
-} | |
- | |
-VariableImpl::~VariableImpl() { | |
-} | |
- | |
-const char * VariableImpl::toString() const { | |
+const char* Variable::Impl::toString() const { | |
return "Variable"; | |
} | |
-IntList VariableImpl::sizes() const { | |
+IntList Variable::Impl::sizes() const { | |
return data.sizes(); | |
} | |
-IntList VariableImpl::strides() const { | |
+IntList Variable::Impl::strides() const { | |
return data.strides(); | |
} | |
-int64_t VariableImpl::dim() const { | |
+int64_t Variable::Impl::dim() const { | |
return data.dim(); | |
} | |
-const char * VariableImpl::typeString() { | |
+const char* Variable::Impl::typeString() { | |
return "VariableType"; | |
} | |
-void * VariableImpl::unsafeGetTH(bool retain) { | |
+void* Variable::Impl::unsafeGetTH(bool retain) { | |
return data.unsafeGetTH(retain); | |
} | |
-Scalar VariableImpl::localScalar() { | |
- return data.pImpl->localScalar(); | |
+std::unique_ptr<at::Storage> Variable::Impl::storage() { | |
+ return data.storage(); | |
} | |
-void VariableImpl::assign_(Scalar s) { | |
- data.assign_(s); | |
+Scalar Variable::Impl::localScalar() { | |
+ return data.pImpl->localScalar(); | |
} | |
-std::shared_ptr<Function> VariableImpl::get_grad_accumulator() { | |
+std::shared_ptr<Function> Variable::Impl::get_grad_accumulator() { | |
if (grad_fn) { | |
- throw std::logic_error("get_grad_accumulator() should be only called on leaf Variables"); | |
+ throw std::logic_error( | |
+ "get_grad_accumulator() should be only called on leaf Variables"); | |
} | |
if (!requires_grad) { | |
return nullptr; | |
} | |
- std::lock_guard<std::mutex> lock(grad_accumulator_lock); | |
+ std::lock_guard<std::mutex> lock(mutex); | |
auto result = grad_accumulator.lock(); | |
if (result) return result; | |
@@ -87,39 +90,86 @@ std::shared_ptr<Function> VariableImpl::get_grad_accumulator() { | |
return result; | |
} | |
-namespace { | |
- | |
-struct VariableTypes { | |
- VariableTypes() { | |
- auto& context = at::globalContext(); | |
- for (int p = 0; p < static_cast<int>(Backend::NumOptions); ++p) { | |
- for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) { | |
- auto baseType = context.type_registry[p][s].get(); | |
- if (baseType) { | |
- auto id = static_cast<int>(baseType->ID()); | |
- types[id].reset(new VariableType(&context, baseType)); | |
- } | |
- } | |
- } | |
+Variable::ViewImpl::ViewImpl( | |
+ Variable base_, | |
+ at::Tensor data_, | |
+ Edge gradient_edge_) | |
+ : Variable::Impl(std::move(data_), false, std::move(gradient_edge_)), | |
+ base(std::move(base_)) { | |
+ TORCH_ASSERTM(base.defined(), "base is undefined"); | |
+ if (base.is_view()) { | |
+ base = base.base(); | |
} | |
+ is_view = true; | |
+ version_counter = base.version_counter(); | |
+ attr_version = version_counter.current_version(); | |
+} | |
- std::unique_ptr<Type> types[static_cast<int>(TypeID::NumOptions)]; | |
-}; | |
+std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() { | |
+ std::lock_guard<std::mutex> lock(mutex); | |
+ if (!grad_fn && !base.requires_grad()) { | |
+ return grad_fn; | |
+ } | |
+ auto current_version = version_counter.current_version(); | |
+ if (attr_version != current_version) { | |
+ TORCH_ASSERT(output_nr == 0); | |
+ auto fn = std::make_shared<generated::AsStridedBackward>(); | |
+ fn->self_geometry = at::TensorGeometry(base); | |
+ fn->size = sizes(); | |
+ fn->stride = strides(); | |
+ fn->storage_offset = data.storage_offset(); | |
+ fn->set_next_edges(collect_next_edges(base)); | |
+ fn->set_num_inputs(1); | |
+ grad_fn = std::move(fn); | |
+ attr_version = current_version; | |
+ } | |
+ return grad_fn; | |
+} | |
-} // anonymous namespace | |
+void Variable::ViewImpl::rebase_history(Edge gradient_edge) { | |
+ TORCH_ASSERT(gradient_edge.input_nr == 0); | |
+ TORCH_ASSERT(gradient_edge.function); | |
+ TORCH_ASSERTM( | |
+ gradient_edge.function->num_inputs() == 1, | |
+ "Functions which modify views in-place must return a single Variable"); | |
+ this->output_nr = gradient_edge.input_nr; | |
+ auto copy_slices = std::make_shared<CopySlices>( | |
+ base, at::TensorGeometry(data), std::move(gradient_edge.function)); | |
+ base.set_gradient_edge({std::move(copy_slices), 0}); | |
+ get_grad_fn(); // trigger an update to the view's grad_fn | |
+} | |
-Type* VariableImpl::getType(const Tensor& tensor) | |
-{ | |
- if (!tensor.defined()) { | |
- throw std::runtime_error("tensor is undefined"); | |
+void Variable::rebase_history(Edge gradient_edge) { | |
+ TORCH_ASSERT(gradient_edge.function != nullptr); | |
+ if (is_view()) { | |
+ auto& impl = static_cast<Variable::ViewImpl&>(*get()); | |
+ impl.rebase_history(std::move(gradient_edge)); | |
+ } else { | |
+ set_gradient_edge(std::move(gradient_edge)); | |
} | |
- return getType(tensor.type()); | |
} | |
-Type* VariableImpl::getType(const Type& baseType) | |
-{ | |
- static VariableTypes vt; | |
- return vt.types[static_cast<int>(baseType.ID())].get(); | |
+Variable Variable::detach() const { | |
+ auto detached = make_variable(data(), /*requires_grad=*/false); | |
+ detached.set_version_counter(version_counter()); | |
+ return detached; | |
} | |
+void Variable::detach_() { | |
+ if (is_view()) { | |
+ throw std::runtime_error( | |
+ "Can't detach views in-place. Use detach() instead"); | |
+ } | |
+ set_requires_grad(false); | |
+ set_gradient_edge(Edge()); | |
+} | |
+ | |
+void Variable::set_tracing_state( | |
+ jit::tracer::ValueTracingState* new_tracing_state) { | |
+ get()->tracing_state.reset(new_tracing_state); | |
+} | |
+ | |
+jit::tracer::ValueTracingState& Variable::tracing_state() const noexcept { | |
+ return *get()->tracing_state; | |
+} | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h | |
index 0f134378..c25fe4cc 100644 | |
--- a/torch/csrc/autograd/variable.h | |
+++ b/torch/csrc/autograd/variable.h | |
@@ -1,226 +1,605 @@ | |
#pragma once | |
-// A wrapper around at::Tensor to represent autograd Variables. Variables | |
-// can be implicitly converted to an at::Tensor. | |
+#include <Python.h> | |
-#include <mutex> | |
+#include "torch/csrc/autograd/edge.h" | |
+#include "torch/csrc/autograd/function_hook.h" | |
+#include "torch/csrc/autograd/variable_version.h" | |
+#include "torch/csrc/utils/auto_unique_ptr.h" | |
+ | |
+#include <ATen/ATen.h> | |
+ | |
+#include <list> | |
#include <memory> | |
+#include <mutex> | |
+#include <stdexcept> | |
+#include <string> | |
#include <vector> | |
-#include <functional> | |
-#include <ATen/ATen.h> | |
-#include "torch/csrc/jit/ir.h" | |
-#include "torch/csrc/jit/tracer_state.h" | |
-#include "torch/csrc/autograd/function_hook.h" | |
-#include "torch/csrc/utils/auto_unique_ptr.h" | |
-#include "torch/csrc/autograd/variable_version.h" | |
-#include "torch/csrc/Types.h" | |
+namespace torch { | |
+namespace autograd { | |
+struct Function; | |
+} // namespace autograd | |
+namespace jit { namespace tracer { | |
+// Has to be forward declared because tracer_state.h has a dependency on | |
+// variable.h. | |
+struct ValueTracingStateElem; | |
+using ValueTracingState = std::list<ValueTracingStateElem>; | |
+}} // namespace jit::tracer | |
+} // namespace torch | |
namespace torch { namespace autograd { | |
-using at::Tensor; | |
-struct VariableImpl; | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// Variable | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// A `Variable` augments a `Tensor` with the ability to interact in our | |
+/// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between | |
+/// `Function`s in the autograd graph. A `Variable` can either be a leaf, like a | |
+/// weight in a neural network, or an interior variable, when it is the result | |
+/// of an operation between variables. Every `Variable` also stores another | |
+/// `Variable` called its `grad` (gradient). If the variable is a leaf, its | |
+/// gradient will be accumulated into this variable. | |
+/// | |
+/// Gradient Edges | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the | |
+/// edge in the autograd graph that connects the variable to a particular input | |
+/// of the gradient function that will be invoked with the variable during the | |
+/// backward pass. More precisely, this gradient function can be one of two | |
+/// things: | |
+/// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the | |
+/// gradient of the function that produced the variable. | |
+/// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a | |
+/// scalar gradient value into its `grad` variable. | |
+/// | |
+/// Versioning | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// Another major feature of `Variable`s are *versions*. Versions are | |
+/// incremented when an in-place mutation of a variable occurs. Versions are | |
+/// useful when constructing `SavedVariable`s, which take a snapshot of a | |
+/// `Variable` at a certain version. You can retrieve a `Variable`'s version | |
+/// through its `current_version()` method. | |
+/// | |
+/// Views | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// It is possible for a `Variable` to be a *view* of another `Variable`, in | |
+/// which case it tracks that `Variable`'s data and autograd history. Beyond | |
+/// construction, the interface of a view is identical to that of a regular | |
+/// `Variable`. You can determine whether `Variable` is in fact a view by | |
+/// probing its `is_view()` method. | |
+/// | |
+/// Interface | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+/// `Variable` inherits from `Tensor` and thus its API is a superset of that of | |
+/// `Tensor`. This means you can perform all the usual mathematical and other | |
+/// operations you can perform on `Tensor`s also on `Variable`s. Furthermore, | |
+/// `Variable` and `Tensor` actually convert implicitly between each other. You | |
+/// can thus call functions defined on `Tensor`s also with `Variable`s. For | |
+/// this, the `Variable` class allows implicit construction from `Tensor`. It is | |
+/// the responsibility of calling code to ensure that this constructor is | |
+/// invoked only when the `Tensor`'s dynamic type is actually `Variable`. Most | |
+/// notably, it is *not* correct to construct a brand new `Variable` from a | |
+/// `Tensor` using this constructor. To do so, you must use the `make_variable` | |
+/// free function instead. To create a view variable, use `make_variable_view`. | |
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
struct Variable : public at::Tensor { | |
- inline Variable(VariableImpl * self, bool retain); | |
- Variable() : Tensor() {} | |
- Variable(const Variable & rhs) : Tensor(rhs) {} | |
- Variable(Variable && rhs) noexcept : Tensor(std::move(rhs)) {} | |
+ /// Default constructor. | |
+ Variable() = default; | |
+ | |
+ // Factory Functions | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ // NOTE: These factory functions have to be friends to access the | |
+ // `Variable::Impl`. As a side effect, it allows us to keep them in the class. | |
+ | |
+ /// Creates a `Variable` that is a *view* of another (*base*) variable. | |
+ /// The `gradient_edge` is an optional (gradient_function, input_number) pair. | |
+ friend Variable | |
+ make_variable_view(Variable base, at::Tensor data, Edge gradient_edge); | |
+ | |
+ /// Creates a `Variable` from the given `Tensor`. `requires_grad` should be | |
+ /// set only for leaves, and determines whether the `Variable` will accumulate | |
+ /// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic | |
+ /// type *must* be `Tensor`. | |
+ friend Variable make_variable(at::Tensor data, bool requires_grad); | |
+ | |
+ /// Creates a `Variable` from the given `Tensor` and specify a | |
+ /// `gradient_edge`, i.e. a (function, input_nr) pair specifying the function | |
+ /// in the autograd graph, and what particular input of that function, this | |
+ /// variable is connected to. | |
+ friend Variable make_variable(at::Tensor data, Edge gradient_edge); | |
+ | |
+ // Tensor Conversions | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ // "Downcasts" a `Tensor` into a `Variable`. Only call this on tensors you | |
+ // know are Variables. | |
+ /*implicit*/ Variable(at::Tensor const& rhs) : at::Tensor(rhs) {} | |
+ /*implicit*/ Variable(at::Tensor&& rhs) noexcept | |
+ : at::Tensor(std::move(rhs)) {} | |
+ | |
+ // NOTE: Assignment operators to Tensor come for free from the constructors. | |
+ | |
+ /// Downcasts the `Tensor` reference to a `Variable` reference. If compiling | |
+ /// in DEBUG mode and the tensor's dynamic type is not in fact `Variable`, | |
+ /// throws a `std::runtime_error` exception. | |
+ /// NOTE: Has to be a friend function because runtime type information is | |
+ /// available only for `TensorImpl`/`Impl` and not the `Tensor`/`Variable` | |
+ /// classes, as the latter are not polymorphic classes (`Tensor` has no | |
+ /// virtual methods). | |
+ friend Variable& as_variable_ref(at::Tensor& tensor); | |
+ | |
+ const at::Tensor& data() const noexcept; | |
+ at::Tensor& data() noexcept; | |
+ | |
+ // Gradient Function and Edges | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ /// Gets the gradient function of the `Variable`. If this is a leaf variable, | |
+ /// the pointer returned will be null. | |
+ const std::shared_ptr<Function>& grad_fn() const; | |
+ | |
+ /// Gets the raw gradient function pointer, whatever it currently is. | |
+ Function* grad_fn_unsafe() const; | |
+ | |
+ /// Set the gradient accumulator of the `Variable`. This is only applicable to | |
+ /// leaf variables. Interior variables should call `set_gradient_edge()`. | |
+ void set_grad_accumulator(std::weak_ptr<Function> grad_accumulator); | |
+ | |
+ /// Attempts to get a pointer to the gradient accumulator of the `Variable`, | |
+ /// if it still exists. If the gradient accumulator function has been | |
+ /// destroyed, returns a `nullptr`. | |
+ std::shared_ptr<Function> try_get_grad_accumulator() const; | |
+ | |
+ /// Gets the gradient accumulator of the `Variable` if it has one, or else | |
+ /// create one on the fly and return it. | |
+ std::shared_ptr<Function> grad_accumulator() const; | |
- // Implicitly casts a Tensor to a Variable. This should only be called on | |
- // Tensors which you know are actually Variables. | |
- /*implicit*/ Variable(Tensor const & rhs) : Tensor(rhs) {} | |
- /*implicit*/ Variable(Tensor && rhs) noexcept : Tensor(std::move(rhs)) {} | |
+ /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the | |
+ /// gradient function if this is an interior `Variable`, or the gradient | |
+ /// accumulator otherwise. If the `Variable` is interior, the returned `Edge` | |
+ /// will store the input index of the `Function` to which this variable is | |
+ /// connected in its `input_nr` field. For leaves, the `input_nr` is always | |
+ /// zero. Note that `set_gradient_edge` and `gradient_edge` are not | |
+ /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and | |
+ /// `set_grad_accumulator` to set the accumulator. | |
+ Edge gradient_edge() const { | |
+ // If grad_fn is null (as is the case for a leaf node), we instead | |
+ // interpret the gradient function to be a gradient accumulator, which will | |
+ // accumulate its inputs into the grad property of the variable. These | |
+ // nodes get suppressed in some situations, see "suppress gradient | |
+ // accumulation" below. Note that only variables which have `requires_grad = | |
+ // True` can have gradient accumulators. | |
+ if (const auto& gradient = grad_fn()) { | |
+ return Edge(gradient, output_nr()); | |
+ } else { | |
+ return Edge(grad_accumulator(), 0); | |
+ } | |
+ } | |
- inline VariableImpl* get() const; | |
+ /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the | |
+ /// `Variable`. | |
+ /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable, | |
+ /// and never the `grad_accumulator`. For the latter, use | |
+ /// `set_grad_accumulator`. This allows late construction of an interior | |
+ /// `Variable`. | |
+ void set_gradient_edge(Edge edge) noexcept; | |
- inline const Tensor & data() const; | |
- inline Tensor & data(); | |
+ /// Returns the input index of the gradient `Function` to which this | |
+ /// `Variable` is connected. | |
+ uint32_t output_nr() const noexcept; | |
- inline Tensor opt_data() const; | |
+ /// True if this `Variable` is a leaf and thus does not have a `grad_fn`. | |
+ bool is_leaf() const noexcept; | |
- inline const Variable & grad() const; | |
- inline Variable & grad(); | |
+ // The Grad Variable | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
- inline const std::shared_ptr<Function>& grad_fn() const; | |
- inline std::shared_ptr<Function>& grad_fn(); | |
+ /// Accesses the gradient `Variable` of this `Variable`. | |
+ const Variable& grad() const noexcept; | |
+ Variable& grad() noexcept; | |
+ void reset_grad() noexcept; | |
- std::shared_ptr<Function> grad_accumulator() const; | |
+ /// Sets the `requires_grad` property of `Variable`. This should be true for | |
+ /// leaf variables that want to accumulate gradients, and false for all other | |
+ /// variables. | |
+ void set_requires_grad(bool requires_grad) noexcept; | |
+ bool requires_grad() const noexcept; | |
- inline const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const; | |
- inline std::vector<std::shared_ptr<FunctionPreHook>>& hooks(); | |
+ // Versions | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
- inline auto_unique_ptr<jit::tracer::ValueTracingState>& tracing_state() const; | |
+ /// Increments the version count of this `Variable`. | |
+ void bump_version() noexcept; | |
+ void set_version_counter(const VariableVersion& version_counter) noexcept; | |
- inline int current_version() const; | |
+ /// Retrieves this `Variable`s version counter. | |
+ const VariableVersion& version_counter() const noexcept; | |
- inline VariableVersion& version_counter() const; | |
+ /// Retrieves the current value of the `Variable`'s version counter. | |
+ /// Equivalent to calling `version_counter().current_version()`. | |
+ uint32_t current_version() const noexcept; | |
- inline const int& output_nr() const; | |
- inline int& output_nr(); | |
+ // Autograd Graph Interaction | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ /// Update the `grad_fn` of an existing Variable. Called after in-place | |
+ /// modifications. | |
+ void rebase_history(Edge gradient_edge); | |
+ | |
+ /// Returns a copy of this `Variable` that is detached from its autograd graph | |
+ /// and has a blank version. This method is OK to call if the `Variable` is a | |
+ /// view. | |
+ Variable detach() const; | |
+ | |
+ /// Like `detach()`, but removes this `Variable` in-place. This method may | |
+ /// only be called on non-view `Variable`s. You can use `is_view()` to check | |
+ /// this. If this `Variable` is a view, throws an `std::runtime_error()`. | |
+ void detach_(); | |
+ | |
+ // Hooks | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ void add_hook(std::shared_ptr<FunctionPreHook> hook); | |
+ const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept; | |
+ void clear_hooks(); | |
+ | |
+ // JIT Tracing | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ void set_tracing_state(jit::tracer::ValueTracingState* new_tracing_state); | |
+ jit::tracer::ValueTracingState& tracing_state() const noexcept; | |
+ | |
+ /// Returns true if the `Variable`'s tracing state is not null. | |
+ bool has_tracing_state() const noexcept; | |
+ | |
+ // View Variables | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ /// Returns true if this `Variable` is a view of another `Variable`. | |
+ bool is_view() const noexcept; | |
+ | |
+ /// Returns the `Variable` that this `Variable` is a view of. If this | |
+ /// `Variable` is not a view, throw a `std::runtime_error`. | |
+ const Variable& base() const; | |
+ | |
+ // Miscellaneous | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ /// Compares this `Variable` to another `Variable` (or `Tensor`) via | |
+ /// pointer-equality. | |
+ bool is_same(const Variable& other) const noexcept { | |
+ return this->pImpl == other.pImpl; | |
+ } | |
- inline const bool& requires_grad() const; | |
- inline bool& requires_grad(); | |
+ void set_name(const std::string& name); | |
+ const std::string& name() const noexcept; | |
- inline const bool& is_volatile() const; | |
- inline bool& is_volatile(); | |
+ PyObject* pyobj() const noexcept; | |
+ void set_pyobj(PyObject* pyobj) noexcept; | |
- inline Variable & operator=(Variable && rhs) &; | |
- inline Variable & operator=(const Variable & rhs) &; | |
- inline Variable & operator=(Tensor && rhs) &; | |
- inline Variable & operator=(const Tensor & rhs) &; | |
+ // Hacks! | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ /// Sets the type of the underlying `Tensor`. Used for a bad (hopefully) | |
+ /// temporary hack in python_variable.h. If removed, also remove the `using | |
+ /// at::TensorImpl::type_;` in `Variable::Impl`. | |
+ void temporary_hack_set_type(at::Type*) noexcept; | |
+ | |
+ private: | |
+ /// Private implementation struct of the `Variable`. This struct declaration | |
+ /// and the `get()` method which exposes it shall forever remain private and | |
+ /// never be exposed to the public interface of this class. | |
+ struct Impl; | |
+ struct ViewImpl; | |
+ | |
+ // Private Methods | |
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+ Variable(Variable::Impl* self, bool retain); | |
+ Impl* get() const noexcept; | |
}; | |
-struct VariableImpl : public at::TensorImpl { | |
-public: | |
- explicit VariableImpl(at::Tensor data); | |
- VariableImpl(at::Tensor data, std::shared_ptr<Function> grad_fn); | |
- VariableImpl(at::Tensor data, bool requires_grad, bool is_volatile=false); | |
- virtual ~VariableImpl(); | |
- virtual const char * toString() const override; | |
- virtual at::IntList sizes() const override; | |
- virtual at::IntList strides() const override; | |
- virtual int64_t dim() const override; | |
- virtual at::Scalar localScalar() override; | |
- virtual void assign_(at::Scalar s) override; | |
- virtual void * unsafeGetTH(bool retain) override; | |
- static const char * typeString(); | |
- | |
- // Get the VariableType for a base Tensor type | |
- static at::Type* getType(const at::Type& baseType); | |
- static at::Type* getType(const at::Tensor& tensor); | |
- | |
-public: | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+// Variable::Impl | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+struct Variable::Impl : public at::TensorImpl { | |
+ explicit Impl( | |
+ at::Tensor data_, | |
+ bool requires_grad_ = false, | |
+ Edge edge = Edge()); | |
+ | |
+ virtual ~Impl(); | |
+ | |
+ const char* toString() const override; | |
+ at::IntList sizes() const override; | |
+ at::IntList strides() const override; | |
+ int64_t dim() const override; | |
+ at::Scalar localScalar() override; | |
+ void* unsafeGetTH(bool retain) override; | |
+ std::unique_ptr<at::Storage> storage() override; | |
+ static const char* typeString(); | |
+ | |
std::shared_ptr<Function> get_grad_accumulator(); | |
+ virtual std::shared_ptr<Function>& get_grad_fn() { | |
+ return grad_fn; | |
+ } | |
+ // Make this field public so we can access it from `Variable`. Part of | |
+ // temporary_hack_set_type. | |
+ using at::TensorImpl::type_; | |
+ | |
+ std::string name; | |
at::Tensor data; | |
+ | |
Variable grad; | |
std::shared_ptr<Function> grad_fn; | |
+ std::weak_ptr<Function> grad_accumulator; | |
+ | |
VariableVersion version_counter; | |
std::vector<std::shared_ptr<FunctionPreHook>> hooks; | |
- std::weak_ptr<Function> grad_accumulator; | |
- std::mutex grad_accumulator_lock; | |
- bool requires_grad; | |
- bool is_volatile; | |
+ | |
+ bool requires_grad; // only meaningful on leaf variables (must be false | |
+ // otherwise) | |
+ bool is_view; | |
// The "output number" of this variable; e.g., if this variable | |
// was the second output of a function, then output_nr == 1. | |
// We use this to make sure we can setup the backwards trace | |
// correctly when this variable is passed to another function. | |
- int output_nr; | |
- PyObject *pyobj; // weak reference | |
+ uint32_t output_nr; | |
+ PyObject* pyobj; // weak reference | |
+ | |
+ // Mutex to ensure that concurrent read operations that modify internal | |
+ // state are still thread-safe. Used by get_grad_fn and | |
+ // get_grad_accumulator. | |
+ std::mutex mutex; | |
// For use in torch::jit::tracer | |
auto_unique_ptr<jit::tracer::ValueTracingState> tracing_state; | |
- friend struct VariableType; | |
}; | |
-inline Variable make_variable(at::Tensor data) { | |
- return Variable(new VariableImpl(std::move(data)), false); | |
-} | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+// Variable::ViewImpl | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+/// A Variable that is a view on another Variable. The base and view share the | |
+/// same version_counter. The grad_fn field of the Variable may become stale | |
+/// due to in-place modifications of the shared data. Accesses should go | |
+/// through get_grad_fn(). All other fields are always valid. | |
+struct Variable::ViewImpl : public Variable::Impl { | |
+ ViewImpl(Variable base_, at::Tensor data_, Edge gradient_edge); | |
+ | |
+ /// Gets the up-to-date grad_fn. If the shared data or base was modified, we | |
+ /// re-create the grad_fn to express the up-to-date view relationship between | |
+ /// this and the base Variable. | |
+ virtual std::shared_ptr<Function>& get_grad_fn() override; | |
+ | |
+ /// Called after in-place modifications. Modifies the grad_fn of the base | |
+ /// Variable. | |
+ void rebase_history(Edge gradient_edge); | |
+ | |
+ /// The base `Variable` (never a view). | |
+ Variable base; | |
+ | |
+ /// The value of the version_counter at the time grad_fn was created. The | |
+ /// grad_fn field is stale if attr_version != | |
+ /// version_counter.current_version(). | |
+ uint32_t attr_version; | |
+}; | |
-inline Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn) { | |
- return Variable(new VariableImpl(std::move(data), std::move(grad_fn)), false); | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+// Variable Implementation | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+// Factory Functions | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline Variable make_variable_view( | |
+ Variable base, | |
+ at::Tensor data, | |
+ Edge gradient_edge = Edge()) { | |
+ if (data.defined()) { | |
+ auto impl = new Variable::ViewImpl( | |
+ std::move(base), std::move(data), std::move(gradient_edge)); | |
+ return Variable(impl, /*retain=*/false); | |
+ } | |
+ return Variable(); | |
} | |
-inline Variable make_variable(at::Tensor data, bool requires_grad, bool is_volatile=false) { | |
- return Variable(new VariableImpl(std::move(data), requires_grad, is_volatile), false); | |
+inline Variable make_variable(at::Tensor data, bool requires_grad = false) { | |
+ if (data.defined()) { | |
+ auto impl = new Variable::Impl(data, requires_grad); | |
+ return Variable(impl, /*retain=*/false); | |
+ } | |
+ return Variable(); | |
} | |
- | |
-inline Variable::Variable(VariableImpl * self, bool retain) : Tensor(self, retain) { | |
+inline Variable make_variable(at::Tensor data, Edge gradient_edge) { | |
+ if (data.defined()) { | |
+ auto impl = new Variable::Impl(data, false, std::move(gradient_edge)); | |
+ return Variable(impl, /*retain=*/false); | |
+ } | |
+ return Variable(); | |
} | |
-inline VariableImpl* Variable::get() const { | |
- return static_cast<VariableImpl*>(pImpl); | |
+// Tensor Conversion | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline Variable& as_variable_ref(at::Tensor& tensor) { | |
+#ifdef DEBUG | |
+ // dynamic_cast will return a nullptr if the `TensorImpl`'s dynamic type is | |
+ // not `Variable::Impl`. | |
+ if (dynamic_cast<Variable::Impl*>(tensor.get()) == nullptr) { | |
+ throw std::runtime_error( | |
+ "Attempted to cast a Tensor to a Variable, but " | |
+ "the dynamic type of the value is not Variable."); | |
+ } | |
+#endif | |
+ return static_cast<Variable&>(tensor); | |
} | |
-inline const Tensor & Variable::data() const { | |
+inline const at::Tensor& Variable::data() const noexcept { | |
return get()->data; | |
} | |
-inline Tensor & Variable::data() { | |
+ | |
+inline at::Tensor& Variable::data() noexcept { | |
return get()->data; | |
} | |
-inline Tensor Variable::opt_data() const { | |
- if (!defined()) { | |
- return Tensor(); | |
- } | |
- return data(); | |
+// Gradient Function and Edges | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline const std::shared_ptr<Function>& Variable::grad_fn() const { | |
+ return get()->get_grad_fn(); | |
} | |
-inline const Variable & Variable::grad() const { | |
- return get()->grad; | |
+inline Function* Variable::grad_fn_unsafe() const { | |
+ return get()->grad_fn.get(); | |
} | |
-inline Variable & Variable::grad() { | |
- return get()->grad; | |
+ | |
+inline void Variable::set_grad_accumulator( | |
+ std::weak_ptr<Function> grad_accumulator) { | |
+ get()->grad_accumulator = std::move(grad_accumulator); | |
+} | |
+ | |
+inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const { | |
+ return get()->grad_accumulator.lock(); | |
} | |
-inline const std::shared_ptr<Function>& Variable::grad_fn() const { | |
- return get()->grad_fn; | |
-}; | |
-inline std::shared_ptr<Function>& Variable::grad_fn() { | |
- return get()->grad_fn; | |
-}; | |
inline std::shared_ptr<Function> Variable::grad_accumulator() const { | |
return get()->get_grad_accumulator(); | |
-}; | |
+} | |
-inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() const { | |
- return get()->hooks; | |
-}; | |
-inline std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() { | |
- return get()->hooks; | |
-}; | |
+inline void Variable::set_gradient_edge(Edge edge) noexcept { | |
+ get()->grad_fn = std::move(edge.function); | |
+ get()->output_nr = edge.input_nr; | |
+} | |
-inline auto_unique_ptr<jit::tracer::ValueTracingState>& Variable::tracing_state() const { | |
- return get()->tracing_state; | |
-}; | |
+inline uint32_t Variable::output_nr() const noexcept { | |
+ return get()->output_nr; | |
+} | |
+ | |
+inline bool Variable::is_leaf() const noexcept { | |
+ return get()->grad_fn == nullptr; | |
+} | |
+ | |
+// The Grad Variable | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline const Variable& Variable::grad() const noexcept { | |
+ return get()->grad; | |
+} | |
+ | |
+inline Variable& Variable::grad() noexcept { | |
+ return get()->grad; | |
+} | |
+ | |
+inline void Variable::reset_grad() noexcept { | |
+ get()->grad.reset(); | |
+} | |
-inline int Variable::current_version() const { | |
+inline void Variable::set_requires_grad(bool requires_grad) noexcept { | |
+ get()->requires_grad = requires_grad; | |
+} | |
+ | |
+inline bool Variable::requires_grad() const noexcept { | |
+ return get()->requires_grad || get()->grad_fn || | |
+ (is_view() && base().requires_grad()); | |
+} | |
+ | |
+// Versions | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline void Variable::set_version_counter( | |
+ const VariableVersion& version_counter) noexcept { | |
+ get()->version_counter = version_counter; | |
+} | |
+ | |
+inline void Variable::bump_version() noexcept { | |
+ get()->version_counter.bump(); | |
+} | |
+ | |
+inline uint32_t Variable::current_version() const noexcept { | |
return get()->version_counter.current_version(); | |
} | |
-inline VariableVersion& Variable::version_counter() const { | |
+inline const VariableVersion& Variable::version_counter() const noexcept { | |
return get()->version_counter; | |
} | |
-inline const int& Variable::output_nr() const { | |
- return get()->output_nr; | |
+// Hooks | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) { | |
+ get()->hooks.push_back(std::move(hook)); | |
} | |
-inline int& Variable::output_nr() { | |
- return get()->output_nr; | |
+inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() | |
+ const noexcept { | |
+ return get()->hooks; | |
+} | |
+ | |
+inline void Variable::clear_hooks() { | |
+ get()->hooks.clear(); | |
} | |
-inline const bool& Variable::requires_grad() const { | |
- return get()->requires_grad; | |
+// JIT Tracing | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline bool Variable::has_tracing_state() const noexcept { | |
+ return get()->tracing_state != nullptr; | |
} | |
-inline bool& Variable::requires_grad() { | |
- return get()->requires_grad; | |
+ | |
+// View Variables | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline bool Variable::is_view() const noexcept { | |
+ return get()->is_view; | |
} | |
-inline const bool& Variable::is_volatile() const { | |
- return get()->is_volatile; | |
+inline const Variable& Variable::base() const { | |
+ if (is_view()) { | |
+ return static_cast<Variable::ViewImpl*>(get())->base; | |
+ } | |
+ throw std::runtime_error("Can't get base of non-view"); | |
} | |
-inline bool& Variable::is_volatile() { | |
- return get()->is_volatile; | |
+ | |
+// Miscellaneous | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline void Variable::set_name(const std::string& name) { | |
+ get()->name = name; | |
} | |
-inline Variable & Variable::operator=(Variable && rhs) & { | |
- rhs.swap(*this); | |
- return *this; | |
+inline const std::string& Variable::name() const noexcept { | |
+ return get()->name; | |
} | |
-inline Variable & Variable::operator=(const Variable & rhs) & { | |
- Variable(rhs).swap(*this); | |
- return *this; | |
+ | |
+inline void Variable::set_pyobj(PyObject* pyobj) noexcept { | |
+ get()->pyobj = pyobj; | |
} | |
-inline Variable & Variable::operator=(Tensor && rhs) & { | |
- rhs.swap(*this); | |
- return *this; | |
+ | |
+inline PyObject* Variable::pyobj() const noexcept { | |
+ return get()->pyobj; | |
} | |
-inline Variable & Variable::operator=(const Tensor & rhs) & { | |
- Variable(rhs).swap(*this); | |
- return *this; | |
+ | |
+// Hacks! | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline void Variable::temporary_hack_set_type(at::Type* new_type) noexcept { | |
+ get()->type_ = new_type; | |
} | |
+// Private Methods | |
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
+ | |
+inline Variable::Variable(Variable::Impl* self, bool retain) | |
+ : at::Tensor(self, retain) {} | |
+ | |
+inline Variable::Impl* Variable::get() const noexcept { | |
+ return static_cast<Variable::Impl*>(pImpl); | |
+} | |
}} // namespace torch::autograd | |
diff --git a/torch/csrc/autograd/variable_version.h b/torch/csrc/autograd/variable_version.h | |
index 3dc28bd1..a5c5d9d3 100644 | |
--- a/torch/csrc/autograd/variable_version.h | |
+++ b/torch/csrc/autograd/variable_version.h | |
@@ -1,98 +1,38 @@ | |
#pragma once | |
+#include <atomic> | |
+#include <cstdint> | |
#include <memory> | |
-namespace torch { namespace autograd { | |
- | |
-struct VersionBlock { | |
- VersionBlock() : version(), live_refs(1) {} | |
- | |
- // monotonically increasing version | |
- std::atomic<int> version; | |
- // number of references excluding SavedVariables | |
- std::atomic<int> live_refs; | |
-}; | |
+// Every Variable has a version counter. Version counters are incremented | |
+// whenever the data or shape of a tensor changes through Variable operations. | |
+// These are typicallly in-place operations. Version counters are used to | |
+// detect modifications to saved variables which would result in incorrect | |
+// gradient calculations. Version counters may be shared between Variables: | |
+// | |
+// 1. A view shares the version counter of the base Variable, | |
+// 2. Detached variables share the version counter of the source, | |
+// 3. Unpacked saved variables share the version counter of the source. | |
-struct SavedVersion; | |
+namespace torch { namespace autograd { | |
struct VariableVersion { | |
- VariableVersion() : version_block(std::make_shared<VersionBlock>()) {} | |
- VariableVersion(const VariableVersion&) = delete; | |
- VariableVersion(VariableVersion&&) = delete; | |
- | |
- ~VariableVersion() { | |
- --version_block->live_refs; | |
+ public: | |
+ // NOTE: As of C++11 and 14, default-constructing a std::atomic variable | |
+ // leaves it in a persistently undefined state. See | |
+ // https://cplusplus.github.io/LWG/issue2334. | |
+ VariableVersion(uint32_t version = 0) | |
+ : version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {} | |
+ | |
+ void bump() noexcept { | |
+ version_block_->fetch_add(1); | |
} | |
- // increment the version counter | |
- void increment() { version_block->version++; } | |
- | |
- // current version | |
- int current_version() const { return version_block->version.load(); } | |
- | |
- // number of variables using this version counter (excludes SavedVariables) | |
- int live_refs() const { return version_block->live_refs.load(); } | |
- | |
- // creates a saved reference with the current version and the counter | |
- inline SavedVersion save() const; | |
- | |
- // Uses another variable's version counter. Used for variables which share storages | |
- // NOTE: not thread-safe to call this from multiple threads without synchronization | |
- VariableVersion& operator=(const VariableVersion& other) { | |
- other.version_block->live_refs++; | |
- version_block->live_refs--; | |
- version_block = other.version_block; | |
- return *this; | |
+ uint32_t current_version() const noexcept { | |
+ return version_block_->load(); | |
} | |
- // Uses the version counter from a SavedVariable | |
- // NOTE: not thread-safe to call this from multiple threads without synchronization | |
- inline VariableVersion& operator=(const SavedVersion& other); | |
- | |
-private: | |
- friend struct SavedVersion; | |
- std::shared_ptr<VersionBlock> version_block; // always non-null | |
+ private: | |
+ std::shared_ptr<std::atomic<uint32_t>> version_block_; | |
}; | |
- | |
-// The version counter used in SavedVariables. Saves the expected_version (the | |
-// version at the time of save) and a reference to the version counter's | |
-// version_block. | |
-struct SavedVersion { | |
- SavedVersion() {} | |
- SavedVersion(const VariableVersion& version) | |
- : expected_version(version.current_version()) | |
- , version_block(version.version_block) {} | |
- | |
- // if the version_block has been modified since when it was saved | |
- bool is_modified() const { | |
- return expected_version != version_block->version.load(); | |
- } | |
- | |
- // true if the version_block is defined | |
- bool defined() const { | |
- return static_cast<bool>(version_block); | |
- } | |
- | |
-private: | |
- friend struct VariableVersion; | |
- int expected_version; | |
- std::shared_ptr<VersionBlock> version_block; // may be null | |
-}; | |
- | |
-SavedVersion VariableVersion::save() const { | |
- return SavedVersion(*this); | |
-} | |
- | |
-VariableVersion& VariableVersion::operator=(const SavedVersion& other) { | |
- if (!other.version_block) { | |
- throw std::runtime_error( | |
- "Can't take version counter from empty SavedVersion. File a bug report."); | |
- } | |
- other.version_block->live_refs++; | |
- version_block->live_refs--; | |
- version_block = other.version_block; | |
- return *this; | |
-} | |
- | |
- | |
}} // namespace torch::autograd |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment