Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created March 17, 2018 06:17
Show Gist options
  • Save goldsborough/d0a7afa162a163d954f5fed8e5a52251 to your computer and use it in GitHub Desktop.
Save goldsborough/d0a7afa162a163d954f5fed8e5a52251 to your computer and use it in GitHub Desktop.
diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h
index f04b472e..7ff9a39f 100644
--- a/torch/csrc/autograd/autograd.h
+++ b/torch/csrc/autograd/autograd.h
@@ -2,12 +2,14 @@
#define THP_AUTOGRAD_H
PyObject * THPAutograd_initExtension(PyObject *_unused);
-bool THPAutograd_initFunctions(PyObject* module);
+void THPAutograd_initFunctions();
namespace torch { namespace autograd {
void initAutogradClosureBindings(PyObject* module);
+PyMethodDef* python_functions();
+
}}
#include "torch/csrc/autograd/python_function.h"
diff --git a/torch/csrc/autograd/edge.h b/torch/csrc/autograd/edge.h
new file mode 100644
index 00000000..e6b34911
--- /dev/null
+++ b/torch/csrc/autograd/edge.h
@@ -0,0 +1,56 @@
+#pragma once
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+
+#include "torch/csrc/utils/hash.h"
+
+namespace torch { namespace autograd {
+
+struct Function;
+
+/// Represents a particular input of a function.
+struct Edge {
+ Edge() noexcept : function(nullptr), input_nr(0) {}
+
+ Edge(std::shared_ptr<Function> function_, uint32_t input_nr_) noexcept
+ : function(std::move(function_)), input_nr(input_nr_) {}
+
+ /// Convenience method to test if an edge is valid.
+ bool is_valid() const noexcept {
+ return function != nullptr;
+ }
+
+ // Required for use in associative containers.
+ bool operator==(const Edge& other) const noexcept {
+ return this->function == other.function && this->input_nr == other.input_nr;
+ }
+
+ bool operator!=(const Edge& other) const noexcept {
+ return !(*this == other);
+ }
+
+ /// The function this `Edge` points to.
+ std::shared_ptr<Function> function;
+
+ /// The identifier of a particular input to the function.
+ uint32_t input_nr;
+};
+}} // namespace torch::autograd
+
+// The idiomatic way of enabling use of a custom type as the key of hash
+// containers in C++11. This method removes the requirement of having to pass
+// a custom hasher to std::unordered_{map, set}.
+// See http://en.cppreference.com/w/cpp/utility/hash for more information.
+namespace std {
+template <>
+struct hash<torch::autograd::Edge> {
+ // These type aliases are required by the standard.
+ using argument_type = torch::autograd::Edge;
+ using return_type = size_t;
+ return_type operator()(const argument_type& edge) const noexcept {
+ return torch::get_hash(edge.function, edge.input_nr);
+ }
+};
+} // namespace std
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 756f56e4..369ca601 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -1,5 +1,9 @@
#include "torch/csrc/autograd/engine.h"
+
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/functions/basic_ops.h"
+#include "torch/csrc/autograd/grad_mode.h"
+#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/utils/auto_gpu.h"
#include <atomic>
@@ -7,6 +11,7 @@
#include <cstdint>
#include <functional>
#include <iostream>
+#include <memory>
#include <mutex>
#include <set>
#include <string>
@@ -14,6 +19,7 @@
#include <unordered_set>
#include <typeinfo>
#include <sstream>
+#include <queue>
#include <TH/TH.h>
#ifdef WITH_CUDA
@@ -48,13 +54,19 @@ struct FunctionTask {
, inputs(std::move(inputs)) {}
};
+struct CompareFunctionTaskTime {
+ bool operator()(FunctionTask const & t1, FunctionTask const & t2) {
+ return t1.fn->sequence_nr() < t2.fn->sequence_nr();
+ }
+};
+
struct ReadyQueue {
- std::deque<FunctionTask> queue;
+ std::priority_queue<FunctionTask, std::vector<FunctionTask>, CompareFunctionTaskTime> heap;
std::condition_variable not_empty;
std::mutex mutex;
- void push_front(FunctionTask item);
- FunctionTask pop_back();
+ void push(FunctionTask item);
+ FunctionTask pop();
};
struct GraphTask {
@@ -64,47 +76,66 @@ struct GraphTask {
std::atomic_bool has_error;
std::atomic<uint64_t> outstanding_tasks;
bool keep_graph;
- bool has_any_work;
+ bool grad_mode;
std::mutex mutex;
// Notified when a task finishes executing. Check outstanding_tasks to see
// if all tasks are done.
std::condition_variable not_done;
- const Engine::pre_callback_map& pre_callbacks;
- const Engine::post_callback_map& post_callbacks;
std::unordered_map<Function*, InputBuffer> not_ready;
std::unordered_map<Function*, int> dependencies;
+ struct ExecInfo {
+ struct Capture {
+ Capture(int input_idx, int output_idx) : input_idx(input_idx), output_idx(output_idx) {}
+ int input_idx; // within Function inputs
+ int output_idx; // within the output vector of a GraphTask
+ };
+
+ bool should_execute() const {
+ return needed || captures;
+ }
+
+ bool needed = false;
+ std::unique_ptr<std::vector<Capture>> captures;
+ };
+ // Exec info has a bit complicated semantics. If it's empty, it means the task is
+ // run in a "default" mode, which means that all next_edges we encounter should
+ // get executed. If it's not empty, only functions that have an entry and this entry
+ // has needed == True should be executed.
+ std::unordered_map<Function*, ExecInfo> exec_info;
+ std::vector<Variable> captured_vars;
+
+ void init_to_execute(Function& graph_root, const edge_list& captures);
+
int owner;
- GraphTask(bool keep_graph, const Engine::pre_callback_map& pre_callbacks, const Engine::post_callback_map& post_callbacks)
+ GraphTask(bool keep_graph, bool grad_mode)
: exception()
, has_error(false)
, outstanding_tasks(0)
, keep_graph(keep_graph)
- , has_any_work(false)
+ , grad_mode(grad_mode)
, mutex()
, not_done()
- , pre_callbacks(pre_callbacks)
- , post_callbacks(post_callbacks)
, not_ready()
, dependencies()
, owner(NO_DEVICE) {}
};
-auto ReadyQueue::push_front(FunctionTask item) -> void {
+auto ReadyQueue::push(FunctionTask item) -> void {
{
std::lock_guard<std::mutex> lock(mutex);
++item.base->outstanding_tasks;
- queue.push_front(std::move(item));
+ heap.push(std::move(item));
}
not_empty.notify_one();
}
-auto ReadyQueue::pop_back() -> FunctionTask {
+auto ReadyQueue::pop() -> FunctionTask {
std::unique_lock<std::mutex> lock(mutex);
- not_empty.wait(lock, [this]{ return !queue.empty(); });
- auto task = std::move(queue.back()); queue.pop_back();
+ not_empty.wait(lock, [this]{ return !heap.empty(); });
+ auto task = std::move(const_cast<FunctionTask&>(heap.top())); heap.pop();
return task;
}
@@ -138,8 +169,9 @@ auto Engine::thread_init(int device) -> void {
auto Engine::thread_main(GraphTask *graph_task) -> void {
auto queue = ready_queues[worker_device + 1];
while (!graph_task || graph_task->outstanding_tasks > 0) {
- FunctionTask task = queue->pop_back();
+ FunctionTask task = queue->pop();
if (task.fn && !task.base->has_error.load()) {
+ GradMode::set_enabled(task.base->grad_mode);
try {
evaluate_function(task);
} catch (std::exception& e) {
@@ -166,7 +198,7 @@ auto Engine::thread_main(GraphTask *graph_task) -> void {
if (--task.base->outstanding_tasks == 0) {
// Synchronize outstanding_tasks with queue mutex
std::atomic_thread_fence(std::memory_order_release);
- ready_queue(base_owner).push_front(FunctionTask(task.base, nullptr, InputBuffer(0)));
+ ready_queue(base_owner).push(FunctionTask(task.base, nullptr, InputBuffer(0)));
}
}
}
@@ -182,14 +214,14 @@ auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void
}
static variable_list call_pre_hooks(Function& fn, variable_list inputs) {
- for (auto& hook : fn.pre_hooks) {
+ for (const auto& hook : fn.pre_hooks()) {
inputs = (*hook)(inputs);
}
return inputs;
}
static variable_list call_post_hooks(Function& fn, variable_list outputs, variable_list inputs) {
- for (auto& hook : fn.post_hooks) {
+ for (const auto& hook : fn.post_hooks()) {
outputs = (*hook)(outputs, inputs);
}
return outputs;
@@ -199,61 +231,57 @@ static variable_list call_function(FunctionTask& task) {
auto& fn = *task.fn;
auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs)));
- auto& pre_callbacks = task.base->pre_callbacks;
- for (auto it_p = pre_callbacks.equal_range(&fn); it_p.first != it_p.second; ++it_p.first) {
- auto& callback = it_p.first->second;
- if (!callback(&fn, inputs)) return variable_list(fn.next_functions.size());
+ if(!task.base->keep_graph) {
+ fn.will_release_variables();
}
-
auto outputs = fn(inputs);
- auto& post_callbacks = task.base->post_callbacks;
- for (auto it_p = post_callbacks.equal_range(&fn); it_p.first != it_p.second; ++it_p.first) {
- auto& callback = it_p.first->second;
- if (!callback(&fn, inputs, outputs)) return variable_list(fn.next_functions.size());
- }
-
return call_post_hooks(fn, std::move(outputs), std::move(inputs));
}
auto Engine::evaluate_function(FunctionTask& task) -> void {
+ // If exec_info is not empty, we have to instrument the execution
+ auto & exec_info = task.base->exec_info;
+ if (!exec_info.empty()) {
+ auto & fn_info = exec_info.at(task.fn.get());
+ if (auto *capture_vec = fn_info.captures.get()) {
+ std::lock_guard<std::mutex> lock(task.base->mutex);
+ for (auto capture : *capture_vec) {
+ task.base->captured_vars[capture.output_idx] = task.inputs[capture.input_idx];
+ }
+ }
+ if (!fn_info.needed) return;
+ }
+
auto outputs = call_function(task);
auto& fn = *task.fn;
if (!task.base->keep_graph) {
- fn.releaseVariables();
+ fn.release_variables();
}
- if (outputs.size() != fn.next_functions.size()) {
+ if (outputs.size() != fn.num_outputs()) {
std::stringstream ss;
ss << "Function '" << fn.name() << "' returned an invalid number of outputs - expected ";
- ss << fn.next_functions.size() << ", but got " << outputs.size();
+ ss << fn.num_outputs() << ", but got " << outputs.size();
throw std::runtime_error(ss.str());
}
int num_outputs = outputs.size();
+ if (num_outputs == 0) return; // Don't even acquire the mutex
+ std::lock_guard<std::mutex> lock(task.base->mutex);
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
- auto& next_fn = fn.next_functions[i].first;
- int input_nr = fn.next_functions[i].second;
-
- if (!next_fn) {
- continue;
- }
+ const auto& next = fn.next_edge(i);
- // Stochastic functions are placed in the ready queue by
- // compute_dependencies, so we have to skip them here.
- if (next_fn->is_stochastic || !next_fn->is_executable) {
- continue;
- }
+ if (!next.is_valid()) continue;
- std::lock_guard<std::mutex> lock(task.base->mutex);
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = task.base->dependencies;
- auto it = dependencies.find(next_fn.get());
+ auto it = dependencies.find(next.function.get());
if (it == dependencies.end()) {
- auto name = next_fn->name();
+ auto name = next.function->name();
throw std::runtime_error(std::string("dependency not found for ") + name);
} else if (--it->second == 0) {
dependencies.erase(it);
@@ -261,72 +289,53 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
}
auto& not_ready = task.base->not_ready;
- auto not_ready_it = not_ready.find(next_fn.get());
+ auto not_ready_it = not_ready.find(next.function.get());
if (not_ready_it == not_ready.end()) {
+ // Skip functions that aren't supposed to be executed
+ if (!exec_info.empty()) {
+ auto it = exec_info.find(next.function.get());
+ if (it == exec_info.end() || !it->second.should_execute()) {
+ continue;
+ }
+ }
// No buffers have been allocated for the function
- InputBuffer input_buffer(next_fn->num_inputs);
- input_buffer.add(input_nr, std::move(output));
+ InputBuffer input_buffer(next.function->num_inputs());
+ input_buffer.add(next.input_nr, std::move(output));
if (is_ready) {
auto& queue = ready_queue(input_buffer.device());
- queue.push_front(FunctionTask(task.base, next_fn, std::move(input_buffer)));
+ queue.push(FunctionTask(task.base, next.function, std::move(input_buffer)));
} else {
- not_ready.emplace(next_fn.get(), std::move(input_buffer));
+ not_ready.emplace(next.function.get(), std::move(input_buffer));
}
} else {
// The function already has a buffer
auto &input_buffer = not_ready_it->second;
- input_buffer.add(input_nr, std::move(output));
+ input_buffer.add(next.input_nr, std::move(output));
if (is_ready) {
auto& queue = ready_queue(input_buffer.device());
- queue.push_front(FunctionTask(task.base, next_fn, std::move(input_buffer)));
+ queue.push(FunctionTask(task.base, next.function, std::move(input_buffer)));
not_ready.erase(not_ready_it);
}
}
}
}
-/** Finds all stochastic functions and appends them to the queue */
-auto Engine::find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task) -> void {
- std::unordered_set<Function*> seen {graph_root};
- function_queue search_queue {graph_root};
- while (search_queue.size() > 0) {
- auto fn = search_queue.back(); search_queue.pop_back();
- for (auto& next_fn_pair : fn->next_functions) {
- auto& next_fn = next_fn_pair.first;
- Function* next_ptr = next_fn.get();
- if (!next_ptr) continue;
- if (next_ptr->is_stochastic && next_ptr->is_executable && seen.count(next_ptr) == 0) {
- ready_queue(-1).push_front(FunctionTask(&task, next_fn, InputBuffer(0)));
- queue.push_back(next_ptr);
- task.has_any_work = true;
- }
- if (seen.count(next_ptr) == 0) {
- seen.insert(next_ptr);
- search_queue.push_back(next_ptr);
- }
- }
- }
-}
-
-/** Computes the number of dependencies for each function which requires grad */
-auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void {
+/* Computes the number of dependencies for each function which requires grad */
+auto Engine::compute_dependencies(Function* root, GraphTask& task) -> void {
// Just to make sure that they will never be added to the queue again
- std::unordered_set<Function*> seen(queue.begin(), queue.end());
+ std::unordered_set<Function*> seen;
+ std::vector<Function*> queue { root };
// Queue contains all nodes that will start propagating gradients.
// We no longer have to expand functions that don't require grad.
auto& dependencies = task.dependencies;
while (queue.size() > 0) {
- auto fn = std::move(queue.back()); queue.pop_back();
- for (auto& next_fn_pair : fn->next_functions) {
- Function* next_ptr = next_fn_pair.first.get();
- if (!next_ptr) continue;
- if (!next_ptr->is_executable) continue;
- if (next_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already
- dependencies[next_ptr] += 1;
- if (seen.count(next_ptr) == 0) {
- seen.insert(next_ptr);
- queue.push_back(next_ptr);
+ auto fn = queue.back(); queue.pop_back();
+ for (const auto& edge : fn->next_edges()) {
+ if (auto next_ptr = edge.function.get()) {
+ dependencies[next_ptr] += 1;
+ const bool was_inserted = seen.insert(next_ptr).second;
+ if (was_inserted) queue.push_back(next_ptr);
}
}
}
@@ -348,40 +357,25 @@ struct ClearCallbacks {
std::mutex& callbacks_lock;
};
-auto Engine::execute(const function_list& input_roots,
+auto Engine::execute(const edge_list& input_roots,
const variable_list& inputs,
bool keep_graph,
- const pre_callback_map& pre_callbacks,
- const post_callback_map& post_callbacks) -> void {
+ bool create_graph,
+ const edge_list& outputs) -> variable_list {
std::call_once(start_threads_flag, &Engine::start_threads, this);
// Callbacks are only valid for the duration of this run and should always be cleared
ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock);
- GraphTask graph_task(keep_graph, pre_callbacks, post_callbacks);
-
+ GraphTask graph_task(keep_graph, create_graph);
std::unique_lock<std::mutex> lock(graph_task.mutex);
+ // Now compute the dependencies for all executable functions and queue the root
auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs);
- function_queue roots;
- for (auto entry : input_roots) {
- if (entry.first->is_executable) {
- graph_task.has_any_work = true;
- roots.push_back(graph_root.get());
- ready_queue(-1).push_front(FunctionTask(&graph_task, graph_root, InputBuffer(0)));
- break;
- }
+ compute_dependencies(graph_root.get(), graph_task);
+ if (!outputs.empty()) {
+ graph_task.init_to_execute(*graph_root, outputs);
}
-
- // Search the graph and find all stochastic functions. Append them to the queue.
- find_stochastic_functions(roots, graph_root.get(), graph_task);
-
- if (!graph_task.has_any_work) {
- throw std::runtime_error(
- "there are no graph nodes that require computing gradients");
- }
-
- // Now compute the dependencies for all executable functions
- compute_dependencies(std::move(roots), graph_task);
+ ready_queue(-1).push(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0)));
// Not a worker
if (worker_device == NO_DEVICE) {
@@ -413,6 +407,8 @@ auto Engine::execute(const function_list& input_roots,
final_callbacks[i]();
cb_lock.lock();
}
+
+ return graph_task.captured_vars;
}
void Engine::queue_callback(std::function<void()> callback) {
@@ -444,4 +440,69 @@ auto Engine::start_threads() -> void {
}
}
+void GraphTask::init_to_execute(Function& graph_root, const edge_list& outputs) {
+ exec_info[&graph_root].needed = true;
+
+ int output_idx = 0;
+ for (auto & output_edge : outputs) {
+ Function *output = output_edge.function.get();
+ auto & info = exec_info[output];
+ if (!info.captures)
+ info.captures.reset(new std::vector<ExecInfo::Capture>());
+ info.captures->emplace_back(output_edge.input_nr, output_idx++);
+ }
+ captured_vars.resize(output_idx);
+
+ // NB: this is an uglier version (recursion replaced with iteration) of the following code:
+ // is_needed = {}
+ // def compute_is_needed(fn):
+ // if fn not in is_needed:
+ // is_needed[fn] = any(compute_is_needed(next_edge)
+ // for next_edge in fn.next_edges)
+ // return is_needed[fn]
+ struct Frame {
+ Frame (Function *fn) : fn(fn), next_next_fn(0) {}
+ Function *fn;
+ std::size_t next_next_fn;
+
+ Function* get_next_fn() {
+ const auto & next = fn->next_edges();
+ auto num_next = next.size();
+ while (next_next_fn < num_next) {
+ auto fn = next[next_next_fn++].function.get();
+ if (fn) return fn;
+ }
+ return nullptr;
+ }
+ };
+ std::vector<Frame> stack;
+ std::unordered_set<Function*> seen;
+ for (const auto & input : graph_root.next_edges()) {
+ if (seen.count(input.function.get()) > 0) continue;
+ stack.emplace_back(input.function.get());
+ while (!stack.empty()) {
+ auto &frame = stack.back();
+ if (Function *next_fn = frame.get_next_fn()) {
+ if (/* bool unseen = */ seen.emplace(next_fn).second) {
+ stack.emplace_back(next_fn);
+ continue; // recurse
+ }
+ } else {
+ // NB: if we were using real recursion we could have saved some lookups
+ // using a return value from recursive call. It would make this manually unrolled
+ // version a lot more complicated, so I skipped that.
+ const auto & next_edges = frame.fn->next_edges();
+ const bool needed = std::any_of(
+ next_edges.begin(), next_edges.end(), [&](const Edge& edge) {
+ auto it = exec_info.find(edge.function.get());
+ return it != exec_info.end() && it->second.should_execute();
+ });
+ exec_info[frame.fn].needed = needed;
+ stack.pop_back();
+ }
+ }
+ }
+}
+
+
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h
index ff1ae45c..eac677a3 100644
--- a/torch/csrc/autograd/engine.h
+++ b/torch/csrc/autograd/engine.h
@@ -27,32 +27,21 @@ struct Engine {
virtual ~Engine();
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;
- using function_queue = std::vector<Function*>;
using dependencies_type = std::unordered_map<Function*, int>;
- using pre_callback_type = std::function<bool (Function*, variable_list&)>;
- using pre_callback_map = std::unordered_multimap<Function*, pre_callback_type>;
- using post_callback_type = std::function<bool (Function*, variable_list&, variable_list&)>;
- using post_callback_map = std::unordered_multimap<Function*, post_callback_type>;
-
// Given a list of (Function, input number) pairs computes the value of the graph
- // by following next_function references.
- virtual void execute(
- const function_list& roots,
+ // by following next_edge references.
+ virtual variable_list execute(
+ const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
- const pre_callback_map& pre_callbacks = pre_callback_map(),
- const post_callback_map& post_callbacks = post_callback_map());
+ bool create_graph,
+ const edge_list& outputs = {});
void queue_callback(std::function<void()> callback);
protected:
- function_queue find_roots(
- const function_list& roots,
- variable_list& inputs,
- GraphTask& task);
- void find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task);
- void compute_dependencies(function_queue queue, GraphTask& task);
+ void compute_dependencies(Function* root, GraphTask& task);
void evaluate_function(FunctionTask& task);
ReadyQueue& ready_queue(int device);
void start_threads();
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp
index b22c3d51..42af461f 100644
--- a/torch/csrc/autograd/function.cpp
+++ b/torch/csrc/autograd/function.cpp
@@ -1,51 +1,24 @@
-#include "function.h"
+#include <Python.h>
-#include <string>
+#include "torch/csrc/autograd/function.h"
-#include "variable.h"
-#include "torch/csrc/jit/ir.h"
#include "torch/csrc/autograd/functions/special.h"
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/jit/ir.h"
-namespace torch { namespace autograd {
+#include <ATen/ATen.h>
-template<typename T>
-auto makeFlags(const T &inputs) -> FunctionFlags {
- int num_inputs = inputs.size();
- FunctionFlags f;
- f.is_executable = false;
- f.is_volatile = false;
- f.next_functions.resize(num_inputs);
- {
- int i = 0;
- for (auto it = inputs.begin(); it != inputs.end(); ++it, ++i) {
- auto& var = *it;
- if (var.defined()) {
- f.is_executable |= var.requires_grad();
- f.is_volatile |= var.is_volatile();
- if (var.grad_fn()) {
- f.next_functions[i] = std::make_pair<>(var.grad_fn(), var.output_nr());
- } else {
- f.next_functions[i] = std::make_pair<>(var.grad_accumulator(), 0);
- }
- }
- }
- }
- f.is_executable &= !f.is_volatile;
- return f;
-}
-
-auto Function::flags(const variable_list& inputs) -> FunctionFlags {
- return makeFlags(inputs);
-}
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <utility>
+#include <vector>
-auto Function::flags(const std::initializer_list<Variable>& inputs) -> FunctionFlags {
- return makeFlags(inputs);
-}
+namespace torch { namespace autograd {
-auto Function::flags(at::TensorList inputs) -> FunctionFlags {
- // TODO: Eliminate the intermediate vector allocation
- return makeFlags(variable_list(inputs.begin(), inputs.end()));
-}
+thread_local uint64_t Function::next_sequence_nr_ = 0;
auto Function::name() -> std::string {
return std::string(typeid(*this).name());
@@ -54,7 +27,7 @@ auto Function::name() -> std::string {
// This function is analogous to make_trace which operates on PythonOp, but this
// function instead works for C++ implemented autograd Functions, which don't
// actually have any backing Python class. We still need to trace them!
-variable_list Function::tracedApply(variable_list inputs) {
+variable_list Function::traced_apply(variable_list inputs) {
using namespace torch::jit;
// Traceable Functions are completely transparent to the JIT.
if (is_traceable()) {
@@ -65,7 +38,14 @@ variable_list Function::tracedApply(variable_list inputs) {
// Insert a CppOp in the trace.
auto& graph = state->graph;
- auto* this_node = graph->createCppOp(getSharedPtr());
+ std::vector<VariableFlags> var_flags;
+ for(auto & input: inputs) {
+ var_flags.push_back(VariableFlags::of(input));
+ }
+ auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
+ this_node->setSourceLocation(std::make_shared<StringSourceLocation>(
+ jit::tracer::getPythonInterpreterStackTrace()
+ ));
for (auto& input: inputs) {
this_node->addInput(tracer::getValueTrace(state, input));
}
@@ -80,7 +60,7 @@ variable_list Function::tracedApply(variable_list inputs) {
int num_outputs = outputs.size();
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
- Node* sel = graph->appendNode(graph->createSelect(this_node, i));
+ auto sel = this_node->addOutput();
// TODO: At the moment, C++ does not track shared storage. It
// should. Update this when that happens.
if (output.defined()) {
@@ -97,29 +77,30 @@ variable_list Function::tracedApply(variable_list inputs) {
// There's no point in wrapping functions in Eval, if we know they already are
// part of another Eval subgraph. This is both a small optimization, and
// it allows us to not implement saved_variables() in many functions.
- bool should_trace_backward = tracing_state->in_eval_subgraph;
+ const bool should_trace_backward = tracing_state_->in_eval_subgraph;
if (!should_trace_backward) {
auto saved_vars = saved_variables();
if (!saved_vars)
- throw std::runtime_error(std::string("saved_variables() needed but not implemented in ") + name());
+ throw std::runtime_error("saved_variables() needed but not implemented in " + name());
variable_list bw_subgraph_inputs(inputs);
for (auto& saved_var : *saved_vars) {
- bw_subgraph_inputs.emplace_back(saved_var.unpack(getSharedPtr()));
+ bw_subgraph_inputs.emplace_back(saved_var.unpack(get_shared_ptr()));
}
tracer::nontraceableBackwardSubgraph(bw_subgraph_inputs, outputs);
}
bool has_backwards_eval = !should_trace_backward || this_eval;
if (has_backwards_eval)
- setUpContextEdge(this_node, num_outputs, inputs, outputs);
+ set_up_context_edge(this_node, inputs, outputs);
}
return outputs;
}
-void Function::setUpContextEdge(jit::Node* node, int ctx_output_nr,
- const variable_list& inputs, const variable_list& outputs) {
- jit::Graph* graph = node->owningGraph();
- jit::Node* ctx_select = graph->appendNode(graph->createSelect(node, ctx_output_nr));
- ctx_select->setType(std::make_shared<jit::HandleType>());
+void Function::set_up_context_edge(
+ jit::Node* this_node,
+ const variable_list& inputs,
+ const variable_list& outputs) {
+ auto ctx_select = this_node->addOutput();
+ ctx_select->setType(jit::HandleType::get());
auto backward_eval = Eval::getBackwardEval(inputs, outputs);
if (backward_eval)
backward_eval->forward_ctx_select = ctx_select;
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index a7efd3eb..ed6955c5 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -1,213 +1,374 @@
#pragma once
-// Function is an abstract class that represents a single operation from one or
-// more variables to one more or variables.
-//
-// Subclasses may represent "forward" or "backward" operations (i.e functions
-// and their derivatives). Some functions may be used as both.
-
+#include "torch/csrc/assertions.h"
+#include "torch/csrc/autograd/edge.h"
+#include "torch/csrc/autograd/grad_mode.h"
+#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/autograd/saved_variable.h"
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/utils/auto_unique_ptr.h"
#include "torch/csrc/utils/python_stub.h"
-#include "torch/csrc/autograd/function_hook.h"
-#include "torch/csrc/autograd/profiler.h"
-#include "torch/csrc/jit/tracer.h"
+#include "torch/csrc/utils/variadic.h"
#include <ATen/ATen.h>
+#include <algorithm>
+#include <cstdint>
+#include <initializer_list>
#include <memory>
+#include <string>
+#include <utility>
#include <vector>
namespace torch { namespace autograd {
-struct Function;
-struct Variable;
+struct Edge;
+struct FunctionPostHook;
+struct FunctionPreHook;
using tensor_list = std::vector<at::Tensor>;
using variable_list = std::vector<Variable>;
-using edge_type = std::pair<std::shared_ptr<Function>, int>;
-using function_list = std::vector<edge_type>;
+using edge_list = std::vector<Edge>;
using saved_variable_list = std::vector<SavedVariable>;
-
-struct edge_hasher {
- std::size_t operator()(const edge_type& edge) const {
-#define HASH_IDX(idx) std::hash<std::tuple_element<idx, edge_type>::type>()(std::get<idx>(edge))
- // TODO: that's probably a bad hash function, but whatever
- return HASH_IDX(0) ^ HASH_IDX(1);
- }
-};
-
-// State used to create "backward" functions
-struct FunctionFlags {
- // Roughly speaking, is_executable corresponds to requires_grad.
- // See http://pytorch.org/docs/notes/autograd.html for more details:
- // both is_executable and is_volatile specify whether or not backwards
- // gradient computation will be performed for a function, but they differ in
- // their precedence.
- bool is_executable = false;
- bool is_volatile = false;
- // What functions take the output of this function as input.
- // There is one function per output of this function.
- function_list next_functions;
-};
-
+using IndexRange = std::pair<size_t, size_t>;
+
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// Function
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// A `Function` is an abstract class that represents an operation taking zero
+/// or more input `Variable`s and producing zero or more output `Variable`s. All
+/// functions in PyTorch's autograd machinery derive from this class and
+/// override its `apply` method. Instances of such subclasses will then be
+/// invokeable via the call operator.
+///
+/// Functions in the Autograd Graph
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// When viewing the autograd system as a graph, `Function`s are the vertices or
+/// nodes, connected to each other via (directed) `Edge`s, which themselves are
+/// represented via (`Function`, input_nr) pairs. `Variable`s are the outputs to
+/// and inputs of `Function`s, and travel between these edges during execution
+/// of the graph. When two or more `Edge`s (from different sources) point at the
+/// same input to a `Function`, the values produced along all of these edges are
+/// implicitly summed prior to being forwarded to the target `Function`.
+///
+/// Hierarchy
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// Subclasses usually represent differentiable functions as well as their
+/// gradient operators. Note, however, that due to the very general definition
+/// of a `Function` taking *zero* or more inputs and producing *zero* or more
+/// outputs, uses of `Function`s are flexible and extend beyond purely
+/// mathematical operations. For example, the `AccumulateGrad` function is a
+/// *sink*: it takes one input, but produces no outputs, instead accumulating
+/// the input as a side effect. At the other extreme, the `GraphRoot` function
+/// receives no inputs from other functions, but produces multiple outputs.
+///
+/// Interface
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// The most important method on `Function` is the call operator, which takes in
+/// a list of variables and produces a list of variables. The precise size of
+/// these lists can be determined with `num_inputs()` and `num_outputs()`.
+/// `Function`s are stitched together via their `next_edge` interface, which let
+/// you manipulate the set of outgoing edges of a `Function`. You can add an
+/// edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and
+/// iterate over them via the `next_edges()` method. Other methods exist for
+/// integration with the JIT and other parts of PyTorch. Every `Function` has a
+/// *sequence number* that increases monotonically in the order of `Function`
+/// construction. It can be retrieved via the `sequence_nr()` method. Note that
+/// this sequence number is *thread local*. This means that when `Function`s
+/// `A`, `B` and `C` are created consecutively in the same thread, their
+/// sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B`
+/// are created in one thread and `C` is created in a new thread, there are *no
+/// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`.
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct Function : std::enable_shared_from_this<Function> {
- Function()
- : num_inputs(0)
- , next_functions()
- , is_executable(false)
- , is_stochastic(false)
- , pre_hooks()
- , post_hooks()
- , pyobj(nullptr)
- {}
-
- Function(FunctionFlags&& flags)
- : num_inputs(0)
- , next_functions(std::move(flags.next_functions))
- , is_executable(flags.is_executable)
- , is_stochastic(false)
- , pre_hooks()
- , post_hooks()
- , pyobj(nullptr)
- {}
-
+ public:
+ /// Construct a new `Function` with `num_inputs` inputs and the given
+ /// `next_edges`.
+ explicit Function(
+ uint32_t num_inputs = 0,
+ edge_list&& next_edges = edge_list())
+ : sequence_nr_(next_sequence_nr_++),
+ num_inputs_(num_inputs),
+ next_edges_(std::move(next_edges)) {}
+
+ /// Functions are neither copyable nor moveable.
Function(const Function& other) = delete;
Function(Function&& other) = delete;
- virtual ~Function() {}
-
- // Implements the operation
- // NOTE: Don't call this function directly. Use apply_fn or operator() instead.
- virtual variable_list apply(const variable_list& inputs) = 0;
- variable_list tracedApply(variable_list inputs);
+ Function& operator=(const Function& other) = delete;
+ Function& operator=(Function&& other) = delete;
+ virtual ~Function() = default;
+ /// Evaluates the function on the given inputs and returns the result of the
+ /// function call.
variable_list operator()(const variable_list& inputs) {
profiler::RecordFunction rec(this);
- if (jit::tracer::isTracing(inputs)) {
- return tracedApply(inputs);
+ if (jit::tracer::isTracingVar(inputs)) {
+ return traced_apply(inputs);
}
return apply(inputs);
}
- // PyFunctions are not managed by shared_ptrs by default, but are bound to the
- // lifetime of their Python object instead.
- virtual std::shared_ptr<Function> getSharedPtr() {
- return shared_from_this();
- };
+ // Graph Connectivity API
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // Computes is_executable, is_volatile, and next_functions from a list
- // of input variables
- static FunctionFlags flags(const variable_list& inputs);
- static FunctionFlags flags(const std::initializer_list<Variable>& inputs);
- static FunctionFlags flags(at::TensorList inputs);
+ // Inputs
- // Releases saved variables if the operation won't be reused
- virtual inline void releaseVariables() {}
+ /// Increments the number of inputs of the function and returns the previous
+ /// value.
+ uint32_t bump_inputs() noexcept {
+ return num_inputs_++;
+ }
- // Function name for debugging
- virtual std::string name();
+ void set_num_inputs(uint32_t num_inputs) noexcept {
+ num_inputs_ = num_inputs;
+ }
- inline bool should_compute_output(int i) const {
- auto& fn = next_functions[i].first;
- return fn && fn->is_executable;
+ uint32_t num_inputs() const noexcept {
+ return num_inputs_;
}
- inline bool should_compute_any_outputs() const {
- for (size_t i = 0; i < next_functions.size(); ++i) {
- if (should_compute_output((int)i)) {
- return true;
- }
- }
- return false;
+ // Outputs ("Next Edges")
+
+ const Edge& next_edge(size_t index) const noexcept {
+ return next_edges_[index];
+ }
+
+ void set_next_edge(size_t index, Edge edge) {
+ next_edges_[index] = std::move(edge);
+ }
+
+ void add_next_edge(Edge edge) {
+ next_edges_.push_back(std::move(edge));
+ }
+
+ void set_next_edges(edge_list&& next_edges) {
+ next_edges_ = std::move(next_edges);
+ }
+
+ const edge_list& next_edges() const noexcept {
+ return next_edges_;
+ }
+
+ edge_list& next_edges() noexcept {
+ return next_edges_;
+ }
+
+ uint32_t num_outputs() const noexcept {
+ return next_edges_.size();
+ }
+
+ // Miscellaneous Methods
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ /// The sequence number of this `Function`.
+ uint64_t sequence_nr() const noexcept {
+ return sequence_nr_;
}
- inline bool should_compute_output(std::initializer_list<int> idxs) const {
- return std::any_of(idxs.begin(), idxs.end(), [this](int i) {
- return should_compute_output(i);
+ /// Returns a shared pointer to `this`. `PyFunction`s are not managed by
+ /// `shared_ptr`s by default, but are bound to the lifetime of their Python
+ /// object instead.
+ virtual std::shared_ptr<Function> get_shared_ptr() {
+ return shared_from_this();
+ }
+
+ /// Returns the name of the dynamic type of the function, for debugging.
+ virtual std::string name();
+
+ /// Returns true if the particular output edge is active, and that particular
+ /// output of this function should be computed.
+ bool should_compute_output(size_t output_edge_index) const {
+ TORCH_ASSERTM(output_edge_index < num_outputs(), "Index out of range");
+ return next_edges_[output_edge_index].is_valid();
+ }
+
+ /// Returns true if any of the output edges in any of the ranges are active.
+ bool should_compute_output(std::initializer_list<IndexRange> idxs) const {
+ return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
+ for (auto i = range.first; i < range.second; i++) {
+ if (should_compute_output(i))
+ return true;
+ }
+ return false;
});
}
- inline void set_flags(FunctionFlags&& flags) {
- is_executable = flags.is_executable;
- next_functions = std::move(flags.next_functions);
+ jit::tracer::FunctionTracingState& tracing_state() noexcept {
+ // Dereferencing will create the `TracingState` if the pointer is empty.
+ return *tracing_state_;
}
- // An op is traceable if all operations happening within apply() are performed
- // on autograd Variables (i.e. apply mostly instantiates and applies other functions).
- virtual inline bool is_traceable() { return false; };
+ /// Returns the `PyObject` stored for this `Function` (for Python
+ /// interaction).
+ PyObject* pyobj() const noexcept {
+ return pyobj_;
+ }
- // An op is said to pass state transparently to backward, if the state consists
- // only of (Saved)Variables and only non-variable objects that parametrize the
- // operation in some way that defines the graph structure AND the backward function
- // is traceable. In particular, parametrization MUST NOT depend on the data
- // of any Variable.
- // TODO: it might be possible to handle cases where backward is non-traceable
- // but state passing could be considered transparent. This will probably depend
- // on saved_variable_list being mutable.
- // NOTE: this value matters only if is_traceable() returns false.
- virtual inline bool passes_state_transparently() { return false; };
+ /// Sets the `PyObject` stored for this `Function` (for Python interaction).
+ void set_pyobj(PyObject* pyobj) noexcept {
+ pyobj_ = pyobj;
+ }
- // Let's the JIT find inputs to apply that are not present explicitly in arguments.
- // Required only for functions that are not traceable, don't pass state to
- // backward transparently, and are not backwards closures of functions that don't
- // pass the state transparently. Which means that hopefully they will hardly ever
- // need to be implemented :)
- virtual inline std::unique_ptr<saved_variable_list> saved_variables() { return nullptr; }
+ /// Create a context edge for the JIT.
+ static void set_up_context_edge(
+ jit::Node* this_node,
+ const variable_list& inputs,
+ const variable_list& outputs);
- static void setUpContextEdge(jit::Node* this_node, int ctx_output_nr,
- const variable_list& inputs, const variable_list& outputs);
+ // Hook API
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- int num_inputs;
- function_list next_functions;
- bool is_executable;
- bool is_stochastic;
- std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks;
- std::vector<std::shared_ptr<FunctionPostHook>> post_hooks;
+ void add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
+ post_hooks_.push_back(std::move(post_hook));
+ }
- PyObject *pyobj; // weak reference
+ const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const
+ noexcept {
+ return post_hooks_;
+ }
- auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state;
-};
+ std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
+ return post_hooks_;
+ }
-// Actually what is a ForwardFunction here applies to all functions that are
-// applied only in forward OR are backward closures that don't save any Variables.
-// I chose this name, because the second situation is quite rare.
-template<bool transparent_state = false>
-struct ForwardFunction : public Function {
- using Function::Function;
+ void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
+ pre_hooks_.push_back(std::move(pre_hook));
+ }
- virtual inline std::unique_ptr<saved_variable_list> saved_variables() final {
- return std::unique_ptr<saved_variable_list>(new saved_variable_list());
+ const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() const
+ noexcept {
+ return pre_hooks_;
}
- virtual inline bool is_traceable() final { return false; };
+ std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept {
+ return pre_hooks_;
+ }
- virtual inline bool passes_state_transparently() final { return transparent_state; };
-};
+ // Customization Points for Subclasses
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// See Function::is_traceable() for definition.
-struct TraceableFunction : public Function {
- using Function::Function;
+ /// Releases saved variables if the operation won't be reused.
+ virtual void release_variables() {}
- virtual inline bool is_traceable() final { return true; };
-};
+ /// Called before an apply if `release_variables()` is going to be called.
+ /// Allows larger ops like `InterpreterAutogradFunction` to incrementally
+ /// release variables as they run.
+ virtual void will_release_variables() {}
-template<typename T>
-struct apply_fn {
- template<typename... Args>
- apply_fn(Args&& ...args)
- : fn_(std::make_shared<T>(std::forward<Args>(args)...)) {}
+ /// Returns true if this function is traceable. An op is traceable if all
+ /// operations happening within `apply()` are performed on autograd
+ /// `Variables` (i.e. apply mostly instantiates and applies other functions).
+ virtual bool is_traceable() {
+ return false;
+ }
- Variable operator()(const variable_list& inputs) {
- return (*fn_)(inputs)[0];
+ /// A `Function` is said to pass state transparently to backward, if the
+ /// state consists only of (Saved)Variables and only non-variable objects
+ /// that parameterize the operation in some way that defines the graph
+ /// structure AND the backward function is traceable. In particular,
+ /// parametrization MUST NOT depend on the data of any `Variable`.
+ /// TODO: it might be possible to handle cases where backward is
+ /// non-traceable but state passing could be considered transparent. This
+ /// will probably depend on saved_variable_list being mutable.
+ /// NOTE: this value matters only if is_traceable() returns false.
+ virtual bool passes_state_transparently() {
+ return false;
}
- template<typename... Args>
- Variable operator()(Args&& ...inputs) {
- return (*fn_)(variable_list{inputs...})[0];
+ /// Returns `Variable`s saved by this `Function`.
+ /// This let's the JIT find inputs to apply that are not present explicitly
+ /// in arguments. Required only for functions that are not traceable, don't
+ /// pass state to backward transparently, and are not backwards closures of
+ /// functions that don't pass the state transparently. Which means that
+ /// hopefully they will hardly ever need to be implemented :)
+ virtual std::unique_ptr<saved_variable_list> saved_variables() {
+ return nullptr;
}
- std::shared_ptr<T> fn_;
+ protected:
+ /// Monotonically incrementing (thread local!) counter to supply sequence
+ /// numbers.
+ static thread_local uint64_t next_sequence_nr_;
+
+ /// Performs the `Function`'s actual operation.
+ virtual variable_list apply(const variable_list& inputs) = 0;
+
+ /// Calls `apply()`, but instruments it with tracing machinery.
+ variable_list traced_apply(variable_list inputs);
+
+ // Since `Function`s are neither copyable nor moveable, we can have const
+ // fields.
+ const uint64_t sequence_nr_;
+
+ uint32_t num_inputs_;
+ edge_list next_edges_;
+ PyObject* pyobj_ = nullptr; // weak reference
+ std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
+ std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
+ auto_unique_ptr<jit::tracer::FunctionTracingState> tracing_state_;
};
+/// See Function::is_traceable() for definition.
+struct TraceableFunction : public Function {
+ using Function::Function;
+ bool is_traceable() final override {
+ return true;
+ }
+};
+
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Associated Free Functions
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+namespace detail {
+// Implementation of `collect_next_edges` (see below).
+struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
+ edge_list next_edges;
+ using IterArgs<MakeNextFunctionList>::operator();
+ void operator()(const Variable& variable) {
+ if (variable.defined()) {
+ next_edges.push_back(variable.gradient_edge());
+ } else {
+ next_edges.emplace_back();
+ }
+ }
+};
+} // namespace detail
+
+/// Create an `Edge` between the given `variable` and the `function`, which is
+/// assumed to be the gradient function of this variable (i.e. the function
+/// through which this variable is backpropagated during the backward pass).
+/// This sets the `grad_fn` property of the `variable`. This function assumes
+/// that the `Variable` is a new input to the gradient function and its
+/// `input_nr` thus equal to `function->num_inputs()`. Additionally, it
+/// increments the `Function`'s number of inputs by one. Approximately
+/// equivalent to `variable.set_gradient_edge(function,
+/// function->bump_inputs())`. If you don't want the `Function`'s `num_inputs`
+/// to be incremented, use `set_gradient_edge` directly.
+inline void create_gradient_edge(
+ Variable& variable,
+ std::shared_ptr<Function> function) {
+ // Copy before move.
+ const auto input_nr = function->bump_inputs();
+ variable.set_gradient_edge({std::move(function), input_nr});
+}
+
+/// Return true if any of the variables in the list require a gradient.
+inline bool any_variable_requires_grad(const variable_list& variables) {
+ return std::any_of(
+ variables.begin(), variables.end(), [](const Variable& variable) {
+ return variable.defined() && variable.requires_grad();
+ });
+}
+
+/// Return the next edges of all the given variables, or tuples of variables.
+template <typename... Variables>
+edge_list collect_next_edges(Variables&&... variables) {
+ if (!GradMode::is_enabled())
+ return {};
+ detail::MakeNextFunctionList make;
+ make.apply(std::forward<Variables>(variables)...);
+ return std::move(make.next_edges);
+}
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h
index 4c195917..03c52fea 100644
--- a/torch/csrc/autograd/function_hook.h
+++ b/torch/csrc/autograd/function_hook.h
@@ -1,6 +1,5 @@
#pragma once
-#include <memory>
#include <vector>
// A hook that's called on gradients
diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp
index 9c1ff87e..d8c9457c 100644
--- a/torch/csrc/autograd/functions/accumulate_grad.cpp
+++ b/torch/csrc/autograd/functions/accumulate_grad.cpp
@@ -1,5 +1,8 @@
-#include "accumulate_grad.h"
+#include <Python.h>
+#include "torch/csrc/autograd/functions/accumulate_grad.h"
+
+#include "torch/csrc/autograd/grad_mode.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/functions/basic_ops.h"
#include "torch/csrc/autograd/functions/tensor.h"
@@ -11,10 +14,7 @@ using at::Tensor;
namespace torch { namespace autograd {
AccumulateGrad::AccumulateGrad(Variable variable_)
- : variable(std::move(variable_)) {
- num_inputs = 1;
- is_executable = 1;
-}
+ : Function(/*num_inputs=*/1), variable(std::move(variable_)) {}
auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
// XXX: this method is not thread-safe!
@@ -24,8 +24,6 @@ auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
return {};
if (variable.grad_fn())
throw std::logic_error("leaf variable has been moved into the graph interior");
- if (variable.current_version() != 0)
- throw std::runtime_error("leaf variable was used in an inplace operation");
if (!variable.requires_grad())
return {};
@@ -34,31 +32,22 @@ auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
new_grad = (*hook)({new_grad})[0];
}
- // TODO: Currently if var.grad is volatile and new-grad is non-volatile we
- // accumulate in-place. We should reconsider this and perhaps add the
- // gradients out-of-place.
-
auto& grad = variable.grad();
if (!grad.defined()) {
- grad = apply_fn<Clone>()(new_grad);
- } else if (grad.is_volatile()) {
+ variable.grad() = new_grad.clone();
+ } else if (!GradMode::is_enabled()) {
// This case is not strictly necessary, but it makes the first-order only case
// slightly more efficient and, what's more important, more predictable for
// the users. Thanks to this case we can avoid changing the grad tensor,
// a thing never promised and documented, but used in some hacks seen
// on the internet.
- AutoGPU guard(grad);
- if (grad.type().isSparse() && !new_grad.type().isSparse()) {
+ if (grad.type().is_sparse() && !new_grad.type().is_sparse()) {
grad.data() = new_grad.data() + grad.data();
} else {
grad.data() += new_grad.data();
}
} else {
- // If grad is non-volatile, it should stay like that
- if (new_grad.is_volatile()) {
- new_grad = make_variable(new_grad.data());
- }
- variable.grad() = apply_fn<Add>()(grad, new_grad);
+ variable.grad() = grad + new_grad;
}
return variable_list();
diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h
index ff6dd0b8..4a3e3bd2 100644
--- a/torch/csrc/autograd/functions/accumulate_grad.h
+++ b/torch/csrc/autograd/functions/accumulate_grad.h
@@ -6,11 +6,11 @@
namespace torch { namespace autograd {
struct AccumulateGrad : public Function {
- AccumulateGrad(Variable variable);
+ explicit AccumulateGrad(Variable variable);
virtual variable_list apply(const variable_list& inputs) override;
Variable variable;
};
-}}
+}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp
index 2b2a7eb8..813706ad 100644
--- a/torch/csrc/autograd/functions/basic_ops.cpp
+++ b/torch/csrc/autograd/functions/basic_ops.cpp
@@ -1,9 +1,15 @@
-#include "basic_ops.h"
+#include "torch/csrc/autograd/functions/basic_ops.h"
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"
+#include <ATen/ATen.h>
+
+#include <memory>
+#include <utility>
+
namespace torch { namespace autograd {
auto Error::apply(const variable_list& grad_outputs) -> variable_list {
@@ -15,51 +21,11 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list {
outputs.reserve(inputs.size());
for (auto& var : inputs) {
// FIXME: share version counters
- outputs.emplace_back(var.defined() ? var.data() : Tensor());
- }
- return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
- return std::make_shared<Error>(msg, std::move(f));
- });
-};
-
-auto Add::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("Add", inputs, 2);
- auto& input1 = inputs[0].data();
- auto& input2 = inputs[1].data();
- AutoGPU guard(input1);
-
- at::Tensor output;
- if (input1.type().isSparse()) {
- output = input2 + input1;
- } else {
- output = input1 + input2;
+ outputs.emplace_back(var.defined() ? var.data() : at::Tensor());
}
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<AddBackward_Deprecated>(std::move(f));
- });
-};
-
-auto AddBackward_Deprecated::apply(const variable_list& grad_outputs) -> variable_list {
- check_input_variables("AddBackward_Deprecated", grad_outputs, 1);
- return {grad_outputs[0], grad_outputs[0]};
-};
-
-auto Mul::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("Mul", inputs, 2);
- AutoGPU guard(inputs[0]);
- auto& input1 = inputs[0].data();
- auto& input2 = inputs[1].data();
-
- auto output = input1 * input2;
-
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<MulBackward>(std::move(f));
+ return wrap_outputs(inputs, std::move(outputs), [&](edge_list&& next_edges) {
+ return std::make_shared<Error>(msg, std::move(next_edges));
});
};
-auto MulBackward::apply(const variable_list& grad_outputs) -> variable_list {
- check_input_variables("MulBackward", grad_outputs, 1);
- throw std::runtime_error("MulBackward::apply not implemented");
-};
-
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h
index 0c165d53..a37934cb 100644
--- a/torch/csrc/autograd/functions/basic_ops.h
+++ b/torch/csrc/autograd/functions/basic_ops.h
@@ -1,18 +1,20 @@
#pragma once
#include <Python.h>
-#include <memory>
-#include <string>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/symbolic.h"
+#include <memory>
+#include <string>
+#include <vector>
+
namespace torch { namespace autograd {
struct Error : public Function {
- Error(std::string msg, FunctionFlags&& flags)
- : Function(std::move(flags))
+ Error(std::string msg, edge_list&& next_edges)
+ : Function(/*num_inputs=*/0, std::move(next_edges))
, msg(std::move(msg)) {}
Error(std::string msg)
@@ -34,11 +36,9 @@ struct DelayedError : public Function {
};
struct GraphRoot : public Function {
- GraphRoot(function_list functions, variable_list inputs)
- : outputs(std::move(inputs)) {
- next_functions = std::move(functions);
- is_executable = true;
- };
+ GraphRoot(edge_list functions, variable_list inputs)
+ : Function(/*num_inputs=*/0, std::move(functions)),
+ outputs(std::move(inputs)) {}
virtual variable_list apply(const variable_list& inputs) {
return outputs;
@@ -47,33 +47,4 @@ struct GraphRoot : public Function {
variable_list outputs;
};
-struct Add : public ForwardFunction<true>, public HasSymbolic {
- Add() {}
-
- virtual variable_list apply(const variable_list& inputs) override;
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
-};
-
-
-struct AddBackward_Deprecated : public Function {
- AddBackward_Deprecated(FunctionFlags&& flags)
- : Function(std::move(flags)) {}
-
- virtual variable_list apply(const variable_list& gradOutputs) override;
- virtual bool is_traceable() override { return true; }
-};
-
-struct Mul : public ForwardFunction<> {
- Mul() {}
-
- virtual variable_list apply(const variable_list& inputs) override;
-};
-
-struct MulBackward : public Function {
- MulBackward(FunctionFlags&& flags)
- : Function(std::move(flags)) {}
-
- virtual variable_list apply(const variable_list& gradOutputs) override;
-};
-
}}
diff --git a/torch/csrc/autograd/functions/batch_normalization.cpp b/torch/csrc/autograd/functions/batch_normalization.cpp
deleted file mode 100644
index 13d83881..00000000
--- a/torch/csrc/autograd/functions/batch_normalization.cpp
+++ /dev/null
@@ -1,256 +0,0 @@
-#include "batch_normalization.h"
-
-#include "torch/csrc/autograd/python_function.h"
-#include "torch/csrc/autograd/python_variable.h"
-#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/functions/utils.h"
-#include "torch/csrc/autograd/functions/basic_ops.h"
-#include "torch/csrc/utils/auto_gil.h"
-#include "torch/csrc/utils/auto_gpu.h"
-#include "torch/csrc/DynamicTypes.h"
-#include "torch/csrc/Exceptions.h"
-#include <sstream>
-
-#ifdef WITH_CUDNN
-#include "torch/csrc/cudnn/BatchNorm.h"
-#include "torch/csrc/cudnn/Handles.h"
-#include "torch/csrc/cudnn/Types.h"
-extern THCState* state;
-#endif
-
-namespace {
- void check_dims_match_num_input_features(const std::string& arg_name, long expected, long actual){
- if (actual != expected){
- std::stringstream ss;
- ss << arg_name << " should contain " << expected << " elements not " << actual ;
- throw std::runtime_error(ss.str());
- }
- }
-}
-
-namespace torch { namespace autograd {
-
-#ifndef CUDNN_BN_MIN_EPSILON
-#define CUDNN_BN_MIN_EPSILON 0
-#endif
-
-auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("BatchNorm", inputs, 3, 1);
-
- AutoGPU guard(inputs[0]);
- auto& input = inputs[0];
- auto& weight = inputs[1];
- auto& bias = inputs[2];
-
- auto num_features = input.sizes()[1];
- check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
- check_dims_match_num_input_features("running_var", num_features, running_var.numel());
- if (weight.defined()) {
- check_dims_match_num_input_features("weight", num_features, weight.numel());
- }
- if (bias.defined()) {
- check_dims_match_num_input_features("bias", num_features, bias.numel());
- }
-
- bool use_cudnn = false;
-#ifdef WITH_CUDNN
- use_cudnn = (input.type().isCuda()
- && (input.type().scalarType() != at::kHalf
- || weight.type().scalarType() == at::kFloat)
- && weight.defined() && bias.defined()
- && input.size(0) <= 131070
- && cudnn_enabled && CUDNN_VERSION >= 5110L);
-#endif
-
- auto input_data = input.data();
- Tensor output;
- auto save_mean = running_mean.type().tensor(running_mean.sizes());
- auto save_std = running_var.type().tensor(running_var.sizes());
-
- if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
-#ifdef WITH_CUDNN
- output = input_data.type().tensor(input.sizes());
- torch::cudnn::cudnn_batch_norm_forward(
- state,
- torch::cudnn::getCudnnHandle(),
- torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)input.unsafeGetTH(false),
- (THVoidTensor*)output.unsafeGetTH(false),
- (THVoidTensor*)weight.unsafeGetTH(false),
- (THVoidTensor*)bias.unsafeGetTH(false),
- (THVoidTensor*)running_mean.unsafeGetTH(false),
- (THVoidTensor*)running_var.unsafeGetTH(false),
- (THVoidTensor*)save_mean.unsafeGetTH(false),
- (THVoidTensor*)save_std.unsafeGetTH(false),
- training,
- momentum,
- eps);
-#endif
- } else {
- output = at::batch_norm_forward(
- input_data, weight.opt_data(), bias.opt_data(),
- running_mean, running_var, training, momentum, eps,
- save_mean, save_std);
- }
-
- auto outputs = as_tensor_list(std::move(output));
- return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
- return std::make_shared<BatchNormBackward>(
- f, *this, std::move(save_mean), std::move(save_std),
- input, weight, bias);
- });
-};
-
-auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_list {
- check_input_variables("BatchNormBackward", grad_outputs, 1);
- auto input_var = this->input.unpack();
- auto weight_var = this->weight.unpack();
- auto bias_var = this->bias.unpack();
-
- auto input = input_var.data();
- auto weight = weight_var.opt_data();
- auto bias = bias_var.opt_data();
-
- AutoGPU guard(input);
-
- bool use_cudnn = false;
-#ifdef WITH_CUDNN
- use_cudnn = (input.type().backend() == at::kCUDA
- && (input.type().scalarType() != at::kHalf
- || weight.type().scalarType() == at::kFloat)
- && weight.defined() && bias.defined() && training
- && input.size(0) <= 131070
- && cudnn_enabled && CUDNN_VERSION >= 5110L);
-#endif
-
- at::Tensor grad_input;
- at::Tensor grad_weight;
- at::Tensor grad_bias;
-
- auto grad_output = grad_outputs[0].data().contiguous();
-
- if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
-#ifdef WITH_CUDNN
- grad_input = input.type().tensor(input.sizes());
- grad_weight = weight.type().tensor(weight.sizes());
- grad_bias = bias.type().tensor(bias.sizes());
- torch::cudnn::cudnn_batch_norm_backward(
- state,
- torch::cudnn::getCudnnHandle(),
- torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)input.unsafeGetTH(false),
- (THVoidTensor*)grad_output.unsafeGetTH(false),
- (THVoidTensor*)grad_input.unsafeGetTH(false),
- (THVoidTensor*)grad_weight.unsafeGetTH(false),
- (THVoidTensor*)grad_bias.unsafeGetTH(false),
- (THVoidTensor*)weight.unsafeGetTH(false),
- (THVoidTensor*)running_mean.unsafeGetTH(false),
- (THVoidTensor*)running_var.unsafeGetTH(false),
- (THVoidTensor*)save_mean.unsafeGetTH(false),
- (THVoidTensor*)save_std.unsafeGetTH(false),
- training,
- eps);
-#endif
- } else {
- std::array<bool, 3> mask = {
- should_compute_output(0),
- should_compute_output(1),
- should_compute_output(2),
- };
- std::tie(grad_input, grad_weight, grad_bias) = at::batch_norm_backward(
- grad_output, input, weight, running_mean, running_var,
- training, eps, save_mean, save_std,
- mask);
- }
-
- // Add saved variables used out of the pure autograd to inputs
- variable_list all_inputs(grad_outputs);
- all_inputs.push_back(input_var);
- if (weight.defined()) {
- all_inputs.push_back(weight_var);
- }
- auto outputs = as_tensor_list(std::move(grad_input),
- std::move(grad_weight),
- std::move(grad_bias));
- return wrap_outputs(all_inputs, std::move(outputs), [&](FunctionFlags f) {
- return std::make_shared<BatchNormBackwardBackward>(
- f, *this, save_mean, save_std,
- input_var, weight_var,
- grad_outputs[0]);
- });
-};
-
-auto BatchNormBackward::releaseVariables() -> void {
- input.data.reset();
- weight.data.reset();
- bias.data.reset();
-}
-
-Variable getReturnTupleVar(PyObject *p, Py_ssize_t pos) {
- PyObject *item = PyTuple_GET_ITEM(p, pos);
- if (item != Py_None) {
- return ((THPVariable*)item)->cdata;
- }
- return Variable();
-}
-
-auto BatchNormBackwardBackward::apply(const variable_list& grad_grad_inputs) -> variable_list {
- check_input_variables("BatchNormBackwardBackward", grad_grad_inputs, 3, 0);
- auto ggI = grad_grad_inputs[0];
- auto ggW = grad_grad_inputs[1];
- auto ggb = grad_grad_inputs[2];
-
- auto input_var = input.unpack();
- AutoGPU guard(input_var);
-
- auto weight_var = weight.unpack();
- auto gO_var = grad_output.unpack();
-
- auto input = input_var.data();
- AutoGIL gil;
-
- THPObjectPtr input_pvar(THPVariable_Wrap(input_var));
- THPObjectPtr weight_pvar(THPVariable_Wrap(weight_var));
-
- THPObjectPtr ggi_pvar(THPVariable_Wrap(ggI));
- THPObjectPtr ggW_pvar(THPVariable_Wrap(ggW));
- THPObjectPtr ggb_pvar(THPVariable_Wrap(ggb));
- THPObjectPtr gO_pvar(THPVariable_Wrap(gO_var));
- THPObjectPtr eps_py(PyFloat_FromDouble(eps));
- THPObjectPtr save_mean_py(createPyObject(save_mean));
- THPObjectPtr save_std_py(createPyObject(save_std));
- THPObjectPtr running_mean_py(createPyObject(running_mean));
- THPObjectPtr running_var_py(createPyObject(running_var));
- PyObject *training_pyo = training ? Py_True : Py_False;
-
- THPObjectPtr args(PyTuple_Pack(12, input_pvar.get(), weight_pvar.get(),
- ggi_pvar.get(), ggW_pvar.get(), ggb_pvar.get(),
- gO_pvar.get(), eps_py.get(),
- save_mean_py.get(), save_std_py.get(),
- running_mean_py.get(), running_var_py.get(),
- training_pyo));
- THPObjectPtr r(PyObject_CallObject(THPBatchNormBackwardBackwardFunction, args.get()));
- if (!r) throw python_error();
- if (!PyTuple_Check(r.get())) {
- throw std::runtime_error("expected PyTuple return from BatchNormBackwardBackward");
- }
-
- auto gI_var = getReturnTupleVar(r, 0);
- auto gG_var = getReturnTupleVar(r, 1);
- auto ggO_var = getReturnTupleVar(r, 2);
-
- if (weight_var.defined()) {
- return {ggO_var, gI_var, gG_var};
- } else {
- return {ggO_var, gI_var};
- }
-};
-
-auto BatchNormBackwardBackward::releaseVariables() -> void {
- input.data.reset();
- weight.data.reset();
- grad_output.data.reset();
-}
-
-
-}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/batch_normalization.h b/torch/csrc/autograd/functions/batch_normalization.h
deleted file mode 100644
index aafb5d68..00000000
--- a/torch/csrc/autograd/functions/batch_normalization.h
+++ /dev/null
@@ -1,93 +0,0 @@
-#pragma once
-
-#include <Python.h>
-#include <memory>
-#include <ATen/ATen.h>
-
-#include "torch/csrc/autograd/function.h"
-#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/symbolic.h"
-#include "torch/csrc/autograd/saved_variable.h"
-
-namespace torch { namespace autograd {
-
-struct BatchNormParams {
- at::Tensor running_mean;
- at::Tensor running_var;
- bool training;
- double momentum;
- double eps;
- bool cudnn_enabled;
-};
-
-struct BatchNormForward : public ForwardFunction<>, public BatchNormParams, public HasSymbolic {
- BatchNormForward(BatchNormParams params)
- : BatchNormParams(std::move(params)) {}
-
- virtual variable_list apply(const variable_list& inputs) override;
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
-};
-
-struct BatchNormBackward : public Function, public BatchNormParams {
- BatchNormBackward(
- FunctionFlags flags,
- BatchNormParams params,
- at::Tensor save_mean,
- at::Tensor save_std,
- Variable input,
- Variable weight,
- Variable bias)
- : Function(std::move(flags))
- , BatchNormParams(std::move(params)) {
- if (is_executable) {
- this->save_mean = std::move(save_mean);
- this->save_std = std::move(save_std);
- this->input = SavedVariable(input, this);
- this->weight = SavedVariable(weight, this);
- this->bias = SavedVariable(bias, this);
- }
- }
-
- virtual variable_list apply(const variable_list& gradOutputs) override;
-
- virtual void releaseVariables() override;
-
- at::Tensor save_mean;
- at::Tensor save_std;
- SavedVariable input;
- SavedVariable weight;
- SavedVariable bias;
-};
-
-struct BatchNormBackwardBackward : public Function, public BatchNormParams {
- BatchNormBackwardBackward(
- FunctionFlags flags,
- BatchNormParams params,
- at::Tensor save_mean,
- at::Tensor save_std,
- Variable input,
- Variable weight,
- Variable grad_output)
- : Function(std::move(flags))
- , BatchNormParams(std::move(params)) {
- if (is_executable) {
- this->save_mean = std::move(save_mean);
- this->save_std = std::move(save_std);
- this->input = SavedVariable(input, this);
- this->weight = SavedVariable(weight, this);
- this->grad_output = SavedVariable(grad_output, this);
- }
- }
-
- virtual variable_list apply(const variable_list& grad_grad_inputs) override;
-
- virtual void releaseVariables() override;
-
- at::Tensor save_mean;
- at::Tensor save_std;
- SavedVariable input;
- SavedVariable weight;
- SavedVariable grad_output;
-};
-
-}}
diff --git a/torch/csrc/autograd/functions/convolution.cpp b/torch/csrc/autograd/functions/convolution.cpp
deleted file mode 100644
index 7c36242d..00000000
--- a/torch/csrc/autograd/functions/convolution.cpp
+++ /dev/null
@@ -1,918 +0,0 @@
-#include "convolution.h"
-
-#include <sstream>
-
-#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/functions/utils.h"
-#include "torch/csrc/autograd/functions/basic_ops.h"
-#include "torch/csrc/autograd/functions/tensor.h"
-#include "torch/csrc/utils/auto_gpu.h"
-
-#include <ATen/ATen.h>
-
-#ifdef WITH_CUDNN
-#include "torch/csrc/cudnn/Conv.h"
-#include "torch/csrc/cudnn/Handles.h"
-#include "torch/csrc/cudnn/Types.h"
-extern THCState* state;
-using namespace torch::cudnn;
-#endif
-
-#ifdef WITH_NNPACK
-#include "torch/csrc/nnpack/NNPACK.h"
-#endif
-
-using torch::cudnn::Convolution;
-using at::Tensor;
-using tensor_pair = std::pair<at::Tensor, at::Tensor>;
-
-namespace torch { namespace autograd {
-
-// Forward function definition and utility functions
-
-static at::Tensor compute_output(
- at::Tensor& input, at::Tensor& weight, at::Tensor& bias, at::Tensor& columns, at::Tensor& ones,
- const ConvForward& params);
-
-static std::tuple<Tensor, Tensor, Tensor> compute_backward(
- at::Tensor& input, at::Tensor& grad_output, at::Tensor& weight, at::Tensor& columns, at::Tensor& ones,
- const ConvBackward& params, std::array<bool, 3> output_mask);
-
-auto ConvParams::is_strided() const -> bool {
- bool is_strided = false;
- for (int s : stride) {
- is_strided |= (s != 1);
- }
- return is_strided;
-}
-
-auto ConvParams::is_dilated() const -> bool {
- bool is_dilated = false;
- for (int d : dilation) {
- is_dilated |= (d != 1);
- }
- return is_dilated;
-}
-
-auto ConvParams::is_padded() const -> bool {
- bool is_padded = false;
- for (int p : padding) {
- is_padded |= (p != 0);
- }
- return is_padded;
-}
-
-auto ConvParams::is_output_padding_neg() const -> bool {
- bool is_non_neg = false;
- for (int p : output_padding) {
- is_non_neg |= (p < 0);
- }
- return is_non_neg;
-}
-
-auto ConvParams::is_output_padding_big() const -> bool {
- bool is_big = false;
- for (size_t i = 0; i < output_padding.size(); i++) {
- is_big |= (output_padding[i] >= stride[i] || output_padding[i] >= dilation[i]);
- }
- return is_big;
-}
-
-auto ConvParams::is_padding_neg() const -> bool {
- bool is_non_neg = false;
- for (int p : padding) {
- is_non_neg |= (p < 0);
- }
- return is_non_neg;
-}
-
-
-auto ConvParams::view1d_as_2d() -> void {
- if (stride.size() == 1) {
- stride.insert(stride.begin(), 1);
- padding.insert(padding.begin(), 0);
- dilation.insert(dilation.begin(), 1);
- output_padding.insert(output_padding.begin(), 0);
- }
-}
-
-auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool {
-#ifdef WITH_CUDNN
- if (!input.type().isCuda() || !cudnn_enabled) {
- return false;
- }
- if (deterministic && is_dilated()) {
- // cudnn doesn't support deterministic dilated convolution fully yet
- return false;
- }
- if (is_dilated()) {
- cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
- // NOTE: extra parenthesis around numbers disable clang warnings about dead code
- return ((CUDNN_VERSION >= (6021)) || (CUDNN_VERSION >= (6000) && prop->major >= 5)) && !is_output_padding_big();
- }
- return !is_output_padding_big();
-#endif
- return false;
-}
-
-auto ConvParams::use_nnpack(const at::Tensor& input) const -> bool {
-#ifdef WITH_NNPACK
- return input.type().ID() == at::TypeID::CPUFloat && // only on CPU Float Tensors
- !is_strided() && // doesn't support strides
- !is_dilated() && // or dilation
- !transposed && // or transposed tensors
- input.ndimension() == 4 && // must be in NCHW format
- input.size(0) >= 16; // ensure large enough batch size to ensure perf, tuneable
-#endif
- return false;
-}
-
-// We currently only have depthwise support for the case where groups ==
-// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
-// a depthwise multiplier)
-auto ConvParams::is_depthwise(
- const at::Tensor& input, const at::Tensor& weight, int groups) const -> bool {
- return input.type().isCuda() &&
- !transposed &&
- input.ndimension() == 4 &&
- input.size(1) == groups &&
- groups > 1 && // no point if there is only a single group
- weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels
-}
-
-std::string ConvForward::name() { return "ConvForward"; }
-
-auto ConvForward::output_size(at::Tensor& input, at::Tensor& weight) const -> std::vector<int64_t> {
- auto in_size = input.sizes();
- auto weight_size = weight.sizes();
- auto dim = input.ndimension();
-
- std::vector<int64_t> output_size(dim);
- output_size[0] = in_size[0];
- output_size[1] = transposed ? weight_size[1] * groups : weight_size[0];
- for (int d = 2; d < dim; ++d) {
- int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
- if (transposed) {
- output_size[d] = (in_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) +
- kernel + output_padding[d - 2];
- } else {
- output_size[d] = (in_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
- }
- }
- return output_size;
-}
-
-static auto view4d(const at::Tensor& tensor) -> at::Tensor {
- if (tensor.ndimension() != 3) throw std::runtime_error("expected 3D tensor");
- return tensor.unsqueeze(2);
-}
-
-static auto view3d(const at::Tensor& tensor) -> at::Tensor {
- if (tensor.ndimension() != 4) throw std::runtime_error("expected 4D tensor");
- return tensor.squeeze(2);
-}
-
-static void check_input_shape_forward(const at::Tensor& input,
- const at::Tensor& weight, const at::Tensor& bias,
- int64_t groups, bool transposed) {
- int k = input.ndimension();
-
- if (weight.ndimension() != k) {
- std::stringstream ss;
- ss << "Expected " << k << "-dimensional input for " << k
- << "-dimensional weight " << weight.sizes() << ", but got input of size "
- << input.sizes() << " instead";
- throw std::runtime_error(ss.str());
- }
- if (weight.size(0) < groups) {
- std::stringstream ss;
- ss << "Given groups=" << groups << ", expected weight to be at least "
- << groups << " at dimension 0, but got weight of size " << weight.sizes()
- << " instead";
- throw std::runtime_error(ss.str());
- }
-
- if (!transposed) {
- if (input.size(1) != (weight.size(1) * groups)) {
- std::stringstream ss;
- ss << "Given groups=" << groups << ", weight" << weight.sizes()
- << ", so expected input" << input.sizes() << " to have "
- << (weight.size(1) * groups) << " channels, but got " << input.size(1)
- << " channels instead";
- throw std::runtime_error(ss.str());
- }
- if (bias.defined() && (bias.ndimension() != 1 || bias.size(0) != weight.size(0))) {
- std::stringstream ss;
- ss << "Given weight of size " << weight.sizes()
- << ", expected bias to be 1-dimensional with " << weight.size(0) << " elements"
- << ", but got bias of size " << bias.sizes() << " instead";
- throw std::runtime_error(ss.str());
- }
- } else { // transposed
- if (input.size(1) != weight.size(0)) {
- std::stringstream ss;
- ss << "Given transposed=" << transposed << ", weight" << weight.sizes()
- << ", so expected input" << input.sizes() << " to have "
- << weight.size(0) << " channels, but got " << input.size(1)
- << " channels instead";
- throw std::runtime_error(ss.str());
- }
- if (bias.defined() && (bias.ndimension() != 1 || bias.size(0) != weight.size(1) * groups)) {
- std::stringstream ss;
- ss << "Given transposed=" << transposed << ", weight of size " << weight.sizes()
- << ", expected bias to be 1-dimensional with " << weight.size(1) * groups << " elements"
- << ", but got bias of size " << bias.sizes() << " instead";
- throw std::runtime_error(ss.str());
- }
- }
-}
-
-static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
- if (!tensor.defined()) {
- return at::Tensor();
- }
- int64_t n = tensor.sizes()[dim] / groups;
- return tensor.narrow(dim, n * g, n).contiguous();
-}
-
-static Variable subvariable(const Variable& var, int dim, int groups, int g) {
- int64_t n = var.sizes()[dim] / groups;
- auto result = apply_fn<Narrow>(dim, n * g, n)(var);
- return result;
-}
-
-static std::vector<int64_t> vecToInt64(const std::vector<int>& src) {
- std::vector<int64_t> res(src.size());
- for (size_t i = 0; i < src.size(); i++) {
- res[i] = static_cast<int64_t>(src[i]);
- }
- return res;
-}
-
-static at::Tensor cat(const tensor_list& tensors, int dim) {
- int num_inputs = tensors.size();
- if (num_inputs == 0) {
- return at::Tensor();
- }
-
- auto output = tensors[0].type().tensor();
- at::cat_out(output, tensors, dim);
- return output;
-}
-
-// ConvForward implementation
-
-auto ConvForward::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("ConvNd", inputs, 3, 2);
- if (is_padding_neg()) throw std::runtime_error("negative padding is not supported");
- if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported");
-
- AutoGPU guard(inputs[0]);
-
- auto input = inputs[0].data().contiguous();
- auto weight = inputs[1].data();
- auto bias = inputs[2].opt_data();
-
- check_input_shape_forward(input, weight, bias, groups, transposed);
-
- int k = input.ndimension();
-
- if (k == 3) {
- view1d_as_2d();
- input = view4d(input);
- weight = view4d(weight);
- }
-
- auto output = input.type().tensor();
- tensor_list columns(groups);
- tensor_list ones(groups);
- std::unique_ptr<Convolution> convolution;
-
- if (is_depthwise(input, weight, groups)) {
- /* output.resize_(output_size(input, weight)); */
-
- auto kernel_size = weight.sizes().slice(2);
- auto stride = vecToInt64(this->stride);
- auto padding = vecToInt64(this->padding);
- auto dilation = vecToInt64(this->dilation);
-
- output = at::conv_depthwise2d_forward(input, weight, kernel_size, bias, stride, padding, dilation);
- } else if (use_cudnn(input)) {
-#ifdef WITH_CUDNN
- if (input.type().ID() != weight.type().ID()){
- std::stringstream ss;
- ss << "Input type (" << input.toString() << ") and weight type (" << weight.toString() << ") should be the same";
- throw std::runtime_error(ss.str());
- }
- if (bias.defined() && input.type().ID() != bias.type().ID()){
- std::stringstream ss;
- ss << "Input type (" << input.toString() << ") and bias type (" << bias.toString() << ") should be the same";
- throw std::runtime_error(ss.str());
- }
-
- output = input.type().tensor();
- output.resize_(output_size(input, weight));
- if (transposed) {
- convolution.reset(cudnn_convolution_transpose_full_forward(
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false),
- bias.defined() ? (THVoidTensor*)bias.unsafeGetTH(false) : nullptr, (THVoidTensor*)output.unsafeGetTH(false),
- padding, stride, dilation, groups, benchmark, deterministic));
- } else {
- convolution.reset(cudnn_convolution_full_forward(
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false),
- bias.defined() ? (THVoidTensor*)bias.unsafeGetTH(false) : nullptr, (THVoidTensor*)output.unsafeGetTH(false),
- padding, stride, dilation, groups, benchmark, deterministic));
- }
-#endif
- } else {
- for (int g = 0; g < groups; ++g) {
- columns[g] = input.type().tensor();
- ones[g] = input.type().tensor();
- }
- if (groups == 1) {
- output = compute_output(
- input, weight, bias,
- columns[0], ones[0], *this);
- } else {
- tensor_list outputs(groups);
- for (int g = 0; g < groups; ++g) {
- auto input_g = subtensor(input, 1, groups, g);
- auto weight_g = subtensor(weight, 0, groups, g);
- auto bias_g = subtensor(bias, 0, groups, g);
- outputs[g] = compute_output(
- input_g, weight_g, bias_g,
- columns[g], ones[g], *this);
- }
- output = cat(outputs, 1);
- }
- }
-
- if (k == 3) {
- output = view3d(output);
- }
-
- auto outputs = as_tensor_list(std::move(output));
- return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
- return std::make_shared<ConvBackward>(
- f, *this,
- inputs[0], inputs[1], inputs[2],
- std::move(columns), std::move(ones), std::move(convolution));
- });
-};
-
-// For Convolution strategies that don't implicitly handle grad_bias, we add a helper
-// function here to perform it using simple Tensor operators
-static at::Tensor compute_grad_bias(const at::Tensor& grad_output) {
- // grad_output is in N, C, H, W, we re-shape and reduce over spatial dims and batches
- return grad_output.contiguous().view({grad_output.size(0), grad_output.size(1), -1}).sum(0).sum(1);
-}
-
-// ConvBackward implementation
-
-auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list {
- check_input_variables("ConvNdBackward", grad_outputs, 1);
- if (is_padding_neg()) throw std::runtime_error("negative padding is not supported");
- if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported");
-
- auto input_var = input_.unpack();
- auto weight_var = weight_.unpack();
- auto bias_var = bias_.unpack();
-
- auto input = input_var.data();
- auto weight = weight_var.data();
-
- AutoGPU guard(input);
-
- auto bias = bias_var.defined() ? bias_var.data() : Tensor();
-
- input = input.contiguous();
- auto grad_output = grad_outputs[0].data().contiguous();
-
- int k = input.ndimension();
- if (k == 3) {
- input = view4d(input);
- weight = view4d(weight);
- grad_output = view4d(grad_output);
- }
-
-
- bool use_depthwise = this->is_depthwise(input, weight, groups);
- bool use_cudnn = this->use_cudnn(input);
-
- at::Tensor grad_input;
- at::Tensor grad_weight;
- at::Tensor grad_bias;
-
- std::array<bool, 3> output_mask = {
- should_compute_output(0),
- should_compute_output(1),
- should_compute_output(2) && bias.defined(),
- };
-
- if (use_depthwise) {
- if (output_mask[0] || output_mask[1]) {
- auto kernel_size = weight.sizes().slice(2);
- auto stride = vecToInt64(this->stride);
- auto padding = vecToInt64(this->padding);
- auto dilation = vecToInt64(this->dilation);
-
- std::tie(grad_input, grad_weight) = at::conv_depthwise2d_backward(
- grad_output, input, weight, kernel_size, stride, padding, dilation,
- {output_mask[0], output_mask[1]});
- }
-
- // THCUNN implementation does not handle bias, so we do it ourselves
- if (output_mask[2]) {
- grad_bias = compute_grad_bias(grad_output);
- }
- } else if (use_cudnn) {
-#ifdef WITH_CUDNN
- if (output_mask[0]) {
- grad_input = input.type().tensor();
- grad_input.resize_as_(input);
- if (transposed) {
- // ConvTranspose uses the same kernels as regular convolution
- // but swaps forward and backward calls
- cudnn_convolution_forward(
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false), (THVoidTensor*)grad_input.unsafeGetTH(false),
- convolution.get(), benchmark, deterministic);
- } else {
- cudnn_convolution_backward_data(
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)grad_input.unsafeGetTH(false), (THVoidTensor*)weight.unsafeGetTH(false),
- convolution.get(), benchmark, deterministic);
- }
- }
- if (output_mask[1] || output_mask[2]) {
- grad_weight = weight.type().tensor();
- grad_weight.resize_as_(weight);
- cudnn_convolution_backward_filter(
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)input.unsafeGetTH(false), (THVoidTensor*)grad_weight.unsafeGetTH(false),
- convolution.get(), benchmark, deterministic);
-
- if (output_mask[2]) {
- grad_bias = bias.type().tensor();
- grad_bias.resize_as_(bias);
- cudnn_convolution_backward_bias(
- state, torch::cudnn::getCudnnHandle(), torch::cudnn::getCudnnDataType(input),
- (THVoidTensor*)grad_output.unsafeGetTH(false), (THVoidTensor*)grad_bias.unsafeGetTH(false),
- convolution.get());
- }
- }
-#endif
- } else if (groups == 1) {
- std::tie(grad_input, grad_weight, grad_bias) = compute_backward(
- input, grad_output, weight, columns[0], ones[0],
- *this, output_mask);
- } else {
- tensor_list grad_inputs(groups);
- tensor_list grad_weights(groups);
- tensor_list grad_biases(groups);
- for (int g = 0; g < groups; ++g) {
- auto input_g = subtensor(input, 1, groups, g);
- auto grad_output_g = subtensor(grad_output, 1, groups, g);
- auto weight_g = subtensor(weight, 0, groups, g);
- std::tie(grad_inputs[g], grad_weights[g], grad_biases[g]) = compute_backward(
- input_g, grad_output_g, weight_g, columns[g], ones[g],
- *this, output_mask);
- }
- if (output_mask[0]) {
- grad_input = cat(grad_inputs, 1);
- }
- if (output_mask[1]) {
- grad_weight = cat(grad_weights, 0);
- }
- if (output_mask[2]) {
- grad_bias = cat(grad_biases, 0);
- }
- }
-
- if (k == 3) {
- if (grad_input.defined()) {
- grad_input = view3d(grad_input);
- }
- if (grad_weight.defined()) {
- grad_weight = view3d(grad_weight);
- }
- }
-
- // Add saved variables used out of the pure autograd to inputs
- variable_list all_inputs(grad_outputs);
- all_inputs.push_back(input_var);
- all_inputs.push_back(weight_var);
-
- auto outputs = as_tensor_list(std::move(grad_input),
- std::move(grad_weight),
- std::move(grad_bias));
- return wrap_outputs(all_inputs, std::move(outputs), [&](FunctionFlags f) {
- return std::make_shared<ConvBackwardBackward>(
- f, *this,
- input_var, weight_var,
- bias_var, grad_outputs[0]);
- });
-};
-
-auto ConvBackward::releaseVariables() -> void {
- input_.data.reset();
- weight_.data.reset();
- bias_.data.reset();
-}
-
-
-// ConvBackwardBackward implementation
-
-auto ConvBackwardBackward::apply(const variable_list& grad_grad_inputs) -> variable_list {
- check_input_variables("ConvNdBackwardBackward", grad_grad_inputs, 3, 0);
-
- auto ggI = grad_grad_inputs[0];
- auto ggW = grad_grad_inputs[1];
- auto ggb = grad_grad_inputs[2];
-
- auto gO = grad_output_.unpack();
- auto weight = weight_.unpack();
- auto input = input_.unpack();
-
- AutoGPU guard(input.data());
-
- // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
- Variable ggO;
- if (ggI.defined()) {
- if (weight.type().isCuda()) {
- weight = apply_fn<Contiguous>()(weight);
- }
- ggO = apply_fn<ConvForward>(*this)(ggI, weight, Variable());
- }
-
- if (ggW.defined()) {
- if (ggW.type().isCuda()) {
- ggW = apply_fn<Contiguous>()(ggW);
- }
- auto ggW_term = apply_fn<ConvForward>(*this)(input, ggW, Variable());
- if (ggO.defined()) {
- ggO = apply_fn<Add>()(ggO, ggW_term);
- } else {
- ggO = ggW_term;
- }
- }
-
- if (ggb.defined()) {
- // View as (1, ggb.size(0), 1, 1...)
-
- // Expand
- std::vector<int64_t> new_size(gO.ndimension(), 1);
- new_size[1] = ggb.sizes()[0];
- auto ggb_contiguous = apply_fn<Contiguous>()(ggb);
- auto ggb_view = apply_fn<View>(new_size)(ggb_contiguous);
-
- // Expand
- auto ggb_expanded = apply_fn<Expand>(gO.sizes())(ggb_view);
-
- if (ggO.defined()) {
- ggO = apply_fn<Add>()(ggO, ggb_expanded);
- } else {
- ggO = ggb_expanded;
- }
- }
-
- // Compute gW = conv(ggI, gO)
- Variable gW;
- if (ggI.defined()) {
- // Modified params with correct padding
- ConvParams gw_conv_params(*this);
-
- // Disable groups as they are handled separately
- auto groups = gw_conv_params.groups;
- gw_conv_params.groups = 1;
- std::swap(gw_conv_params.dilation, gw_conv_params.stride);
-
- // Transpose gO and ggI to accumulate over batch
- auto gOt = apply_fn<Transpose>(0, 1)(gO);
- auto ggIt = apply_fn<Transpose>(0, 1)(ggI);
-
- Variable gWt;
- // Compute conv
- if (groups == 1) {
- if (gOt.type().isCuda()) {
- gOt = apply_fn<Contiguous>()(gOt);
- }
-
- // Compute conv
- if (transposed) {
- gw_conv_params.transposed = false;
- gWt = apply_fn<ConvForward>(gw_conv_params)(gOt, ggIt, Variable());
- } else {
- gWt = apply_fn<ConvForward>(gw_conv_params)(ggIt, gOt, Variable());
- }
- } else {
- variable_list gWt_list(groups);
- for (int g = 0; g < groups; ++g) {
- auto ggIt_g = subvariable(ggIt, 0, groups, g);
- auto gOt_g = subvariable(gOt, 0, groups, g);
- if (gOt_g.type().isCuda()) {
- gOt_g = apply_fn<Contiguous>()(gOt_g);
- }
-
- // Compute conv
- if (transposed) {
- gw_conv_params.transposed = false;
- gWt_list[g] = apply_fn<ConvForward>(gw_conv_params)(gOt_g, ggIt_g, Variable());
- } else {
- gWt_list[g] = apply_fn<ConvForward>(gw_conv_params)(ggIt_g, gOt_g, Variable());
- }
- }
-
- gWt = apply_fn<Cat>(1)(gWt_list);
- }
-
- // Transpose gW to match chan_in and chan_out
- gW = apply_fn<Transpose>(0, 1)(gWt);
-
- // narrow gW to only relevant portion
- // we do it this way instead of narrowing the input itself because
- // the ConvForward kernels don't support asymmetric padding.
- auto gW_size = gW.sizes();
- auto w_size = weight.sizes();
- for (size_t i = 2; i < gW_size.size(); ++i) {
- if (gW_size[i] > w_size[i]) {
- gW = apply_fn<Narrow>(i, 0, w_size[i])(gW);
- gW_size = gW.sizes();
- }
- }
- }
-
- // Compute gI = convT(ggW, gO.t()) if !transposed
- // gI = conv(go, ggw) if transposed
- Variable gI;
- if (ggW.defined()) {
- ConvParams gi_conv_params(*this);
- gi_conv_params.transposed = !transposed;
-
- if (transposed) {
- if (gO.type().isCuda()) {
- gO = apply_fn<Contiguous>()(gO);
- }
- gI = apply_fn<ConvForward>(gi_conv_params)(gO, ggW, Variable());
-
- // narrow gI to only relevant portion
- // we do it this way because negative output_padding is not supported
- // TODO: figure out if we can narrow gO and save some compute,
- // rather than narrowing the computed gI
- auto gI_size = gI.sizes();
- auto i_size = input.sizes();
- for (size_t i = 2; i < gI_size.size(); ++i) {
- if (gI_size[i] > i_size[i]) {
- gI = apply_fn<Narrow>(i, 0, i_size[i])(gI);
- gI_size = gI.sizes();
- }
- }
- } else {
- auto groups = gi_conv_params.groups;
- gi_conv_params.groups = 1;
- // swap stride and dilation
- std::swap(gi_conv_params.dilation, gi_conv_params.stride);
-
- auto ggWt = apply_fn<Transpose>(0, 1)(ggW);
- auto gOt = apply_fn<Transpose>(0, 1)(gO);
-
- // calculate output_padding
- auto kernel_size = weight.sizes().slice(2);
- auto input_shape = input.sizes().slice(2);
- auto grad_output_shape = gO.sizes().slice(2);
-
- if (kernel_size.size() == 1) {
- auto expected_input_shape = (kernel_size[0] - 1) * gi_conv_params.stride[1]
- - 2 * gi_conv_params.padding[1]
- + (gi_conv_params.dilation[1] * (grad_output_shape[0] - 1) + 1);
- if (expected_input_shape != input_shape[0]) {
- gi_conv_params.output_padding[1] = input_shape[0] - expected_input_shape;
- }
- } else {
- for(size_t i = 0; i < kernel_size.size(); ++i) {
- // Check if whole input has been used or not
- auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.stride[i]
- - 2 * gi_conv_params.padding[i]
- + (gi_conv_params.dilation[i] * (grad_output_shape[i] - 1) + 1);
- if (expected_input_shape != input_shape[i]) {
- gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape;
- }
- }
- }
-
- Variable gIt;
- if (groups == 1) {
- if (gOt.type().isCuda()) {
- gOt = apply_fn<Contiguous>()(gOt);
- }
-
- gIt = apply_fn<ConvForward>(gi_conv_params)(ggWt, gOt, Variable());
- } else {
- variable_list gIt_list(groups);
- for (int g = 0; g < groups; ++g) {
- auto ggWt_g = subvariable(ggWt, 1, groups, g);
- auto gOt_g = subvariable(gOt, 0, groups, g);
- if (gOt_g.type().isCuda()) {
- gOt_g = apply_fn<Contiguous>()(gOt_g);
- }
-
- gIt_list[g] = apply_fn<ConvForward>(gi_conv_params)(ggWt_g, gOt_g, Variable());
- }
-
- gIt = apply_fn<Cat>(0)(gIt_list);
- }
-
- gI = apply_fn<Transpose>(0, 1)(gIt);
- }
- }
-
- if (should_compute_output(0) && !ggO.defined()) ggO = at::zeros_like(gO);
- if (should_compute_output(1) && !gI.defined()) gI = at::zeros_like(input);
- if (should_compute_output(2) && !gW.defined()) gW = at::zeros_like(weight);
- bool is_volatile = std::any_of(grad_grad_inputs.begin(), grad_grad_inputs.end(), [](const Variable& v){
- return v.defined() && v.is_volatile();
- });
- auto results = variable_list({ggO, gI, gW});
- for (auto& result : results) {
- result.is_volatile() |= is_volatile;
- }
- return results;
-}
-
-auto ConvBackwardBackward::releaseVariables() -> void {
- input_.data.reset();
- weight_.data.reset();
- bias_.data.reset();
- grad_output_.data.reset();
-}
-
-// Forward and backward functions for Tensor
-
-static at::Tensor compute_output(
- at::Tensor& input, at::Tensor& weight, at::Tensor& bias,
- at::Tensor& columns, at::Tensor& ones,
- const ConvForward& params) {
-
- auto dim = input.ndimension();
- auto dilated = params.is_dilated();
- auto kernel_size = weight.sizes().slice(2);
- auto stride = vecToInt64(params.stride);
- auto padding = vecToInt64(params.padding);
- auto dilation = vecToInt64(params.dilation);
- auto output_padding = vecToInt64(params.output_padding);
-
- if (params.transposed) {
- if (dim == 4) {
- return at::conv_transpose2d_forward(
- input, weight, kernel_size, bias,
- stride, padding, output_padding, dilation,
- columns, ones);
- } else if (dim == 5) {
- return at::conv_transpose3d_forward(
- input, weight, bias,
- stride, padding, output_padding, dilation,
- columns, ones);
- }
- } else { /* Not transposed */
- if (dim == 4) {
- if (dilated) {
- return at::conv_dilated2d_forward(
- input, weight, kernel_size, bias,
- stride, padding, dilation,
- columns, ones);
- } else { /* dim == 4, non-dilated */
- if (params.use_nnpack(input)) {
-#ifdef WITH_NNPACK
- // THNN functions handle resizing the output Tensor themselves,
- // but NNPACK expects the Tensors to be in the appropriate shape
- // already, so we resize here
- auto output = input.type().tensor(params.output_size(input, weight));
- nnpack::SpatialConvolution_updateOutput(
- input, output, weight, bias,
- kernel_size[1], kernel_size[0],
- params.padding[1], params.padding[0]);
- return output;
-#endif
- } else {
- /* CPU implementation has specialized MM kernels
- for non-dilated case here */
- return at::conv2d_forward(
- input, weight, kernel_size, bias,
- stride, padding,
- columns, ones);
- }
- }
- } else if (dim == 5 && (input.type().isCuda() || dilated)) {
- return at::conv_dilated3d_forward(
- input, weight, kernel_size, bias,
- stride, padding, dilation,
- columns, ones);
- } else if (dim == 5) { /* dim == 5, CPU, non-dilated */
- /* CPU implementation has specialized MM kernels
- for non-dilated case here */
- return at::conv3d_forward(
- input, weight, kernel_size, bias,
- stride, padding,
- columns);
- }
- }
-
- throw std::runtime_error("unsupported ConvNd parameters");
-}
-
-static std::tuple<Tensor, Tensor, Tensor> compute_backward(
- at::Tensor& input, at::Tensor& grad_output, at::Tensor& weight,
- at::Tensor& columns, at::Tensor& ones,
- const ConvBackward& params,
- std::array<bool, 3> output_mask) {
-
- auto kernel_size = weight.sizes().slice(2);
- auto stride = vecToInt64(params.stride);
- auto padding = vecToInt64(params.padding);
- auto dilation = vecToInt64(params.dilation);
- auto output_padding = vecToInt64(params.output_padding);
-
- auto dim = input.ndimension();
- auto dilated = params.is_dilated();
-
- if (params.transposed) {
- if (dim == 4) {
- return at::conv_transpose2d_backward(
- grad_output, input, weight, kernel_size,
- stride, padding, output_padding, dilation,
- columns, ones, output_mask);
- } else if (dim == 5) {
- return at::conv_transpose3d_backward(
- grad_output, input, weight,
- stride, padding, output_padding, dilation,
- columns, ones, output_mask);
- }
- } else { /* Not transposed */
- if (dim == 4) {
- if (dilated) {
- return at::conv_dilated2d_backward(
- grad_output, input, weight, kernel_size,
- stride, padding, dilation,
- columns, ones, output_mask);
- } else {
- if (params.use_nnpack(input)) {
-#ifdef WITH_NNPACK
- Tensor grad_input;
- Tensor grad_weight;
- Tensor grad_bias;
-
- if (output_mask[0]) {
- grad_input = input.type().tensor(input.sizes());
- nnpack::SpatialConvolution_updateGradInput(
- input, grad_output, grad_input, weight,
- kernel_size[1], kernel_size[0],
- params.padding[1], params.padding[0]);
- }
-
- // NNPACK does not have a bias gradient calculation, so we split
- // into two calls here if necessary
- if (output_mask[1]) {
- grad_weight = weight.type().tensor(weight.sizes());
- grad_weight.zero_();
- nnpack::SpatialConvolution_accGradWeight(
- input, grad_output, grad_weight,
- kernel_size[1], kernel_size[0],
- params.padding[1], params.padding[0]);
- }
-
- if (output_mask[2]) {
- grad_bias = compute_grad_bias(grad_output);
- }
-
- return std::make_tuple(grad_input, grad_weight, grad_bias);
-#endif
- } else {
- /* CPU implementation has specialized MM kernels
- for non-dilated case here */
- return at::conv2d_backward(
- grad_output, input, weight, kernel_size,
- stride, padding,
- columns, ones, output_mask);
- }
- }
- } else if (dim == 5 && (input.type().isCuda() || dilated)) {
- return at::conv_dilated3d_backward(
- grad_output, input, weight, kernel_size,
- stride, padding, dilation,
- columns, ones, output_mask);
- } else if (dim == 5) { /* dim == 5, CPU, non-dilated */
- /* CPU implementation has specialized MM kernels
- for non-dilated case here */
- return at::conv3d_backward(
- grad_output, input, weight, kernel_size,
- stride, padding,
- columns, ones, output_mask);
- }
- }
-
- throw std::runtime_error("unsupported ConvNdBackward parameters");
-}
-
-}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/convolution.h b/torch/csrc/autograd/functions/convolution.h
deleted file mode 100644
index f8ba1bcb..00000000
--- a/torch/csrc/autograd/functions/convolution.h
+++ /dev/null
@@ -1,123 +0,0 @@
-#pragma once
-
-#include <Python.h>
-#include <ATen/ATen.h>
-#include <memory>
-#include <vector>
-#include <iostream>
-
-#include "torch/csrc/autograd/function.h"
-#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/symbolic.h"
-#include "torch/csrc/autograd/saved_variable.h"
-
-#ifdef WITH_CUDNN
-#include "torch/csrc/cudnn/Conv.h"
-#else
-namespace torch { namespace cudnn {
-struct Convolution {};
-}}
-#endif
-
-namespace torch { namespace autograd {
-
-struct ConvParams {
- std::vector<int> stride;
- std::vector<int> padding;
- std::vector<int> dilation;
- bool transposed;
- std::vector<int> output_padding;
- int groups;
- bool benchmark;
- bool deterministic;
- bool cudnn_enabled;
-
- bool is_strided() const;
- bool is_dilated() const;
- bool is_padded() const;
- bool is_output_padding_neg() const;
- bool is_output_padding_big() const;
- bool is_padding_neg() const;
- void view1d_as_2d();
- bool use_cudnn(const at::Tensor& input) const;
- bool use_nnpack(const at::Tensor& input) const;
- bool is_depthwise(const at::Tensor& input, const at::Tensor& weight, int groups) const;
-};
-
-struct ConvForward : public ForwardFunction<>, public ConvParams, public HasSymbolic {
- explicit ConvForward(ConvParams params) : ConvParams(std::move(params)) {}
-
- virtual std::string name() override;
- virtual variable_list apply(const variable_list& inputs) override;
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
-
- std::vector<int64_t> output_size(at::Tensor& input, at::Tensor& weight) const;
-};
-
-struct ConvBackward : public Function, public ConvParams {
- ConvBackward(
- FunctionFlags flags,
- ConvParams params,
- Variable input,
- Variable weight,
- Variable bias,
- tensor_list columns,
- tensor_list ones,
- std::unique_ptr<torch::cudnn::Convolution> convolution)
- : Function(std::move(flags))
- , ConvParams(std::move(params))
- , convolution(std::move(convolution)) {
- if (is_executable) {
- this->input_ = SavedVariable(input, this);
- this->weight_ = SavedVariable(weight, this);
- if (bias.defined()) {
- this->bias_ = SavedVariable(bias, this);
- }
- this->columns = std::move(columns);
- this->ones = std::move(ones);
- }
- }
-
- virtual variable_list apply(const variable_list& gradOutputs) override;
-
- virtual void releaseVariables() override;
-
- SavedVariable input_;
- SavedVariable weight_;
- SavedVariable bias_;
- tensor_list columns;
- tensor_list ones;
- std::unique_ptr<torch::cudnn::Convolution> convolution;
-};
-
-struct ConvBackwardBackward : public Function, public ConvParams {
- ConvBackwardBackward(
- FunctionFlags flags,
- ConvParams params,
- Variable input,
- Variable weight,
- Variable bias,
- Variable grad_output)
- : Function(std::move(flags))
- , ConvParams(std::move(params)) {
- if (is_executable) {
- this->input_ = SavedVariable(input, this);
- this->weight_ = SavedVariable(weight, this);
- if (bias.defined()) {
- this->bias_ = SavedVariable(bias, this);
- }
- this->grad_output_ = SavedVariable(grad_output, this);
- }
- }
-
- virtual variable_list apply(const variable_list& grad_grad_inputs) override;
-
- virtual void releaseVariables() override;
-
- SavedVariable input_;
- SavedVariable weight_;
- SavedVariable bias_;
- SavedVariable grad_output_;
-};
-
-}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp
index 3f5da160..fe5c52c0 100644
--- a/torch/csrc/autograd/functions/init.cpp
+++ b/torch/csrc/autograd/functions/init.cpp
@@ -1,10 +1,9 @@
-#include "batch_normalization.h"
-#include "convolution.h"
+#include "Python.h"
#include "accumulate_grad.h"
#include "basic_ops.h"
#include "tensor.h"
#include "special.h"
-#include "jit_closure.h"
+#include "torch/csrc/jit/interpreter_autograd_function.h"
#include "torch/csrc/autograd/functions/pybind.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/autograd/generated/python_functions.h"
@@ -15,41 +14,6 @@
using namespace torch::autograd;
using torch::TupleParser;
-struct BatchNormCtor {
- BatchNormForward* operator()(PyObject* args) {
- BatchNormParams params;
-
- TupleParser parser(args, 6);
- parser.parse(params.running_mean, "running_mean");
- parser.parse(params.running_var, "running_var");
- parser.parse(params.training, "training");
- parser.parse(params.momentum, "momentum");
- parser.parse(params.eps, "eps");
- parser.parse(params.cudnn_enabled, "cudnn_enabled");
-
- return new BatchNormForward(std::move(params));
- }
-};
-
-struct ConvCtor {
- ConvForward* operator()(PyObject* args) {
- ConvParams params;
-
- TupleParser parser(args, 9);
- parser.parse(params.stride, "stride");
- parser.parse(params.padding, "padding");
- parser.parse(params.dilation, "dilation");
- parser.parse(params.transposed, "transposed");
- parser.parse(params.output_padding, "output_padding");
- parser.parse(params.groups, "groups");
- parser.parse(params.benchmark, "benchmark");
- parser.parse(params.deterministic, "deterministic");
- parser.parse(params.cudnn_enabled, "cudnn_enabled");
-
- return new ConvForward(std::move(params));
- }
-};
-
struct DelayedErrorCtor {
DelayedError* operator()(PyObject* args) {
std::string msg;
@@ -69,7 +33,7 @@ struct NoCtor {
template<typename C, typename T>
static void addClass(PyObject* module, PyTypeObject& type, const char* name,
- PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL)
+ PyGetSetDef* function_properties=nullptr, PyMethodDef* function_methods=nullptr)
{
createForwardFunctionPyTypeObject<T>(type, name, function_properties, function_methods);
Py_INCREF(&type);
@@ -86,7 +50,7 @@ PyObject* getTupleAttr(PyObject* obj, void* _unused)
auto& arr = ((T*)(self->cdata.get()))->*ptr;
auto num_elems = arr.size();
THPObjectPtr py_tuple(PyTuple_New(num_elems));
- if (!py_tuple) return NULL;
+ if (!py_tuple) return nullptr;
for (size_t i = 0; i < num_elems; ++i) {
PyTuple_SET_ITEM(py_tuple.get(), i, Convert(arr[i]));
}
@@ -105,125 +69,6 @@ PyObject* getValueAttr(PyObject* obj, void* _unused)
END_HANDLE_TH_ERRORS
}
-template<typename T, typename ParamsT, at::Tensor ParamsT::*ptr>
-PyObject* getTensorAttr(PyObject* obj, void* _unused)
-{
- HANDLE_TH_ERRORS
- THPCppFunction* self = (THPCppFunction*)obj;
- auto& val = ((T*)(self->cdata.get()))->*ptr;
- THPObjectPtr py_tensor;
- if (!val.defined()) {
- Py_INCREF(Py_None);
- py_tensor = Py_None;
- } else {
- py_tensor = torch::createPyObject(val);
- }
- return py_tensor.release();
- END_HANDLE_TH_ERRORS
-}
-
-static struct PyGetSetDef batch_norm_forward_properties[] = {
- THP_FUNCTION_DEFAULT_PROPERTIES,
- {(char*)"running_mean", (getter)getTensorAttr<BatchNormForward, BatchNormParams,
- &BatchNormParams::running_mean>, NULL, NULL, NULL},
- {(char*)"running_var", (getter)getTensorAttr<BatchNormForward, BatchNormParams,
- &BatchNormParams::running_var>, NULL, NULL, NULL},
- {(char*)"training", (getter)getValueAttr<BatchNormForward, bool, BatchNormParams,
- &BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL},
- {(char*)"momentum", (getter)getValueAttr<BatchNormForward, double, BatchNormParams,
- &BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL},
- {(char*)"eps", (getter)getValueAttr<BatchNormForward, double, BatchNormParams,
- &BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL},
- {(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormForward, bool, BatchNormParams,
- &BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL},
- {NULL}
-};
-
-static struct PyGetSetDef batch_norm_backward_properties[] = {
- THP_FUNCTION_DEFAULT_PROPERTIES,
- {(char*)"running_mean", (getter)getTensorAttr<BatchNormBackward, BatchNormParams,
- &BatchNormParams::running_mean>, NULL, NULL, NULL},
- {(char*)"running_var", (getter)getTensorAttr<BatchNormBackward, BatchNormParams,
- &BatchNormParams::running_var>, NULL, NULL, NULL},
- {(char*)"training", (getter)getValueAttr<BatchNormBackward, bool, BatchNormParams,
- &BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL},
- {(char*)"momentum", (getter)getValueAttr<BatchNormBackward, double, BatchNormParams,
- &BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL},
- {(char*)"eps", (getter)getValueAttr<BatchNormBackward, double, BatchNormParams,
- &BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL},
- {(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormBackward, bool, BatchNormParams,
- &BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL},
- {NULL}
-};
-
-static struct PyGetSetDef batch_norm_backward_backward_properties[] = {
- THP_FUNCTION_DEFAULT_PROPERTIES,
- {(char*)"running_mean", (getter)getTensorAttr<BatchNormBackwardBackward, BatchNormParams,
- &BatchNormParams::running_mean>, NULL, NULL, NULL},
- {(char*)"running_var", (getter)getTensorAttr<BatchNormBackwardBackward, BatchNormParams,
- &BatchNormParams::running_var>, NULL, NULL, NULL},
- {(char*)"training", (getter)getValueAttr<BatchNormBackwardBackward, bool, BatchNormParams,
- &BatchNormParams::training, long, PyBool_FromLong>, NULL, NULL, NULL},
- {(char*)"momentum", (getter)getValueAttr<BatchNormBackwardBackward, double, BatchNormParams,
- &BatchNormParams::momentum, double, PyFloat_FromDouble>, NULL, NULL, NULL},
- {(char*)"eps", (getter)getValueAttr<BatchNormBackwardBackward, double, BatchNormParams,
- &BatchNormParams::eps, double, PyFloat_FromDouble>, NULL, NULL, NULL},
- {(char*)"cudnn_enabled", (getter)getValueAttr<BatchNormBackwardBackward, bool, BatchNormParams,
- &BatchNormParams::cudnn_enabled, long, PyBool_FromLong>, NULL, NULL, NULL},
- {NULL}
-};
-
-static struct PyGetSetDef conv_forward_properties[] = {
- THP_FUNCTION_DEFAULT_PROPERTIES,
- {(char*)"stride", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
- &ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"padding", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
- &ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"dilation", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
- &ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"transposed", (getter)getValueAttr<ConvForward, bool, ConvParams,
- &ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL},
- {(char*)"output_padding", (getter)getTupleAttr<ConvForward, std::vector<int>, ConvParams,
- &ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"groups", (getter)getValueAttr<ConvForward, int, ConvParams,
- &ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL},
- {NULL}
-};
-
-static struct PyGetSetDef conv_backward_properties[] = {
- THP_FUNCTION_DEFAULT_PROPERTIES,
- {(char*)"stride", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
- &ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"padding", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
- &ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"dilation", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
- &ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"transposed", (getter)getValueAttr<ConvBackward, bool, ConvParams,
- &ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL},
- {(char*)"output_padding", (getter)getTupleAttr<ConvBackward, std::vector<int>, ConvParams,
- &ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"groups", (getter)getValueAttr<ConvBackward, int, ConvParams,
- &ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL},
- {NULL}
-};
-
-static struct PyGetSetDef conv_backward_backward_properties[] = {
- THP_FUNCTION_DEFAULT_PROPERTIES,
- {(char*)"stride", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
- &ConvParams::stride, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"padding", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
- &ConvParams::padding, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"dilation", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
- &ConvParams::dilation, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"transposed", (getter)getValueAttr<ConvBackwardBackward, bool, ConvParams,
- &ConvParams::transposed, long, PyBool_FromLong>, NULL, NULL, NULL},
- {(char*)"output_padding", (getter)getTupleAttr<ConvBackwardBackward, std::vector<int>, ConvParams,
- &ConvParams::output_padding, long, PyInt_FromLong>, NULL, NULL, NULL},
- {(char*)"groups", (getter)getValueAttr<ConvBackwardBackward, int, ConvParams,
- &ConvParams::groups, long, PyInt_FromLong>, NULL, NULL, NULL},
- {NULL}
-};
-
static PyObject* accumulateGradVar(PyObject *_self, void* _unused)
{
THPCppFunction* self = (THPCppFunction*)_self;
@@ -233,79 +78,57 @@ static PyObject* accumulateGradVar(PyObject *_self, void* _unused)
static struct PyGetSetDef accumulate_grad_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
- {(char*)"variable", accumulateGradVar, NULL, NULL, NULL},
- {NULL}
+ {(char*)"variable", accumulateGradVar, nullptr, nullptr, nullptr},
+ {nullptr}
};
-bool THPAutograd_initFunctions(PyObject* _unused)
+void THPAutograd_initFunctions()
{
THPObjectPtr module(PyModule_New("torch._C._functions"));
- if (!module) return false;
-
- static PyTypeObject BatchNormClass, BatchNormBackwardClass, BatchNormBackwardBackwardClass;
- addClass<BatchNormForward, BatchNormCtor>(module, BatchNormClass, "BatchNorm", batch_norm_forward_properties);
- addClass<BatchNormBackward, NoCtor>(module, BatchNormBackwardClass, "BatchNormBackward", batch_norm_backward_properties);
- addClass<BatchNormBackwardBackward, NoCtor>(module, BatchNormBackwardBackwardClass, "BatchNormBackwardBackward", batch_norm_backward_backward_properties);
-
- static PyTypeObject ConvClass, ConvBackwardClass, ConvBackwardBackwardClass;
- addClass<ConvForward, ConvCtor>(module, ConvClass, "ConvNd", conv_forward_properties);
- addClass<ConvBackward, NoCtor>(module, ConvBackwardClass, "ConvNdBackward", conv_backward_properties);
- addClass<ConvBackwardBackward, NoCtor>(module, ConvBackwardBackwardClass, "ConvNdBackwardBackward", conv_backward_backward_properties);
+ if (!module) throw python_error();
static PyTypeObject AccumulateGradClass;
addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties);
- static PyTypeObject AddClass, AddBackwardClass;
- addClass<Add, NoCtor>(module, AddClass, "Add");
- addClass<AddBackward_Deprecated, NoCtor>(module, AddBackwardClass, "AddBackward_Deprecated");
-
static PyTypeObject ErrorClass;
addClass<Error, NoCtor>(module, ErrorClass, "Error");
static PyTypeObject DelayedErrorClass;
addClass<DelayedError, DelayedErrorCtor>(module, DelayedErrorClass, "DelayedError");
- static PyTypeObject CloneClass;
- addClass<Clone, NoCtor>(module, CloneClass, "Clone");
- static PyTypeObject ContiguousClass;
- addClass<Contiguous, NoCtor>(module, ContiguousClass, "Contiguous");
- static PyTypeObject IdentityClass;
- addClass<Identity, NoCtor>(module, IdentityClass, "Identity");
- static PyTypeObject TransposeClass;
- addClass<Transpose, NoCtor>(module, TransposeClass, "Transpose");
- static PyTypeObject ViewClass;
- addClass<View, NoCtor>(module, ViewClass, "View");
- static PyTypeObject ExpandClass;
- addClass<Expand, NoCtor>(module, ExpandClass, "Expand");
- static PyTypeObject NarrowClass;
- addClass<Narrow, NoCtor>(module, NarrowClass, "Narrow");
- static PyTypeObject CatClass;
- addClass<Cat, NoCtor>(module, CatClass, "Cat");
-
static PyTypeObject EvalClass;
addClass<Eval, NoCtor>(module, EvalClass, "Eval");
- static PyTypeObject AutogradClosureClass;
- addClass<AutogradClosure, NoCtor>(module, AutogradClosureClass, "AutogradClosure");
+ static PyTypeObject InterpreterAutogradClass;
+ addClass<torch::jit::InterpreterAutogradFunction, NoCtor>(module, InterpreterAutogradClass, "InterpreterAutogradFunction");
+
+ static PyTypeObject CopyBackwardsClass;
+ addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards");
+
+ static PyTypeObject CopySlicesClass;
+ addClass<CopySlices, NoCtor>(module, CopySlicesClass, "CopySlices");
generated::initialize_autogenerated_functions();
- THPObjectPtr parent(PyImport_ImportModule("torch._C"));
- if (!parent) return false;
- PyModule_AddObject(parent.get(), "_functions", module.release());
- return true;
+ auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
+ if (!c_module) throw python_error();
+
+ Py_INCREF(module);
+ if (PyModule_AddObject(c_module, "_functions", module) < 0) {
+ throw python_error();
+ }
}
namespace torch { namespace autograd {
void initAutogradClosureBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
- py::class_<AutogradClosureFactory,std::shared_ptr<AutogradClosureFactory>>(m, "AutogradClosureFactory")
- .def("__call__", &AutogradClosureFactory::construct)
+ py::class_<jit::InterpreterFunctionFactory,std::shared_ptr<jit::InterpreterFunctionFactory>>(m, "InterpreterFunctionFactory")
+ .def("__call__", &jit::InterpreterFunctionFactory::construct_function)
;
- m.def("_jit_createAutogradClosure", [](jit::tracer::TracingState* tracing_state) {
- return std::make_shared<AutogradClosureFactory>(tracing_state);
+ m.def("_jit_createInterpreterFactory", [](jit::tracer::TracingState* tracing_state) {
+ return std::make_shared<jit::InterpreterFunctionFactory>(tracing_state);
});
}
diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp
deleted file mode 100644
index 33cff95a..00000000
--- a/torch/csrc/autograd/functions/jit_closure.cpp
+++ /dev/null
@@ -1,835 +0,0 @@
-#include "torch/csrc/autograd/functions/jit_closure.h"
-
-#include "torch/csrc/Exceptions.h"
-#include "torch/csrc/utils/auto_gil.h"
-#include "torch/csrc/utils/functional.h"
-#include "torch/csrc/autograd/engine.h"
-#include "torch/csrc/autograd/functions/special.h"
-#include "torch/csrc/autograd/functions/basic_ops.h"
-#include "torch/csrc/autograd/functions/tensor.h"
-#include "torch/csrc/autograd/functions/utils.h"
-#include "torch/csrc/autograd/python_engine.h"
-#include "torch/csrc/autograd/python_variable.h"
-#include "torch/csrc/autograd/python_function.h"
-#include "torch/csrc/jit/generated/aten_dispatch.h"
-#ifdef WITH_CUDA
-#include "torch/csrc/jit/fusion_compiler.h"
-#endif
-namespace torch { namespace autograd {
-
-using namespace torch::jit;
-using namespace torch::jit::tracer;
-
-// Used when an output has multiple uses (there's only one entry
-// in next_functions per output).
-struct Replicate : public Function {
- Replicate() {
- is_executable = true;
- num_inputs = 1;
- }
-
- virtual variable_list apply(const variable_list& inputs) {
- return variable_list(next_functions.size(), inputs[0]);
- }
-};
-
-// This class is never put in the autograd graph: see InputPlaceholder
-// and EvalPlaceholder.
-struct Placeholder : public Function {
- virtual variable_list apply(const variable_list& inputs) {
- return inputs;
- }
-};
-
-// Used for inputs of previous previous stages
-struct PrevStageInput : public Replicate {};
-// Used for inputs to the closure (execution roots)
-struct InputPlaceholder : public Placeholder {};
-// Used to mark places that will have to apply Evals from previous stages.
-//
-// Why do we need this? Let us recall the raison d'etre of Eval nodes: they
-// exist so that when we compute an backwards autograd closure on the fly
-// while executing forwards, we can use exactly that closure when backwards
-// executes. Morally, this closure is simply an input to the backwards
-// computation, and in the compiler IR representation, it's represented
-// precisely this way (with opaque Handle nodes.)
-//
-// However, the *autograd* execution model only accounts for Variable
-// input/outputs, which a Handle is not! "So why not add non-Variable inputs
-// to autograd"? Perhaps this could be made to work, but it is a bit awkward:
-// it would involve totally adding a new type of input to the execution model.
-// Autograd is not intended to be a general purpose programming language
-// runtime, so on balance, we decided to consider solutions which don't involve
-// adding new types of inputs to autograd, instead passing the closures "out
-// of band".
-//
-// By what mechanism, then, can we actually pass the closure? Here is the idea.
-// Instead of actually inserting an "Eval" node, we instead insert an
-// EvalPlaceholder, which doesn't know anything about evaluating a closure.
-// Then, at the time when we want to partially apply the actual closure
-// (computed from the forwards pass), we stick a pre-callback on the EvalPlaceholder
-// that takes the inputs, does the actual Eval, and then passes on the outputs
-// (which the EvalPlaceholder subsequently passes through.)
-//
-// Remember that callbacks are NOT stored on a Function object itself: they are
-// registered on a per AutogradClosure (for which there may be multiple per
-// graph). So we can't do something like mutate a
-// Eval Function to give it the autograd closure to run inside its main body:
-// that violates the invariant that autograd graphs are immutable! (In other
-// words, the same EvalPlaceholder may be participating in multiple engine
-// executions) You truly must somehow associate these closures with the graph as
-// a whole, and there must be a mechanism to ferry this data to the Function
-// itself. A callback is just the ticket.
-struct EvalPlaceholder : public Placeholder {};
-
-// Used for the graph output. Execution should be stopped by a callback before apply().
-struct Output : public Function {
- Output(int ninputs) {
- is_executable = true;
- num_inputs = ninputs;
- }
-
- virtual variable_list apply(const variable_list& inputs) {
- throw std::runtime_error("Output::apply called");
- }
-};
-
-struct SimpleEval : public Function {
- SimpleEval(const std::shared_ptr<Function>& fn)
- : fn(fn) {}
-
- virtual variable_list apply(const variable_list& inputs) override {
- return fn->apply(inputs);
- }
-
- std::shared_ptr<Function> fn;
-};
-
-struct EmitNull : public Function {
- EmitNull() {
- is_executable = true;
- num_inputs = 0;
- }
-
- virtual variable_list apply(const variable_list& inputs) {
- return {Variable()};
- };
-};
-
-struct LambdaFunction : public Function {
- LambdaFunction(const jit::TensorOp& op)
- : LambdaFunction(op.num_inputs, op.op) {
- this->name_ = op.name;
- }
-
- LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn)
- : fn_(fn) {
- this->is_executable = true;
- this->num_inputs = num_inputs;
- }
-
- virtual std::string name() override {
- return name_.size() == 0 ? "LambdaFunction" : name_;
- }
-
- virtual variable_list apply(const variable_list& inputs) override {
- return fn_(inputs);
- }
-
- std::string name_;
- std::function<variable_list(const variable_list&)> fn_;
-};
-
-// Wraps a PythonOp and dispatches calls to Functions implemented in Python
-struct PythonCall : public Function {
- PythonCall(PythonOp *op)
- : cconv(op->cconv)
- , scalar_args() {
-
- Py_INCREF(op->pyobj.get());
- pyobj = op->pyobj.get();
-
- scalar_args.reserve(op->scalar_args.size());
- for (auto& arg_ptr : op->scalar_args) {
- Py_INCREF(arg_ptr.get());
- scalar_args.emplace_back(arg_ptr.get());
- }
- }
-
- virtual variable_list apply(const variable_list& inputs) {
- AutoGIL gil;
-
- THPObjectPtr apply_fn {PyObject_GetAttrString(pyobj, "apply")};
- if (!apply_fn) throw python_error();
-
- THPObjectPtr py_inputs { packInputs(inputs) };
- THPObjectPtr result { PyObject_Call(apply_fn.get(), py_inputs.get(), NULL) };
- if (!result) throw python_error();
- return unpackOutputs(result);
- }
-
- THPObjectPtr packInputs(const variable_list& inputs) {
- THPObjectPtr py_inputs { PyTuple_New(cconv.size()) };
- if (!py_inputs) throw python_error();
-
- auto var_it = inputs.begin();
- auto scalar_it = scalar_args.begin();
- int input_nr = 0;
-
- for (auto arg_type : cconv) {
- PyObject *obj = nullptr;
- if (arg_type == 's') {
- if (scalar_it == scalar_args.end())
- throw std::runtime_error("expected too many scalar args");
- obj = (scalar_it++)->get();
- Py_INCREF(obj);
- } else if (arg_type == 't') {
- if (var_it == inputs.end())
- throw std::runtime_error("expected too many inputs");
- obj = THPVariable_Wrap(*(var_it++));
- } else {
- throw std::runtime_error("unexpected calling convention");
- }
- PyTuple_SET_ITEM(py_inputs.get(), input_nr++, obj);
- }
-
- return py_inputs;
- }
-
- variable_list unpackOutputs(THPObjectPtr& result) {
- variable_list var_result;
-
- ensure_tuple(result);
- auto num_outputs = PyTuple_GET_SIZE(result.get());
- for (int i = 0; i < num_outputs; ++i) {
- PyObject *output = PyTuple_GET_ITEM(result.get(), i);
- if (!THPVariable_Check(output))
- throw std::runtime_error("Function.apply returned a non-Variable output");
- THPVariable *var = (THPVariable*)output;
- var_result.emplace_back(var->cdata);
- }
-
- return var_result;
- }
-
- THPObjectPtr pyobj;
- std::string cconv;
- std::vector<THPObjectPtr> scalar_args;
-};
-
-// Note [Handling nullary functions in the autograd engine]
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// Today, the autograd engine cannot handle nullary functions, because
-// it assumes that every non-input function has at least one input.
-// This fits nicely with the scheduling strategy, which schedules a
-// function for execution when all of its inputs are ready. Unfortunately,
-// constants are nullary.
-//
-// Instead, we use a little hack. Rather than creating an extra root
-// for every constant, we add a single new root, ConstantFactory, which
-// when run triggers all of the actual constant functions, WrapConstant,
-// which actually contribute a constant. Furthermore, we use a single
-// null input to ensure that the next_function index has a valid offset.
-//
-// One possible alternative to represent this might be to special case constants
-// in the execution engine, as a separate vector of roots. But the current
-// strategy seems to work fine and isn't too difficult to construct a trace
-// for.
-
-struct WrapConstant : public Function {
- WrapConstant(at::Tensor value)
- : value(std::move(value)) {
- is_executable = true;
- num_inputs = 1;
- }
-
- virtual variable_list apply(const variable_list& inputs) {
- if (inputs.size() != 1 || inputs[0].defined())
- throw std::logic_error("WrapConstant nodes should only receive a single NULL input");
- AutoGPU guard(value);
- return {make_variable(value.clone())};
- }
-
- at::Tensor value;
-};
-
-// See Note [Handling nullary functions in the autograd engine]
-struct ConstantFactory : public Function {
- ConstantFactory() {
- is_executable = true;
- num_inputs = 1;
- }
-
- virtual variable_list apply(const variable_list& inputs) {
- if (inputs.size() != 1 || inputs[0].defined())
- throw std::logic_error("ConstantFactory nodes should only receive a single NULL input");
- return variable_list(next_functions.size());
- }
-};
-
-#ifdef WITH_CUDA
-struct FusionGroupFunction : public Function {
- FusionGroupFunction(const std::shared_ptr<CompiledFusionFunction> & function)
- : function(function) {}
- virtual variable_list apply(const variable_list& inputs) {
- //TODO: handle the case where inputs do not match the device function was
- // compiled for
- std::vector<at::Tensor> data;
- for(auto & input : inputs)
- data.push_back(input.data());
- AutoGPU guard(data.back());
- std::vector<at::Tensor> outputs;
- outputs.reserve(function->outputDescriptors().size());
- for(auto & od : function->outputDescriptors()) {
- outputs.push_back(at::CUDA(od.scalar_type).tensor());
- }
- function->launch(data, outputs);
- return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) {
- return std::make_shared<torch::autograd::Error>("FusionGroupFunction is not differentiable", std::move(f));
- });
- }
-private:
- std::shared_ptr<CompiledFusionFunction> function;
-};
-#endif
-
-////////////////////////////////////////////////////////////////////////////////
-////////////////////////////////////////////////////////////////////////////////
-
-// A helper struct that precomputes and caches information regarding cross-stage
-// dependencies and state passing.
-//
-// Example:
-// graph (%1,
-// %2,
-// ------ stage 1 ------
-// %9,
-// ------ stage 2 ------
-// %31,
-// %32) {
-// %3.0, %3.1 = MulConstant(2)(%2)
-// %6.0, %6.1 = Mul()(%3.0, %1)
-// ---------------- stage 1 ----------------
-// %10.0, %10.1, %10.2 = Eval(%9, %6.1)
-// %23.0, %23.1 = Eval(%10.0, %3.1)
-// ---------------- stage 2 ----------------
-// %33.0, %33.1 = Eval(%32, %23.1)
-// %44.0, %44.1, %44.2, %44.3 = Eval(%33.0, %31, %10.2)
-// %78.0, %78.1 = Eval(%44.1, %3.1)
-// return (%6.0, %10.1, %23.0, %44.0, %44.2, %78.0);
-// }
-//
-// Then:
-//
-// graph->stage() = 2
-// stage_begins = [%3, %10, %33, %0] (0 = return node)
-// stage_inputs = [
-// [%1, %2],
-// [%9],
-// [%31, %32]
-// ]
-// stage_outputs = [
-// [%6.0],
-// [%10.1, %23],
-// [%44.0, %44.2, %78.0]
-// ]
-// prev_stage_inputs = [
-// [], # Always empty!
-// [%6.1, %3.1],
-// [%23.1, %10.2, %3.1]
-// ]
-// cur_stage_captures = [
-// [%6.1, %3.1],
-// [%23.1, %10.2],
-// [] # Always empty!
-// ]
-struct CrossStageStateDesc {
- CrossStageStateDesc(Graph* graph)
- // FYI: graph->stage() == the last stage we have traced
- // (e.g., forwards+backwards = 1)
- : stage_inputs(graph->stage() + 1)
- , stage_outputs(graph->stage() + 1)
- , prev_stage_inputs(graph->stage() + 1)
- , cur_stage_captures(graph->stage() + 1) {
-
- std::size_t current_stage = -1;
- for (auto node : graph->nodes()) {
- // Look for stage boundaries
- if (node->stage() != current_stage) {
- JIT_ASSERT(node->stage() == current_stage + 1);
- current_stage = node->stage();
- stage_begins.push_back(node);
- }
- // Look for things we need to save
- for (auto input : node->inputs()) {
- if (input->stage() != current_stage) {
- JIT_ASSERT(input->stage() < current_stage);
- // We need to save it in all intermediate stages too
- for (auto i = current_stage; i > input->stage(); --i) {
- prev_stage_inputs[i].insert(input);
- }
- cur_stage_captures[input->stage()].insert(input);
- }
- }
- }
-
- // It's convenient to pretend output is one more stage - we can always
- // take an iterator for stage i and i+1 as loop boundaries
- stage_begins.push_back(graph->return_node());
-
- // Scatter inputs and outputs across stage buckets
- for (auto input : graph->inputs())
- stage_inputs[input->stage()].push_back(input);
- for (auto output : graph->outputs())
- stage_outputs[output->stage()].push_back(output);
-
- JIT_ASSERT(prev_stage_inputs.front().empty());
- JIT_ASSERT(cur_stage_captures.back().empty());
- }
-
- // For each stage, the first Node in Graph's topological sort which
- // is a member of this stage. In general, the stages of nodes in
- // a graph will look like this:
- //
- // 000000011111112222222E (E is the Return node)
- // ^ ^ ^ ^
- //
- // We have pointers to the caret'ed nodes.
- std::vector<Node*> stage_begins;
- std::vector<std::vector<Node*>> stage_inputs;
- std::vector<std::vector<Node*>> stage_outputs;
- // A set of all Nodes from previous stage that pass anything (both Variables
- // and handles) to current stage.
- std::vector<std::unordered_set<Node*>> prev_stage_inputs;
- // A set of all Nodes from this stage, that need their values to be captured
- // for future stages (applies to both Variables and handles).
- std::vector<std::unordered_set<Node*>> cur_stage_captures;
-};
-
-// Creates a graph for a given stage and stores information necessary to construct
-// an AutogradClosure with it
-struct StageClosure {
- using node_fn_map_type = std::unordered_map<Node*, std::shared_ptr<Function>>;
-
- StageClosure(TracingState *state, const CrossStageStateDesc& xstate, std::size_t stage)
- : var_flags(state->var_flags.at(stage))
- , const_factory(std::make_shared<ConstantFactory>()) {
- auto graph = state->graph.get();
- node_fn_map_type node_map;
- // This map caches PrevStageInputs for a given node, so that you don't
- // create multiple PrevStageInput for the same node.
- node_fn_map_type prev_stage_input_map;
-
- // Prepare output node and compute an offset within return node inputs where
- // nodes from this stage apear.
- output = std::make_shared<Output>(xstate.stage_outputs[stage].size());
- node_map[graph->return_node()] = output;
- std::size_t output_offset = 0;
- for (std::size_t i = 0; i < stage; ++i)
- output_offset += xstate.stage_outputs[i].size();
-
- // Builds up a closure for node. It assumes that it has been called
- // for all nodes that use outputs of node, which is why we iterate
- // in reverse topological order.
- auto add_node = [&](Node *node) {
- JIT_ASSERT(node->stage() == stage);
-
- // Get function object
- auto fn = getFunction(node);
- if (!fn) return; // This node is a no-op
-
- // Initialize function fields
- fn->is_executable = true;
- if (fn->num_inputs == 0) {
- fn->num_inputs = node->inputs().size();
- }
- fillNextFunctions(node, fn, node_map, output_offset, stage);
-
- registerPrevStageInputs(node, fn, prev_stage_input_map);
- node_map[node] = fn;
- };
-
- for (auto it = std::next(xstate.stage_begins[stage+1]->reverseIterator()),
- end = std::next(xstate.stage_begins[stage]->reverseIterator()); it != end; ++it) {
- add_node(*it);
- }
- for (auto node : xstate.stage_inputs[stage]) {
- add_node(node);
- }
-
- // Prepare inputs.
- for (Node *input : xstate.stage_inputs[stage]) {
- roots.emplace_back(node_map.at(input), 0);
- }
- for (auto & entry : prev_stage_input_map) {
- roots.emplace_back(entry.second, 0);
- prev_stage_variables.emplace_back(entry.first->unique());
- }
- // NOTE: prev_stage_handles have been already filled in by add_node
- JIT_ASSERT(prev_stage_variables.size() + prev_stage_handles.size() == xstate.prev_stage_inputs[stage].size());
-
- // Prepare a list of values / handles to capture
- for (auto captured_node : xstate.cur_stage_captures[stage]) {
- if (captured_node->kind() == kSelect) {
- auto & fn = node_map.at(captured_node->input());
- if (captured_node->type()->kind() == TypeKind::TensorType) {
- captured_variables.emplace_back(fn.get(), captured_node->i(kOffset), captured_node->unique());
- } else {
- JIT_ASSERT(captured_node->type()->kind() == TypeKind::HandleType);
- captured_handles.emplace(fn.get(), captured_node->unique());
- }
- } else {
- JIT_ASSERT(captured_node->type()->kind() == TypeKind::TensorType);
- auto & fn = node_map.at(captured_node);
- captured_variables.emplace_back(fn.get(), 0, captured_node->unique());
- }
- }
-
- roots.emplace_back(const_factory, 0);
-
- findCopiedNextFunctions(state, stage);
- }
-
- // Returns a function implementing functionality of a given node,
- // or nullptr if it's a no-op for autograd.
- std::shared_ptr<Function> getFunction(Node *node) {
- IR_IFM(node, PythonOp)
- return std::make_shared<PythonCall>(value);
- IR_ELSEIFM(CppOp)
- if (dynamic_cast<Eval*>(value->fn.get())) {
- auto fn = std::make_shared<EvalPlaceholder>();
-
- // All Eval nodes take context edges as an input, and we need to register
- // all such places
- auto inputs = value->inputs();
- JIT_ASSERT(inputs.size() > 0);
- auto handle_input = inputs[inputs.size() - 1];
- JIT_ASSERT(handle_input->type()->kind() == TypeKind::HandleType);
- prev_stage_handles.emplace_back(fn.get(), handle_input->unique());
-
- fn->num_inputs = node->inputs().size() - 1;
- return fn;
- } else {
- return std::make_shared<SimpleEval>(value->fn);
- }
- IR_ELSEIF(Select)
- // No-op. Selects are handled by their inputs.
- return nullptr;
- IR_ELSEIF(FusionGroup)
-#ifdef WITH_CUDA
- // TODO: make this more robust - handle device and contiguity changes!
- auto fusion_fn = sharedFusionCompiler().getOrCompile(*value->g(kSubgraph));
- return std::make_shared<FusionGroupFunction>(std::move(fusion_fn));
-#else
- throw std::runtime_error("don't know how to execute FusionGroups without CUDA");
-#endif
- IR_ELSEIF(Param)
- auto fn = std::make_shared<InputPlaceholder>();
- fn->num_inputs = 1;
- return fn;
- IR_ELSEIF(Constant)
- auto fn = std::make_shared<torch::autograd::WrapConstant>(value->t(kvalue));
- const_factory->next_functions.emplace_back(fn, 0);
- fn->num_inputs = 1;
- return fn;
- IR_ELSEIF(Undefined)
- return std::make_shared<EmitNull>();
- IR_ELSEIF(Transpose) // NOTE: Transpose in ONNX is Permute in Torch
- auto permutation = value->is(kperm);
- if (permutation != std::vector<int64_t>({1, 0}))
- throw std::runtime_error("Transpose isn't fully supported in closure compiler");
- return std::make_shared<LambdaFunction>(1, [](const variable_list& vars) -> variable_list {
- return {make_variable(vars[0].data().transpose(1, 0), vars[0].requires_grad())};
- });
- IR_ELSEIF(Reshape)
- auto shape = value->is(kshape);
- return std::make_shared<LambdaFunction>(1, [shape](const variable_list& vars) -> variable_list {
- return {make_variable(vars[0].data().view(shape), vars[0].requires_grad())};
- });
- IR_ELSEIF(Gemm)
- auto beta = value->f(kbeta);
- auto alpha = value->f(kalpha);
- return std::make_shared<LambdaFunction>(3, [beta, alpha](const variable_list& vars) -> variable_list {
- return {vars[2].addmm(vars[0], vars[1], beta, alpha)};
- });
- IR_ELSE()
- return std::make_shared<LambdaFunction>(getTensorOp(node));
- IR_END()
- }
-
- // Fill in the next_functions of the Function we just allocated
- void fillNextFunctions(Node *node, const std::shared_ptr<Function>& fn, node_fn_map_type& node_map, int output_offset, std::size_t stage) {
- auto graph = node->owningGraph();
- // Gather uses of each output
- std::vector<std::reference_wrapper<const use_list>> output_uses_refs;
- if (node->hasMultipleOutputs()) {
- // Each use is a single Select node corresponding to an output
- for (auto& use : node->uses()) {
- if (use.user->isHandle()) continue;
- auto& select_uses = use.user->uses();
- output_uses_refs.emplace_back(select_uses);
- }
- } else {
- output_uses_refs.emplace_back(node->uses());
- }
-
- // Fill next_functions accordingly to uses of each output
- // There's some fiddling required for fixing the offset of uses for return node, so it's
- // better to keep this logic in a lambda.
- auto append_use = [&node_map, graph, output_offset](const std::shared_ptr<Function>& fn, Use& use) {
- int offset = use.offset;
- if (use.user == graph->return_node()) offset -= output_offset;
- fn->next_functions.emplace_back(node_map.at(use.user), offset);
- };
- for (auto& output_uses_ref : output_uses_refs) {
- // Filter out uses from future stages (except for output!)
- auto output_uses = filter(output_uses_ref.get(), [stage, graph](const Use& use) {
- return use.user->stage() == stage || use.user == graph->return_node();
- });
- // Optimize out unnecessary Replicate nodes
- if (output_uses.size() == 1) {
- append_use(fn, output_uses[0]);
- // If an output was used more than once, we need to insert a Replicate node
- // because there's only a single entry for an output in next_functions
- } else {
- auto replicate = std::make_shared<Replicate>();
- for (auto& use : output_uses) append_use(replicate, use);
- fn->next_functions.emplace_back(std::move(replicate), 0);
- }
- }
- }
-
- // Possibly create PrevStageInputs for any uses of nodes from previous
- // stages, and fill in their next_functions with our use.
- void registerPrevStageInputs(Node *node, const std::shared_ptr<Function>& fn,
- node_fn_map_type& prev_stage_input_map) {
- const auto& inputs = node->inputs();
- for (std::size_t i = 0; i < inputs.size(); ++i) {
- auto input_node = inputs[i];
- if (input_node->type()->kind() == TypeKind::HandleType) continue;
- JIT_ASSERT(input_node->type()->kind() == TypeKind::TensorType);
- if (input_node->stage() < node->stage()) {
- auto info = prev_stage_input_map.emplace(input_node, nullptr);
- auto & input_fn_ptr = info.first->second;
- // Create a node if insertion took place
- if (info.second) input_fn_ptr.reset(new PrevStageInput());
- input_fn_ptr->next_functions.emplace_back(fn, i);
- }
- }
- }
-
- // If this stage produces gradients of any of previous stage inputs,
- // it needs to include them in its next_functions. However, we do not
- // necessarily keep them as SavedVariables, so it's not straightforward
- // to use wrap_outputs for this purpose. Here, we find all next_functions
- // from the previous stage that will need to be copied as next_functions
- // of this stage (the copy is done explicitly in lambda constructor given to
- // wrap_outputs).
- // NOTE: we depend on the Eval input ordering here (i.e. inherited/prev stage
- // outputs come after this stage inputs and remain sorted).
- void findCopiedNextFunctions(TracingState *state, std::size_t stage) {
- if (stage == 0) return;
- auto & current_outputs = state->output_edges[stage];
- auto & prev_outputs = state->output_edges[stage - 1];
- for (auto & output : current_outputs) {
- auto prev_it = std::find(prev_outputs.begin(), prev_outputs.end(), output);
- if (prev_it == prev_outputs.end()) continue;
- copied_next_fns.push_back(std::distance(prev_outputs.begin(), prev_it));
- }
- }
-
- // Roots for a call to the engine. The list contains function in this order:
- // [ apply input roots | prev stage input roots | constant factory ]
- function_list roots;
- std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>> var_flags;
-
- // Output node
- std::shared_ptr<Function> output;
- std::shared_ptr<ConstantFactory> const_factory;
-
- std::vector<int> copied_next_fns;
-
- // These will be used by each instantiation of AutogradClosure to register hooks.
- std::vector<int> prev_stage_variables; // unique
- std::vector<std::pair<Function*, int>> prev_stage_handles; // (placeholder, unique)
- // After the function is run, take either a Variable or a backward handle, and
- // put it in the environment under 'unique'.
- std::vector<std::tuple<Function*, int, int>> captured_variables; // (function, output_nr, unique)
- std::unordered_map<Function*, int> captured_handles; // (function, unique)
-};
-
-// Computes and stores an array of StageClosures for each stage in the graph
-struct MultiStageClosure {
- MultiStageClosure(TracingState* state) {
- auto graph = state->graph.get();
- CrossStageStateDesc xstate {graph};
- auto num_stages = graph->stage() + 1;
- for (std::size_t i = 0; i < num_stages; ++i) {
- stages.emplace_back(state, xstate, i);
- }
- }
-
- std::vector<StageClosure> stages;
-};
-
-AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc)
- : AutogradClosure(desc, 0) {}
-
-// TODO: there's a lot processing involved in creating a new AutogradClosure instance,
-// so it might be worth to keep a pool of unused instances (or at least their attrs)
-// for all stages. We can't save saved_vars and saved_handles, but all callbacks
-// can be made reusable.
-AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage)
- : desc(desc)
- , stage(stage) {
- auto & stage_desc = desc->stages[stage];
-
- // Callbacks to capture Variables for backward closure
- for (auto & entry : stage_desc.captured_variables) {
- auto & fn = std::get<0>(entry);
- auto output_offset = std::get<1>(entry);
- auto saved_idx = std::get<2>(entry);
- post_callbacks.emplace(fn, [this, saved_idx, output_offset](Function* fn, variable_list& inputs, variable_list& outputs) {
- std::lock_guard<std::mutex> lock(this->capture_mutex);
- this->captured_vars[saved_idx] = outputs[output_offset].data();
- return true;
- });
- }
-
- // Callbacks to capture handles (backward subgraphs) for backward closure
- for (auto & entry : stage_desc.captured_handles) {
- auto & fn = entry.first;
- auto saved_idx = entry.second;
- // Evals already wrap their backwards and they will be handled in the
- // next loop if needed
- if (dynamic_cast<EvalPlaceholder*>(fn)) continue;
- // Otherwise we have to wrap the backwards in a handle ourselves
- post_callbacks.emplace(fn, [this, saved_idx](Function* fn, variable_list& inputs, variable_list& outputs) {
- auto eval_fn = std::make_shared<Eval>();
- eval_fn->replaceSubgraph(inputs, outputs);
- std::lock_guard<std::mutex> lock(this->capture_mutex);
- this->captured_handles[saved_idx] = std::move(eval_fn);
- return true;
- });
- }
-
- // Callbacks that run closures received from forward and optionally capture
- // them for next stages
- for (auto & entry : stage_desc.prev_stage_handles) {
- int unique = entry.second;
- // Check if we need to capture the handle for next stage too
- auto it = stage_desc.captured_handles.find(entry.first);
- int saved_idx = it == stage_desc.captured_handles.end() ? -1 : it->second;
- post_callbacks.emplace(entry.first, [this, unique, saved_idx](Function* fn, variable_list& inputs, variable_list& outputs) {
- outputs = (*this->saved_handles.at(unique))(inputs);
- if (saved_idx != -1) {
- auto eval_fn = Eval::getBackwardEval(inputs, outputs);
- std::lock_guard<std::mutex> lock(this->capture_mutex);
- this->captured_handles[saved_idx] = std::move(eval_fn);
- }
- return true;
- });
- }
-
- // A callback to capture the output
- pre_callbacks.emplace(stage_desc.output.get(), [this](Function*, variable_list& inputs) {
- std::lock_guard<std::mutex> lock(this->capture_mutex);
- this->outputs.reserve(inputs.size());
- for (auto & input : inputs)
- this->outputs.emplace_back(input.opt_data());
- return false; // Stop execution
- });
-}
-
-variable_list AutogradClosure::apply(const variable_list& inputs) {
- auto& stage_closure = desc->stages[stage];
-
- // Validate inputs
- auto num_inputs = inputs.size();
- if (num_inputs != stage_closure.var_flags.first.size())
- throw std::runtime_error("AutogradClosure received an incorrect number of inputs");
- for (std::size_t i = 0; i < num_inputs; ++i) {
- auto & flags = stage_closure.var_flags.first[i];
- if (!flags.verify(inputs[i]))
- throw std::runtime_error("AutogradClosure received inputs with different flags");
- }
-
- // TODO: we could run all this with volatile variables, but we need to somehow capture handles
- // we should enable requires_grad only for the parts that need it
- auto input_leaves = fmap(inputs, [](const Variable& v) {
- return v.defined() ? make_variable(v.data(), v.requires_grad(), v.is_volatile()) : Variable();
- });
- for (auto unique : desc->stages[stage].prev_stage_variables)
- input_leaves.emplace_back(make_variable(saved_vars.at(unique), true, false));
- input_leaves.emplace_back(Variable()); // for ConstantFactory
-
- auto& engine = python::PythonEngine::getDefaultEngine();
- engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks);
-
- // Create the backward function lazily
- auto make_grad_fn = [this]() -> std::shared_ptr<Function> {
- if (this->stage == this->desc->stages.size() - 1) {
- std::string msg = "JIT closure compiled only for ";
- msg += std::to_string(this->stage);
- msg += " backwards";
- return std::make_shared<Error>(std::move(msg));
- }
- auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1));
- // TODO: don't make a full copy of saved_* - copy only the things that bw needs
- bw_fn->saved_vars = this->saved_vars;
- bw_fn->saved_vars.insert(std::make_move_iterator(this->captured_vars.begin()),
- std::make_move_iterator(this->captured_vars.end()));
- bw_fn->saved_handles = this->saved_handles;
- bw_fn->saved_handles.insert(std::make_move_iterator(this->captured_handles.begin()),
- std::make_move_iterator(this->captured_handles.end()));
- // Patch next_functions to include prevous stage next_functions
- for (auto copied_idx : this->desc->stages[this->stage + 1].copied_next_fns) {
- bw_fn->next_functions.push_back(this->next_functions[copied_idx]);
- }
- // This is needed because of copied functions (even if all inputs of this stage
- // didn't require grad, copied function can), and is always valid because we assert
- // that flags are the same as when we compiled the closure (and the tracing Eval
- // was run, so it must have been executable).
- bw_fn->is_executable = true;
- return bw_fn;
- };
-
- // See Note [Null-edge pruning]
- variable_list result;
- auto num_outputs = outputs.size();
- std::shared_ptr<Function> grad_fn;
- JIT_ASSERT(outputs.size() == stage_closure.var_flags.second.size());
- for (std::size_t i = 0; i < num_outputs; ++i) {
- auto & flags = stage_closure.var_flags.second[i];
- if (flags.requires_grad) {
- if (!grad_fn) grad_fn = make_grad_fn();
- result.push_back(make_variable(outputs[i], grad_fn));
- } else {
- result.push_back(make_variable(outputs[i], flags.requires_grad, flags.is_volatile));
- }
- }
-
- // If we created grad_fn for any of the outputs, we also need to fill in next_functions
- if (grad_fn) {
- for (auto & input : inputs) {
- if (!input.requires_grad()) continue;
- grad_fn->next_functions.emplace_back(
- input.grad_fn() ? input.grad_fn() : input.grad_accumulator(),
- input.output_nr());
- }
- }
-
- captured_vars.clear();
- captured_handles.clear();
- outputs.clear();
- return result;
-}
-
-AutogradClosureFactory::AutogradClosureFactory(TracingState *state)
- : desc(std::make_shared<MultiStageClosure>(state)) {}
-
-std::shared_ptr<Function> AutogradClosureFactory::construct() {
- return std::make_shared<AutogradClosure>(desc);
-}
-
-}}
diff --git a/torch/csrc/autograd/functions/jit_closure.h b/torch/csrc/autograd/functions/jit_closure.h
deleted file mode 100644
index 6d905e63..00000000
--- a/torch/csrc/autograd/functions/jit_closure.h
+++ /dev/null
@@ -1,50 +0,0 @@
-#pragma once
-
-#include <Python.h>
-#include <memory>
-#include <unordered_map>
-
-#include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/tracer_state.h"
-#include "torch/csrc/autograd/engine.h"
-#include "torch/csrc/autograd/function.h"
-#include "torch/csrc/autograd/variable.h"
-
-namespace torch { namespace autograd {
-
-struct MultiStageClosure;
-
-struct AutogradClosureFactory {
- AutogradClosureFactory(torch::jit::tracer::TracingState *graph);
-
- std::shared_ptr<Function> construct();
-
- std::shared_ptr<MultiStageClosure> desc;
-};
-
-struct AutogradClosure : public Function {
- AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc);
-
- virtual variable_list apply(const variable_list& inputs) override;
-
-private:
- AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage);
-
- variable_list rewrapInputs(const variable_list& inputs);
-
- std::shared_ptr<MultiStageClosure> desc;
- std::size_t stage;
-
- std::unordered_map<int, at::Tensor> saved_vars;
- std::unordered_map<int, std::shared_ptr<Function>> saved_handles;
-
- Engine::pre_callback_map pre_callbacks;
- Engine::post_callback_map post_callbacks;
-
- std::unordered_map<int, at::Tensor> captured_vars;
- std::unordered_map<int, std::shared_ptr<Function>> captured_handles;
- tensor_list outputs;
- std::mutex capture_mutex;
-};
-
-}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/onnx/basic_ops.cpp b/torch/csrc/autograd/functions/onnx/basic_ops.cpp
deleted file mode 100644
index 5a89a1fa..00000000
--- a/torch/csrc/autograd/functions/onnx/basic_ops.cpp
+++ /dev/null
@@ -1,12 +0,0 @@
-#include "torch/csrc/autograd/functions/basic_ops.h"
-
-namespace torch { namespace autograd {
-
-jit::node_list Add::symbolic(SymbolicContext* ctx, jit::node_list inputs) {
- auto & g = ctx->graph;
- auto node = g->create(jit::kAdd, inputs);
- g->appendNode(node);
- return {node};
-}
-
-}}
diff --git a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp b/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
deleted file mode 100644
index 08f675c4..00000000
--- a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
+++ /dev/null
@@ -1,38 +0,0 @@
-#include "torch/csrc/autograd/functions/batch_normalization.h"
-#include <sstream>
-
-#include "torch/csrc/jit/ir.h"
-
-namespace torch {
-namespace autograd {
-
-jit::node_list BatchNormForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) {
- auto & g = ctx->graph;
- // X, Scale, Bias
- auto bn = g->appendNode(g->create(jit::kBatchNormalization, {inputs.at(0),inputs.at(1),inputs.at(2)}));
- bn->addInput(jit::tracer::getBufferTrace(*ctx->buffer_map, running_mean));
- bn->addInput(jit::tracer::getBufferTrace(*ctx->buffer_map, running_var));
- bn->i_(jit::kis_test, !this->training);
- bn->f_(jit::kepsilon, eps);
- //bn->s_(jit::korder, "NCHW");
- bn->f_(jit::kmomentum, 1 - momentum);
-
- auto orig_output = g->appendNode(g->createSelect(bn, 0));
-
- if(this->training) {
- g->appendNode(g->createSelect(bn, 1)->setType(bn->input(3)->type()));
- g->appendNode(g->createSelect(bn, 2)->setType(bn->input(4)->type()));
- // dummy output
- for(int i = 3; i < 5; i++) {
- g->appendNode(g->createSelect(bn, i)->setDebugName("batch_norm_dead_output"));
- }
- }
- bn->is_(jit::kconsumed_inputs,{0,0,0,1,1});
-
- ctx->batch_norm_count++;
- return {orig_output};
-}
-
-
-} // torch::autograd
-} // torch
diff --git a/torch/csrc/autograd/functions/onnx/convolution.cpp b/torch/csrc/autograd/functions/onnx/convolution.cpp
deleted file mode 100644
index a8270fc8..00000000
--- a/torch/csrc/autograd/functions/onnx/convolution.cpp
+++ /dev/null
@@ -1,78 +0,0 @@
-#include "torch/csrc/autograd/functions/convolution.h"
-
-namespace torch { namespace autograd {
-
-// Note [Caffe2ConvTranspose]
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~
-// ConvTranspose in Caffe2 is a bit silly: bias is mandatory. But ONNX
-// has removed bias input from official ConvTranspose. How can the Caffe2
-// backend do the translation? It can't! It's impossible! So as a temporary
-// hack while we wait for Caffe2 to make bias optional, we are using a
-// Caffe2ConvTranspose experimental ONNX op which has a mandatory bias.
-// PyTorch has no trouble making the zero-filled tensor.
-//
-// For code simplicity, even if PyTorch was given a bias tensor, it is NOT
-// passed here; it's done as an external addition. This is less efficient
-// but this code should be temporary anyway.
-
-jit::node_list ConvForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) {
- auto & g = ctx->graph;
- // See Note [Caffe2ConvTranspose]
- auto n = g->create(!transposed ? jit::kConv : jit::kConvTranspose,
- {inputs.at(0), inputs.at(1)});
-
- // Irritatingly, Caffe2 requires us to specify kernels,
- // but we don't actually have that information directly
- // recorded in ConvForward. So we have to reverse
- // engineer it from the input types...
- // TODO: dynamic_cast ew
- auto weight_type = inputs.at(1)->type()->cast<jit::TensorType>();
- JIT_ASSERT(weight_type);
- auto weight_size = weight_type->sizes();
-
- // See Note [Caffe2ConvTranspose]
- if(transposed) {
- n->addInput(g->appendNode(g->createConstant(at::CPU(at::kFloat).zeros({weight_size[1]}))));
- }
-
- g->appendNode(n);
-
- std::vector<int64_t> kernel_size(weight_size.begin() + 2, weight_size.end());
- n->is_(jit::kkernel_shape, std::move(kernel_size));
- std::vector<int64_t> kernel_stride(stride.begin(),stride.end());
- n->is_(jit::kstrides, std::move(kernel_stride));
-
- std::vector<int64_t> kernel_pads(padding.begin(),padding.end());
- // NB: Caffe2 let's specifying top and bottom pads separately;
- // PyTorch assumes it's symmetric
- for (int p : padding) {
- kernel_pads.push_back(p);
- }
- n->is_(jit::kpads,std::move(kernel_pads));
-
- std::vector<int64_t> kernel_dilations(dilation.begin(),dilation.end());
- n->is_(jit::kdilations,std::move(kernel_dilations));
- n->i_(jit::kgroup,groups);
-
- // Not in ONNX?
- // TODO: implement it once ConvTranspose in ONNX gets `adj` argument instead
- // of providing `output_shape`
- for (int p : output_padding) {
- JIT_EXPECTM(p == 0, "output padding is not supported.");
- }
-
- // ignore benchmark/cudnn_enabled
-
- if (inputs.at(2)->kind() != jit::kUndefined) {
- // TODO: Set type here based on RETURN type (not available atm)
- auto a_n = g->create(jit::kAdd, {g->appendNode(g->createSelect(n, 0)), inputs.at(2)});
- a_n->i_(jit::kbroadcast, 1);
- a_n->i_(jit::kaxis, 1);
- g->appendNode(a_n);
- return {a_n};
- } else {
- return {n};
- }
-}
-
-}}
diff --git a/torch/csrc/autograd/functions/special.cpp b/torch/csrc/autograd/functions/special.cpp
index c0e75f69..f750c39f 100644
--- a/torch/csrc/autograd/functions/special.cpp
+++ b/torch/csrc/autograd/functions/special.cpp
@@ -1,15 +1,37 @@
#include "torch/csrc/autograd/functions/special.h"
+#include "torch/csrc/assertions.h"
#include "torch/csrc/autograd/python_engine.h"
+#include "torch/csrc/autograd/edge.h"
+#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/edge.h"
+
+#include <cstdint>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <utility> // for swap
namespace torch { namespace autograd {
+// Used when an output has multiple uses (there's only one entry
+// in next_edges per output).
+struct Replicate : public Function {
+ Replicate() : Function(/*num_inputs=*/1) {}
+
+ virtual variable_list apply(const variable_list& inputs) {
+ TORCH_ASSERT(inputs.size() == 1);
+ return variable_list(num_outputs(), inputs[0]);
+ }
+};
+
// Note [Null-edge pruning]
// Evals have a problem with null edges appearing in the graph, because there's
// no way to tell the identity of the input (i.e. each nullptr might have been
// a different input, all of them might have been a single input, etc.).
// However, null edges are generally quite useless, so we can safely prune them,
-// by removing them from next_functions of Eval node and never allocating
+// by removing them from next_edges of Eval node and never allocating
// placeholders for them. This is a bit annoying because backward subgraphs may
// have many less outputs than forward graph had inputs, but I don't think there's
// a way around it. It's a tiny perf optimization too :)
@@ -38,11 +60,11 @@ auto Eval::getSubgraph(const variable_list& inputs, const variable_list& outputs
input_edges.reserve(inputs.size());
for (auto & input : inputs) {
if (!input.defined()) continue;
- input_edges.emplace(input.grad_fn() ? input.grad_fn() : input.grad_accumulator(), input.output_nr());
+ input_edges.emplace(input.gradient_edge());
}
// This is used to stop the search in situation 2 and find the corresponding placeholders.
- std::unordered_map<edge_type, std::shared_ptr<EvalOutput>, edge_hasher> inherited_edges;
+ std::unordered_map<Edge, std::shared_ptr<EvalOutput>> inherited_edges;
inherited_edges.reserve(inherited_placeholders.size());
for (auto & placeholder : inherited_placeholders) {
input_edges.emplace(placeholder->next_edge);
@@ -62,15 +84,14 @@ auto Eval::getSubgraph(const variable_list& inputs, const variable_list& outputs
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
JIT_ASSERT(fn);
- fn->tracing_state->in_eval_subgraph = true;
- int num_edges = fn->next_functions.size();
- for (int i = 0; i < num_edges; ++i) {
- auto & edge = fn->next_functions[i];
- auto & next_fn = edge.first;
- if (!next_fn) continue; // See Note [Null-edge pruning]
+ fn->tracing_state().in_eval_subgraph = true;
+ const auto num_outputs = fn->num_outputs();
+ for (size_t i = 0; i < num_outputs; ++i) {
+ const auto& edge = fn->next_edge(i);
+ if (!edge.function) continue; // See Note [Null-edge pruning]
// Edge belongs to subgraph boundary. Register that and don't search along it.
if (input_edges.count(edge) > 0) {
- subgraph.boundary.begins.emplace(fn->getSharedPtr(), i);
+ subgraph.boundary.begins.emplace(fn->get_shared_ptr(), i);
subgraph.boundary.ends.emplace(edge);
auto it = inherited_edges.find(edge);
// Situation 2. If that edge is actually pointing to an earlier stage subgraph,
@@ -81,14 +102,14 @@ auto Eval::getSubgraph(const variable_list& inputs, const variable_list& outputs
continue;
}
// Situation 1. If we end up in a placeholder, we need to inherit it.
- if (auto placeholder = std::dynamic_pointer_cast<EvalOutput>(next_fn)) {
+ if (auto placeholder = std::dynamic_pointer_cast<EvalOutput>(edge.function)) {
extra_placeholders.emplace(placeholder);
subgraph.boundary.ends.emplace(placeholder->next_edge);
continue;
}
- bool unseen = seen.emplace(next_fn.get()).second;
+ bool unseen = seen.emplace(edge.function.get()).second;
if (unseen)
- queue.emplace_back(next_fn.get());
+ queue.emplace_back(edge.function.get());
}
}
@@ -106,47 +127,46 @@ bool Eval::trySimpleEval(const variable_list& inputs, const variable_list& outpu
if (inherited_placeholders.size() != 0) return false;
auto& grad_fn = outputs[0].grad_fn();
- if (static_cast<std::size_t>(grad_fn->num_inputs) >= max_outputs) return false;
- if (static_cast<std::size_t>(grad_fn->num_inputs) != outputs.size()) return false;
+ if (static_cast<std::size_t>(grad_fn->num_inputs()) >= max_outputs) return false;
+ if (static_cast<std::size_t>(grad_fn->num_inputs()) != outputs.size()) return false;
// Check that all outputs have the same grad_fn and cover all its inputs
bitset_type output_nrs = 0;
- bitset_type expected_bitset = ((1 << grad_fn->num_inputs) - 1);
+ bitset_type expected_bitset = ((1 << grad_fn->num_inputs()) - 1);
for (auto & output : outputs) {
if (output.grad_fn() != grad_fn) return false;
output_nrs |= (1 << output.output_nr());
}
if (output_nrs != expected_bitset) return false;
- // Check that grad_fn->next_functions matches the inputs exactly
+ // Check that grad_fn's next_edges match the inputs exactly.
auto num_inputs = inputs.size();
- auto& grad_next_fns = grad_fn->next_functions;
- if (num_inputs != grad_next_fns.size()) return false;
+ if (num_inputs != grad_fn->num_outputs()) return false;
for (std::size_t i = 0; i < num_inputs; ++i) {
- // Unfortunately, null edge pruning (see Note [Null-edge pruning]) applies to
- // autograd functions which would otherwise be eligible for the SimpleEval
- // optimization. This makes everything more complicated, so for now we just don't
- // attempt the optimization in this case. To fix it properly,
- // we'd need to filter grad_next_fns and outputs of apply in Eval::apply.
- // The check below tests if null edge pruning occurred.
- if (!inputs[i].defined() || !grad_next_fns[i].first) return false;
- const auto& input_grad = inputs[i].grad_fn() ? inputs[i].grad_fn() : inputs[i].grad_accumulator();
- if (grad_next_fns[i].first != input_grad || grad_next_fns[i].second != inputs[i].output_nr()) return false;
+ const auto& next_grad_edge = grad_fn->next_edge(i);
+ // Unfortunately, null edge pruning (see Note [Null-edge pruning]) applies
+ // to autograd functions which would otherwise be eligible for the
+ // SimpleEval optimization. This makes everything more complicated, so for
+ // now we just don't attempt the optimization in this case. To fix it
+ // properly, we'd need to filter grad_fn's output edges and outputs of
+ // apply in Eval::apply. The check below tests if null edge pruning
+ // occurred.
+ if (!inputs[i].defined() || !next_grad_edge.is_valid()) return false;
+ if (next_grad_edge != inputs[i].gradient_edge()) return false;
}
// Success! We still need to set up placeholders for next stages and to drop
// references to the graph.
- std::swap(next_functions, grad_next_fns);
- grad_next_fns.reserve(num_inputs);
+ std::swap(next_edges_, grad_fn->next_edges());
+ grad_fn->next_edges().reserve(num_inputs);
placeholders.reserve(num_inputs);
- for (std::size_t i = 0; i < num_inputs; ++i) {
- auto placeholder = std::make_shared<EvalOutput>(next_functions[i]);
- grad_next_fns.emplace_back(placeholder, 0);
+ for (const auto& input : next_edges_) {
+ auto placeholder = std::make_shared<EvalOutput>(input);
+ grad_fn->add_next_edge({placeholder, 0});
placeholders.emplace_back(std::move(placeholder));
}
- is_executable = grad_fn->is_executable;
simple_graph = grad_fn;
- grad_fn->tracing_state->in_eval_subgraph = true;
+ grad_fn->tracing_state().in_eval_subgraph = true;
return true;
}
@@ -160,12 +180,12 @@ variable_list Eval::filterRelevantOutputs(const variable_list& inputs, const var
ignored_grad_fns.reserve(inputs.size());
for (auto& input : inputs) {
if (!input.defined()) continue;
- ignored_grad_fns.emplace(input.grad_fn(), input.output_nr());
+ ignored_grad_fns.insert(input.gradient_edge());
}
for (auto& output : outputs) {
if (!output.defined()) continue;
if (!output.grad_fn()) continue;
- if (ignored_grad_fns.count(std::make_pair(output.grad_fn(), output.output_nr())) > 0) continue;
+ if (ignored_grad_fns.count(output.gradient_edge()) > 0) continue;
relevant_outputs.emplace_back(output);
}
return relevant_outputs;
@@ -176,10 +196,7 @@ auto Eval::computeInputOrder(const variable_list& inputs, const placeholder_list
int idx = 0;
for (auto & input : inputs) {
if (!input.defined()) continue;
- input_order.emplace(
- std::make_pair(input.grad_fn() ? input.grad_fn() : input.grad_accumulator(), input.output_nr()),
- idx++
- );
+ input_order.emplace(input.gradient_edge(), idx++);
}
for (auto & placeholder : inherited_placeholders)
input_order.emplace(placeholder->next_edge, idx++);
@@ -200,12 +217,12 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
if (!trySimpleEval(inputs, relevant_outputs, inherited_placeholders)) {
roots.reserve(relevant_outputs.size());
for (auto & output : relevant_outputs)
- roots.emplace_back(output.grad_fn(), output.output_nr());
+ roots.push_back(output.gradient_edge());
auto subgraph = getSubgraph(inputs, relevant_outputs, inherited_placeholders);
// Prepare output placeholder nodes for each end.
- std::unordered_map<edge_type, std::shared_ptr<EvalOutput>, edge_hasher> ends_to_outputs;
+ std::unordered_map<Edge, std::shared_ptr<EvalOutput>> ends_to_outputs;
for (auto & placeholder : placeholders) {
ends_to_outputs[placeholder->next_edge] = placeholder;
}
@@ -218,20 +235,18 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
// Replace begins with pointers to output nodes.
// This detaches the subgraph from the full backward graph.
- for (auto & begin : subgraph.boundary.begins) {
- auto & fn = begin.first;
- auto offset = begin.second;
- fn->next_functions[offset] = std::make_pair(ends_to_outputs.at(fn->next_functions[offset]), 0);
+ for (auto& begin : subgraph.boundary.begins) {
+ const auto& edge = begin.function->next_edge(begin.input_nr);
+ begin.function->set_next_edge(
+ begin.input_nr, Edge(ends_to_outputs.at(edge), 0));
}
// Replace subgraph with this node.
- next_functions.insert(next_functions.begin(), subgraph.boundary.ends.begin(), subgraph.boundary.ends.end());
- is_executable = std::any_of(relevant_outputs.begin(), relevant_outputs.end(),
- [](const Variable& var) { return var.requires_grad(); });
+ next_edges_.insert(next_edges_.begin(), subgraph.boundary.ends.begin(), subgraph.boundary.ends.end());
// Ensure placeholders and inputs are sorted in the same way.
edge_order input_order = computeInputOrder(inputs, inherited_placeholders);
- std::sort(next_functions.begin(), next_functions.end(), [&input_order](const edge_type &a, const edge_type &b) {
+ std::sort(next_edges_.begin(), next_edges_.end(), [&input_order](const Edge &a, const Edge &b) {
return input_order.at(a) < input_order.at(b);
});
std::sort(placeholders.begin(), placeholders.end(), [&input_order](const std::shared_ptr<EvalOutput> &a, const std::shared_ptr<EvalOutput> &b) {
@@ -240,9 +255,30 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
}
// Rebase outputs.
+ auto this_shared = shared_from_this();
+ std::unordered_set<Variable*> repeated_outputs;
+ // NB: every output can be in 3 states:
+ // - unique so far - only the else of second if is taken
+ // - repeated first time - first if + first branch of second if
+ // - repeated many times - first branch of second if only
for (auto & output : relevant_outputs) {
- output.grad_fn() = shared_from_this();
- output.get()->output_nr = num_inputs++;
+ // This output is already rebased. This happens when there
+ // the same Variable has been returned multiple times, and
+ // is repeated in this list.
+ if (output.grad_fn_unsafe() == this) {
+ auto replicate = std::make_shared<Replicate>();
+ replicate->add_next_edge({this_shared, output.output_nr()});
+ output.set_gradient_edge({std::move(replicate), 0});
+ repeated_outputs.emplace(&output);
+ }
+ // NOTE: this check should be fairly cheap, and the set shouldn't
+ // perform any allocations until we actually see repeated outputs.
+ if (repeated_outputs.count(&output) > 0) {
+ auto & replicate = output.grad_fn();
+ replicate->add_next_edge({this_shared, num_inputs_++});
+ } else {
+ autograd::create_gradient_edge(output, this_shared);
+ }
}
return true;
@@ -253,11 +289,12 @@ variable_list Eval::apply(const variable_list& inputs) {
if (simple_graph) {
outputs = (*simple_graph)(inputs);
} else {
- std::mutex outputs_mutex;
- outputs.resize(placeholders.size());
auto& engine = python::PythonEngine::getDefaultEngine();
auto exec_data = filterRoots(inputs);
- engine.execute(exec_data.first, exec_data.second, true, getCallbacks(outputs, outputs_mutex));
+ auto next_edges = fmap(
+ placeholders,
+ [](const std::shared_ptr<EvalOutput>& o) { return Edge(o, 0); });
+ outputs = engine.execute(exec_data.first, exec_data.second, true, true, next_edges);
}
auto bw_eval = newEval();
@@ -267,7 +304,7 @@ variable_list Eval::apply(const variable_list& inputs) {
// This node already does it (backward of non-traceable backward is implicitly non-traceable),
// and it passes more information (backward Eval may inherit placeholders) than
// Function::traced_apply has available.
- tracing_state->in_eval_subgraph = true;
+ tracing_state_->in_eval_subgraph = true;
return outputs;
}
@@ -275,9 +312,9 @@ variable_list Eval::apply(const variable_list& inputs) {
// TODO: once we clean up the stochastic function mess it should be possible to ignore
// nullptr inputs in the Engine (it implies that the Variables is 0, so the jacobian vector
// product will be all zero too).
-std::pair<function_list, variable_list> Eval::filterRoots(const variable_list& inputs) {
+std::pair<edge_list, variable_list> Eval::filterRoots(const variable_list& inputs) {
variable_list filtered_inputs;
- function_list filtered_roots;
+ edge_list filtered_roots;
auto num_inputs = inputs.size();
if (roots.size() != num_inputs)
throw std::logic_error("inputs.size() != roots.size()");
@@ -305,20 +342,4 @@ std::pair<function_list, variable_list> Eval::filterRoots(const variable_list& i
return std::make_pair(std::move(filtered_roots), std::move(filtered_inputs));
}
-Engine::pre_callback_map Eval::getCallbacks(variable_list& outputs, std::mutex& outputs_mutex) {
- Engine::pre_callback_map callbacks;
- int num_outputs = placeholders.size();
- for (int i = 0; i < num_outputs; ++i) {
- auto& output_fn = placeholders[i];
- callbacks.emplace(output_fn.get(), [&outputs, &outputs_mutex, i](Function* _unused, variable_list& inputs) -> bool {
- if (inputs.size() != 1)
- throw std::logic_error("placeholder callback received too many inputs");
- std::lock_guard<std::mutex> lock(outputs_mutex);
- outputs[i] = inputs[0];
- return false; // Stop at output nodes
- });
- }
- return callbacks;
-}
-
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/special.h b/torch/csrc/autograd/functions/special.h
index 4af9f251..7676083b 100644
--- a/torch/csrc/autograd/functions/special.h
+++ b/torch/csrc/autograd/functions/special.h
@@ -1,37 +1,34 @@
#pragma once
#include <Python.h>
-#include <memory>
-#include <string>
-#include <mutex>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/engine.h"
+#include <memory>
+#include <mutex>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
namespace torch { namespace autograd {
struct EvalOutput : Function {
- EvalOutput(const edge_type& next_edge)
- : next_edge(next_edge) {
- num_inputs = 1;
- // It would be nice if we could inherit this from the function of next_edge,
- // but we want to always run this node to capture the output. This might
- // confuse some of the functions causing them to do unnecessary work.
- // TODO: it should be possible to improve this once we get rid of NULL Variables
- is_executable = true;
- }
+ explicit EvalOutput(const Edge& next_edge_)
+ : Function(/*num_inputs=*/1), next_edge(next_edge_) {}
virtual variable_list apply(const variable_list& inputs) override {
throw std::logic_error("EvalOutput::apply() called");
}
- edge_type next_edge;
+ Edge next_edge;
};
struct Eval : Function {
- using edge_set = std::unordered_set<edge_type, edge_hasher>;
- using edge_order = std::unordered_map<edge_type, int, edge_hasher>;
+ using edge_set = std::unordered_set<Edge>;
+ using edge_order = std::unordered_map<Edge, int>;
using placeholder_list = std::vector<std::shared_ptr<EvalOutput>>;
// This struct has only one member, but it's useful to e.g. add a set of all
@@ -40,12 +37,12 @@ struct Eval : Function {
struct Boundary {
// All nodes from within the subgraph that connect to the outside.
// These are the places that will need to be patched to point to placeholders.
- // Contains pairs of (fn, offset into next_functions).
+ // Contains pairs of (fn, offset into next_edges).
edge_set begins;
// All nodes that are not in the subgraph, but are in the union of
- // next_functions of all nodes from the subgraph. These are the places that
+ // next_edges of all nodes from the subgraph. These are the places that
// will be modeled by placeholders.
- // Contains pairs of (fn, input_nr) and is equivalent to next_functions
+ // Contains pairs of (fn, input_nr) and is equivalent to next_edges
// of an Eval node that will replace the subgraph.
edge_set ends;
};
@@ -78,20 +75,19 @@ struct Eval : Function {
return std::make_shared<Eval>();
}
- // Roots are empty if simple_graph is not NULL.
+ // Roots are empty if simple_graph is not nullptr.
// simple_graph is an optimization of first backward stage - in this case
// all Eval subgraphs contain only a single gradient function, and the
// graph search on creation + call to the engine in apply can be elided
- function_list roots;
+ edge_list roots;
std::shared_ptr<Function> simple_graph;
placeholder_list placeholders;
- jit::Node* forward_ctx_select = nullptr;
+ jit::Value* forward_ctx_select = nullptr;
bool traceable = false;
private:
- std::pair<function_list, variable_list> filterRoots(const variable_list& inputs);
- Engine::pre_callback_map getCallbacks(variable_list& outputs, std::mutex& outputs_mutex);
+ std::pair<edge_list, variable_list> filterRoots(const variable_list& inputs);
Subgraph getSubgraph(
const variable_list& inputs,
diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp
index 00086cd5..dfd8baf3 100644
--- a/torch/csrc/autograd/functions/tensor.cpp
+++ b/torch/csrc/autograd/functions/tensor.cpp
@@ -1,117 +1,92 @@
-#include "tensor.h"
+#include "torch/csrc/autograd/functions/tensor.h"
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/functions/basic_ops.h"
#include "torch/csrc/autograd/functions/utils.h"
+#include "torch/csrc/autograd/generated/Functions.h"
+#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/utils/auto_gpu.h"
-namespace torch { namespace autograd {
-
-auto Identity::apply(const variable_list& inputs) -> variable_list {
- return inputs;
-};
-
-auto Clone::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("Clone", inputs, 1);
- auto& input = inputs[0].data();
- AutoGPU guard(input);
-
- at::Tensor output = input.clone();
-
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<Identity>(std::move(f));
- });
-};
-
-auto Contiguous::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("Contiguous", inputs, 1);
- auto& input = inputs[0].data();
- AutoGPU guard(input);
+#include <cstdint>
+#include <memory>
+#include <utility>
- at::Tensor output = input.contiguous();
+namespace torch { namespace autograd {
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<Identity>(std::move(f));
- });
+auto CopyBackwards::apply(const variable_list& grads) -> variable_list {
+ check_input_variables("CopyBackwards", grads, 1);
+ auto& grad = grads[0];
+ variable_list grad_inputs(2);
+ if (should_compute_output(0)) {
+ grad_inputs[0] = at::zeros_like(grad);
+ }
+ if (should_compute_output(1)) {
+ AutoGPU autoGPU(src_device);
+ if (grad.is_cuda() && grad.get_device() != src_device) {
+ grad_inputs[1] = src_type->copy(grad);
+ } else {
+ grad_inputs[1] = grad.toType(*src_type);
+ }
+ }
+ return grad_inputs;
};
-auto Transpose::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("Transpose", inputs, 1);
-
- auto& input = inputs[0].data();
- AutoGPU guard(input);
-
- at::Tensor output = input.transpose(dim1, dim2);
-
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<Transpose>(dim1, dim2);
- });
-}
-
-auto View::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("View", inputs, 1);
-
- auto& input = inputs[0].data();
- AutoGPU guard(input);
-
- at::Tensor output = input.view(size);
-
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<View>(input.sizes());
- });
-}
-
-auto Expand::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("Expand", inputs, 1);
-
- auto& input = inputs[0].data();
- AutoGPU guard(input);
-
- at::Tensor output = input.expand(size);
-
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<Error>("Expand is not differentiable", std::move(f));
- });
+CopySlices::CopySlices(
+ const Variable& base_var,
+ at::TensorGeometry view_,
+ std::shared_ptr<Function> fn_)
+ : Function(/*num_inputs=*/1),
+ base(base_var),
+ view(std::move(view_)),
+ fn(std::move(fn_)) {
+ // Take the next_edges of fn as our own, except for index 0 which goes
+ // to base instead of the view.
+ const auto num_outputs = fn->num_outputs();
+ next_edges_.reserve(num_outputs);
+ add_next_edge(base_var.gradient_edge());
+ for (size_t i = 1; i < num_outputs; i++) {
+ add_next_edge(fn->next_edge(i));
+ }
}
-auto Narrow::apply(const variable_list& inputs) -> variable_list {
- check_input_variables("Narrow", inputs, 1);
-
- auto& input = inputs[0].data();
- AutoGPU guard(input);
-
- at::Tensor output = input.narrow(dim, start, size);
+auto CopySlices::apply(const variable_list& inputs) -> variable_list {
+ check_input_variables("CopySlices", inputs, 1);
+ auto& grad = inputs[0];
- return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
- return std::make_shared<Error>("Narrow is not differentiable", std::move(f));
- });
-}
-
-auto Cat::apply(const variable_list& inputs) -> variable_list {
- int num_inputs = inputs.size();
- if (num_inputs == 0) {
- throw std::runtime_error("Cat operation expect at least one argument.");
+ if (!fn) {
+ throw std::runtime_error(ERR_BACKWARD_TWICE);
}
- auto& input = inputs[0].data();
- AutoGPU guard(input);
-
- std::vector<at::Tensor> tensors(num_inputs);
- for (int i = 0; i < num_inputs; ++i) {
- tensors[i] = inputs[i].data();
+ auto result = grad.type().tensor(base.sizes(), base.strides());
+ result.copy_(grad);
+
+ auto offset = view.storage_offset() - base.storage_offset();
+ auto grad_slice = result.as_strided(view.sizes(), view.strides(), offset);
+
+ // TODO: We clone grad_slice because we modify it below and "fn" might save
+ // it for the backward of res. We might be able to avoid the clone() if
+ // double-backprop is disabled.
+ auto res = (*fn)({ grad_slice.clone() });
+
+ variable_list grad_inputs(num_outputs());
+ for (size_t i = 0; i < res.size(); i++) {
+ if (should_compute_output(i)) {
+ TORCH_ASSERT(res[i].defined());
+ if (i == 0) {
+ grad_slice.copy_(res[i]);
+ grad_inputs[i] = std::move(result);
+ } else {
+ grad_inputs[i] = std::move(res[i]);
+ }
+ }
}
- auto output = input.type().cat(tensors, dim);
- return wrap_outputs(inputs, as_tensor_list(output), [&](FunctionFlags f) {
- return std::make_shared<Error>("Cat is not differentiable", std::move(f));
- });
+ return grad_inputs;
}
-auto Chunk::apply(const variable_list& inputs) -> variable_list {
- auto outputs = chunk(inputs[0].data(), chunks, dim);
- return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) {
- return std::make_shared<Error>("Chunk is not differentiable", std::move(f));
- });
+void CopySlices::release_variables() {
+ fn = nullptr;
}
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h
index 50208182..0c1463f9 100644
--- a/torch/csrc/autograd/functions/tensor.h
+++ b/torch/csrc/autograd/functions/tensor.h
@@ -1,91 +1,37 @@
#pragma once
#include <Python.h>
-#include <memory>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
-namespace torch { namespace autograd {
-
-struct Identity : public TraceableFunction {
- using TraceableFunction::TraceableFunction;
-
- virtual variable_list apply(const variable_list& inputs) override;
-};
-
-struct Clone : public ForwardFunction<> {
- Clone() {}
-
- virtual variable_list apply(const variable_list& inputs) override;
-};
-
-struct Contiguous : public ForwardFunction<> {
- Contiguous() {}
-
- virtual variable_list apply(const variable_list& inputs) override;
-};
-
-struct Transpose : public ForwardFunction<> {
- Transpose(int64_t dim1, int64_t dim2)
- : dim1(dim1)
- , dim2(dim2) {}
-
- virtual variable_list apply(const variable_list& inputs) override;
-
- int64_t dim1;
- int64_t dim2;
-};
-
-struct View : public ForwardFunction<> {
- View(std::vector<int64_t> size)
- : size(size) {}
-
- virtual variable_list apply(const variable_list& inputs) override;
-
- std::vector<int64_t> size;
-};
-
-struct Expand : public ForwardFunction<> {
- Expand(std::vector<int64_t> size)
- : size(size) {}
-
- virtual variable_list apply(const variable_list& inputs) override;
+#include "ATen/Type.h"
+#include <ATen/TensorGeometry.h>
- std::vector<int64_t> size;
-};
-
-struct Narrow : public ForwardFunction<> {
- Narrow(int64_t dim, int64_t start, int64_t size)
- : dim(dim)
- , start(start)
- , size(size) {}
-
- virtual variable_list apply(const variable_list& inputs) override;
-
- int64_t dim;
- int64_t start;
- int64_t size;
-};
+#include <cstdint>
+#include <memory>
-struct Cat : public ForwardFunction<> {
- Cat(int64_t dim)
- : dim(dim) {}
+namespace torch { namespace autograd {
+struct CopyBackwards : public Function {
virtual variable_list apply(const variable_list& inputs) override;
- int64_t dim;
+ at::Type *src_type;
+ int64_t src_device;
};
-struct Chunk : public Function {
- Chunk(int64_t chunks, int64_t dim)
- : chunks(chunks), dim(dim) {}
+// Performs grad[idx] = fn(grad[idx]), but out-of-place. The slicing operation
+// grad[idx] is defined by the relative sizes, strides, and offset of base and
+// view.
+struct CopySlices : public Function {
+ CopySlices(const Variable& base, at::TensorGeometry view, std::shared_ptr<Function> fn);
- virtual variable_list apply(const variable_list& inputs) override;
+ virtual variable_list apply(const variable_list& grads) override;
+ virtual void release_variables() override;
-private:
- int64_t chunks;
- int64_t dim;
+ at::TensorGeometry base;
+ at::TensorGeometry view;
+ std::shared_ptr<Function> fn;
};
}}
diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp
index 6b5c81ac..09939a10 100644
--- a/torch/csrc/autograd/functions/utils.cpp
+++ b/torch/csrc/autograd/functions/utils.cpp
@@ -1,33 +1,35 @@
#include "torch/csrc/autograd/functions/utils.h"
-#include "torch/csrc/utils/functional.h"
-#include "torch/csrc/jit/tracer.h"
+#include "torch/csrc/autograd/edge.h"
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include <sstream>
+#include <vector>
namespace torch { namespace autograd {
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
function_constructor ctr) {
- auto flags = Function::flags(inputs);
variable_list result;
result.reserve(outputs.size());
- if (flags.is_volatile) {
+ if (!any_variable_requires_grad(inputs)) {
for (auto& output : outputs) {
if (output.defined()) {
- result.emplace_back(make_variable(output, false, true));
+ result.push_back(make_variable(output, /*requires_grad=*/false));
} else {
result.emplace_back();
}
}
} else {
- auto grad_fn = ctr(std::move(flags));
+ auto grad_fn = ctr(collect_next_edges(inputs));
for (auto& output : outputs) {
if (output.defined()) {
- result.emplace_back(make_variable(output, grad_fn));
+ auto variable = autograd::make_variable(output, /*requires_grad=*/false);
+ autograd::create_gradient_edge(variable, grad_fn);
+ result.push_back(std::move(variable));
} else {
- ++grad_fn->num_inputs;
+ grad_fn->bump_inputs();
result.emplace_back();
}
}
@@ -53,5 +55,4 @@ void check_input_variables(const char* name, const variable_list& inputs, int ar
}
}
}
-
-}}
+}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h
index 622652ca..18fc8628 100644
--- a/torch/csrc/autograd/functions/utils.h
+++ b/torch/csrc/autograd/functions/utils.h
@@ -1,44 +1,28 @@
#pragma once
#include <Python.h>
-#include <functional>
-#include <memory>
-#include <array>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
-namespace torch { namespace autograd {
-
-using function_constructor = std::function<std::shared_ptr<Function>(FunctionFlags)>;
-
-template<typename ...Args>
-inline variable_list as_variable_list(Args&& ... args) {
- std::array<variable_list::value_type, sizeof...(args)> arr = { {std::move(args)...} };
- return variable_list(std::make_move_iterator(arr.begin()),
- std::make_move_iterator(arr.end()));
-}
+#include <functional>
+#include <memory>
+#include <vector>
-template<typename ...Args>
-inline tensor_list as_tensor_list(Args&& ... args) {
- std::array<tensor_list::value_type, sizeof...(args)> arr = { {std::move(args)...} };
- return tensor_list(std::make_move_iterator(arr.begin()),
- std::make_move_iterator(arr.end()));
-}
+namespace torch { namespace autograd {
+using function_constructor = std::function<std::shared_ptr<Function>(edge_list&&)>;
/**
- * Wraps the tensor outputs in variables, and if necessary (i.e., none of the
- * inputs are volatile), uses the function ctr and inputs to create a grad_fn
- * for each of them.
+ * Wraps the tensor outputs in variables and creates the grad_fn and sets the
+ * grad_fn if necessary.
*/
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
function_constructor ctr);
/**
* Checks that inputs contains exactly `args` items and that the first `required_args`
- * items are not NULL. If not specified, `required_args` defaults to `args`.
+ * items are not nullptr. If not specified, `required_args` defaults to `args`.
*/
void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1);
-
}}
diff --git a/torch/csrc/autograd/grad_mode.cpp b/torch/csrc/autograd/grad_mode.cpp
new file mode 100644
index 00000000..6409c697
--- /dev/null
+++ b/torch/csrc/autograd/grad_mode.cpp
@@ -0,0 +1,7 @@
+#include "grad_mode.h"
+
+namespace torch { namespace autograd {
+
+thread_local bool GradMode::_enabled = 1;
+
+}}
diff --git a/torch/csrc/autograd/grad_mode.h b/torch/csrc/autograd/grad_mode.h
new file mode 100644
index 00000000..98b04d7a
--- /dev/null
+++ b/torch/csrc/autograd/grad_mode.h
@@ -0,0 +1,26 @@
+#pragma once
+
+namespace torch { namespace autograd {
+
+struct GradMode {
+ static bool is_enabled() {
+ return _enabled;
+ }
+ static void set_enabled(bool enabled) {
+ _enabled = enabled;
+ }
+private:
+ static thread_local bool _enabled;
+};
+
+struct AutoGradMode {
+ AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) {
+ GradMode::set_enabled(enabled);
+ }
+ ~AutoGradMode() {
+ GradMode::set_enabled(prev_mode);
+ }
+ bool prev_mode;
+};
+
+}}
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index d0fedc4b..81b5958b 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -1,46 +1,15 @@
#include <Python.h>
#include "torch/csrc/utils/pybind.h"
+#include "torch/csrc/autograd/grad_mode.h"
#include "torch/csrc/autograd/profiler.h"
#include "THP.h"
-namespace pybind11 { namespace detail {
-
-template <> struct type_caster<torch::autograd::profiler::EventKind> {
-public:
- PYBIND11_TYPE_CASTER(torch::autograd::profiler::EventKind, _("torch::autograd::profiler::EventKind"));
-
- bool load(handle src, bool) {
- try {
- auto str = py::cast<std::string>(src);
- if (str == "push") {
- value = torch::autograd::profiler::EventKind::PushRange;
- } else if (str == "pop") {
- value = torch::autograd::profiler::EventKind::PopRange;
- } else if (str == "mark") {
- value = torch::autograd::profiler::EventKind::Mark;
- } else {
- return false;
- }
- } catch (std::exception& e) {
- return false;
- }
- return true;
- }
- static handle cast(torch::autograd::profiler::EventKind src, return_value_policy /* policy */, handle /* parent */) {
- switch (src) {
- case torch::autograd::profiler::EventKind::PushRange:
- return py::cast("push").release();
- case torch::autograd::profiler::EventKind::PopRange:
- return py::cast("pop").release();
- case torch::autograd::profiler::EventKind::Mark:
- return py::cast("mark").release();
- }
- __builtin_unreachable();
- }
-};
-
-}} // namespace pybind11::detail
+#ifdef _MSC_VER
+#define ENSURE_UNREACHABLE __assume(0);
+#else
+#define ENSURE_UNREACHABLE __builtin_unreachable();
+#endif
PyObject * THPAutograd_initExtension(PyObject *_unused)
{
@@ -50,20 +19,75 @@ PyObject * THPAutograd_initExtension(PyObject *_unused)
THPVariableClass = PyMapping_GetItemString(autograd_dict,(char*)"Variable");
THPFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"Function");
- THPUtils_assert_PyImport("torch.nn._functions.thnn", thnn_functions);
- THPBatchNormBackwardBackwardFunction = PyObject_GetAttrString(thnn_functions,(char*)"batchnorm_double_backwards_fn");
-
- THPStochasticFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"StochasticFunction");
THPUtils_assert(THPVariableClass, "couldn't find Variable class in "
"torch.autograd module");
THPUtils_assert(THPFunctionClass, "couldn't find Function class in "
"torch.autograd module");
- THPUtils_assert(THPStochasticFunctionClass, "couldn't find "
- "StochasticFunction class in torch.autograd module");
auto m = py::handle(autograd_module).cast<py::module>();
+
+ py::class_<torch::autograd::profiler::Event>(m,"ProfilerEvent")
+ .def("kind",&torch::autograd::profiler::Event::kind)
+ .def("name",&torch::autograd::profiler::Event::name)
+ .def("thread_id",&torch::autograd::profiler::Event::thread_id)
+ .def("device",&torch::autograd::profiler::Event::device)
+ .def("cpu_elapsed_us",&torch::autograd::profiler::Event::cpu_elapsed_us)
+ .def("cuda_elapsed_us",&torch::autograd::profiler::Event::cuda_elapsed_us)
+ .def("has_cuda",&torch::autograd::profiler::Event::has_cuda);
+ py::enum_<torch::autograd::profiler::ProfilerState>(m,"ProfilerState")
+ .value("Disabled", torch::autograd::profiler::ProfilerState::Disabled)
+ .value("CPU", torch::autograd::profiler::ProfilerState::CPU)
+ .value("CUDA", torch::autograd::profiler::ProfilerState::CUDA)
+ .value("NVTX", torch::autograd::profiler::ProfilerState::NVTX);
+
m.def("_enable_profiler", torch::autograd::profiler::enableProfiler);
m.def("_disable_profiler", torch::autograd::profiler::disableProfiler);
+ m.def("_push_range", [](const char *name) {
+ using namespace torch::autograd::profiler;
+ if (state == ProfilerState::Disabled) return;
+ pushRange(name);
+ });
+ m.def("_pop_range", []() {
+ using namespace torch::autograd::profiler;
+ if (state == ProfilerState::Disabled) return;
+ popRange();
+ });
+
Py_RETURN_TRUE;
}
+
+namespace torch { namespace autograd {
+
+static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
+ HANDLE_TH_ERRORS
+ if (!PyBool_Check(arg)) {
+ at::runtime_error("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
+ }
+ GradMode::set_enabled(arg == Py_True);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * is_grad_enabled(PyObject* _unused, PyObject *arg) {
+ HANDLE_TH_ERRORS
+ if (GradMode::is_enabled()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+ END_HANDLE_TH_ERRORS
+}
+
+// autograd methods on torch._C
+static PyMethodDef methods[] = {
+ {"set_grad_enabled", (PyCFunction)set_grad_enabled, METH_O, nullptr},
+ {"is_grad_enabled", (PyCFunction)is_grad_enabled, METH_NOARGS, nullptr},
+ {nullptr, nullptr, 0, nullptr}
+};
+
+PyMethodDef* python_functions() {
+ return methods;
+}
+
+}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp
index 4b6d40e0..77946214 100644
--- a/torch/csrc/autograd/input_buffer.cpp
+++ b/torch/csrc/autograd/input_buffer.cpp
@@ -6,43 +6,36 @@
namespace torch { namespace autograd {
-InputBuffer::InputBuffer(size_t size)
- : buffer(size)
- {}
void InputBuffer::add(size_t pos, Variable var) {
TORCH_ASSERT(pos >= 0 && pos < buffer.size());
if (!var.defined()) {
return;
}
- auto& item = buffer[pos];
- if (!item.first.defined()) {
- auto current_version = var.current_version();
- buffer[pos] = std::make_pair<>(std::move(var), current_version);
+ auto& old_var = buffer[pos];
+ if (!old_var.defined()) {
+ buffer[pos] = std::move(var);
} else {
- auto result = apply_fn<Add>()(item.first, std::move(var));
- buffer[pos] = std::make_pair<>(std::move(result), 0);
+ // ATen doesn't route sparse additions correctly...
+ if (old_var.type().is_sparse()) {
+ buffer[pos] = var + old_var;
+ } else {
+ buffer[pos] = old_var + var;
+ }
}
}
auto InputBuffer::device() const -> int {
- for (auto& pair : buffer) {
- if (pair.first.defined() && pair.first.type().isCuda()) {
- return pair.first.get_device();
+ for (auto& var : buffer) {
+ if (var.defined() && var.type().is_cuda()) {
+ return var.get_device();
}
}
return -1;
}
auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
- InputBuffer _buffer = std::move(g);
- auto& buffer = _buffer.buffer;
- int size = buffer.size();
- std::vector<Variable> result;
- result.reserve(size);
- for (int i = 0; i != size; ++i) {
- result.emplace_back(buffer[i].first);
- }
+ std::vector<Variable> result = std::move(g.buffer);
return result;
}
diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h
index bf3eaef9..2f9aba4f 100644
--- a/torch/csrc/autograd/input_buffer.h
+++ b/torch/csrc/autograd/input_buffer.h
@@ -16,21 +16,24 @@
namespace torch { namespace autograd {
struct InputBuffer {
- explicit InputBuffer(size_t size);
+ explicit InputBuffer(size_t size)
+ : buffer(size) {}
InputBuffer(const InputBuffer& other) = delete;
InputBuffer(InputBuffer&& other) = default;
+ InputBuffer& operator=(InputBuffer&& other) = default;
// Accumulates the variable at a specified index.
void add(size_t idx, Variable var);
int device() const;
+ Variable operator[](std::size_t pos) { return buffer[pos]; }
+
// Returns the inputs as a list of variables. Destroys given InputBuffer.
static std::vector<Variable> variables(InputBuffer&& buffer);
private:
- // (Variable, version at save)
- std::vector<std::pair<Variable, int>> buffer;
+ std::vector<Variable> buffer;
};
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp
index 4b6f5429..13c2e21f 100644
--- a/torch/csrc/autograd/profiler.cpp
+++ b/torch/csrc/autograd/profiler.cpp
@@ -1,40 +1,73 @@
+#include "Python.h"
#include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/autograd/function.h"
namespace torch { namespace autograd { namespace profiler {
-bool profiling = false;
-bool using_cuda;
+ProfilerState state = ProfilerState::Disabled;
+uint32_t next_thread_id = 0;
std::mutex all_event_lists_mutex;
std::list<std::shared_ptr<RangeEventList>> all_event_lists;
thread_local std::shared_ptr<RangeEventList> event_list;
+thread_local int32_t thread_id;
void RecordFunction::pushFunctionRange(Function* fn) {
pushRange(fn->name());
}
-void enableProfiler(bool use_cuda) {
+#ifdef WITH_CUDA
+static void onEachDevice(std::function<void(int)> op) {
+ AutoGPU gpu_guard;
+ int count;
+ TORCH_CUDA_CHECK(cudaGetDeviceCount(&count));
+ for(int i = 0; i < count; i++) {
+ gpu_guard.setDevice(i);
+ op(i);
+ }
+}
+#endif
+
+void enableProfiler(ProfilerState new_state) {
+ TORCH_ASSERT(new_state != ProfilerState::Disabled);
#ifndef WITH_CUDA
- if (use_cuda)
- throw std::runtime_error("Can't use CUDA profiler - PyTorch was compiled without CUDA");
+ if (new_state == ProfilerState::NVTX)
+ throw std::runtime_error("Can't use NVTX profiler - PyTorch was compiled without CUDA");
#endif
- if (profiling) {
- if (use_cuda != using_cuda)
- throw std::runtime_error("can't change use_cuda flag while profiler is running");
- return;
+ if (state != ProfilerState::Disabled && new_state != state) {
+ throw std::runtime_error("can't change kind of profiling (e.g. NVTX to CPU) while profiler is running");
}
- profiling = true;
- using_cuda = use_cuda;
- mark("__start_profile");
+ state = new_state;
+
+#ifdef WITH_CUDA
+ if(state == ProfilerState::CUDA) {
+ // event recording appears to have some startup overhead, so we need to
+ // to generate some dummy events first before recording syncrhonization events
+ for(int i = 0; i < 5; i++) {
+ onEachDevice([](int d) {
+ mark("__cuda_startup");
+ cudaDeviceSynchronize();
+ });
+ }
+
+ // cuda events must be on the same device, so we need a start event recorded
+ // for each gpu. we then use this event to synchronize time on the GPU
+ // with the CPU clock.
+ onEachDevice([](int d) {
+ mark("__cuda_start_event");
+ });
+ }
+#endif
+ mark("__start_profile", false);
}
thread_event_lists disableProfiler() {
- if (!profiling) {
+ if (state == ProfilerState::Disabled) {
throw std::runtime_error("can't disable profiler when it's not running");
}
+ ProfilerState old_state = state;
mark("__stop_profile");
- profiling = false;
- if (using_cuda) {
+ state = ProfilerState::Disabled;
+ if (old_state == ProfilerState::NVTX) {
return thread_event_lists();
} else {
thread_event_lists result;
diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h
index 2f2be145..d87e891a 100644
--- a/torch/csrc/autograd/profiler.h
+++ b/torch/csrc/autograd/profiler.h
@@ -11,8 +11,14 @@
#include <cstdint>
#include <string>
#include <list>
+#include <sstream>
#include <forward_list>
#include <tuple>
+#include "ATen/ATen.h"
+#include "torch/csrc/cuda/cuda_check.h"
+#ifdef WITH_CUDA
+#include <cuda_runtime.h>
+#endif
namespace torch { namespace autograd {
@@ -24,16 +30,95 @@ constexpr inline std::size_t ceilToMultiple(std::size_t a, std::size_t b) {
return ((a + b - 1) / b) * b;
}
+inline uint64_t getTime() {
+ using namespace std::chrono;
+ using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type;
+ return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
+}
+
enum class EventKind {
Mark,
PushRange,
PopRange
};
-// NOTE: we don't need a flag saying if an event is a kernel, because it's
-// used only for the CPU-side perf recording.
-using Event = std::tuple<std::string, uint64_t, EventKind>; // (name, time, kind)
+struct Event {
+ Event(EventKind kind, std::string name, uint32_t thread_id, bool record_cuda)
+ : kind_(kind)
+ , name_(std::move(name))
+ , thread_id_(thread_id) {
+#ifdef WITH_CUDA
+ if(record_cuda) {
+ TORCH_CUDA_CHECK(cudaGetDevice(&device_));
+ TORCH_CUDA_CHECK(cudaEventCreate(&event));
+ auto stream = at::globalContext().getCurrentCUDAStream();
+ cpu_ns_ = getTime();
+ TORCH_CUDA_CHECK(cudaEventRecord(event, stream));
+ } else {
+ cpu_ns_ = getTime();
+ }
+#else
+ cpu_ns_ = getTime();
+#endif
+ }
+ std::string kind() const {
+ switch(kind_) {
+ case EventKind::Mark: return "mark";
+ case EventKind::PushRange: return "push";
+ case EventKind::PopRange: return "pop";
+ }
+ throw std::runtime_error("unknown EventKind");
+ }
+ const std::string & name() const {
+ return name_;
+ }
+ uint32_t thread_id() const {
+ return thread_id_;
+ }
+ double cpu_elapsed_us(const Event & e) {
+ return (e.cpu_ns_ - cpu_ns_)/(1000.0);
+ }
+ double cuda_elapsed_us(const Event & e) {
+#ifdef WITH_CUDA
+ if(!e.has_cuda() || !has_cuda()) {
+ throw std::logic_error("Events were not recorded for CUDA");
+ }
+ if(e.device() != device()) {
+ throw std::logic_error("Events are not on the same device");
+ }
+ TORCH_CUDA_CHECK(cudaEventSynchronize(event));
+ TORCH_CUDA_CHECK(cudaEventSynchronize(e.event));
+ float ms;
+ TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event, e.event));
+ return ms*1000.0;
+#else
+ throw std::logic_error("CUDA not enabled");
+#endif
+ }
+ bool has_cuda() const {
+#ifdef WITH_CUDA
+ return event != nullptr;
+#else
+ return false;
+#endif
+ }
+ int device() const {
+ return device_;
+ }
+private:
+ EventKind kind_;
+ std::string name_;
+ uint32_t thread_id_;
+ int64_t cpu_ns_; // signed to allow for negative intervals
+#ifdef WITH_CUDA
+ cudaEvent_t event = nullptr;
+#endif
+ int device_ = -1;
+};
+// a linked-list of fixed sized vectors, to avoid
+// a std::vector resize from taking a large amount of time inside
+// a profiling event
struct RangeEventList {
constexpr static std::size_t MB = 1024 * 1024;
constexpr static std::size_t event_block_size = 16 * MB;
@@ -70,81 +155,85 @@ struct RangeEventList {
std::forward_list<block_type> blocks;
};
-extern bool profiling;
-extern bool using_cuda;
+enum class ProfilerState {
+ Disabled,
+ CPU, // CPU-only profiling
+ CUDA, // CPU + CUDA events
+ NVTX, // only emit NVTX markers
+};
+
+extern ProfilerState state;
+extern uint32_t next_thread_id;
extern std::mutex all_event_lists_mutex;
extern std::list<std::shared_ptr<RangeEventList>> all_event_lists;
+
extern thread_local std::shared_ptr<RangeEventList> event_list;
+extern thread_local int32_t thread_id;
inline RangeEventList& getEventList() {
if (!event_list) {
std::lock_guard<std::mutex> guard(all_event_lists_mutex);
event_list = std::make_shared<RangeEventList>();
+ thread_id = next_thread_id++;
all_event_lists.emplace_front(event_list);
}
return *event_list;
}
-inline uint64_t getTime() {
- using namespace std::chrono;
- using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type;
- return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
-}
-
-inline void mark(std::string name) {
- if (using_cuda) {
+inline void mark(std::string name, bool include_cuda = true) {
+ if (state == ProfilerState::NVTX) {
#ifdef WITH_CUDA
nvtxMarkA(name.c_str());
#else
- throw std::logic_error("mark called with use_cuda=True, but compiled without CUDA");
+ throw std::logic_error("mark called with NVTX tracing, but compiled without CUDA");
#endif
} else {
- getEventList().record(std::move(name), getTime(), EventKind::Mark);
+ getEventList().record(EventKind::Mark, std::move(name), thread_id, include_cuda && state == ProfilerState::CUDA);
}
}
inline void pushRange(std::string name) {
- if (using_cuda) {
+ if (state == ProfilerState::NVTX) {
#ifdef WITH_CUDA
nvtxRangePushA(name.c_str());
#else
- throw std::logic_error("pushRange called with use_cuda=True, but compiled without CUDA");
+ throw std::logic_error("pushRange called with NVTX tracing, but compiled without CUDA");
#endif
} else {
- getEventList().record(std::move(name), getTime(), EventKind::PushRange);
+ getEventList().record(EventKind::PushRange, std::move(name), thread_id, state == ProfilerState::CUDA);
}
}
inline void popRange() {
- if (using_cuda) {
+ if (state == ProfilerState::NVTX) {
#ifdef WITH_CUDA
nvtxRangePop();
#else
- throw std::logic_error("popRange called with use_cuda=True, but compiled without CUDA");
+ throw std::logic_error("popRange called with NVTX tracing, but compiled without CUDA");
#endif
} else {
- getEventList().record(std::string(), getTime(), EventKind::PopRange);
+ getEventList().record(EventKind::PopRange, std::string(), thread_id, state == ProfilerState::CUDA);
}
}
struct RecordFunction {
explicit RecordFunction(Function *fn) {
- if (!profiling) return;
+ if (state == ProfilerState::Disabled) return;
pushFunctionRange(fn);
}
explicit RecordFunction(std::string name) {
- if (!profiling) return;
+ if (state == ProfilerState::Disabled) return;
pushRange(std::move(name));
}
explicit RecordFunction(const char *name) {
- if (!profiling) return;
+ if (state == ProfilerState::Disabled) return;
pushRange(name);
}
~RecordFunction() {
- if (!profiling) return;
+ if (state == ProfilerState::Disabled) return;
popRange();
}
@@ -155,7 +244,7 @@ struct RecordFunction {
using thread_event_lists = std::vector<std::vector<Event>>;
// NOTE: changing profiler modes is **NOT THREAD SAFE**. You should ensure that
// there no autograd functions are being executed when these function are used.
-void enableProfiler(bool use_cuda);
+void enableProfiler(ProfilerState state);
thread_event_lists disableProfiler();
} // namespace profiler
diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp
index bc88c92d..7cc06082 100644
--- a/torch/csrc/autograd/python_cpp_function.cpp
+++ b/torch/csrc/autograd/python_cpp_function.cpp
@@ -62,12 +62,12 @@ PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs)
int THPCppFunction_traverse(PyObject* self, visitproc visit, void *arg)
{
auto& fn = *((THPCppFunction*)self)->cdata;
- for (auto& hook : fn.pre_hooks) {
+ for (const auto& hook : fn.pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
- for (auto& hook : fn.post_hooks) {
+ for (const auto& hook : fn.post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
@@ -80,7 +80,7 @@ int THPCppFunction_clear(PyObject* self)
auto f = (THPCppFunction*)self;
// Remove the weak ref of the c++ object if it exist
if (f->cdata) {
- f->cdata->pyobj = nullptr;
+ f->cdata->set_pyobj(nullptr);
}
f->cdata.reset();
return 0;
@@ -97,19 +97,18 @@ void THPCppFunction_dealloc(PyObject* self)
PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook)
{
- auto& next_functions = self->cdata->next_functions;
- auto num_next = next_functions.size();
+ const auto num_next = self->cdata->num_outputs();
THPObjectPtr py_functions(PyTuple_New(num_next));
- if (!py_functions) return NULL;
+ if (!py_functions) return nullptr;
for (size_t i = 0; i < num_next; ++i) {
- auto& c_tuple = next_functions[i];
+ auto& c_tuple = self->cdata->next_edge(i);
THPObjectPtr tuple(PyTuple_New(2));
- if (!tuple) return NULL;
- PyObject *py_fn = functionToPyObject(c_tuple.first);
- if (!py_fn) return NULL;
+ if (!tuple) return nullptr;
+ PyObject *py_fn = functionToPyObject(c_tuple.function);
+ if (!py_fn) return nullptr;
PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
- PyObject *py_idx = PyLong_FromLong(c_tuple.second);
- if (!py_idx) return NULL;
+ PyObject *py_idx = PyLong_FromLong(c_tuple.input_nr);
+ if (!py_idx) return nullptr;
PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
}
@@ -117,7 +116,7 @@ PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook)
}
PyObject* THPCppFunction_requires_grad(THPCppFunction* self) {
- return PyBool_FromLong(self->cdata->is_executable);
+ Py_RETURN_TRUE;
}
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var)
@@ -127,8 +126,9 @@ PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var)
}
auto var = (THPVariable*)_var;
auto& fn = *((THPCppFunction*)self)->cdata;
- fn.pre_hooks.push_back(std::make_shared<PyFunctionPreHook>(
- var->backward_hooks, var->cdata.output_nr()));
+ std::unique_ptr<FunctionPreHook> hook(
+ new PyFunctionPreHook(var->backward_hooks, var->cdata.output_nr()));
+ fn.add_pre_hook(std::move(hook));
Py_RETURN_NONE;
}
@@ -141,12 +141,12 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook)
static struct PyMethodDef default_methods[] = {
THP_FUNCTION_DEFAULT_METHODS,
- {NULL}
+ {nullptr}
};
static struct PyGetSetDef default_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
- {NULL}
+ {nullptr}
};
PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name,
@@ -182,8 +182,8 @@ PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
return obj;
}
- if (cdata->pyobj) {
- Py_INCREF(cdata->pyobj);
+ if (cdata->pyobj()) {
+ Py_INCREF(cdata->pyobj());
} else {
auto& fn = *cdata;
auto it = cpp_function_types.find(std::type_index(typeid(fn)));
@@ -194,15 +194,15 @@ PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
PyTypeObject* type = (PyTypeObject*)it->second.get();
THPObjectPtr obj(type->tp_alloc(type, 0));
- if (!obj) return NULL;
+ if (!obj) return nullptr;
THPCppFunction* f = (THPCppFunction*)obj.get();
new (&f->cdata) std::shared_ptr<Function>(cdata);
// No INCREF here as we only have a weak reference
- cdata->pyobj = obj.release();
+ cdata->set_pyobj(obj.release());
}
- return cdata->pyobj;
+ return cdata->pyobj();
}
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype)
@@ -214,7 +214,7 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype)
PyObject* registerFunctionHook(Function& fn, PyObject* hook)
{
PyObject* dict = Py_None;
- for (auto& hook : fn.post_hooks) {
+ for (const auto& hook : fn.post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
dict = pyhook->dict;
break;
@@ -222,13 +222,14 @@ PyObject* registerFunctionHook(Function& fn, PyObject* hook)
}
THPObjectPtr register_fn(PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
- if (!register_fn) return NULL;
- THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, NULL));
- if (!res) return NULL;
+ if (!register_fn) return nullptr;
+ THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
+ if (!res) return nullptr;
if (dict == Py_None) {
dict = PyTuple_GET_ITEM(res.get(), 0);
- fn.post_hooks.push_back(std::make_shared<PyFunctionPostHook>(dict));
+ std::unique_ptr<FunctionPostHook> hook(new PyFunctionPostHook(dict));
+ fn.add_post_hook(std::move(hook));
}
PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h
index 35b9d252..aa518c1c 100644
--- a/torch/csrc/autograd/python_cpp_function.h
+++ b/torch/csrc/autograd/python_cpp_function.h
@@ -19,24 +19,24 @@ template<typename Ctor>
PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
THPObjectPtr obj(type->tp_alloc(type, 0));
- if (!obj) return NULL;
+ if (!obj) return nullptr;
THPCppFunction* f = (THPCppFunction*)obj.get();
HANDLE_TH_ERRORS
new (&f->cdata) std::shared_ptr<Function>(Ctor()(args));
END_HANDLE_TH_ERRORS
if (!f->cdata) {
- return NULL;
+ return nullptr;
}
return obj.release();
}
#define THP_FUNCTION_DEFAULT_METHODS \
- {(char*)"_register_hook_dict", (PyCFunction)THPCppFunction_register_hook_dict, METH_O, NULL}, \
- {(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, NULL}
+ {(char*)"_register_hook_dict", (PyCFunction)THPCppFunction_register_hook_dict, METH_O, nullptr}, \
+ {(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, nullptr}
#define THP_FUNCTION_DEFAULT_PROPERTIES \
- {(char*)"next_functions", (getter)THPCppFunction_next_functions, NULL, NULL, NULL}, \
- {(char*)"requires_grad", (getter)THPCppFunction_requires_grad, NULL, NULL, NULL}
+ {(char*)"next_functions", (getter)THPCppFunction_next_functions, nullptr, nullptr, nullptr}, \
+ {(char*)"requires_grad", (getter)THPCppFunction_requires_grad, nullptr, nullptr, nullptr}
PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook);
PyObject* THPCppFunction_requires_grad(THPCppFunction* self);
@@ -50,7 +50,7 @@ PyObject* registerFunctionHook(Function& fn, PyObject* hook);
template<typename Ctor>
PyTypeObject* createForwardFunctionPyTypeObject(PyTypeObject& type, const char* name,
- PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL)
+ PyGetSetDef* function_properties=nullptr, PyMethodDef* function_methods=nullptr)
{
type.tp_new = &CppFunction_pynew<Ctor>;
return _initFunctionPyTypeObject(type, name, function_properties, function_methods);
diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp
index e47eed4c..f5cce00b 100644
--- a/torch/csrc/autograd/python_engine.cpp
+++ b/torch/csrc/autograd/python_engine.cpp
@@ -1,12 +1,18 @@
#include "torch/csrc/autograd/python_engine.h"
-#include "torch/csrc/autograd/engine.h"
-#include "torch/csrc/autograd/python_function.h"
-#include "torch/csrc/THP.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/PtrWrapper.h"
+#include "torch/csrc/THP.h"
+#include "torch/csrc/autograd/engine.h"
+#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/edge.h"
+#include "torch/csrc/autograd/python_function.h"
#include "torch/csrc/utils/auto_gil.h"
+#ifndef _WIN32
+#include <pthread.h>
+#endif
+
#include <unordered_set>
using namespace torch::autograd;
@@ -36,14 +42,14 @@ void PythonEngine::thread_on_exception(FunctionTask& task, std::exception& e) {
Engine::thread_on_exception(task, e);
}
-void PythonEngine::execute(
- const function_list& roots,
+variable_list PythonEngine::execute(
+ const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
- const pre_callback_map& pre_callbacks,
- const post_callback_map& post_callbacks) {
+ bool create_graph,
+ const edge_list& outputs) {
try {
- Engine::execute(roots, inputs, keep_graph, pre_callbacks, post_callbacks);
+ return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
} catch (python_error& e) {
e.restore();
throw;
@@ -56,77 +62,21 @@ PythonEngine& PythonEngine::getDefaultEngine() {
}}} // namespace torch::autograd::python
-PyObject *THPEngineClass = NULL;
-
-struct CallbackContext {
- std::string error;
- THPObjectPtr outputs;
- // Used to determine which callback arguments should be used to
- // fill outputs.
- // Function -> ([grad_nr, outputs_idx], is_leaf)
- std::unordered_map<
- std::shared_ptr<Function>,
- std::pair<std::vector<std::pair<int, int>>, bool>> output_map;
-};
-
-void compute_partial_exec_callbacks(const function_list& roots,
- const CallbackContext& ctx,
- Engine::pre_callback_map& map,
- bool allow_unreachable) {
- // This callback is used to suppress the computation of a node
- // if it is not necessary.
- static Engine::pre_callback_type abort_callback(
- [](Function* fn, variable_list &vars) { return false; });
-
- std::vector<Function*> queue;
- std::unordered_set<Function*> seen; // for the initial DFS
- std::unordered_set<Function*> needed; // functions to compute
- std::unordered_map<Function*, std::vector<Function*>> rev_graph;
-
- // Reverse the next_fn edges
- queue.reserve(roots.size());
- for (auto& root : roots) {
- auto ptr = root.first.get();
- bool unseen;
- std::tie(std::ignore, unseen) = seen.insert(ptr);
- if (unseen) queue.emplace_back(ptr);
- }
- while (!queue.empty()) {
- auto fn = queue.back(); queue.pop_back();
- for (auto& next_fn_pair : fn->next_functions) {
- auto next_fn = next_fn_pair.first.get();
- if (!next_fn) continue;
- rev_graph[next_fn].push_back(fn);
- if (seen.insert(next_fn).second) {
- queue.push_back(next_fn);
- }
- }
- }
- auto all_functions = std::move(seen); // this is cheap and improves readability
-
- // Find all functions we need to compute
- queue.clear();
- for (auto input_info: ctx.output_map) {
- auto input = input_info.first.get();
- auto rev_edges_it = rev_graph.find(input);
- if (!allow_unreachable && rev_edges_it == rev_graph.end())
- throw std::runtime_error("differentiated input is unreachable");
- queue.emplace_back(input);
- needed.insert(input);
- }
- while (!queue.empty()) {
- auto fn = queue.back(); queue.pop_back();
- for (auto rev_next_fn : rev_graph[fn]) {
- if (needed.insert(rev_next_fn).second) {
- queue.push_back(rev_next_fn);
- }
- }
- }
-
- // Prevent expansion for functions in {all_vertices} \ {needed}
- for (auto fn : all_functions) {
- if (needed.count(fn) > 0) continue;
- map.emplace(fn, abort_callback);
+PyObject *THPEngineClass = nullptr;
+
+static bool _reinitialize_engine = false;
+
+static void _maybe_reinitialize_engine_after_fork() {
+ // This is "probably" thread-safe because the flag is set in a fork handler
+ // before any threads are created, and this function is only called with the
+ // GIL held. However, using fork + threads is playing with fire so this is
+ // more of a "best effort" thing. For example, if the fork occurs while the
+ // backwards threads hold a lock, we'll probably deadlock in the engine
+ // destructor.
+ if (_reinitialize_engine) {
+ engine.~PythonEngine();
+ new (&engine) torch::autograd::python::PythonEngine();
+ _reinitialize_engine = false;
}
}
@@ -134,17 +84,19 @@ void compute_partial_exec_callbacks(const function_list& roots,
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
- PyObject *variables = NULL;
- PyObject *grad_variables = NULL;
+ _maybe_reinitialize_engine_after_fork();
+ PyObject *variables = nullptr;
+ PyObject *grad_variables = nullptr;
unsigned char keep_graph = 0;
- PyObject *inputs = NULL;
- unsigned char only_inputs = 0;
- unsigned char allow_unreachable = 0;
- const char *accepted_kwargs[] = {"variables", "grad_variables",
- "keep_graph", "inputs", "only_inputs", "allow_unreachable", NULL};
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOb|Obb", (char**)accepted_kwargs,
- &variables, &grad_variables, &keep_graph, &inputs, &only_inputs, &allow_unreachable))
- return NULL;
+ unsigned char create_graph = 0;
+ PyObject *inputs = nullptr;
+ const char *accepted_kwargs[] = {
+ "variables", "grad_variables", "keep_graph", "create_graph", "inputs",
+ nullptr
+ };
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|O", (char**)accepted_kwargs,
+ &variables, &grad_variables, &keep_graph, &create_graph, &inputs))
+ return nullptr;
THPUtils_assert(PyTuple_Check(variables), "variables argument is expected to "
"be a tuple, but got %s", THPUtils_typename(variables));
@@ -156,32 +108,23 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
THPUtils_assert(num_variables == num_gradients, "got %ld variables and %ld "
"gradients", num_variables, num_gradients);
- function_list roots(num_variables);
- variable_list grads(num_variables);
+ edge_list roots;
+ roots.reserve(num_variables);
+ variable_list grads;
+ grads.reserve(num_variables);
for (int i = 0; i < num_variables; i++) {
PyObject *_variable = PyTuple_GET_ITEM(variables, i);
THPUtils_assert(THPVariable_Check(_variable), "element %d of variables "
"tuple is not a Variable", i);
auto& variable = ((THPVariable*)_variable)->cdata;
- THPUtils_assert(!variable.is_volatile(),
- "element %d of variables tuple is volatile", i);
- // If grad_fn is NULL (as is the case for a leaf node), we instead
- // interpret the gradient function to be a grad accumulator,
- // which will accumulate its inputs into the grad property of the
- // variable. These nodes get suppressed in some situations,
- // see "suppress grad accumulation" below. Note that only variables which
- // have requires_grad=True can have grad accumulators.
- auto grad_fn = variable.grad_fn() ? variable.grad_fn() : variable.grad_accumulator();
- int output_nr = variable.grad_fn() ? variable.output_nr() : 0;
- THPUtils_assert(!variable.is_volatile(),
- "element %d of variables tuple is volatile", i);
- THPUtils_assert(grad_fn,
+ auto gradient_edge = variable.gradient_edge();
+ THPUtils_assert(gradient_edge.function,
"element %d of variables does not require grad and does not have a grad_fn", i);
- roots[i] = std::make_pair<>(std::move(grad_fn), output_nr);
+ roots.push_back(std::move(gradient_edge));
PyObject *grad = PyTuple_GET_ITEM(grad_variables, i);
if (THPVariable_Check(grad)) {
- grads[i] = ((THPVariable*)grad)->cdata;
+ grads.push_back(((THPVariable*)grad)->cdata);
} else {
THPUtils_assert(grad == Py_None,
"element %d of gradients tuple is not a Variable or None", i);
@@ -190,74 +133,46 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}
}
- Engine::pre_callback_map callbacks;
- CallbackContext ctx;
- if (inputs != NULL) {
- THPUtils_assert(PyTuple_Check(inputs), "inputs argument has to be a tuple");
+ edge_list output_edges;
+ if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
- ctx.outputs = PyTuple_New(num_inputs);
- if (!ctx.outputs) return NULL;
- // First, find all relevant functions and fill ctx.output_map
+ output_edges.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Variables, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
+ const auto output_nr = input_var->cdata.output_nr();
auto grad_fn = input_var->cdata.grad_fn();
- int output_nr = input_var->cdata.output_nr();
- bool is_leaf = !grad_fn;
- if (is_leaf) {
- grad_fn = input_var->cdata.get()->grad_accumulator.lock();
+ if (!grad_fn) {
+ grad_fn = input_var->cdata.try_get_grad_accumulator();
}
THPUtils_assert(input_var->cdata.requires_grad(),
"One of the differentiated Variables does not require grad");
- if (allow_unreachable && !grad_fn) continue;
- THPUtils_assert(grad_fn,
- "One of the differentiated Variables appears to not have been used in the graph");
- THPUtils_assert(grad_fn->is_executable,
- "One of the differentiated Variables has a non-executable grad_fn. Submit a bug report.");
- auto& fn_info = ctx.output_map[grad_fn];
- fn_info.first.emplace_back(output_nr, i);
- fn_info.second = is_leaf;
- }
- // Register callbacks that will gather the outputs
- for (auto& entry : ctx.output_map) {
- auto& fn_info = entry.second;
- callbacks.emplace(entry.first.get(), [&ctx, &fn_info](Function* _unused, variable_list& grads) {
- auto& saved_outputs = fn_info.first;
- bool is_leaf = fn_info.second;
- AutoGIL gil;
- for (auto& saved_out : saved_outputs) {
- PyTuple_SET_ITEM(ctx.outputs.get(), saved_out.second,
- THPVariable_Wrap(grads[saved_out.first]));
- }
- // Suppress grad accumulation.
- // If the variable is a leaf, the next function to execute
- // is a grad_accumulator. But when inputs != NULL, we should
- // NOT accumulate, so terminate execution.
- return !is_leaf;
- });
- }
- // Disable execution for all unneeded functions
- if (only_inputs) {
- compute_partial_exec_callbacks(roots, ctx, callbacks, allow_unreachable);
+ if (!grad_fn) {
+ output_edges.emplace_back();
+ } else {
+ THPUtils_assert(grad_fn,
+ "One of the differentiated Variables appears to not have been used in the graph");
+ output_edges.emplace_back(grad_fn, output_nr);
+ }
}
}
+ variable_list outputs;
{
AutoNoGIL no_gil;
- engine.execute(roots, grads, keep_graph, callbacks);
+ outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
}
- if (ctx.outputs) {
- for (int i = 0; i < PyTuple_GET_SIZE(inputs); i++) {
- // XXX: initializing tuples with NULL pointers might be a CPython
- // implementation detail
- if (PyTuple_GET_ITEM(ctx.outputs.get(), i)) continue;
- Py_INCREF(Py_None);
- PyTuple_SET_ITEM(ctx.outputs.get(), i, Py_None);
+ if (inputs != nullptr) {
+ int num_inputs = PyTuple_GET_SIZE(inputs);
+ THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
+ if (!py_outputs) return nullptr;
+ for (int i = 0; i < num_inputs; i++) {
+ PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
}
- return ctx.outputs.release();
+ return py_outputs.release();
} else {
Py_RETURN_NONE;
}
@@ -265,14 +180,17 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}
PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
+ HANDLE_TH_ERRORS
+ _maybe_reinitialize_engine_after_fork();
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { AutoGIL gil; Py_DECREF(obj); });
Py_INCREF(_callback);
engine.queue_callback([callback]() {
AutoGIL gil;
- THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), NULL)};
+ THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
if (!result) throw python_error();
});
Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
}
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
@@ -281,14 +199,14 @@ PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
}
static struct PyMethodDef THPEngine_methods[] = {
- {(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, NULL},
- {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, NULL},
- {NULL}
+ {(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr},
+ {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr},
+ {nullptr}
};
PyTypeObject THPEngineType = {
- PyVarObject_HEAD_INIT(NULL, 0)
+ PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._EngineBase", /* tp_name */
sizeof(THPEngine), /* tp_basicsize */
0, /* tp_itemsize */
@@ -308,7 +226,7 @@ PyTypeObject THPEngineType = {
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
- NULL, /* tp_doc */
+ nullptr, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
@@ -328,8 +246,17 @@ PyTypeObject THPEngineType = {
THPEngine_new /* tp_new */
};
+static void child_atfork() {
+ _reinitialize_engine = true;
+}
+
bool THPEngine_initModule(PyObject *module)
{
+#ifndef _WIN32
+ if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
+ throw std::runtime_error("unable to set pthread_atfork handler");
+ }
+#endif
if (PyType_Ready(&THPEngineType) < 0)
return false;
Py_INCREF(&THPEngineType);
diff --git a/torch/csrc/autograd/python_engine.h b/torch/csrc/autograd/python_engine.h
index 14627cd6..dd9300c8 100644
--- a/torch/csrc/autograd/python_engine.h
+++ b/torch/csrc/autograd/python_engine.h
@@ -1,6 +1,8 @@
#pragma once
#include <Python.h>
+
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/engine.h"
bool THPEngine_initModule(PyObject *module);
@@ -10,12 +12,12 @@ namespace torch { namespace autograd { namespace python {
struct PythonEngine : public Engine {
virtual void thread_init(int device) override;
virtual void thread_on_exception(FunctionTask& task, std::exception& e) override;
- virtual void execute(
- const function_list& roots,
+ virtual variable_list execute(
+ const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
- const pre_callback_map& pre_callbacks = pre_callback_map(),
- const post_callback_map& post_callbacks = post_callback_map()) override;
+ bool create_graph,
+ const edge_list& outputs = {}) override;
static PythonEngine& getDefaultEngine();
};
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 68a2ae2f..7d84562f 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -8,30 +8,25 @@
#include <ATen/ATen.h>
#include "THP.h"
+#include "torch/csrc/autograd/grad_mode.h"
#include "torch/csrc/autograd/functions/accumulate_grad.h"
#include "torch/csrc/autograd/functions/basic_ops.h"
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/autograd/python_hook.h"
-#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/autograd/saved_variable.h"
+#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/utils/auto_gil.h"
#include "torch/csrc/utils/auto_gpu.h"
#include "torch/csrc/Exceptions.h"
-#ifdef WITH_CUDA
-#include "cuda/AutoGPU.h"
-#endif
-
using namespace torch;
using namespace torch::autograd;
using namespace torch::jit;
using at::Tensor;
-PyObject *THPFunctionClass = NULL;
-PyObject *THPStochasticFunctionClass = NULL;
-PyObject *THPBatchNormBackwardBackwardFunction = NULL;
+PyObject *THPFunctionClass = nullptr;
#define THPFunction_assert(condition, ...) \
if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); }
@@ -43,7 +38,7 @@ VariableInfo::VariableInfo(const Variable& var)
, device(-1)
, size(var.sizes())
, requires_grad(var.requires_grad()) {
- if (var.type().isCuda()) {
+ if (var.type().is_cuda()) {
device = var.get_device();
}
}
@@ -60,15 +55,7 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
if (!pyInputs) throw python_error();
for (size_t i = 0; i != inputs.size(); ++i) {
- PyObject* input;
- if (inputs[i].defined()) {
- input = createPyObject(inputs[i].data());
- if (!input) throw python_error();
- } else {
- input = Py_None;
- Py_INCREF(input);
- }
- PyTuple_SET_ITEM(pyInputs.get(), i, input);
+ PyTuple_SET_ITEM(pyInputs.get(), i, THPVariable_Wrap(inputs[i]));
}
THPObjectPtr r(PyObject_CallMethod(
@@ -80,13 +67,13 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
for (int i = 0; i != num_outputs; ++i) {
PyObject* obj = PyTuple_GET_ITEM(r.get(), i);
if (obj != Py_None) {
- if (!THPModule_isTensor(obj)) {
- std::string msg("expected Tensor (got '");
+ if (!THPVariable_Check(obj)) {
+ std::string msg("expected Variable (got '");
msg += THPUtils_typename(obj);
msg += "')'";
throw std::runtime_error(msg);
}
- tensor_results[i] = createTensor(obj);
+ tensor_results[i] = ((THPVariable*)obj)->cdata.data();
}
}
@@ -96,9 +83,13 @@ auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
// leads to unexpected error messages ("no nodes require computing gradients"),
// but I don't have a better idea. These functions would raise an error
// in backward anyway.
- return wrap_outputs(inputs, std::move(tensor_results), [this](FunctionFlags &&f) {
- return std::make_shared<Error>(name() + " is not differentiable twice", std::move(f));
- });
+ return wrap_outputs(
+ inputs,
+ std::move(tensor_results),
+ [this](edge_list&& next_edges) {
+ return std::make_shared<Error>(
+ name() + " is not differentiable twice", std::move(next_edges));
+ });
}
// NOTE: this function is written in a way that assumes it's only called for backward;
@@ -208,7 +199,7 @@ auto PyFunction::is_traceable() -> bool {
return traceable_py_bool == Py_True;
}
-auto PyFunction::releaseVariables() -> void {
+auto PyFunction::release_variables() -> void {
AutoGIL gil;
auto f = (THPFunction*) obj;
f->saved_variables.clear();
@@ -226,7 +217,7 @@ auto PyFunction::name() -> std::string {
return name;
}
-auto PyFunction::getSharedPtr() -> std::shared_ptr<Function> {
+auto PyFunction::get_shared_ptr() -> std::shared_ptr<Function> {
return THPFunction_asFunction((THPFunction*)obj);
}
@@ -235,18 +226,17 @@ auto PyFunction::getSharedPtr() -> std::shared_ptr<Function> {
// Traverse and clear are required for supporting Python's GC cycle handling.
static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
{
- for (auto& hook : self->cdata.pre_hooks) {
+ for (const auto& hook : self->cdata.pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
- for (auto& hook : self->cdata.post_hooks) {
+ for (const auto& hook : self->cdata.post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
Py_VISIT(self->to_save);
- Py_VISIT(self->shared_pairs);
Py_VISIT(self->non_differentiable);
Py_VISIT(self->dirty_tensors);
return 0;
@@ -254,12 +244,11 @@ static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg)
static int THPFunction_clear(THPFunction *self)
{
- self->cdata.num_inputs = 0;
+ self->cdata.set_num_inputs(0);
Py_CLEAR(self->needs_input_grad);
Py_CLEAR(self->to_save);
- Py_CLEAR(self->shared_pairs);
Py_CLEAR(self->non_differentiable);
Py_CLEAR(self->dirty_tensors);
@@ -268,10 +257,13 @@ static int THPFunction_clear(THPFunction *self)
self->saved_variables.clear();
self->is_variable_input.clear();
- // XXX: this will clear all hooks (not only Python ones)
- // I guess it's ok to leave it as is for now.
- auto pre_hooks = std::move(self->cdata.pre_hooks);
- auto post_hooks = std::move(self->cdata.post_hooks);
+ // Moving the hooks out makes sure to first disassociate them from the
+ // function, but without destroying any of them. They will get deleted when
+ // exiting this scope. This is important, because deleting Python objects can
+ // trigger deletion of other objects, and they can reference this function,
+ // seeing it in a half-deleted state.
+ auto pre_hooks = std::move(self->cdata.pre_hooks());
+ auto post_hooks = std::move(self->cdata.post_hooks());
return 0;
}
@@ -291,7 +283,7 @@ static void THPFunction_dealloc(THPFunction* self)
PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
PyObject* obj = type->tp_alloc(type, 0);
- if (!obj) return NULL;
+ if (!obj) return nullptr;
// Python zero-initializes the object memory, so there's no need to initialize
// most fields
THPFunction* self = (THPFunction*)obj;
@@ -300,8 +292,7 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
new (&self->input_info) std::vector<VariableInfo>();
new (&self->saved_variables) std::vector<SavedVariable>();
new (&self->is_variable_input) std::vector<bool>();
- self->cdata.num_inputs = -1;
- self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass);
+ self->cdata.set_num_inputs(0);
return obj;
}
@@ -313,67 +304,37 @@ using t2var_type = std::unordered_map<PyObject *, THPVariable *>;
// Bump the counters of all recorded dirty input tensors, adding each of them
// into dirty_inputs. Also does some sanity checking.
-static void _mark_dirty(THPFunction *self, t2var_type &t2var,
- std::unordered_set<PyObject *> &dirty_inputs)
+static std::vector<PyObject*> _mark_dirty(THPFunction *self)
{
// Increase versions of modified tensors
- if (!self->dirty_tensors) return;
+ std::vector<PyObject*> dirty_inputs;
+ if (!self->dirty_tensors) return dirty_inputs;
THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd "
"internal error: dirty_tensors attribute is expected to be a tuple "
"but is %s", THPUtils_typename(self->dirty_tensors));
Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
for (int i = 0; i < num_dirty; i++) {
- PyObject *tensor = PyTuple_GET_ITEM(self->dirty_tensors, i);
- dirty_inputs.insert(tensor);
- THPVariable *variable;
- try {
- variable = t2var.at(tensor);
- } catch (std::out_of_range &e) {
- THPFunction_assert(THPModule_isTensor(tensor), "mark_dirty can "
- "only accept tensors, but argument %d is of type %s", i,
- THPUtils_typename(tensor));
- THPFunction_assert(false, "mark_dirty only accepts input tensors, but "
- "argument %d isn't one", i);
- }
- auto& version_counter = variable->cdata.version_counter();
- THPFunction_assert(version_counter.live_refs() == 1,
- "in-place operations can be only used on variables that don't share "
- "storage with any other variables, but detected that there are %d "
- "objects sharing it",
- version_counter.live_refs());
- version_counter.increment();
+ PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
+ THPFunction_assert(THPVariable_Check(obj), "mark_dirty can "
+ "only accept variables, but argument %d is of type %s", i,
+ THPUtils_typename(obj));
+
+ dirty_inputs.push_back(obj);
+ auto variable = (THPVariable*)obj;
+ variable->cdata.bump_version();
}
// We're not going to ever need this so let's remove references now
- Py_DECREF(self->dirty_tensors);
- self->dirty_tensors = NULL;
+ Py_CLEAR(self->dirty_tensors);
+ return dirty_inputs;
}
-static void _transplant_var(VariableImpl& var, const std::shared_ptr<Function>& fn, int output_nr, bool is_volatile)
-{
- if (is_volatile) {
- var.grad_fn = nullptr;
- var.requires_grad = false;
- var.is_volatile = true;
- var.output_nr = 0;
- } else {
- var.grad_fn = fn;
- var.requires_grad = fn->is_executable;
- var.is_volatile = is_volatile;
- var.output_nr = output_nr;
- }
- var.grad.reset();
- var.hooks.clear();
- if (auto grad_acc_fn = var.grad_accumulator.lock()) {
- auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
- grad_acc->variable.reset();
- }
-}
+static std::unordered_set<PyObject*> _parse_non_differentiable(THPFunction *self);
// Given a Python tuple of raw output tensors (raw_output), set each of
// the corresponding entries in a different Python tuple (outputs) with
// these tensors wrapped with variables. We save the gradient function (self)
-// to the variable if the output is not volatile (is_volatile).
+// to the variable if the output requires grad.
//
// There is a considerable amount of complexity to handle if the operation
// that produced these output tensors is inplace. A mapping of *input*
@@ -381,92 +342,93 @@ static void _transplant_var(VariableImpl& var, const std::shared_ptr<Function>&
// the set of dirty tensors (dirty_inputs) is used to figure out what to
// do in this case. After this method is run, t2var is extended with
// mappings for output tensors as well.
-static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
- std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
- PyObject *outputs, bool is_volatile)
+static void _wrap_outputs(THPFunction *self,
+ PyObject* inputs_tuple, PyObject *raw_output, PyObject *outputs, bool is_executable)
{
- auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self);
+ auto cdata = is_executable ? THPFunction_asFunction(self) : nullptr;
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
- if (self->cdata.is_executable) {
+ if (is_executable) {
self->output_info.clear();
self->output_info.reserve(num_outputs);
}
- for (int i = 0; i < num_outputs; i++) {
- PyObject *output = PyTuple_GET_ITEM(raw_output, i);
- THPVariable *output_var;
- auto it = t2var.find(output);
- if (it == t2var.end()) {
- // A completely new tensor - just wrap it and continue
- if (is_volatile) {
- output_var = (THPVariable*)THPVariable_NewVolatile(output);
- } else {
- output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata);
+
+ std::unordered_set<PyObject*> inputs;
+ int num_inputs = PyTuple_GET_SIZE(inputs_tuple);
+ for (int i = 0; i < num_inputs; i++) {
+ inputs.emplace(PyTuple_GET_ITEM(inputs_tuple, i));
+ }
+
+ auto non_differentiable = _parse_non_differentiable(self);
+ auto dirty_inputs = _mark_dirty(self);
+
+ auto as_variable = [&](PyObject* obj, int i) -> Variable {
+ if (THPVariable_Check(obj)) {
+ return ((THPVariable*)obj)->cdata;
+ }
+ throw TypeError("%s.forward: expected Variable (got %s) for return value %d",
+ Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
+ };
+
+ // Sets the grad_fn and output_nr of an output Variable.
+ auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
+ bool is_differentiable) {
+ if (!is_differentiable) {
+ if (!var.requires_grad()) return;
+ // NB: we don't support returning non-differentiable views that could require grad
+ // (this could happen if someone were to return an input to the function).
+ if (var.is_view()) {
+ throw std::runtime_error("Returning Variables sharing storage with other Variables "
+ "that require grad is not supported in Python functions. "
+ "Please submit a feature request if you hit this error.");
}
- } else {
- // If one of the outputs was also an input tensor it's a bit more complicated.
- THPVariable *input_var = it->second;
- auto& input_var_ = input_var->cdata;
- if (input_var_.grad_fn()) {
- Py_INCREF(input_var);
- output_var = input_var;
- // If it's not a leaf we want to move it in the graph so backprop
- // will be computed correctly, but only if it was modified. Otherwise
- // it's better to minimize the number of operations that mutate the graph.
- // grad_fn <- variable <- self ==> grad_fn <- self <- variable
- if (dirty_inputs.count(output) > 0) {
- _transplant_var(*input_var_.get(), cdata, i, is_volatile);
- }
- } else {
- // If the leaf Variable has been returned, we have to move it after the
- // current function to ensure the gradient is computed correctly.
- // There are two cases now:
- // 1. It has been modified in-place. If it didn't require_grad it's ok,
- // but if it does, then it's a clear error.
- // 2. It hasn't been modified. This means that it must have been
- // returned unchanged, and we can simply return a new Variable
- // referencing the same storage.
- if (dirty_inputs.count(output) > 0) {
- if (!input_var_.requires_grad()) {
- Py_INCREF(input_var);
- output_var = input_var;
- _transplant_var(*input_var_.get(), cdata, i, is_volatile);
- } else { // input_var_.requires_grad
- throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
- }
- } else {
- // An input has been returned, but it wasn't modified. It's better
- // not to move the Variable, because there are some legitimate cases
- // where making it non-leaf would break stuff (e.g. broadcast). Also,
- // returning the input Variable is not a good option either,
- // because if someone registers hooks on it, they will fire with grads
- // from all usages, not only from usages of this output. This is why
- // we'll return a copy and join their version counters. This has
- // a side-effect of making in-place ops on any of these Variables an
- // immediate error, but it would be raised anyway once someone
- // calls backward.
- if (is_volatile) {
- output_var = (THPVariable*)THPVariable_NewVolatile(output);
- } else {
- output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata);
- }
- if (!output_var) throw python_error();
- output_var->cdata.version_counter() = input_var->cdata.version_counter();
- }
+ var.detach_();
+ } else if (is_modified) {
+ if (var.is_leaf() && var.requires_grad()) {
+ throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
}
+ // If the input was modified, transplant the grad_fn in the graph:
+ // grad_fn <- variable <- self ==> grad_fn <- self <- variable
+ var.reset_grad();
+ var.clear_hooks();
+ if (auto grad_acc_fn = var.try_get_grad_accumulator()) {
+ auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
+ grad_acc->variable.reset();
+ }
+ if (cdata) {
+ var.rebase_history({cdata, output_nr});
+ }
+ } else if (is_input) {
+ // An input has been returned, but it wasn't modified. Return it as a view
+ // so that we can attach a new grad_fn to the Variable.
+ var = var.slice();
+ var.set_gradient_edge({cdata, output_nr});
+ } else if (cdata) {
+ var.set_gradient_edge({cdata, output_nr});
}
- if (!output_var) throw python_error();
+ };
+
+ for (int i = 0; i < num_outputs; i++) {
+ PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
+
+ bool is_input = inputs.count(obj) > 0;
+ bool is_modified = std::find(dirty_inputs.begin(), dirty_inputs.end(), obj) != dirty_inputs.end();
+ bool is_differentiable = is_executable && non_differentiable.count(obj) == 0;
- if (self->cdata.is_executable) {
- self->output_info.emplace_back(output_var->cdata);
+ // Note that output Variables may be repeated. In that case, the last call
+ // to set_history wins.
+ auto var = as_variable(obj, i);
+ set_history(var, i, is_input, is_modified, is_differentiable);
+
+ if (is_executable) {
+ self->output_info.emplace_back(var);
}
- t2var[output] = output_var;
- output_var->cdata.get()->output_nr = i;
- PyTuple_SET_ITEM(outputs, i, (PyObject*)output_var);
+
+ PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(var));
}
}
// Save any variables that requested by to_save
-static void _save_variables(THPFunction* self, t2var_type &t2var)
+static void _save_variables(THPFunction* self)
{
if (!self->to_save) return;
@@ -478,110 +440,54 @@ static void _save_variables(THPFunction* self, t2var_type &t2var)
self->saved_variables.reserve(num_saved);
auto cdata_ptr = &self->cdata;
for (int i = 0; i < num_saved; i++) {
- PyObject *tensor = PyTuple_GET_ITEM(self->to_save, i);
- if (tensor == Py_None) {
+ PyObject *obj = PyTuple_GET_ITEM(self->to_save, i);
+ if (obj == Py_None) {
self->saved_variables.emplace_back();
continue;
+ } else if (THPVariable_Check(obj)) {
+ auto variable = (THPVariable*)obj;
+ bool is_output = variable->cdata.grad_fn().get() == cdata_ptr;
+ self->saved_variables.emplace_back(variable->cdata, is_output);
+ } else {
+ throw TypeError(
+ "save_for_backward can only save variables, but argument %d is of "
+ "type %s", i, Py_TYPE(obj)->tp_name);
}
-
- THPVariable *variable;
- try {
- variable = t2var.at(tensor);
- } catch(std::out_of_range &e) {
- THPFunction_assert(THPModule_isTensor(tensor),
- "save_for_backward can only save tensors, but argument %d is of "
- "type %s", i, THPUtils_typename(tensor));
- THPFunction_assert(false, "save_for_backward can only save input or output "
- "tensors, but argument %d doesn't satisfy this condition", i);
- }
-
- self->saved_variables.emplace_back(variable->cdata, cdata_ptr);
}
// Free .to_save
- Py_DECREF(self->to_save);
- self->to_save = NULL;
-}
-
-// t2var maps input and output tensors to variables
-static void _join_version_counters(THPFunction *self, t2var_type &t2var)
-{
- if (!self->shared_pairs) return;
- THPFunction_assert(PyTuple_Check(self->shared_pairs), "autograd internal "
- "error: shared_pairs attribute is expected to be a tuple but is %s",
- THPUtils_typename(self->shared_pairs));
- Py_ssize_t num_shared = PyTuple_GET_SIZE(self->shared_pairs);
- for (int i = 0; i < num_shared; i++) {
- PyObject *shared_tuple = PyTuple_GET_ITEM(self->shared_pairs, i);
- THPFunction_assert(PyTuple_Check(shared_tuple), "mark_shared_storages "
- "accepts a number of pairs, but one of the arguments is of type %s",
- THPUtils_typename(shared_tuple));
- THPFunction_assert(PyTuple_GET_SIZE(shared_tuple) == 2,
- "mark_shared_storages accepts pairs, but argument %d is a tuple of "
- "%d elements", i, PyTuple_GET_SIZE(shared_tuple));
-
- // Now we're sure it's really a pair!
- THPVariable *v1, *v2;
- try {
- // NB: According to the documentation, v1 is an input tensor, and v2
- // is an output tensor, but we don't actually check this
- v1 = t2var.at(PyTuple_GET_ITEM(shared_tuple, 0));
- v2 = t2var.at(PyTuple_GET_ITEM(shared_tuple, 1));
- } catch(std::out_of_range &e) {
- // One tuple items wasn't present in t2var, so there are two cases:
- // 1. it's not a tensor
- // 2. it's not an input nor an output
- PyObject *t1 = PyTuple_GET_ITEM(shared_tuple, 0);
- PyObject *t2 = PyTuple_GET_ITEM(shared_tuple, 1);
- THPFunction_assert(THPModule_isTensor(t1) && THPModule_isTensor(t2),
- "mark_shared_storages accepts pairs of tensors, but one of them "
- "contains %s and %s", THPUtils_typename(t1), THPUtils_typename(t2));
- THPFunction_assert(false, "mark_shared_storages only accepts pairs of input "
- "and output tensors, but argument %d doesn't satify this "
- "condition", i);
- }
- v2->cdata.version_counter() = v1->cdata.version_counter();
- }
- // Free .shared_pairs
- Py_DECREF(self->shared_pairs);
- self->shared_pairs = NULL;
+ Py_CLEAR(self->to_save);
}
// Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable)
-static void _mark_non_differentiable(THPFunction *self, t2var_type &t2var)
+static std::unordered_set<PyObject*>
+_parse_non_differentiable(THPFunction *self)
{
- if (!self->non_differentiable) return;
+ std::unordered_set<PyObject*> set;
+ if (!self->non_differentiable) return set;
THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd "
"internal error: non_differentiable attribute is expected to be a "
"tuple but is %s", THPUtils_typename(self->non_differentiable));
Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable);
+ set.reserve(num_nondiff);
for (int i = 0; i < num_nondiff; i++) {
PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i);
- THPVariable *var;
- try {
- var = t2var.at(t);
- THPFunction_assert(var->cdata.grad_fn().get() == &self->cdata,
- "mark_non_differentiable only accepts output tensors, but "
- "argument %d isn't an output", i);
- } catch (std::out_of_range &e) {
- THPFunction_assert(THPModule_isTensor(t), "mark_non_differentiable "
- "only accepts tensor arguments, but got %s", THPUtils_typename(t));
- THPFunction_assert(false, "mark_non_differentiable only accepts function "
- "outputs");
- }
- var->cdata.requires_grad() = false;
+ THPFunction_assert(THPVariable_Check(t), "mark_non_differentiable "
+ "only accepts variable arguments, but got %s", THPUtils_typename(t));
+ set.insert(t);
}
- Py_DECREF(self->non_differentiable);
- self->non_differentiable = NULL;
+ Py_CLEAR(self->non_differentiable);
+ return set;
}
struct UnpackedInput {
- THPObjectPtr tensor_input;
+ THPObjectPtr input_tuple;
variable_list input_vars;
};
struct InputFlags {
- FunctionFlags flags;
+ bool is_executable = false;
+ edge_list next_edges;
THPObjectPtr needs_input_grad;
std::vector<bool> is_variable_input;
};
@@ -592,60 +498,53 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
InputFlags flags;
auto num_args = PyTuple_GET_SIZE(args);
- unpacked.tensor_input = PyTuple_New(num_args);
+ unpacked.input_tuple = PyTuple_New(num_args);
flags.needs_input_grad = PyTuple_New(num_args);
for (int i = 0; i < num_args; i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
- PyObject *new_arg;
bool is_variable = THPVariable_Check(arg);
flags.is_variable_input.push_back(is_variable);
if (!is_variable) {
+ // TODO: remove this code path once Variable and Tensor are merged in Python
if (enforce_variables) {
THPUtils_setError("expected a Variable argument, but got %s",
THPUtils_typename(arg));
throw python_error();
}
- Py_INCREF(arg);
- new_arg = arg;
Py_INCREF(Py_False);
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
} else {
THPVariable* variable = (THPVariable*)arg;
- new_arg = THPVariable_get_data(variable);
unpacked.input_vars.push_back(variable->cdata);
PyObject* needs_grad = variable->cdata.requires_grad() ? Py_True : Py_False;
Py_INCREF(needs_grad);
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
}
- PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg);
+ Py_INCREF(arg);
+ PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
}
- flags.flags = Function::flags(unpacked.input_vars);
+ flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
+ flags.next_edges = collect_next_edges(unpacked.input_vars);
return std::make_pair(std::move(unpacked), std::move(flags));
}
-static void _trace_create(PyObject* op_obj, THPFunction* bw_obj,
- PyObject *input_objects, PyObject *output_objects,
- const variable_list& input_vars, bool is_inplace) {
- if (!tracer::isTracing(input_vars))
- return;
-
- if (!op_obj) {
+static void _assert_not_tracing(const char* name, const variable_list& input_vars) {
+ if (tracer::isTracingVar(input_vars)) {
std::ostringstream oss;
- oss << "Attempted to trace " << Py_TYPE(bw_obj)->tp_name;
+ oss << "Attempted to trace " << name;
oss << ", but tracing of legacy functions is not supported";
throw std::runtime_error(oss.str());
}
+}
- auto tracing_state = tracer::getTracingState(input_vars);
- bw_obj->is_traced = true;
-
- // Isolate C variable ptrs in a vector
- variable_list output_vars;
- for (int i = 0; i < PyTuple_GET_SIZE(output_objects); ++i) {
- THPVariable *var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i);
- output_vars.emplace_back(var->cdata);
+static jit::tracer::PreTraceInfo _trace_pre_record(
+ PyObject* op_obj,
+ PyObject *input_objects,
+ const variable_list& input_vars) {
+ if (!tracer::isTracingVar(input_vars)) {
+ return jit::tracer::PreTraceInfo();
}
// Save scalar args and the calling convention
@@ -665,49 +564,37 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj,
}
}
- auto state_lock = tracing_state->lock();
-
- // Note [getValueTrace can allocate nodes]
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // When an input variable is not traced, we create a constant instruction
- // to represent it. This means that you must invoke getValueTrace() BEFORE
- // actually constructing the function that takes these variables as inputs.
- // If we do it the other order, the graph will be in the wrong topological
- // order.
-
- // See Note [getValueTrace can allocate nodes]
- std::vector<Node*> value_traces;
- value_traces.reserve(input_vars.size());
- for (auto& i : input_vars)
- value_traces.emplace_back(tracer::getValueTrace(tracing_state, i));
+ Py_INCREF(op_obj);
+ auto pyobj = THPObjectPtr(op_obj);
+ return jit::tracer::preRecordPythonTrace(
+ std::move(pyobj),
+ std::move(arg_types),
+ input_vars,
+ std::move(scalar_args));
+}
- // NB: this function is called only from THPFunction_apply, which is used only
- // when computing forward. All these functions are non-traceable by definition,
- // because they are implemented in terms of tensor operations. Hence, there's no
- // need for any conditionals in here and we can always create the node.
+static void _trace_post_record(
+ const jit::tracer::PreTraceInfo& trace_info,
+ PyObject* op_obj,
+ const variable_list& input_vars,
+ PyObject *output_objects,
+ bool is_inplace) {
+ if (!trace_info.state) {
+ return;
+ }
- // Construct the IR Node and its Selects
- Py_INCREF(op_obj);
- auto& graph = tracing_state->graph;
- auto this_expr = graph->appendNode(graph->createPythonOp(
- THPObjectPtr(op_obj),
- arg_types,
- false, // TODO: remove is_legacy
- std::move(scalar_args)));
- for (auto t : value_traces)
- this_expr->addInput(t);
-
- int num_outputs = output_vars.size();
+ // Isolate C variable ptrs in a vector
+ int num_outputs = PyTuple_GET_SIZE(output_objects);
+ variable_list output_vars(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
- auto& output = output_vars[i];
- // NOTE: normally we don't add Select nodes when there's only a single
- // output, but Python nodes can't be optimized away, so we simplify the
- // code here.
- Node* sel = graph->appendNode(graph->createSelect(this_expr, i));
- sel->inferTypeFrom(output.data());
- tracer::setValueTrace(tracing_state, output, sel);
+ auto var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i);
+ output_vars[i] = var->cdata;
}
- this_expr->i_(kinplace, is_inplace);
+
+ jit::tracer::postRecordTrace(trace_info, output_vars);
+
+ auto state_lock = trace_info.state->lock();
+ trace_info.n->i_(kinplace, is_inplace);
// See definition in function.cpp.
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};
@@ -717,12 +604,15 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj,
// tracing_state->in_eval_subgraph (it's always false, because they are never part of backward
// subgraphs AND we don't even materialize the forward function).
if (!passes_state_transparently) {
+ // TODO: sgross and ezyang don't know if this is right
tracer::nontraceableBackwardSubgraph(input_vars, output_vars);
- Function::setUpContextEdge(this_expr, num_outputs, input_vars, output_vars);
+ Function::set_up_context_edge(trace_info.n, input_vars, output_vars);
}
}
-PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked, PyObject *inputs, THPObjectPtr&& raw_output, bool is_volatile) {
+PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked,
+ PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable,
+ const jit::tracer::PreTraceInfo& trace_info) {
bool unpack_output = ensure_tuple(raw_output);
auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
@@ -730,10 +620,10 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
THPObjectPtr outputs(PyTuple_New(num_outputs));
if (!outputs) throw python_error();
- grad_fn->cdata.num_inputs = num_outputs;
+ grad_fn->cdata.set_num_inputs(num_outputs);
// Record type, device, and size information about inputs
- if (grad_fn->cdata.is_executable) {
+ if (is_executable) {
grad_fn->input_info.clear();
grad_fn->input_info.reserve(unpacked.input_vars.size());
for (auto& var : unpacked.input_vars) {
@@ -741,36 +631,22 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
}
}
- // Initialize t2var map with input tensors
- t2var_type t2var;
- for (auto& c_var : unpacked.input_vars) {
- THPVariable* py_var = (THPVariable*)c_var.get()->pyobj;
- t2var.emplace(py_var->data, py_var);
- }
-
- std::unordered_set<PyObject *> dirty_inputs;
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
- _mark_dirty(grad_fn, t2var, dirty_inputs);
- _wrap_outputs(grad_fn, t2var, dirty_inputs, raw_output, outputs, is_volatile);
- // At this point, t2var contains output tensors as well
- _join_version_counters(grad_fn, t2var);
- if (grad_fn->cdata.is_executable) {
- _mark_non_differentiable(grad_fn, t2var);
- }
- // NOTE: _trace_create has to run before _save_variables, because we need
+ _wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable);
+ // NOTE: _trace_post_record has to run before _save_variables, because we need
// to assign traces to outputs before we convert them to SavedVariables.
// On the other hand, it needs to go after _mark_non_differentiable, because
// it might be wraping backwards in Evals, and _mark_non_differentiable uses
// grad_fn pointer equality for error checking.
- _trace_create(op_obj, grad_fn, inputs, outputs, unpacked.input_vars, is_inplace);
- if (grad_fn->cdata.is_executable) {
- _save_variables(grad_fn, t2var);
+ _trace_post_record(trace_info, op_obj, unpacked.input_vars, outputs, is_inplace);
+ if (is_executable) {
+ _save_variables(grad_fn);
} else {
// Remove unnecessary attributes
Py_XDECREF(grad_fn->to_save);
- grad_fn->to_save = NULL;
+ grad_fn->to_save = nullptr;
Py_XDECREF(grad_fn->non_differentiable);
- grad_fn->non_differentiable = NULL;
+ grad_fn->non_differentiable = nullptr;
}
// Unpack the output, unless .forward() returned a tuple
@@ -792,17 +668,25 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
auto info_pair = unpack_input<true>(_inputs);
auto& unpacked_input = info_pair.first;
auto& input_info = info_pair.second;
- bool is_volatile = input_info.flags.is_volatile;
- self->cdata.set_flags(std::move(input_info.flags));
+ bool is_executable = input_info.is_executable;
+ self->cdata.set_next_edges(std::move(input_info.next_edges));
self->needs_input_grad = input_info.needs_input_grad.release();
+ // We don't support tracing in the legacy code path
+ _assert_not_tracing(Py_TYPE(self)->tp_name, unpacked_input.input_vars);
+
// Now we're ready to call a forward (implemented in Python)
- THPObjectPtr forward_fn(PyObject_GetAttrString((PyObject*)self, "forward"));
- if (!forward_fn) return NULL;
- THPObjectPtr raw_output(PyObject_CallObject(forward_fn, unpacked_input.tensor_input));
- if (!raw_output) return NULL;
+ THPObjectPtr raw_output;
+ {
+ AutoGradMode grad_mode(false);
+ THPObjectPtr forward_fn(PyObject_GetAttrString((PyObject*)self, "forward"));
+ if (!forward_fn) return nullptr;
+ raw_output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
+ if (!raw_output) return nullptr;
+ }
- return process_outputs(nullptr, self, unpacked_input, _inputs, std::move(raw_output), is_volatile);
+ return process_outputs(nullptr, self, unpacked_input, _inputs, std::move(raw_output),
+ is_executable, jit::tracer::PreTraceInfo());
END_HANDLE_TH_ERRORS
}
@@ -812,9 +696,9 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
torch::autograd::profiler::RecordFunction record(((PyTypeObject*)cls)->tp_name);
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
- if (!backward_cls) return NULL;
- THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, NULL));
- if (!ctx_obj) return NULL;
+ if (!backward_cls) return nullptr;
+ THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
+ if (!ctx_obj) return nullptr;
THPFunction* ctx = (THPFunction*)ctx_obj.get();
// Prepare inputs and allocate context (grad fn)
@@ -822,32 +706,41 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
UnpackedInput& unpacked_input = info_pair.first;
InputFlags& input_info = info_pair.second;
+ // Record input nodes if tracing
+ auto trace_info = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
+ if (trace_info.state) {
+ // TODO: ezyang suggests this is unused and can be removed
+ ctx->is_traced = true;
+ }
+
// Initialize backward function (and ctx)
- bool is_volatile = input_info.flags.is_volatile;
- ctx->cdata.set_flags(std::move(input_info.flags));
+ bool is_executable = input_info.is_executable;
+ ctx->cdata.set_next_edges(std::move(input_info.next_edges));
ctx->needs_input_grad = input_info.needs_input_grad.release();
ctx->is_variable_input = std::move(input_info.is_variable_input);
- // Prepend ctx to tensor_input, in preparation for static method call
+ // Prepend ctx to input_tuple, in preparation for static method call
auto num_args = PyTuple_GET_SIZE(inputs);
- THPObjectPtr ctx_tensor_input(PyTuple_New(num_args + 1));
- PyTuple_SET_ITEM(ctx_tensor_input.get(), 0, ctx_obj.release());
+ THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
+ PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, ctx_obj.release());
for (int i = 0; i < num_args; ++i) {
- PyObject *arg = PyTuple_GET_ITEM(unpacked_input.tensor_input.get(), i);
+ PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
Py_INCREF(arg);
- PyTuple_SET_ITEM(ctx_tensor_input.get(), i + 1, arg);
+ PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
}
// Call forward
- THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
- if (!forward_fn) return NULL;
- THPObjectPtr tensor_outputs(PyObject_CallObject(forward_fn, ctx_tensor_input));
- if (!tensor_outputs) return NULL;
-
- THPObjectPtr outputs {process_outputs(cls, ctx, unpacked_input, inputs,
- std::move(tensor_outputs), is_volatile)};
+ THPObjectPtr tensor_outputs;
+ {
+ AutoGradMode grad_mode(false);
+ THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
+ if (!forward_fn) return nullptr;
+ tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
+ if (!tensor_outputs) return nullptr;
+ }
- return outputs.release();
+ return process_outputs(cls, ctx, unpacked_input, inputs, std::move(tensor_outputs),
+ is_executable, trace_info);
END_HANDLE_TH_ERRORS
}
@@ -856,51 +749,52 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
// Backward
////////////////////////////////////////////////////////////////////////////////
-static void _prepare_grad_output(THPFunction *self, THPObjectPtr& raw_grad_output)
+static void _prepare_grads(THPFunction *self, THPObjectPtr& raw_grads, bool is_grad_output)
{
AutoGPU gpu_guard(-1);
- int num_grad_output = PyTuple_GET_SIZE(raw_grad_output.get());
- // First, check if any of grad_outputs is None. If not, there's nothing to do
+ int num_grads = PyTuple_GET_SIZE(raw_grads.get());
+ // First, check if any of grads is None. If not, there's nothing to do
bool has_none = false;
- for (int i = 0; i < num_grad_output; i++) {
- has_none |= PyTuple_GET_ITEM(raw_grad_output.get(), i) == Py_None;
+ for (int i = 0; i < num_grads; i++) {
+ has_none |= PyTuple_GET_ITEM(raw_grads.get(), i) == Py_None;
}
if (!has_none)
return;
- THPObjectPtr grad_output;
- grad_output = PyTuple_New(num_grad_output);
- if (!grad_output) throw python_error();
+ THPObjectPtr grads;
+ grads = PyTuple_New(num_grads);
+ if (!grads) throw python_error();
// Look for Nones and replace them with new buffers
- auto& output_info = self->output_info;
- for (int i = 0; i < num_grad_output; i++) {
- PyObject *grad = PyTuple_GET_ITEM(raw_grad_output.get(), i);
+ auto& grads_info = is_grad_output ? self->output_info : self->input_info;
+ TORCH_ASSERT(grads_info.size() == (size_t)num_grads);
+ for (int i = 0; i < num_grads; i++) {
+ PyObject *grad = PyTuple_GET_ITEM(raw_grads.get(), i);
if (grad == Py_None) {
- grad = createPyObject(output_info[i].zeros(gpu_guard).data());
+ grad = THPVariable_Wrap(grads_info[i].zeros(gpu_guard));
if (!grad) throw python_error();
} else {
Py_INCREF(grad);
}
- PyTuple_SET_ITEM(grad_output.get(), i, grad);
+ PyTuple_SET_ITEM(grads.get(), i, grad);
}
- raw_grad_output = grad_output.release();
+ raw_grads = grads.release();
}
static void _trim_grad_input(THPFunction *self, THPObjectPtr& grad_input)
{
int num_grads = PyTuple_GET_SIZE(grad_input.get());
- int num_next_fns = self->cdata.next_functions.size();
- if (num_grads > num_next_fns) {
+ const int num_outputs = self->cdata.num_outputs();
+ if (num_grads > num_outputs) {
// Check that all extra grads are none
bool all_none = true;
- for (int i = num_next_fns; i < num_grads; i++) {
+ for (int i = num_outputs; i < num_grads; i++) {
all_none = (PyTuple_GET_ITEM(grad_input.get(), i) == Py_None);
if (!all_none) break;
}
// If yes, slice the tuple
if (all_none) {
- num_grads = num_next_fns;
+ num_grads = num_outputs;
grad_input = PyTuple_GetSlice(grad_input.get(), 0, num_grads);
if (!grad_input) throw python_error();
}
@@ -915,44 +809,46 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
PyObject *raw_grad_output = PyTuple_GET_ITEM(args, 0);
PyObject *retain_variables = PyTuple_GET_ITEM(args, 1);
if (!PyTuple_Check(raw_grad_output) || !PyBool_Check(retain_variables)) {
- THPUtils_invalidArguments(args, NULL, "_do_backward", 1, "(tuple, bool)");
- return NULL;
+ THPUtils_invalidArguments(args, nullptr, "_do_backward", 1, "(tuple, bool)");
+ return nullptr;
}
- THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == self->cdata.num_inputs,
+ THPUtils_assert(PyTuple_GET_SIZE(raw_grad_output) == self->cdata.num_inputs(),
"%s got an invalid number of gradients (expected %d got %d)",
- THPUtils_typename(self), self->cdata.num_inputs,
+ THPUtils_typename(self), self->cdata.num_inputs(),
PyTuple_GET_SIZE(raw_grad_output));
// Some of the output might have been unused, so we have to allocate
// zero-filled buffers instead
Py_INCREF(raw_grad_output);
THPObjectPtr grad_output(raw_grad_output);
- _prepare_grad_output(self, grad_output);
+ _prepare_grads(self, grad_output, true);
// self.backward(*grad_output)
THPObjectPtr backward_fn(PyObject_GetAttrString((PyObject*)self, "backward"));
THPUtils_assert(backward_fn.get(), "function %s doesn't implement a required "
"'backward' method", THPUtils_typename((PyObject*)self));
THPObjectPtr grad_input(PyObject_CallObject(backward_fn, grad_output.get()));
- if (!grad_input) return NULL;
+ if (!grad_input) return nullptr;
ensure_tuple(grad_input);
// We allow functions to return more gradients, than there were outputs,
// if and only if the additional ones are all None
_trim_grad_input(self, grad_input);
int num_grads = PyTuple_GET_SIZE(grad_input.get());
- int num_next_fns = self->cdata.next_functions.size();
- THPUtils_assert(num_grads == num_next_fns, "%s returned an invalid number of "
+ int num_outputs = self->cdata.num_outputs();
+ THPUtils_assert(num_grads == num_outputs, "%s returned an invalid number of "
"gradient tensors (expected %d, but got %d)", THPUtils_typename(self),
- num_next_fns, num_grads);
+ num_outputs, num_grads);
+ // If any of the remaining grad_inputs are None, zero them.
+ _prepare_grads(self, grad_input, false);
return grad_input.release();
} catch (python_error& e) {
- return NULL;
+ return nullptr;
} catch (std::exception& e) {
THPUtils_setError(e.what());
- return NULL;
+ return nullptr;
}
}
@@ -964,7 +860,9 @@ PyObject* THPFunction__register_hook_dict(THPFunction *self, PyObject *_var)
{
THPUtils_assert(THPVariable_Check(_var), "_register_hook_dict expected a variable");
THPVariable *var = (THPVariable*)_var;
- self->cdata.pre_hooks.emplace_back(new PyFunctionPreHook(var->backward_hooks, var->cdata.output_nr()));
+ std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(
+ var->backward_hooks, var->cdata.output_nr()));
+ self->cdata.add_pre_hook(std::move(hook));
Py_RETURN_NONE;
}
@@ -985,7 +883,7 @@ static PyObject *unpack_saved_variables(
int num_saved = saved_variables.size();
THPObjectPtr saved(PyTuple_New(num_saved));
if (!saved)
- return NULL;
+ return nullptr;
auto saved_for = THPFunction_asFunction(self);
for (int i = 0; i < num_saved; i++) {
auto unpacked_var = saved_variables[i].unpack(saved_for);
@@ -1003,32 +901,36 @@ static PyObject *unpack_saved_variables(
PyObject *THPFunction_saved_tensors(THPFunction *self, void *_unused)
{
+ HANDLE_TH_ERRORS
return unpack_saved_variables(self, [](const Variable& var) {
- return createPyObject(var.data());
+ return THPVariable_Wrap(var);
});
+ END_HANDLE_TH_ERRORS
}
PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused)
{
+ HANDLE_TH_ERRORS
return unpack_saved_variables(self, [](const Variable& var) {
return THPVariable_Wrap(var);
});
+ END_HANDLE_TH_ERRORS
}
PyObject *THPFunction_next_functions(THPFunction *self, void *_unused)
{
- auto& next_fns = self->cdata.next_functions;
- int size = next_fns.size();
- THPObjectPtr result(PyTuple_New(size));
+ const auto num_outputs = self->cdata.num_outputs();
+ THPObjectPtr result(PyTuple_New(num_outputs));
if (!result)
- return NULL;
- for (int i = 0; i < size; i++) {
+ return nullptr;
+ for (uint32_t i = 0; i < num_outputs; i++) {
THPObjectPtr fn_tuple(PyTuple_New(2));
- if (!fn_tuple) return NULL;
- PyObject* fn = functionToPyObject(next_fns[i].first);
- if (!fn) return NULL;
+ if (!fn_tuple) return nullptr;
+ const auto& edge = self->cdata.next_edge(i);
+ PyObject* fn = functionToPyObject(edge.function);
+ if (!fn) return nullptr;
PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
- PyTuple_SET_ITEM(fn_tuple.get(), 1, PyInt_FromLong(next_fns[i].second));
+ PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr));
PyTuple_SET_ITEM(result.get(), i, fn_tuple.release());
}
return result.release();
@@ -1075,43 +977,36 @@ PyObject* getImplMember(PyObject* obj, void* _unused) {
return Convert(self->cdata.*ptr);
}
-int setRequiresGrad(PyObject* obj, PyObject* value, void* _unused) {
- auto self = (THPFunction*)obj;
- if (!PyBool_Check(value)) {
- PyErr_Format(PyExc_TypeError, "'is_executable' must be a bool");
- return -1;
- }
- self->cdata.is_executable = (value == Py_True);
- return 0;
+PyObject* getRequiresGrad(PyObject* obj, void* _unused) {
+ Py_RETURN_TRUE;
}
}
static struct PyGetSetDef THPFunction_properties[] = {
- {"saved_tensors", (getter)THPFunction_saved_tensors, NULL, NULL, NULL},
- {"saved_variables", (getter)THPFunction_saved_variables, NULL, NULL, NULL},
- {"next_functions", (getter)THPFunction_next_functions, NULL, NULL, NULL},
- {"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>, NULL, NULL},
- {"shared_pairs", &getObject<&THPFunction::shared_pairs>, &setObject<&THPFunction::shared_pairs>, NULL, NULL},
- {"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>, NULL, NULL},
- {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, NULL, NULL},
- {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, NULL, NULL, NULL},
- {"requires_grad", &getImplMember<bool, &Function::is_executable, PyBool_FromLong>, &setRequiresGrad, NULL, NULL},
- {"_is_tracing", &getMember<char, &THPFunction::is_traced, PyBool_FromLong>, NULL, NULL, NULL},
- {NULL}
+ {"saved_tensors", (getter)THPFunction_saved_tensors, nullptr, nullptr, nullptr},
+ {"saved_variables", (getter)THPFunction_saved_variables, nullptr, nullptr, nullptr},
+ {"next_functions", (getter)THPFunction_next_functions, nullptr, nullptr, nullptr},
+ {"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>, nullptr, nullptr},
+ {"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>, nullptr, nullptr},
+ {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, nullptr, nullptr},
+ {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, nullptr, nullptr, nullptr},
+ {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
+ {"_is_tracing", &getMember<char, &THPFunction::is_traced, PyBool_FromLong>, nullptr, nullptr, nullptr},
+ {nullptr}
};
static struct PyMethodDef THPFunction_methods[] = {
- {(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, NULL},
- {(char*)"_do_forward", (PyCFunction)THPFunction_do_forward, METH_VARARGS, NULL},
- {(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, NULL},
- {(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, NULL},
- {(char*)"register_hook", (PyCFunction)THPFunction_register_hook, METH_O, NULL},
- {NULL}
+ {(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
+ {(char*)"_do_forward", (PyCFunction)THPFunction_do_forward, METH_VARARGS, nullptr},
+ {(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, nullptr},
+ {(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, nullptr},
+ {(char*)"register_hook", (PyCFunction)THPFunction_register_hook, METH_O, nullptr},
+ {nullptr}
};
PyTypeObject THPFunctionType = {
- PyVarObject_HEAD_INIT(NULL, 0)
+ PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._FunctionBase", /* tp_name */
sizeof(THPFunction), /* tp_basicsize */
0, /* tp_itemsize */
@@ -1131,7 +1026,7 @@ PyTypeObject THPFunctionType = {
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */
- NULL, /* tp_doc */
+ nullptr, /* tp_doc */
(traverseproc)THPFunction_traverse, /* tp_traverse */
(inquiry)THPFunction_clear, /* tp_clear */
0, /* tp_richcompare */
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h
index 837b0add..72e894a3 100644
--- a/torch/csrc/autograd/python_function.h
+++ b/torch/csrc/autograd/python_function.h
@@ -34,9 +34,9 @@ struct PyFunction : public Function {
virtual variable_list apply(const variable_list& inputs) override;
variable_list legacy_apply(const variable_list& inputs);
- virtual void releaseVariables() override;
+ virtual void release_variables() override;
virtual std::string name() override;
- virtual std::shared_ptr<Function> getSharedPtr() override;
+ virtual std::shared_ptr<Function> get_shared_ptr() override;
virtual bool is_traceable() override;
// THPFunction this Function is wrapping.
@@ -66,19 +66,15 @@ struct THPFunction {
PyObject *needs_input_grad;
// Python tuple of tensors whose variables we should save. Set
- // by Python with 'save_for_backward'. If NULL, no tensors were
+ // by Python with 'save_for_backward'. If nullptr, no tensors were
// saved.
PyObject *to_save;
- // Python pairs of distinct tensors which share storage. Set by
- // Python with 'mark_shared_storage'. If NULL, no tensors share
- // storage.
- PyObject *shared_pairs;
// Python tuple of tensors which are not differentiable. Set by
- // Python with 'mark_non_differentiable'. If NULL, no tensors were
+ // Python with 'mark_non_differentiable'. If nullptr, no tensors were
// non-differentiable.
PyObject *non_differentiable;
// Python tuple of tensors which had inplace updates in the forward()
- // pass. Set by Python with 'mark_dirty'. If NULL, no tensors were
+ // pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were
// modified inplace.
PyObject *dirty_tensors;
@@ -98,8 +94,6 @@ struct THPFunction {
bool THPFunction_initModule(PyObject *module);
extern PyTypeObject THPFunctionType;
extern PyObject *THPFunctionClass;
-extern PyObject *THPStochasticFunctionClass;
-extern PyObject *THPBatchNormBackwardBackwardFunction; // Temporarily here until we move it to C++
// XXX: this function requires the GIL (it can have side effects).
std::shared_ptr<torch::autograd::PyFunction> THPFunction_asFunction(THPFunction* self);
diff --git a/torch/csrc/autograd/python_hook.cpp b/torch/csrc/autograd/python_hook.cpp
index c87669aa..3ceb1f4a 100644
--- a/torch/csrc/autograd/python_hook.cpp
+++ b/torch/csrc/autograd/python_hook.cpp
@@ -169,11 +169,11 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject
throw std::runtime_error(ss.str());
}
- if (original.type().isCuda() != result.type().isCuda()) {
+ if (original.type().is_cuda() != result.type().is_cuda()) {
std::stringstream ss;
auto name = hook_name(hook);
ss << "hook '" << name << "' has changed the type of value";
- if (original.type().isCuda()) {
+ if (original.type().is_cuda()) {
ss << " (was CUDA tensor got CPU tensor)";
} else {
ss << " (was CPU tensor got CUDA tensor)";
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index 40321519..ecb59008 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -1,23 +1,38 @@
#include "torch/csrc/autograd/python_variable.h"
-#include <structmember.h>
-
#include "THP.h"
#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/Size.h"
#include "torch/csrc/Types.h"
+#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/autograd/python_hook.h"
+#include "torch/csrc/autograd/python_variable_indexing.h"
+#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/functions/accumulate_grad.h"
-#include "torch/csrc/cuda/AutoGPU.h"
+#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/generated/VariableType.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/jit/tracer_state.h"
+#include "torch/csrc/tensor/python_tensor.h"
#include "torch/csrc/utils/auto_gil.h"
-#include "torch/csrc/Exceptions.h"
-#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/utils/python_strings.h"
+
+#include <ATen/ATen.h>
+#include <list>
+#include <memory>
+#include <structmember.h>
using namespace at;
using namespace torch::autograd;
-PyObject *THPVariableClass = NULL;
+PyObject *THPVariableClass = nullptr;
+
+static const char* VOLATILE_WARNING =
+ "volatile was removed and now has no effect. Use "
+ "`with torch.no_grad():` instead.";
// Creates a new Python object for a Variable. The Variable must not already
// have a PyObject* associated with it.
@@ -27,11 +42,13 @@ static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var)
if (obj) {
auto v = (THPVariable*) obj;
new (&v->cdata) Variable(std::move(var));
- v->cdata.get()->pyobj = obj;
- if (auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn().get())) {
+ v->cdata.set_pyobj(obj);
+ if (auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn_unsafe())) {
// Create a new reference to the THPFunction. This ensures that ref count
// of the THPFunction is at least the number of referring THPVariables.
- v->cdata.grad_fn() = THPFunction_asFunction((THPFunction*)fn->obj);
+ const auto output_nr = v->cdata.output_nr();
+ auto grad_fn = THPFunction_asFunction((THPFunction*)fn->obj);
+ v->cdata.set_gradient_edge({std::move(grad_fn), output_nr});
}
}
return obj;
@@ -43,11 +60,7 @@ PyObject * THPVariable_Wrap(Variable var)
Py_RETURN_NONE;
}
- if (var.dim() == 0) {
- throw std::runtime_error("Variable API does not support Scalars");
- }
-
- if (auto obj = var.get()->pyobj) {
+ if (auto obj = var.pyobj()) {
Py_INCREF(obj);
return obj;
}
@@ -55,50 +68,8 @@ PyObject * THPVariable_Wrap(Variable var)
return THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, std::move(var));
}
-// This function DOES NOT steal a reference to data
-PyObject * THPVariable_NewWithFunction(PyObject *data, const std::shared_ptr<torch::autograd::Function>& grad_fn)
-{
- THPUtils_assert(THPModule_isTensor(data), "data must be a Tensor");
-
- Variable v = make_variable(torch::createTensor(data));
- v.requires_grad() = grad_fn->is_executable;
- v.grad_fn() = grad_fn;
-
- PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v));
- if (obj) {
- ((THPVariable*)obj)->data = data;
- Py_INCREF(data);
- }
- return obj;
-}
-
-// This function DOES NOT steal a reference to data
-PyObject * THPVariable_NewVolatile(PyObject *data)
-{
- Variable v = make_variable(torch::createTensor(data), false, true);
- PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v));
- if (obj) {
- ((THPVariable*)obj)->data = data;
- Py_INCREF(data);
- }
- return obj;
-}
-
-// This function DOES NOT steal a reference to data
-PyObject * THPVariable_NewLeaf(PyObject *data)
-{
- Variable v = make_variable(torch::createTensor(data));
- PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, std::move(v));
- if (obj) {
- ((THPVariable*)obj)->data = data;
- Py_INCREF(data);
- }
- return obj;
-}
-
static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
{
- Py_VISIT(self->data);
Py_VISIT(self->backward_hooks);
// We don't want to traverse the grad_fn, even if the Variable owns it and the
// shared pointer's use count is 1. This is because we would need to treat
@@ -115,7 +86,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
// for more details about the race condition involving traversing the grad_fn
// and the python GC.
if (self->cdata.defined()) {
- for (auto& hook : self->cdata.hooks()) {
+ for (const auto& hook : self->cdata.hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
@@ -126,13 +97,12 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
static int THPVariable_clear(THPVariable *self)
{
- Py_CLEAR(self->data);
Py_CLEAR(self->backward_hooks);
if (self->cdata.defined()) {
- if (auto grad_acc = self->cdata.get()->grad_accumulator.lock()) {
- grad_acc->pre_hooks.clear();
+ if (auto grad_acc = self->cdata.try_get_grad_accumulator()) {
+ grad_acc->pre_hooks().clear();
}
- self->cdata.get()->pyobj = nullptr;
+ self->cdata.set_pyobj(nullptr);
}
self->cdata.reset();
return 0;
@@ -148,25 +118,24 @@ static void THPVariable_dealloc(THPVariable* self)
PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
+ HANDLE_TH_ERRORS
THPObjectPtr _data;
- PyObject *data = NULL;
- PyObject *grad_fn = NULL;
+ PyObject *data = nullptr;
+ PyObject *grad_fn = nullptr;
char is_volatile = 0;
char requires_grad = 0;
+ const char* name = nullptr;
- const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", NULL};
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbO", (char**)accepted_args,
- &data, &requires_grad, &is_volatile, &grad_fn))
- return NULL;
+ const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args,
+ &data, &requires_grad, &is_volatile, &grad_fn, &name))
+ return nullptr;
if (grad_fn == Py_None)
- grad_fn = NULL;
+ grad_fn = nullptr;
- if (data == NULL || data == Py_None) {
- // For legacy serialization code, create an empty tensor temporarily.
- at::Tensor tensor = at::CPU(at::kFloat).tensor();
- _data = torch::createPyObject(tensor);
- data = _data.get();
+ if (is_volatile) {
+ PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
}
THPUtils_assert(!(is_volatile && requires_grad),
@@ -174,23 +143,34 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)
THPUtils_assert(!grad_fn || THPFunction_Check(grad_fn),
"Variable _grad_fn has to be a Function object or None, but got %s",
THPUtils_typename(grad_fn));
- THPUtils_assert(THPModule_isTensor(data), "Variable data has to "
- "be a tensor, but got %s", THPUtils_typename(data));
+ Tensor tensor;
+ if (!data || data == Py_None) {
+ // For legacy serialization code, create an empty tensor. This is also used
+ // by nn.Parameter() with no arguments.
+ auto var = torch::tensor::get_default_tensor_type().tensor();
+ tensor = static_cast<Variable&>(var).data();
+ } else if (THPVariable_Check(data)) {
+ tensor = ((THPVariable*)data)->cdata.data();
+ } else {
+ throw torch::TypeError("Variable data has to be a tensor, but got %s",
+ THPUtils_typename(data));
+ }
Variable var;
if (grad_fn) {
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
- var = make_variable(torch::createTensor(data), grad_fn_);
+ Edge edge(grad_fn_, grad_fn_->bump_inputs());
+ var = make_variable(std::move(tensor), std::move(edge));
} else {
- var = make_variable(torch::createTensor(data), requires_grad, is_volatile);
+ var = make_variable(std::move(tensor), requires_grad);
}
- PyObject* self = THPVariable_NewWithVar(type, std::move(var));
- if (self) {
- ((THPVariable*)self)->data = data;
- Py_INCREF(data);
+ if (name) {
+ var.set_name(name);
}
- return self;
+
+ return THPVariable_NewWithVar(type, std::move(var));
+ END_HANDLE_TH_ERRORS
}
int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds)
@@ -199,13 +179,14 @@ int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds)
// The 'data' argument is optional in __new__ to handle legacy serialized
// Variables.
PyObject *data;
- PyObject *grad_fn = NULL;
+ PyObject *grad_fn = nullptr;
char is_volatile = 0;
char requires_grad = 0;
+ const char* name = nullptr;
- const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", NULL};
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbO", (char**)accepted_args,
- &data, &requires_grad, &is_volatile, &grad_fn))
+ const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args,
+ &data, &requires_grad, &is_volatile, &grad_fn, &name))
return -1;
return 0;
@@ -214,151 +195,144 @@ int THPVariable_pyinit(PyObject *self, PyObject *args, PyObject *kwds)
typedef PyObject *(*getter)(PyObject *, void *);
typedef int (*setter)(PyObject *, PyObject *, void *);
+PyObject *THPVariable_get_cdata(THPVariable *self)
+{
+ HANDLE_TH_ERRORS
+ auto& var = self->cdata;
+ return PyLong_FromVoidPtr(var.unsafeGetTH(false));
+ END_HANDLE_TH_ERRORS
+}
+
PyObject *THPVariable_get_version(THPVariable *self)
{
+ HANDLE_TH_ERRORS
auto& var = self->cdata;
return PyInt_FromLong(var.current_version());
+ END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_get_grad_fn(THPVariable *self)
{
+ HANDLE_TH_ERRORS
auto& var = self->cdata;
if (!var.grad_fn()) {
Py_RETURN_NONE;
}
return functionToPyObject(var.grad_fn());
+ END_HANDLE_TH_ERRORS
}
-int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj)
+static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj)
{
+ HANDLE_TH_ERRORS
THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None");
- self->cdata.grad_fn() = nullptr;
+ self->cdata.detach_();
return 0;
+ END_HANDLE_TH_ERRORS_RET(-1)
}
-PyObject *THPVariable_is_leaf(THPVariable *self)
+static PyObject *THPVariable_is_leaf(THPVariable *self)
{
+ HANDLE_TH_ERRORS
return PyBool_FromLong(!self->cdata.grad_fn());
+ END_HANDLE_TH_ERRORS
}
-PyObject * THPVariable_get_data(THPVariable *self)
+static PyObject * THPVariable_get_data(THPVariable *self)
{
- if (!self->data) {
- self->data = torch::createPyObject(self->cdata.data());
- }
- Py_XINCREF(self->data);
- return self->data;
-}
-
-namespace {
-
-// XXX: This is a hack to access private TensorImpl::type_
-// http://bloglitb.blogspot.com/2011/12/access-to-private-members-safer.html
-// This is currently needed because module.float() changes the type of the
-// data field of each variable. We should fix this and not allow changing the
-// type of var.data.
-
-template<typename Tag, typename Tag::type M>
-struct Rob {
- friend typename Tag::type get(Tag) {
- return M;
- }
-};
-
-struct TensorImpl_Type {
- typedef Type* TensorImpl::*type;
- friend type get(TensorImpl_Type);
-};
-
-template struct Rob<TensorImpl_Type, &TensorImpl::type_>;
-
+ HANDLE_TH_ERRORS
+ return THPVariable_Wrap(make_variable(self->cdata.data(), false));
+ END_HANDLE_TH_ERRORS
}
int THPVariable_set_data(THPVariable *self, PyObject *data)
{
- THPUtils_assertRet(-1, THPModule_isTensor(data), "Variable data has to "
- "be a tensor, but got %s", THPUtils_typename(data));
- Py_INCREF(data);
- Py_XDECREF(self->data);
- self->data = data;
- Tensor tensor = torch::createTensor(data);
- if (&self->cdata.data().type() != &tensor.type()) {
+ HANDLE_TH_ERRORS
+ if (!THPVariable_Check(data)) {
+ throw torch::TypeError("Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name);
+ }
+ Tensor tensor = THPVariable_UnpackData(data);
+ if (self->cdata.data().type() != tensor.type()) {
// we change the type of var.data so we must change the type of var
- auto newType = VariableImpl::getType(tensor);
- self->cdata.get()->*get(TensorImpl_Type()) = newType;
+ auto newType = VariableType::getType(tensor);
+ self->cdata.temporary_hack_set_type(newType);
}
- self->cdata.data() = tensor;
+ self->cdata.data() = std::move(tensor);
return 0;
+ END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject *THPVariable_get_grad(THPVariable *self)
{
+ HANDLE_TH_ERRORS
return THPVariable_Wrap(self->cdata.grad());
+ END_HANDLE_TH_ERRORS
}
-int THPVariable_set_grad(THPVariable *self, PyObject *other)
+int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
{
+ HANDLE_TH_ERRORS
auto& var = self->cdata;
- if (other == Py_None) {
- var.grad().reset();
+ if (py_grad == Py_None) {
+ var.reset_grad();
return 0;
}
- THPUtils_assertRet(-1, THPVariable_Check(other),
- "expected Variable or None (got %s)", THPUtils_typename(other));
- THPUtils_assertRet(-1, self != (THPVariable*)other,
+ THPUtils_assertRet(-1, THPVariable_Check(py_grad),
+ "expected Variable or None (got %s)", THPUtils_typename(py_grad));
+ THPUtils_assertRet(-1, self != (THPVariable*)py_grad,
"can't assign Variable as its own grad");
- auto& data = var.data();
- auto& other_var = ((THPVariable*)other)->cdata;
- auto& other_data = other_var.data();
+ auto& grad = ((THPVariable*)py_grad)->cdata;
+ auto& sparseType = var.type().toBackend(var.is_cuda() ? kSparseCUDA : kSparseCPU);
- // Make sure the data is ok
- THPUtils_assertRet(-1, other_data.type().ID() == data.type().ID(),
+ THPUtils_assertRet(-1, grad.type() == var.type() || grad.type() == sparseType,
"assigned grad has data of a different type");
- THPUtils_assertRet(-1, other_data.type().isCuda() == data.type().isCuda(),
- "assigned grad has data located on a different device");
- if (data.type().isCuda()) {
- THPUtils_assertRet(-1, other_data.get_device() == data.get_device(),
+ if (var.type().is_cuda()) {
+ THPUtils_assertRet(-1, grad.get_device() == var.get_device(),
"assigned grad has data located on a different device");
}
- THPUtils_assertRet(-1, other_data.sizes().vec() == data.sizes().vec(),
+ THPUtils_assertRet(-1, grad.sizes().equals(var.sizes()),
"assigned grad has data of a different size");
- var.grad() = other_var;
+ var.grad() = grad;
return 0;
+ END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject *THPVariable_get_volatile(THPVariable *self)
{
- auto& var = self->cdata;
- return PyBool_FromLong(var.is_volatile());
+ const char* msg = "volatile was removed (Variable.volatile is always False)";
+ PyErr_WarnEx(PyExc_UserWarning, msg, 1);
+ Py_RETURN_FALSE;
}
int THPVariable_set_volatile(THPVariable *self, PyObject *obj)
{
- THPUtils_assertRet(-1, PyBool_Check(obj), "volatile must be a bool");
- THPUtils_assertRet(-1, !self->cdata.grad_fn(),
- "volatile can only be set on leaf variables");
- self->cdata.is_volatile() = (obj == Py_True);
- return 0;
+ return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
}
PyObject *THPVariable_get_output_nr(THPVariable *self)
{
- return PyInt_FromLong(self->cdata.output_nr());
+ HANDLE_TH_ERRORS
+ const auto output_nr = static_cast<long>(self->cdata.output_nr());
+ return PyInt_FromLong(output_nr);
+ END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_get_requires_grad(THPVariable *self)
{
+ HANDLE_TH_ERRORS
return PyBool_FromLong(self->cdata.requires_grad());
+ END_HANDLE_TH_ERRORS
}
int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
{
+ HANDLE_TH_ERRORS
THPUtils_assertRet(-1, PyBool_Check(obj), "requires_grad must be a bool");
auto& var = self->cdata;
- if (var.grad_fn()) {
+ if (!var.is_leaf()) {
const char *hint = "";
if (obj == Py_False) {
hint = " If you want to use a computed variable in a subgraph "
@@ -368,54 +342,119 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
THPUtils_setError("you can only change requires_grad flags of leaf variables.%s", hint);
return -1;
}
- var.requires_grad() = (obj == Py_True);
- if (auto grad_accumulator = var.get()->grad_accumulator.lock()) {
- grad_accumulator->is_executable = var.requires_grad();
- }
+ var.set_requires_grad(obj == Py_True);
return 0;
+ END_HANDLE_TH_ERRORS_RET(-1)
+}
+
+PyObject *THPVariable_get_name(THPVariable* self)
+{
+ if (self->cdata.name() == "")
+ Py_RETURN_NONE;
+ return THPUtils_packString(self->cdata.name().c_str());
}
PyObject *THPVariable_get_backwards_hooks(THPVariable *self)
{
+ HANDLE_TH_ERRORS
if (self->backward_hooks) {
Py_INCREF(self->backward_hooks);
return self->backward_hooks;
}
Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
}
int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj)
{
+ HANDLE_TH_ERRORS
if (obj == Py_None) {
obj = nullptr;
}
Py_XINCREF(obj);
Py_XDECREF(self->backward_hooks);
self->backward_hooks = obj;
- self->cdata.hooks().clear();
+ self->cdata.clear_hooks();
if (obj) {
- self->cdata.hooks().emplace_back(new PyFunctionPreHook(obj, 0));
+ self->cdata.add_hook(std::make_shared<PyFunctionPreHook>(obj, 0));
}
return 0;
+ END_HANDLE_TH_ERRORS_RET(-1)
+}
+
+PyObject *THPVariable_get_base(THPVariable *self)
+{
+ HANDLE_TH_ERRORS
+ if (self->cdata.is_view()) {
+ return THPVariable_Wrap(self->cdata.base());
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject *THPVariable_get_shape(THPVariable *self)
+{
+ HANDLE_TH_ERRORS
+ auto& self_ = self->cdata;
+ auto sizes = self_.sizes();
+ return THPSize_New(sizes.size(), (int64_t *)sizes.data());
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject *THPVariable_is_cuda(THPVariable *self)
+{
+ HANDLE_TH_ERRORS
+ auto& self_ = self->cdata;
+ return torch::autograd::utils::wrap(self_.is_cuda());
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject *THPVariable_is_sparse(THPVariable *self)
+{
+ HANDLE_TH_ERRORS
+ auto& self_ = self->cdata;
+ return torch::autograd::utils::wrap(self_.is_sparse());
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject *THPVariable_dtype(THPVariable *self)
+{
+ HANDLE_TH_ERRORS
+ auto& self_ = self->cdata;
+ return torch::autograd::utils::wrap(torch::getDtype(self_.type()));
+ END_HANDLE_TH_ERRORS
}
static struct PyGetSetDef THPVariable_properties[] = {
- {"_version", (getter)THPVariable_get_version, NULL, NULL, NULL},
- {"grad_fn", (getter)THPVariable_get_grad_fn, NULL, NULL, NULL},
- {"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, NULL, NULL},
- {"is_leaf", (getter)THPVariable_is_leaf, NULL, NULL, NULL},
- {"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, NULL, NULL},
- {"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL}, // only for legacy reasons
- {"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL},
- {"volatile", (getter)THPVariable_get_volatile, (setter)THPVariable_set_volatile, NULL, NULL},
- {"output_nr", (getter)THPVariable_get_output_nr, NULL, NULL, NULL},
- {"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad, NULL, NULL},
- {"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, NULL, NULL},
- {NULL}
+ {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr},
+ {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr},
+ {"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr},
+ {"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, nullptr, nullptr},
+ {"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr},
+ {"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, nullptr, nullptr},
+ {"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr}, // only for legacy reasons
+ {"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr},
+ {"_base", (getter)THPVariable_get_base, nullptr, nullptr, nullptr},
+ {"volatile", (getter)THPVariable_get_volatile, (setter)THPVariable_set_volatile, nullptr, nullptr},
+ {"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr},
+ {"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad, nullptr, nullptr},
+ {"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, nullptr, nullptr},
+ {"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
+ {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
+ {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
+ {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
+ {"dtype", (getter)THPVariable_dtype, NULL, NULL, NULL},
+ {nullptr}
+};
+
+static PyMappingMethods THPVariable_as_mapping = {
+ THPVariable_length,
+ THPVariable_getitem,
+ THPVariable_setitem,
};
PyTypeObject THPVariableType = {
- PyVarObject_HEAD_INIT(NULL, 0)
+ PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._VariableBase", /* tp_name */
sizeof(THPVariable), /* tp_basicsize */
0, /* tp_itemsize */
@@ -427,7 +466,7 @@ PyTypeObject THPVariableType = {
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
- 0, /* tp_as_mapping */
+ &THPVariable_as_mapping, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
@@ -435,7 +474,7 @@ PyTypeObject THPVariableType = {
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */
- NULL, /* tp_doc */
+ nullptr, /* tp_doc */
(traverseproc)THPVariable_traverse, /* tp_traverse */
(inquiry)THPVariable_clear, /* tp_clear */
0, /* tp_richcompare */
@@ -458,6 +497,7 @@ PyTypeObject THPVariableType = {
namespace torch { namespace autograd {
extern PyMethodDef variable_methods[];
+extern void initTorchFunctions(PyObject *module);
}}
@@ -470,5 +510,6 @@ bool THPVariable_initModule(PyObject *module)
return false;
Py_INCREF(&THPVariableType);
PyModule_AddObject(module, "_VariableBase", (PyObject *)&THPVariableType);
+ torch::autograd::initTorchFunctions(module);
return true;
}
diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h
index 10b64fc3..21ab0596 100644
--- a/torch/csrc/autograd/python_variable.h
+++ b/torch/csrc/autograd/python_variable.h
@@ -5,32 +5,34 @@
#include <ATen/ATen.h>
#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/THP_export.h"
// Python object that backs torch.autograd.Variable
struct THPVariable {
PyObject_HEAD
// Payload
torch::autograd::Variable cdata;
- // Tensor this wraps (corresponds to Python attr 'data').
- // It assumed that a THPVariable is *uniquely* identified by the
- // tensor it wraps.
- // Invariant: v->data == v->cdata->data
- PyObject* data;
// Hooks to be run on backwards pass (corresponds to Python attr
// '_backwards_hooks', set by 'register_hook')
PyObject* backward_hooks;
};
-extern PyObject *THPVariableClass;
+THP_API PyObject *THPVariableClass;
bool THPVariable_initModule(PyObject *module);
-PyObject * THPVariable_NewVolatile(PyObject *data);
-PyObject * THPVariable_NewLeaf(PyObject *data);
-PyObject * THPVariable_NewWithFunction(PyObject *data, const std::shared_ptr<torch::autograd::Function>& var);
PyObject * THPVariable_Wrap(torch::autograd::Variable var);
-PyObject * THPVariable_get_data(THPVariable *self);
inline bool THPVariable_Check(PyObject *obj)
{
return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass);
}
+
+inline torch::autograd::Variable& THPVariable_Unpack(PyObject* obj) {
+ auto var = (THPVariable*)obj;
+ return var->cdata;
+}
+
+inline at::Tensor& THPVariable_UnpackData(PyObject* obj) {
+ auto var = (THPVariable*)obj;
+ return var->cdata.data();
+}
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
new file mode 100644
index 00000000..320c5cab
--- /dev/null
+++ b/torch/csrc/autograd/python_variable_indexing.cpp
@@ -0,0 +1,350 @@
+#include "torch/csrc/autograd/python_variable_indexing.h"
+
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/THP_export.h"
+#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/utils/python_compat.h"
+#include "torch/csrc/utils/python_numbers.h"
+#include "torch/csrc/utils/tensor_new.h"
+
+#include <ATen/ExpandUtils.h>
+#include <vector>
+
+using namespace at;
+using namespace torch::autograd::utils;
+
+namespace torch { namespace autograd {
+
+Py_ssize_t THPVariable_length(PyObject* self) {
+ HANDLE_TH_ERRORS
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
+ if (self_.dim() == 0) {
+ return 0;
+ }
+ return (Py_ssize_t)self_.size(0);
+ END_HANDLE_TH_ERRORS_RET(-1)
+}
+
+
+// We allow indexing by integers, slices, ellipsis, None, Variables,
+// and tuples of those types. We also handle bools as if they were a
+// Variable[ByteTensor].
+
+static int64_t count_specified_dimensions(PyObject* index) {
+ // Count the number of indexed dimensions (everything but ellipsis and None)
+ int64_t count = 0;
+ auto size = PyTuple_GET_SIZE(index);
+ for (Py_ssize_t i = 0; i < size; i++) {
+ PyObject* obj = PyTuple_GET_ITEM(index, i);
+ if (THPVariable_Check(obj)) {
+ auto& var = reinterpret_cast<THPVariable*>(obj)->cdata;
+ if (var.type().scalarType() == kByte) {
+ count += var.dim();
+ } else {
+ count++;
+ }
+ } else if (obj != Py_None && obj != Py_Ellipsis) {
+ count++;
+ }
+ }
+ return count;
+}
+
+[[noreturn]]
+static void invalid_index(PyObject* obj) {
+ throw IndexError(
+ "only integers, slices (`:`), ellipsis (`...`), None and long or byte "
+ "Variables are valid indices (got %s)", Py_TYPE(obj)->tp_name);
+}
+
+static Variable applySlice(const Variable& self, int64_t dim, PyObject* slice, bool ensure_view=false) {
+ Py_ssize_t start, stop, step, slicelength;
+ auto length = self.size(dim);
+ if (!THPUtils_parseSlice(slice, length, &start, &stop, &step, &slicelength)) {
+ throw python_error();
+ }
+ if (step == 0) {
+ throw ValueError("step cannot be zero");
+ }
+ if (step < 0) {
+ // TODO: implement negative step
+ throw ValueError("negative step not yet supported");
+ }
+ if (!ensure_view && start == 0 && stop == length && step == 1) {
+ return self;
+ }
+ return self.slice(dim, start, stop, step);
+}
+
+static Variable applySelect(const Variable& self, int64_t dim, int64_t index) {
+ if (index == 0 && dim == 0 && self.dim() == 0) {
+ // Deprecated support for indexing 0-dim tensors as if they were 1-dim.
+ PyErr_WarnEx(PyExc_UserWarning,
+ "invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. "
+ "Use tensor.item() to convert a 0-dim tensor to a Python number", 1);
+ return at::alias(self);
+ }
+ int64_t size = self.size(dim);
+ if (index < -size || index >= size) {
+ throw IndexError("index %lld is out of bounds for dimension %lld with size %lld",
+ index, dim, size);
+ }
+ if (index < 0) {
+ index += size;
+ }
+ return self.select(dim, index);
+}
+
+static Variable sequenceToVariable(const Type& type, PyObject* seq) {
+ auto& idx_type = type.toScalarType(kLong);
+ return torch::utils::new_from_data(idx_type, -1, seq);
+}
+
+static Variable valueToTensor(const Type & type, PyObject* value) {
+ if (THPVariable_Check(value)) {
+ return reinterpret_cast<THPVariable*>(value)->cdata;
+ }
+ if (THPUtils_checkLong(value)) {
+ return type.scalarTensor(Scalar(THPUtils_unpackLong(value)));
+ }
+ if (PyFloat_Check(value)) {
+ return type.scalarTensor(Scalar(THPUtils_unpackDouble(value)));
+ }
+ throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
+}
+
+static Variable applySlicing(const Variable& self, PyObject* index, variable_list& outIndices) {
+ int64_t size = PyTuple_GET_SIZE(index);
+ int64_t dim = 0;
+ int64_t specified_dims = count_specified_dimensions(index);
+
+ auto handle_var = [&](const Variable& var) {
+ // TODO: check scalarType
+ outIndices.resize(dim + 1);
+ outIndices[dim] = var;
+ dim++;
+ };
+
+ if (specified_dims > self.dim()) {
+ throw IndexError("too many indices for tensor of dimension %d", (int)self.dim());
+ }
+
+ Variable result = self;
+ for (int64_t i = 0; i < size; i++) {
+ PyObject* obj = PyTuple_GET_ITEM(index, i);
+ if (THPUtils_checkLong(obj)) {
+ result = applySelect(result, dim, THPUtils_unpackLong(obj));
+ } else if (PySlice_Check(obj)) {
+ result = applySlice(result, dim, obj);
+ dim++;
+ } else if (obj == Py_Ellipsis) {
+ dim += self.dim() - specified_dims;
+ } else if (obj == Py_None) {
+ result = result.unsqueeze(dim);
+ dim++;
+ } else if (THPVariable_Check(obj)) {
+ handle_var(reinterpret_cast<THPVariable*>(obj)->cdata);
+ } else if (PySequence_Check(obj)) {
+ handle_var(sequenceToVariable(self.type(), obj));
+ } else {
+ auto index = THPObjectPtr(PyNumber_Index(obj));
+ if (!index) {
+ PyErr_Clear();
+ invalid_index(obj);
+ }
+ result = applySelect(result, dim, THPUtils_unpackLong(index));
+ }
+ }
+ return result;
+}
+
+static std::vector<Tensor> asTensorList(const variable_list& v) {
+ return std::vector<Tensor>(v.begin(), v.end());
+}
+
+static Variable dispatch_index(const Variable& self, const variable_list& indices) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(self);
+ return self.index(asTensorList(indices));
+}
+
+static Variable dispatch_index_put_(Variable& self, const variable_list& indices, const Variable& value) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(self);
+ return self.index_put_(asTensorList(indices), value);
+}
+
+static bool treatSequenceAsTuple(PyObject* index) {
+ if (PyTuple_Check(index)) {
+ return true;
+ }
+ if (!PySequence_Check(index)) {
+ return false;
+ }
+ // This uses a heuristics from NumPy for determining whether to treat
+ // non-tuple sequences as if they were a tuple. From the NumPy code comments:
+ //
+ // "At this point, we're left with a non-tuple, non-array, sequence:
+ // typically, a list. We use some somewhat-arbitrary heuristics from here
+ // onwards to decided whether to treat that list as a single index, or a
+ // list of indices. Backwards compatibility only takes effect for short
+ // sequences - otherwise we treat it like any other scalar."
+ auto n = PySequence_Size(index);
+ if (n < 0) {
+ // Negative size indicates a Python error in the PySequence_Size call.
+ PyErr_Clear();
+ return false;
+ }
+ if (n >= 32) {
+ return false;
+ }
+ for (Py_ssize_t i = 0; i < n; i++) {
+ auto obj = THPObjectPtr{PySequence_GetItem(index, i)};
+ if (!obj.get()) {
+ PyErr_Clear();
+ return false;
+ }
+ if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) || PySlice_Check(obj.get())) {
+ return true;
+ }
+ if (obj.get() == Py_Ellipsis || obj.get() == Py_None) {
+ return true;
+ }
+ }
+ return false;
+}
+
+static THPObjectPtr wrapTuple(PyObject* index) {
+ THPObjectPtr res;
+ if (treatSequenceAsTuple(index)) {
+ res = PySequence_Tuple(index);
+ } else {
+ res = PyTuple_Pack(1, index);
+ }
+ if (!res) throw python_error();
+ return res;
+}
+
+static bool isSingleBoolScalar(const variable_list& vars) {
+ return vars.size() == 1 && vars[0].type().scalarType() == ScalarType::Byte && vars[0].dim() == 0;
+}
+
+static PyObject* applyBoolGetitem(const Variable& self, bool index) {
+ if (index) {
+ return wrap(self.type().copy(self.unsqueeze(0)));
+ } else {
+ return wrap(self.type().tensor({0}));
+ }
+}
+
+PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
+ HANDLE_TH_ERRORS
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
+ AutoGPU auto_gpu(self_);
+
+ // handle simple types: integers, slices, ellipsis
+ if (index == Py_None) {
+ return wrap(self_.unsqueeze(0));
+ } else if (index == Py_Ellipsis) {
+ return wrap(at::alias(self_));
+ } else if (THPUtils_checkLong(index)) {
+ return wrap(applySelect(self_, 0, THPUtils_unpackLong(index)));
+ } else if (PyBool_Check(index)) {
+ return applyBoolGetitem(self_, index == Py_True);
+ } else if (PySlice_Check(index)) {
+ return wrap(applySlice(self_, 0, index, true));
+ }
+
+ // wrap index in a tuple if it's not already one
+ THPObjectPtr holder = wrapTuple(index);
+
+ variable_list variableIndices;
+ Variable sliced = applySlicing(self_, holder.get(), variableIndices);
+ if (variableIndices.empty()) {
+ if (sliced.is_same(self_)) {
+ // ensure we return a shallow copy for things like x[...]
+ sliced = at::alias(sliced);
+ }
+ return wrap(sliced);
+ }
+ if (isSingleBoolScalar(variableIndices)) {
+ return applyBoolGetitem(self_, variableIndices[0].toCByte());
+ }
+
+ // indexing by tensors ("advanced" indexing)
+ return wrap(dispatch_index(sliced, variableIndices));
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static void copy_to(Variable dst, const Variable& src) {
+ Tensor b_src;
+ // To match numpy semantics:
+ // As a special case for backwards compatibility,
+ // strip away unit dimensions from the left of 'src'
+ auto src_sizes = src.sizes();
+ size_t first_nonzero_src = src_sizes.size();
+ for (size_t i = 0; i < src_sizes.size(); ++i) {
+ if (src_sizes[i] != 1) {
+ first_nonzero_src = i;
+ break;
+ }
+ }
+
+ src_sizes = src_sizes.slice(first_nonzero_src);
+ std::tie(b_src) = expand_inplace(dst, src.view(src_sizes), "setitem");
+ dst.copy_(b_src);
+}
+
+int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
+ HANDLE_TH_ERRORS
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
+ AutoGPU auto_gpu(self_);
+ auto value = valueToTensor(self_.type(), py_value);
+
+ // handle simple types: integers, slices, ellipsis, bool
+ if (index == Py_False) {
+ // do nothing for false (technically we should check the size, but we don't have
+ // real 0-sized shapes.
+ return 0;
+ } else if (index == Py_Ellipsis) {
+ copy_to(self_, value);
+ return 0;
+ } else if (index == Py_None || index == Py_True) {
+ copy_to(self_.unsqueeze(0), value);
+ return 0;
+ } else if (THPUtils_checkLong(index)) {
+ copy_to(applySelect(self_, 0, THPUtils_unpackLong(index)), value);
+ return 0;
+ } else if (PySlice_Check(index)) {
+ copy_to(applySlice(self_, 0, index), value);
+ return 0;
+ }
+
+ // wrap index in a tuple if it's not already one
+ THPObjectPtr holder = wrapTuple(index);
+
+ variable_list variableIndices;
+ Variable sliced = applySlicing(self_, holder.get(), variableIndices);
+ if (variableIndices.empty()) {
+ copy_to(sliced, value);
+ return 0;
+ }
+ if (isSingleBoolScalar(variableIndices)) {
+ if (variableIndices[0].toCByte()) {
+ copy_to(self_.unsqueeze(0), value);
+ }
+ return 0;
+ }
+
+ // indexing by tensors ("advanced" indexing)
+ dispatch_index_put_(sliced, variableIndices, value);
+ return 0;
+ END_HANDLE_TH_ERRORS_RET(-1)
+}
+
+}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/python_variable_indexing.h b/torch/csrc/autograd/python_variable_indexing.h
new file mode 100644
index 00000000..e961ed5c
--- /dev/null
+++ b/torch/csrc/autograd/python_variable_indexing.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include <Python.h>
+
+namespace torch { namespace autograd {
+
+Py_ssize_t THPVariable_length(PyObject* self);
+PyObject* THPVariable_getitem(PyObject* self, PyObject* index);
+int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value);
+
+}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp
index e112e515..0fbbaf97 100644
--- a/torch/csrc/autograd/saved_variable.cpp
+++ b/torch/csrc/autograd/saved_variable.cpp
@@ -1,83 +1,91 @@
+#include "Python.h"
#include "torch/csrc/autograd/saved_variable.h"
+#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/jit/tracer_state.h"
-using namespace at;
+#include <ATen/Tensor.h>
+
+#include <cstdint>
+#include <list>
+#include <memory>
namespace torch { namespace autograd {
-SavedVariable::SavedVariable(const Variable& variable, Function* saved_for)
- : SavedVariable() {
- if (!variable.defined()) {
- return;
- }
- data = variable.data();
- requires_grad = variable.requires_grad();
- is_volatile = variable.is_volatile();
- expected_version = variable.current_version();
- version = variable.get()->version_counter.save();
- has_grad_fn = variable.grad_fn() != nullptr;
- output_nr = variable.output_nr();
- if (!has_grad_fn) {
- grad_accumulator = variable.grad_accumulator();
- }
- if (variable.grad_fn().get() != saved_for) {
- grad_fn = variable.grad_fn();
- }
- if (variable.tracing_state()) {
- tracing_state.reset(new jit::tracer::ValueTracingState(*variable.tracing_state()));
+SavedVariable::SavedVariable(const Variable& variable, bool is_output) {
+ if (variable.defined()) {
+ was_default_constructed_ = false;
+ output_nr_ = variable.output_nr();
+ requires_grad_ = variable.requires_grad();
+ has_grad_fn_ = !variable.is_leaf();
+ // These copies are all shared_ptr copies, so slightly more expensive.
+ // Do them here instead of in the init list in case data is undefined.
+ data_ = variable.data();
+ if (variable.is_leaf()) {
+ grad_accumulator_ = variable.grad_accumulator();
+ } else if (!is_output) {
+ grad_fn_ = variable.grad_fn();
+ }
+ version_counter_ = variable.version_counter();
+ saved_version_ = version_counter_.current_version();
+ if (variable.has_tracing_state()) {
+ tracing_state_.reset(
+ new jit::tracer::ValueTracingState(variable.tracing_state()));
+ }
}
}
-auto SavedVariable::unpack(std::shared_ptr<Function> saved_for) const -> Variable {
- if (!data.defined()) {
- if (version.defined()) {
+Variable SavedVariable::unpack(std::shared_ptr<Function> saved_for) const {
+ if (!data_.defined()) {
+ if (!was_default_constructed_) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
return Variable();
}
- if (version.is_modified()) {
+ if (saved_version_ != version_counter_.current_version()) {
throw std::runtime_error(
"one of the variables needed for gradient computation has been "
"modified by an inplace operation");
}
- Variable var = make_variable(data, requires_grad, is_volatile);
- if (has_grad_fn && !grad_fn) {
+ auto grad_fn = grad_fn_;
+ if (has_grad_fn_ && !grad_fn) {
if (!saved_for) {
// If saving the grad_fn would create a circular reference, then it must
// be passed in to the unpack function.
throw std::runtime_error("No grad_fn for non-leaf saved variable");
}
- var.grad_fn() = saved_for;
+ grad_fn = std::move(saved_for);
+ }
+
+ // NB: saved views are unpacked as normal Variables (not views) even though
+ // they still share the same storage. This works only because we never call
+ // in-place functions on unpacked variables.
+ Variable var;
+ if (grad_fn) {
+ var = make_variable(data_, Edge(std::move(grad_fn), output_nr_));
} else {
- var.grad_fn() = grad_fn;
+ var = make_variable(data_, requires_grad_);
}
- var.output_nr() = output_nr;
- var.version_counter() = version;
+ var.set_version_counter(saved_version_);
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
// should have saved the grad accumulator. Even if the Variable no longer
- // alive, the accumulator should be kept alive by the references in the graph).
- if (requires_grad && !var.grad_fn() && grad_accumulator.expired())
+ // alive, the accumulator should be kept alive by the references in the
+ // graph).
+ if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
- var.get()->grad_accumulator = grad_accumulator;
- if (tracing_state)
- var.tracing_state().reset(new jit::tracer::ValueTracingState(*tracing_state));
+ var.set_grad_accumulator(grad_accumulator_);
+ if (tracing_state_) {
+ var.set_tracing_state(new jit::tracer::ValueTracingState(*tracing_state_));
+ }
return var;
}
-auto SavedVariable::unpack_data(std::shared_ptr<Function> saved_for) const -> Tensor {
- auto var = unpack(saved_for);
- if (var.defined()) {
- return var.data();
- }
- return Tensor();
-}
-
-
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time, but the buffers have "
"already been freed. Specify retain_graph=True when calling backward "
diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h
index 6ec741da..7372d10c 100644
--- a/torch/csrc/autograd/saved_variable.h
+++ b/torch/csrc/autograd/saved_variable.h
@@ -1,50 +1,55 @@
#pragma once
-#include <mutex>
-#include <memory>
-#include <functional>
+#include "torch/csrc/autograd/variable_version.h"
+#include "torch/csrc/jit/tracer_state.h"
+
#include <ATen/ATen.h>
-#include "torch/csrc/jit/tracer_state.h"
-#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/Types.h"
+#include <cstdint>
+#include <list>
+#include <memory>
namespace torch { namespace autograd {
+struct Variable;
struct Function;
extern const char* ERR_BACKWARD_TWICE;
-struct SavedVariable {
- SavedVariable()
- : data()
- , has_grad_fn(false)
- , version()
- , requires_grad(false)
- , is_volatile(false)
- , expected_version(-1) {}
+/// A snapshot of a variable at a certain version. A `SavedVariable` stores
+/// enough information to reconstruct a variable from a certain point in time.
+class SavedVariable {
+ public:
+ SavedVariable() = default;
+ SavedVariable(const Variable& variable, bool is_output);
+ SavedVariable(SavedVariable&&) = default;
+ SavedVariable& operator=(SavedVariable&&) = default;
- SavedVariable(const Variable& variable, Function* saved_for);
+ /// Reconstructs the saved variable. Pass `saved_for` as the gradient
+ /// function if constructing the `SavedVariable` with it would have caused a
+ /// circular reference.
+ Variable unpack(std::shared_ptr<Function> saved_for = nullptr) const;
+ void reset_data() {
+ return data_.reset();
+ }
+
+ private:
+ at::Tensor data_;
- at::Tensor data;
// The gradient function associated with this node. If has_grad_fn
// is false, then this is a leaf node. Note that the grad_fn is not saved if
// it would create a circular reference. In that case, the grad_fn must be
// passed in to the unpack function when reconstructing the Variable.
- bool has_grad_fn;
- std::shared_ptr<Function> grad_fn;
- std::weak_ptr<Function> grad_accumulator;
- SavedVersion version;
- bool requires_grad;
- bool is_volatile;
- int expected_version;
- int output_nr;
- std::unique_ptr<jit::tracer::ValueTracingState> tracing_state;
-
- Variable unpack(std::shared_ptr<Function> saved_for=nullptr) const;
- at::Tensor unpack_data(std::shared_ptr<Function> saved_for=nullptr) const;
+ std::shared_ptr<Function> grad_fn_;
+ std::weak_ptr<Function> grad_accumulator_;
+ std::unique_ptr<jit::tracer::ValueTracingState> tracing_state_;
+ VariableVersion version_counter_;
+
+ uint32_t saved_version_;
+ uint32_t output_nr_;
+ bool was_default_constructed_ = true;
+ bool requires_grad_;
+ bool has_grad_fn_;
};
-
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/symbolic.h b/torch/csrc/autograd/symbolic.h
index a81c22c9..b0fd5ccc 100644
--- a/torch/csrc/autograd/symbolic.h
+++ b/torch/csrc/autograd/symbolic.h
@@ -8,20 +8,10 @@ namespace torch { namespace autograd {
struct SymbolicContext {
jit::Graph* graph;
- const std::unordered_map<void*, jit::Node*>* buffer_map;
- int batch_norm_count = 0;
};
struct symbolic_unconvertible : public std::runtime_error {
using std::runtime_error::runtime_error;
};
-
-struct HasSymbolic {
- // Add some nodes to the ONNX protobuf, under the assumption that this node
- // as a whole has the represented inputs and outputs. Raises a
- // symbolic_unconvertible exception if conversion is not supported.
- virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) = 0;
-};
-
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h
index 0b83675d..5f087205 100644
--- a/torch/csrc/autograd/utils/wrap_outputs.h
+++ b/torch/csrc/autograd/utils/wrap_outputs.h
@@ -6,6 +6,7 @@
#include <Python.h>
#include <tuple>
+#include "torch/csrc/Dtype.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/utils/python_numbers.h"
@@ -13,10 +14,6 @@
namespace torch { namespace autograd { namespace utils {
inline PyObject* wrap(at::Tensor tensor) {
- if (tensor.defined() && tensor.dim() == 0) {
- // don't expose 0-dim tensors to Variable API.
- Variable(tensor).data().as_strided_({1}, {1});
- }
return THPVariable_Wrap(Variable(std::move(tensor)));
}
@@ -37,6 +34,27 @@ inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor> tensors) {
return r.release();
}
+inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> tensors) {
+ auto r = THPObjectPtr{PyTuple_New(4)};
+ if (!r) throw python_error();
+ PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors))));
+ return r.release();
+}
+
+inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> tensors) {
+ auto r = THPObjectPtr{PyTuple_New(5)};
+ if (!r) throw python_error();
+ PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 4, wrap(std::move(std::get<4>(tensors))));
+ return r.release();
+}
+
inline PyObject* wrap(at::TensorList tl) {
auto r = THPObjectPtr{PyTuple_New(tl.size())};
if (!r) throw python_error();
@@ -58,6 +76,10 @@ inline PyObject* wrap(int64_t value) {
return THPUtils_packInt64(value);
}
+inline PyObject* wrap(double value) {
+ return PyFloat_FromDouble(value);
+}
+
inline PyObject* wrap(void* value) {
return THPUtils_packInt64(reinterpret_cast<intptr_t>(value));
}
@@ -66,5 +88,9 @@ inline PyObject* wrap(at::Scalar scalar) {
return wrap(scalar.toTensor());
}
+inline PyObject* wrap(THPDtype *dtype) {
+ Py_INCREF(dtype);
+ return (PyObject*)dtype;
+}
}}} // namespace torch::autograd::utils
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index 7ea75c80..ca4b65ec 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -1,83 +1,86 @@
#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/generated/VariableType.h"
+#include "torch/csrc/assertions.h"
+#include "torch/csrc/autograd/edge.h"
+#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/functions/accumulate_grad.h"
+#include "torch/csrc/autograd/functions/tensor.h"
+#include "torch/csrc/autograd/generated/Functions.h"
+#include "torch/csrc/autograd/generated/VariableType.h"
+#include "torch/csrc/autograd/variable_version.h"
+#include "torch/csrc/jit/tracer_state.h"
+#include "torch/csrc/utils/auto_unique_ptr.h"
-using namespace at;
+#include <ATen/ATen.h>
-namespace torch { namespace autograd {
+#include <list>
+#include <memory>
+#include <mutex>
+#include <stdexcept>
+#include <string>
+#include <vector>
-VariableImpl::VariableImpl(Tensor data_, bool requires_grad, bool is_volatile)
- : TensorImpl(getType(data_))
- , data(std::move(data_))
- , grad()
- , version_counter()
- , requires_grad(requires_grad)
- , is_volatile(is_volatile)
- , output_nr(0)
- , pyobj(nullptr) {
+namespace torch { namespace autograd {
+Variable::Impl::Impl(at::Tensor data_, bool requires_grad_, Edge gradient_edge_)
+ : TensorImpl(VariableType::getType(data_)),
+ data(std::move(data_)),
+ grad_fn(std::move(gradient_edge_.function)),
+ requires_grad(requires_grad_),
+ is_view(false),
+ output_nr(gradient_edge_.input_nr),
+ pyobj(nullptr) {
+ TORCH_ASSERTM(
+ !grad_fn || !requires_grad,
+ "_requires_grad should be false if grad_fn is set");
if (!data.defined()) {
throw std::runtime_error("data is undefined");
}
}
-VariableImpl::VariableImpl(Tensor data, std::shared_ptr<Function> grad_fn)
- : VariableImpl(std::move(data))
-{
- this->grad_fn = grad_fn;
- requires_grad = grad_fn->is_executable;
- output_nr = grad_fn->num_inputs++;
-}
+Variable::Impl::~Impl() = default;
-VariableImpl::VariableImpl(Tensor data)
- : VariableImpl(std::move(data), false, false)
-{
-}
-
-VariableImpl::~VariableImpl() {
-}
-
-const char * VariableImpl::toString() const {
+const char* Variable::Impl::toString() const {
return "Variable";
}
-IntList VariableImpl::sizes() const {
+IntList Variable::Impl::sizes() const {
return data.sizes();
}
-IntList VariableImpl::strides() const {
+IntList Variable::Impl::strides() const {
return data.strides();
}
-int64_t VariableImpl::dim() const {
+int64_t Variable::Impl::dim() const {
return data.dim();
}
-const char * VariableImpl::typeString() {
+const char* Variable::Impl::typeString() {
return "VariableType";
}
-void * VariableImpl::unsafeGetTH(bool retain) {
+void* Variable::Impl::unsafeGetTH(bool retain) {
return data.unsafeGetTH(retain);
}
-Scalar VariableImpl::localScalar() {
- return data.pImpl->localScalar();
+std::unique_ptr<at::Storage> Variable::Impl::storage() {
+ return data.storage();
}
-void VariableImpl::assign_(Scalar s) {
- data.assign_(s);
+Scalar Variable::Impl::localScalar() {
+ return data.pImpl->localScalar();
}
-std::shared_ptr<Function> VariableImpl::get_grad_accumulator() {
+std::shared_ptr<Function> Variable::Impl::get_grad_accumulator() {
if (grad_fn) {
- throw std::logic_error("get_grad_accumulator() should be only called on leaf Variables");
+ throw std::logic_error(
+ "get_grad_accumulator() should be only called on leaf Variables");
}
if (!requires_grad) {
return nullptr;
}
- std::lock_guard<std::mutex> lock(grad_accumulator_lock);
+ std::lock_guard<std::mutex> lock(mutex);
auto result = grad_accumulator.lock();
if (result) return result;
@@ -87,39 +90,86 @@ std::shared_ptr<Function> VariableImpl::get_grad_accumulator() {
return result;
}
-namespace {
-
-struct VariableTypes {
- VariableTypes() {
- auto& context = at::globalContext();
- for (int p = 0; p < static_cast<int>(Backend::NumOptions); ++p) {
- for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
- auto baseType = context.type_registry[p][s].get();
- if (baseType) {
- auto id = static_cast<int>(baseType->ID());
- types[id].reset(new VariableType(&context, baseType));
- }
- }
- }
+Variable::ViewImpl::ViewImpl(
+ Variable base_,
+ at::Tensor data_,
+ Edge gradient_edge_)
+ : Variable::Impl(std::move(data_), false, std::move(gradient_edge_)),
+ base(std::move(base_)) {
+ TORCH_ASSERTM(base.defined(), "base is undefined");
+ if (base.is_view()) {
+ base = base.base();
}
+ is_view = true;
+ version_counter = base.version_counter();
+ attr_version = version_counter.current_version();
+}
- std::unique_ptr<Type> types[static_cast<int>(TypeID::NumOptions)];
-};
+std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() {
+ std::lock_guard<std::mutex> lock(mutex);
+ if (!grad_fn && !base.requires_grad()) {
+ return grad_fn;
+ }
+ auto current_version = version_counter.current_version();
+ if (attr_version != current_version) {
+ TORCH_ASSERT(output_nr == 0);
+ auto fn = std::make_shared<generated::AsStridedBackward>();
+ fn->self_geometry = at::TensorGeometry(base);
+ fn->size = sizes();
+ fn->stride = strides();
+ fn->storage_offset = data.storage_offset();
+ fn->set_next_edges(collect_next_edges(base));
+ fn->set_num_inputs(1);
+ grad_fn = std::move(fn);
+ attr_version = current_version;
+ }
+ return grad_fn;
+}
-} // anonymous namespace
+void Variable::ViewImpl::rebase_history(Edge gradient_edge) {
+ TORCH_ASSERT(gradient_edge.input_nr == 0);
+ TORCH_ASSERT(gradient_edge.function);
+ TORCH_ASSERTM(
+ gradient_edge.function->num_inputs() == 1,
+ "Functions which modify views in-place must return a single Variable");
+ this->output_nr = gradient_edge.input_nr;
+ auto copy_slices = std::make_shared<CopySlices>(
+ base, at::TensorGeometry(data), std::move(gradient_edge.function));
+ base.set_gradient_edge({std::move(copy_slices), 0});
+ get_grad_fn(); // trigger an update to the view's grad_fn
+}
-Type* VariableImpl::getType(const Tensor& tensor)
-{
- if (!tensor.defined()) {
- throw std::runtime_error("tensor is undefined");
+void Variable::rebase_history(Edge gradient_edge) {
+ TORCH_ASSERT(gradient_edge.function != nullptr);
+ if (is_view()) {
+ auto& impl = static_cast<Variable::ViewImpl&>(*get());
+ impl.rebase_history(std::move(gradient_edge));
+ } else {
+ set_gradient_edge(std::move(gradient_edge));
}
- return getType(tensor.type());
}
-Type* VariableImpl::getType(const Type& baseType)
-{
- static VariableTypes vt;
- return vt.types[static_cast<int>(baseType.ID())].get();
+Variable Variable::detach() const {
+ auto detached = make_variable(data(), /*requires_grad=*/false);
+ detached.set_version_counter(version_counter());
+ return detached;
}
+void Variable::detach_() {
+ if (is_view()) {
+ throw std::runtime_error(
+ "Can't detach views in-place. Use detach() instead");
+ }
+ set_requires_grad(false);
+ set_gradient_edge(Edge());
+}
+
+void Variable::set_tracing_state(
+ jit::tracer::ValueTracingState* new_tracing_state) {
+ get()->tracing_state.reset(new_tracing_state);
+}
+
+jit::tracer::ValueTracingState& Variable::tracing_state() const noexcept {
+ return *get()->tracing_state;
+}
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h
index 0f134378..c25fe4cc 100644
--- a/torch/csrc/autograd/variable.h
+++ b/torch/csrc/autograd/variable.h
@@ -1,226 +1,605 @@
#pragma once
-// A wrapper around at::Tensor to represent autograd Variables. Variables
-// can be implicitly converted to an at::Tensor.
+#include <Python.h>
-#include <mutex>
+#include "torch/csrc/autograd/edge.h"
+#include "torch/csrc/autograd/function_hook.h"
+#include "torch/csrc/autograd/variable_version.h"
+#include "torch/csrc/utils/auto_unique_ptr.h"
+
+#include <ATen/ATen.h>
+
+#include <list>
#include <memory>
+#include <mutex>
+#include <stdexcept>
+#include <string>
#include <vector>
-#include <functional>
-#include <ATen/ATen.h>
-#include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/tracer_state.h"
-#include "torch/csrc/autograd/function_hook.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
-#include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/Types.h"
+namespace torch {
+namespace autograd {
+struct Function;
+} // namespace autograd
+namespace jit { namespace tracer {
+// Has to be forward declared because tracer_state.h has a dependency on
+// variable.h.
+struct ValueTracingStateElem;
+using ValueTracingState = std::list<ValueTracingStateElem>;
+}} // namespace jit::tracer
+} // namespace torch
namespace torch { namespace autograd {
-using at::Tensor;
-struct VariableImpl;
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// Variable
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// A `Variable` augments a `Tensor` with the ability to interact in our
+/// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
+/// `Function`s in the autograd graph. A `Variable` can either be a leaf, like a
+/// weight in a neural network, or an interior variable, when it is the result
+/// of an operation between variables. Every `Variable` also stores another
+/// `Variable` called its `grad` (gradient). If the variable is a leaf, its
+/// gradient will be accumulated into this variable.
+///
+/// Gradient Edges
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
+/// edge in the autograd graph that connects the variable to a particular input
+/// of the gradient function that will be invoked with the variable during the
+/// backward pass. More precisely, this gradient function can be one of two
+/// things:
+/// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
+/// gradient of the function that produced the variable.
+/// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
+/// scalar gradient value into its `grad` variable.
+///
+/// Versioning
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// Another major feature of `Variable`s are *versions*. Versions are
+/// incremented when an in-place mutation of a variable occurs. Versions are
+/// useful when constructing `SavedVariable`s, which take a snapshot of a
+/// `Variable` at a certain version. You can retrieve a `Variable`'s version
+/// through its `current_version()` method.
+///
+/// Views
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// It is possible for a `Variable` to be a *view* of another `Variable`, in
+/// which case it tracks that `Variable`'s data and autograd history. Beyond
+/// construction, the interface of a view is identical to that of a regular
+/// `Variable`. You can determine whether `Variable` is in fact a view by
+/// probing its `is_view()` method.
+///
+/// Interface
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// `Variable` inherits from `Tensor` and thus its API is a superset of that of
+/// `Tensor`. This means you can perform all the usual mathematical and other
+/// operations you can perform on `Tensor`s also on `Variable`s. Furthermore,
+/// `Variable` and `Tensor` actually convert implicitly between each other. You
+/// can thus call functions defined on `Tensor`s also with `Variable`s. For
+/// this, the `Variable` class allows implicit construction from `Tensor`. It is
+/// the responsibility of calling code to ensure that this constructor is
+/// invoked only when the `Tensor`'s dynamic type is actually `Variable`. Most
+/// notably, it is *not* correct to construct a brand new `Variable` from a
+/// `Tensor` using this constructor. To do so, you must use the `make_variable`
+/// free function instead. To create a view variable, use `make_variable_view`.
+///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct Variable : public at::Tensor {
- inline Variable(VariableImpl * self, bool retain);
- Variable() : Tensor() {}
- Variable(const Variable & rhs) : Tensor(rhs) {}
- Variable(Variable && rhs) noexcept : Tensor(std::move(rhs)) {}
+ /// Default constructor.
+ Variable() = default;
+
+ // Factory Functions
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ // NOTE: These factory functions have to be friends to access the
+ // `Variable::Impl`. As a side effect, it allows us to keep them in the class.
+
+ /// Creates a `Variable` that is a *view* of another (*base*) variable.
+ /// The `gradient_edge` is an optional (gradient_function, input_number) pair.
+ friend Variable
+ make_variable_view(Variable base, at::Tensor data, Edge gradient_edge);
+
+ /// Creates a `Variable` from the given `Tensor`. `requires_grad` should be
+ /// set only for leaves, and determines whether the `Variable` will accumulate
+ /// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic
+ /// type *must* be `Tensor`.
+ friend Variable make_variable(at::Tensor data, bool requires_grad);
+
+ /// Creates a `Variable` from the given `Tensor` and specify a
+ /// `gradient_edge`, i.e. a (function, input_nr) pair specifying the function
+ /// in the autograd graph, and what particular input of that function, this
+ /// variable is connected to.
+ friend Variable make_variable(at::Tensor data, Edge gradient_edge);
+
+ // Tensor Conversions
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ // "Downcasts" a `Tensor` into a `Variable`. Only call this on tensors you
+ // know are Variables.
+ /*implicit*/ Variable(at::Tensor const& rhs) : at::Tensor(rhs) {}
+ /*implicit*/ Variable(at::Tensor&& rhs) noexcept
+ : at::Tensor(std::move(rhs)) {}
+
+ // NOTE: Assignment operators to Tensor come for free from the constructors.
+
+ /// Downcasts the `Tensor` reference to a `Variable` reference. If compiling
+ /// in DEBUG mode and the tensor's dynamic type is not in fact `Variable`,
+ /// throws a `std::runtime_error` exception.
+ /// NOTE: Has to be a friend function because runtime type information is
+ /// available only for `TensorImpl`/`Impl` and not the `Tensor`/`Variable`
+ /// classes, as the latter are not polymorphic classes (`Tensor` has no
+ /// virtual methods).
+ friend Variable& as_variable_ref(at::Tensor& tensor);
+
+ const at::Tensor& data() const noexcept;
+ at::Tensor& data() noexcept;
+
+ // Gradient Function and Edges
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ /// Gets the gradient function of the `Variable`. If this is a leaf variable,
+ /// the pointer returned will be null.
+ const std::shared_ptr<Function>& grad_fn() const;
+
+ /// Gets the raw gradient function pointer, whatever it currently is.
+ Function* grad_fn_unsafe() const;
+
+ /// Set the gradient accumulator of the `Variable`. This is only applicable to
+ /// leaf variables. Interior variables should call `set_gradient_edge()`.
+ void set_grad_accumulator(std::weak_ptr<Function> grad_accumulator);
+
+ /// Attempts to get a pointer to the gradient accumulator of the `Variable`,
+ /// if it still exists. If the gradient accumulator function has been
+ /// destroyed, returns a `nullptr`.
+ std::shared_ptr<Function> try_get_grad_accumulator() const;
+
+ /// Gets the gradient accumulator of the `Variable` if it has one, or else
+ /// create one on the fly and return it.
+ std::shared_ptr<Function> grad_accumulator() const;
- // Implicitly casts a Tensor to a Variable. This should only be called on
- // Tensors which you know are actually Variables.
- /*implicit*/ Variable(Tensor const & rhs) : Tensor(rhs) {}
- /*implicit*/ Variable(Tensor && rhs) noexcept : Tensor(std::move(rhs)) {}
+ /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
+ /// gradient function if this is an interior `Variable`, or the gradient
+ /// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
+ /// will store the input index of the `Function` to which this variable is
+ /// connected in its `input_nr` field. For leaves, the `input_nr` is always
+ /// zero. Note that `set_gradient_edge` and `gradient_edge` are not
+ /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
+ /// `set_grad_accumulator` to set the accumulator.
+ Edge gradient_edge() const {
+ // If grad_fn is null (as is the case for a leaf node), we instead
+ // interpret the gradient function to be a gradient accumulator, which will
+ // accumulate its inputs into the grad property of the variable. These
+ // nodes get suppressed in some situations, see "suppress gradient
+ // accumulation" below. Note that only variables which have `requires_grad =
+ // True` can have gradient accumulators.
+ if (const auto& gradient = grad_fn()) {
+ return Edge(gradient, output_nr());
+ } else {
+ return Edge(grad_accumulator(), 0);
+ }
+ }
- inline VariableImpl* get() const;
+ /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
+ /// `Variable`.
+ /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
+ /// and never the `grad_accumulator`. For the latter, use
+ /// `set_grad_accumulator`. This allows late construction of an interior
+ /// `Variable`.
+ void set_gradient_edge(Edge edge) noexcept;
- inline const Tensor & data() const;
- inline Tensor & data();
+ /// Returns the input index of the gradient `Function` to which this
+ /// `Variable` is connected.
+ uint32_t output_nr() const noexcept;
- inline Tensor opt_data() const;
+ /// True if this `Variable` is a leaf and thus does not have a `grad_fn`.
+ bool is_leaf() const noexcept;
- inline const Variable & grad() const;
- inline Variable & grad();
+ // The Grad Variable
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- inline const std::shared_ptr<Function>& grad_fn() const;
- inline std::shared_ptr<Function>& grad_fn();
+ /// Accesses the gradient `Variable` of this `Variable`.
+ const Variable& grad() const noexcept;
+ Variable& grad() noexcept;
+ void reset_grad() noexcept;
- std::shared_ptr<Function> grad_accumulator() const;
+ /// Sets the `requires_grad` property of `Variable`. This should be true for
+ /// leaf variables that want to accumulate gradients, and false for all other
+ /// variables.
+ void set_requires_grad(bool requires_grad) noexcept;
+ bool requires_grad() const noexcept;
- inline const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const;
- inline std::vector<std::shared_ptr<FunctionPreHook>>& hooks();
+ // Versions
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- inline auto_unique_ptr<jit::tracer::ValueTracingState>& tracing_state() const;
+ /// Increments the version count of this `Variable`.
+ void bump_version() noexcept;
+ void set_version_counter(const VariableVersion& version_counter) noexcept;
- inline int current_version() const;
+ /// Retrieves this `Variable`s version counter.
+ const VariableVersion& version_counter() const noexcept;
- inline VariableVersion& version_counter() const;
+ /// Retrieves the current value of the `Variable`'s version counter.
+ /// Equivalent to calling `version_counter().current_version()`.
+ uint32_t current_version() const noexcept;
- inline const int& output_nr() const;
- inline int& output_nr();
+ // Autograd Graph Interaction
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ /// Update the `grad_fn` of an existing Variable. Called after in-place
+ /// modifications.
+ void rebase_history(Edge gradient_edge);
+
+ /// Returns a copy of this `Variable` that is detached from its autograd graph
+ /// and has a blank version. This method is OK to call if the `Variable` is a
+ /// view.
+ Variable detach() const;
+
+ /// Like `detach()`, but removes this `Variable` in-place. This method may
+ /// only be called on non-view `Variable`s. You can use `is_view()` to check
+ /// this. If this `Variable` is a view, throws an `std::runtime_error()`.
+ void detach_();
+
+ // Hooks
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ void add_hook(std::shared_ptr<FunctionPreHook> hook);
+ const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
+ void clear_hooks();
+
+ // JIT Tracing
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ void set_tracing_state(jit::tracer::ValueTracingState* new_tracing_state);
+ jit::tracer::ValueTracingState& tracing_state() const noexcept;
+
+ /// Returns true if the `Variable`'s tracing state is not null.
+ bool has_tracing_state() const noexcept;
+
+ // View Variables
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ /// Returns true if this `Variable` is a view of another `Variable`.
+ bool is_view() const noexcept;
+
+ /// Returns the `Variable` that this `Variable` is a view of. If this
+ /// `Variable` is not a view, throw a `std::runtime_error`.
+ const Variable& base() const;
+
+ // Miscellaneous
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ /// Compares this `Variable` to another `Variable` (or `Tensor`) via
+ /// pointer-equality.
+ bool is_same(const Variable& other) const noexcept {
+ return this->pImpl == other.pImpl;
+ }
- inline const bool& requires_grad() const;
- inline bool& requires_grad();
+ void set_name(const std::string& name);
+ const std::string& name() const noexcept;
- inline const bool& is_volatile() const;
- inline bool& is_volatile();
+ PyObject* pyobj() const noexcept;
+ void set_pyobj(PyObject* pyobj) noexcept;
- inline Variable & operator=(Variable && rhs) &;
- inline Variable & operator=(const Variable & rhs) &;
- inline Variable & operator=(Tensor && rhs) &;
- inline Variable & operator=(const Tensor & rhs) &;
+ // Hacks!
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ /// Sets the type of the underlying `Tensor`. Used for a bad (hopefully)
+ /// temporary hack in python_variable.h. If removed, also remove the `using
+ /// at::TensorImpl::type_;` in `Variable::Impl`.
+ void temporary_hack_set_type(at::Type*) noexcept;
+
+ private:
+ /// Private implementation struct of the `Variable`. This struct declaration
+ /// and the `get()` method which exposes it shall forever remain private and
+ /// never be exposed to the public interface of this class.
+ struct Impl;
+ struct ViewImpl;
+
+ // Private Methods
+ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ Variable(Variable::Impl* self, bool retain);
+ Impl* get() const noexcept;
};
-struct VariableImpl : public at::TensorImpl {
-public:
- explicit VariableImpl(at::Tensor data);
- VariableImpl(at::Tensor data, std::shared_ptr<Function> grad_fn);
- VariableImpl(at::Tensor data, bool requires_grad, bool is_volatile=false);
- virtual ~VariableImpl();
- virtual const char * toString() const override;
- virtual at::IntList sizes() const override;
- virtual at::IntList strides() const override;
- virtual int64_t dim() const override;
- virtual at::Scalar localScalar() override;
- virtual void assign_(at::Scalar s) override;
- virtual void * unsafeGetTH(bool retain) override;
- static const char * typeString();
-
- // Get the VariableType for a base Tensor type
- static at::Type* getType(const at::Type& baseType);
- static at::Type* getType(const at::Tensor& tensor);
-
-public:
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Variable::Impl
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+struct Variable::Impl : public at::TensorImpl {
+ explicit Impl(
+ at::Tensor data_,
+ bool requires_grad_ = false,
+ Edge edge = Edge());
+
+ virtual ~Impl();
+
+ const char* toString() const override;
+ at::IntList sizes() const override;
+ at::IntList strides() const override;
+ int64_t dim() const override;
+ at::Scalar localScalar() override;
+ void* unsafeGetTH(bool retain) override;
+ std::unique_ptr<at::Storage> storage() override;
+ static const char* typeString();
+
std::shared_ptr<Function> get_grad_accumulator();
+ virtual std::shared_ptr<Function>& get_grad_fn() {
+ return grad_fn;
+ }
+ // Make this field public so we can access it from `Variable`. Part of
+ // temporary_hack_set_type.
+ using at::TensorImpl::type_;
+
+ std::string name;
at::Tensor data;
+
Variable grad;
std::shared_ptr<Function> grad_fn;
+ std::weak_ptr<Function> grad_accumulator;
+
VariableVersion version_counter;
std::vector<std::shared_ptr<FunctionPreHook>> hooks;
- std::weak_ptr<Function> grad_accumulator;
- std::mutex grad_accumulator_lock;
- bool requires_grad;
- bool is_volatile;
+
+ bool requires_grad; // only meaningful on leaf variables (must be false
+ // otherwise)
+ bool is_view;
// The "output number" of this variable; e.g., if this variable
// was the second output of a function, then output_nr == 1.
// We use this to make sure we can setup the backwards trace
// correctly when this variable is passed to another function.
- int output_nr;
- PyObject *pyobj; // weak reference
+ uint32_t output_nr;
+ PyObject* pyobj; // weak reference
+
+ // Mutex to ensure that concurrent read operations that modify internal
+ // state are still thread-safe. Used by get_grad_fn and
+ // get_grad_accumulator.
+ std::mutex mutex;
// For use in torch::jit::tracer
auto_unique_ptr<jit::tracer::ValueTracingState> tracing_state;
- friend struct VariableType;
};
-inline Variable make_variable(at::Tensor data) {
- return Variable(new VariableImpl(std::move(data)), false);
-}
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Variable::ViewImpl
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// A Variable that is a view on another Variable. The base and view share the
+/// same version_counter. The grad_fn field of the Variable may become stale
+/// due to in-place modifications of the shared data. Accesses should go
+/// through get_grad_fn(). All other fields are always valid.
+struct Variable::ViewImpl : public Variable::Impl {
+ ViewImpl(Variable base_, at::Tensor data_, Edge gradient_edge);
+
+ /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
+ /// re-create the grad_fn to express the up-to-date view relationship between
+ /// this and the base Variable.
+ virtual std::shared_ptr<Function>& get_grad_fn() override;
+
+ /// Called after in-place modifications. Modifies the grad_fn of the base
+ /// Variable.
+ void rebase_history(Edge gradient_edge);
+
+ /// The base `Variable` (never a view).
+ Variable base;
+
+ /// The value of the version_counter at the time grad_fn was created. The
+ /// grad_fn field is stale if attr_version !=
+ /// version_counter.current_version().
+ uint32_t attr_version;
+};
-inline Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn) {
- return Variable(new VariableImpl(std::move(data), std::move(grad_fn)), false);
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Variable Implementation
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+// Factory Functions
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline Variable make_variable_view(
+ Variable base,
+ at::Tensor data,
+ Edge gradient_edge = Edge()) {
+ if (data.defined()) {
+ auto impl = new Variable::ViewImpl(
+ std::move(base), std::move(data), std::move(gradient_edge));
+ return Variable(impl, /*retain=*/false);
+ }
+ return Variable();
}
-inline Variable make_variable(at::Tensor data, bool requires_grad, bool is_volatile=false) {
- return Variable(new VariableImpl(std::move(data), requires_grad, is_volatile), false);
+inline Variable make_variable(at::Tensor data, bool requires_grad = false) {
+ if (data.defined()) {
+ auto impl = new Variable::Impl(data, requires_grad);
+ return Variable(impl, /*retain=*/false);
+ }
+ return Variable();
}
-
-inline Variable::Variable(VariableImpl * self, bool retain) : Tensor(self, retain) {
+inline Variable make_variable(at::Tensor data, Edge gradient_edge) {
+ if (data.defined()) {
+ auto impl = new Variable::Impl(data, false, std::move(gradient_edge));
+ return Variable(impl, /*retain=*/false);
+ }
+ return Variable();
}
-inline VariableImpl* Variable::get() const {
- return static_cast<VariableImpl*>(pImpl);
+// Tensor Conversion
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline Variable& as_variable_ref(at::Tensor& tensor) {
+#ifdef DEBUG
+ // dynamic_cast will return a nullptr if the `TensorImpl`'s dynamic type is
+ // not `Variable::Impl`.
+ if (dynamic_cast<Variable::Impl*>(tensor.get()) == nullptr) {
+ throw std::runtime_error(
+ "Attempted to cast a Tensor to a Variable, but "
+ "the dynamic type of the value is not Variable.");
+ }
+#endif
+ return static_cast<Variable&>(tensor);
}
-inline const Tensor & Variable::data() const {
+inline const at::Tensor& Variable::data() const noexcept {
return get()->data;
}
-inline Tensor & Variable::data() {
+
+inline at::Tensor& Variable::data() noexcept {
return get()->data;
}
-inline Tensor Variable::opt_data() const {
- if (!defined()) {
- return Tensor();
- }
- return data();
+// Gradient Function and Edges
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline const std::shared_ptr<Function>& Variable::grad_fn() const {
+ return get()->get_grad_fn();
}
-inline const Variable & Variable::grad() const {
- return get()->grad;
+inline Function* Variable::grad_fn_unsafe() const {
+ return get()->grad_fn.get();
}
-inline Variable & Variable::grad() {
- return get()->grad;
+
+inline void Variable::set_grad_accumulator(
+ std::weak_ptr<Function> grad_accumulator) {
+ get()->grad_accumulator = std::move(grad_accumulator);
+}
+
+inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const {
+ return get()->grad_accumulator.lock();
}
-inline const std::shared_ptr<Function>& Variable::grad_fn() const {
- return get()->grad_fn;
-};
-inline std::shared_ptr<Function>& Variable::grad_fn() {
- return get()->grad_fn;
-};
inline std::shared_ptr<Function> Variable::grad_accumulator() const {
return get()->get_grad_accumulator();
-};
+}
-inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() const {
- return get()->hooks;
-};
-inline std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks() {
- return get()->hooks;
-};
+inline void Variable::set_gradient_edge(Edge edge) noexcept {
+ get()->grad_fn = std::move(edge.function);
+ get()->output_nr = edge.input_nr;
+}
-inline auto_unique_ptr<jit::tracer::ValueTracingState>& Variable::tracing_state() const {
- return get()->tracing_state;
-};
+inline uint32_t Variable::output_nr() const noexcept {
+ return get()->output_nr;
+}
+
+inline bool Variable::is_leaf() const noexcept {
+ return get()->grad_fn == nullptr;
+}
+
+// The Grad Variable
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline const Variable& Variable::grad() const noexcept {
+ return get()->grad;
+}
+
+inline Variable& Variable::grad() noexcept {
+ return get()->grad;
+}
+
+inline void Variable::reset_grad() noexcept {
+ get()->grad.reset();
+}
-inline int Variable::current_version() const {
+inline void Variable::set_requires_grad(bool requires_grad) noexcept {
+ get()->requires_grad = requires_grad;
+}
+
+inline bool Variable::requires_grad() const noexcept {
+ return get()->requires_grad || get()->grad_fn ||
+ (is_view() && base().requires_grad());
+}
+
+// Versions
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline void Variable::set_version_counter(
+ const VariableVersion& version_counter) noexcept {
+ get()->version_counter = version_counter;
+}
+
+inline void Variable::bump_version() noexcept {
+ get()->version_counter.bump();
+}
+
+inline uint32_t Variable::current_version() const noexcept {
return get()->version_counter.current_version();
}
-inline VariableVersion& Variable::version_counter() const {
+inline const VariableVersion& Variable::version_counter() const noexcept {
return get()->version_counter;
}
-inline const int& Variable::output_nr() const {
- return get()->output_nr;
+// Hooks
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
+ get()->hooks.push_back(std::move(hook));
}
-inline int& Variable::output_nr() {
- return get()->output_nr;
+inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
+ const noexcept {
+ return get()->hooks;
+}
+
+inline void Variable::clear_hooks() {
+ get()->hooks.clear();
}
-inline const bool& Variable::requires_grad() const {
- return get()->requires_grad;
+// JIT Tracing
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline bool Variable::has_tracing_state() const noexcept {
+ return get()->tracing_state != nullptr;
}
-inline bool& Variable::requires_grad() {
- return get()->requires_grad;
+
+// View Variables
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline bool Variable::is_view() const noexcept {
+ return get()->is_view;
}
-inline const bool& Variable::is_volatile() const {
- return get()->is_volatile;
+inline const Variable& Variable::base() const {
+ if (is_view()) {
+ return static_cast<Variable::ViewImpl*>(get())->base;
+ }
+ throw std::runtime_error("Can't get base of non-view");
}
-inline bool& Variable::is_volatile() {
- return get()->is_volatile;
+
+// Miscellaneous
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline void Variable::set_name(const std::string& name) {
+ get()->name = name;
}
-inline Variable & Variable::operator=(Variable && rhs) & {
- rhs.swap(*this);
- return *this;
+inline const std::string& Variable::name() const noexcept {
+ return get()->name;
}
-inline Variable & Variable::operator=(const Variable & rhs) & {
- Variable(rhs).swap(*this);
- return *this;
+
+inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
+ get()->pyobj = pyobj;
}
-inline Variable & Variable::operator=(Tensor && rhs) & {
- rhs.swap(*this);
- return *this;
+
+inline PyObject* Variable::pyobj() const noexcept {
+ return get()->pyobj;
}
-inline Variable & Variable::operator=(const Tensor & rhs) & {
- Variable(rhs).swap(*this);
- return *this;
+
+// Hacks!
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline void Variable::temporary_hack_set_type(at::Type* new_type) noexcept {
+ get()->type_ = new_type;
}
+// Private Methods
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+inline Variable::Variable(Variable::Impl* self, bool retain)
+ : at::Tensor(self, retain) {}
+
+inline Variable::Impl* Variable::get() const noexcept {
+ return static_cast<Variable::Impl*>(pImpl);
+}
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/variable_version.h b/torch/csrc/autograd/variable_version.h
index 3dc28bd1..a5c5d9d3 100644
--- a/torch/csrc/autograd/variable_version.h
+++ b/torch/csrc/autograd/variable_version.h
@@ -1,98 +1,38 @@
#pragma once
+#include <atomic>
+#include <cstdint>
#include <memory>
-namespace torch { namespace autograd {
-
-struct VersionBlock {
- VersionBlock() : version(), live_refs(1) {}
-
- // monotonically increasing version
- std::atomic<int> version;
- // number of references excluding SavedVariables
- std::atomic<int> live_refs;
-};
+// Every Variable has a version counter. Version counters are incremented
+// whenever the data or shape of a tensor changes through Variable operations.
+// These are typicallly in-place operations. Version counters are used to
+// detect modifications to saved variables which would result in incorrect
+// gradient calculations. Version counters may be shared between Variables:
+//
+// 1. A view shares the version counter of the base Variable,
+// 2. Detached variables share the version counter of the source,
+// 3. Unpacked saved variables share the version counter of the source.
-struct SavedVersion;
+namespace torch { namespace autograd {
struct VariableVersion {
- VariableVersion() : version_block(std::make_shared<VersionBlock>()) {}
- VariableVersion(const VariableVersion&) = delete;
- VariableVersion(VariableVersion&&) = delete;
-
- ~VariableVersion() {
- --version_block->live_refs;
+ public:
+ // NOTE: As of C++11 and 14, default-constructing a std::atomic variable
+ // leaves it in a persistently undefined state. See
+ // https://cplusplus.github.io/LWG/issue2334.
+ VariableVersion(uint32_t version = 0)
+ : version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {}
+
+ void bump() noexcept {
+ version_block_->fetch_add(1);
}
- // increment the version counter
- void increment() { version_block->version++; }
-
- // current version
- int current_version() const { return version_block->version.load(); }
-
- // number of variables using this version counter (excludes SavedVariables)
- int live_refs() const { return version_block->live_refs.load(); }
-
- // creates a saved reference with the current version and the counter
- inline SavedVersion save() const;
-
- // Uses another variable's version counter. Used for variables which share storages
- // NOTE: not thread-safe to call this from multiple threads without synchronization
- VariableVersion& operator=(const VariableVersion& other) {
- other.version_block->live_refs++;
- version_block->live_refs--;
- version_block = other.version_block;
- return *this;
+ uint32_t current_version() const noexcept {
+ return version_block_->load();
}
- // Uses the version counter from a SavedVariable
- // NOTE: not thread-safe to call this from multiple threads without synchronization
- inline VariableVersion& operator=(const SavedVersion& other);
-
-private:
- friend struct SavedVersion;
- std::shared_ptr<VersionBlock> version_block; // always non-null
+ private:
+ std::shared_ptr<std::atomic<uint32_t>> version_block_;
};
-
-// The version counter used in SavedVariables. Saves the expected_version (the
-// version at the time of save) and a reference to the version counter's
-// version_block.
-struct SavedVersion {
- SavedVersion() {}
- SavedVersion(const VariableVersion& version)
- : expected_version(version.current_version())
- , version_block(version.version_block) {}
-
- // if the version_block has been modified since when it was saved
- bool is_modified() const {
- return expected_version != version_block->version.load();
- }
-
- // true if the version_block is defined
- bool defined() const {
- return static_cast<bool>(version_block);
- }
-
-private:
- friend struct VariableVersion;
- int expected_version;
- std::shared_ptr<VersionBlock> version_block; // may be null
-};
-
-SavedVersion VariableVersion::save() const {
- return SavedVersion(*this);
-}
-
-VariableVersion& VariableVersion::operator=(const SavedVersion& other) {
- if (!other.version_block) {
- throw std::runtime_error(
- "Can't take version counter from empty SavedVersion. File a bug report.");
- }
- other.version_block->live_refs++;
- version_block->live_refs--;
- version_block = other.version_block;
- return *this;
-}
-
-
}} // namespace torch::autograd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment