Skip to content

Commit

Permalink
Add torch.autograd.differentiate
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed May 1, 2017
1 parent 20aa5b0 commit e5db8f9
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 27 deletions.
2 changes: 2 additions & 0 deletions docs/source/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Automatic differentiation package - torch.autograd

.. autofunction:: backward

.. autofunction:: differentiate

Variable
--------

Expand Down
34 changes: 34 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,53 @@ def backward(ctx, grad_output):
def test_hessian_vector(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
y = Variable(torch.randn(2, 2), requires_grad=True)

z = x ** 2 + y * x + y ** 2
z.backward(Variable(torch.ones(2, 2), requires_grad=True), retain_variables=True)

x_grad = 2 * x.data + y.data
y_grad = x.data + 2 * y.data
self.assertEqual(x.grad.data, x_grad)
self.assertEqual(y.grad.data, y_grad)

grad_sum = 2 * x.grad + y.grad
grad_sum.backward(torch.ones(2, 2))
x_hv = torch.ones(2, 2) * 5
y_hv = torch.ones(2, 2) * 4
self.assertEqual(x.grad.data, x_grad + x_hv)
self.assertEqual(y.grad.data, y_grad + y_hv)

def test_differentiate(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
y = Variable(torch.randn(2, 2), requires_grad=True)
z = x ** 2 + y * x + y ** 2
z.backward(Variable(torch.ones(2, 2)), retain_variables=True)

x_grad = 2 * x.data + y.data
y_grad = x.data + 2 * y.data
self.assertEqual(x.grad.data, x_grad)
self.assertEqual(y.grad.data, y_grad)

grad_sum = 2 * x.grad + y.grad
x_hv = torch.autograd.differentiate(
outputs=[grad_sum], grad_outputs=[torch.ones(2, 2)],
inputs=[x], only_inputs=True, retain_variables=True)
expected_x_hv = torch.ones(2, 2) * 5
expected_y_hv = torch.ones(2, 2) * 4

self.assertEqual(x_hv, expected_x_hv)
self.assertEqual(x.grad.data, x_grad)
self.assertEqual(y.grad.data, y_grad)

grad_sum = 2 * x.grad + y.grad
x_hv = torch.autograd.differentiate(
outputs=[grad_sum], grad_outputs=[torch.ones(2, 2)],
inputs=[x], only_inputs=False)

self.assertEqual(x_hv, expected_x_hv)
self.assertEqual(x.grad.data, x_grad)
self.assertEqual(y.grad.data, y_grad + expected_y_hv)

def test_hooks(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
Expand Down
36 changes: 36 additions & 0 deletions torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,40 @@ def backward(variables, grad_variables, retain_variables=False):
Variable._execution_engine.run_backward(
tuple(variables), grad_variables, retain_variables)


def differentiate(outputs, grad_outputs, inputs, only_inputs=True, retain_variables=True):
"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
``grad_outputs`` should be a sequence of length matching ``output``
containing the pre-computed gradients w.r.t. each of the outputs. If an
output doesn't require_grad, then the gradient can be ``None``).
Gradients can be given as Tensors when one doesn't need the graph of the
derivative, or as Variables, in which case the graph will be created.
If ``only_inputs`` is True, the function will only return a list of gradients
w.r.t the specified inputs. If it's False, then gradient w.r.t. all remaining
leaves will still be computed, and will be accumulated into their ``.grad``
attribute.
Arguments:
outputs (sequence of Variable): outputs of the differentiated function.
grad_outputs (sequence of Tensor or Variable): Gradients w.r.t each output.
The jacobian will be multiplied by these vectors from the left.
inputs (sequence of Variable): Inputs w.r.t. which the gradient will be
returned (and not accumulated into ``.grad``).
only_inputs (bool, optional): If True, gradient w.r.t. leaves that are
part of the graph, but are not in ``inputs`` won't be computed and
accumulated.
retain_variables (bool, optional): If True, buffers necessary for
computing the gradients won't be freed after use. It is only
necessary to specify True if you want to differentiate any subgraph
again.
"""
grad_outputs = tuple(var if isinstance(var, Variable) or var is None
else Variable(var, volatile=True)
for var in grad_outputs)
return Variable._execution_engine.run_backward(
tuple(outputs), grad_outputs, retain_variables,
tuple(inputs), only_inputs)

assert torch._C._autograd_init()
55 changes: 35 additions & 20 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <functional>
#include <iostream>
#include <mutex>
#include <set>
Expand Down Expand Up @@ -32,12 +33,12 @@ namespace torch { namespace autograd {
struct FunctionTask {
GraphTask* base;
std::shared_ptr<Function> fn;
InputBuffer grad;
InputBuffer inputs;

FunctionTask(GraphTask* base, std::shared_ptr<Function> fn, InputBuffer grad)
FunctionTask(GraphTask* base, std::shared_ptr<Function> fn, InputBuffer inputs)
: base(base)
, fn(fn)
, grad(std::move(grad)) {}
, inputs(std::move(inputs)) {}
};

struct ReadyQueue {
Expand All @@ -58,17 +59,19 @@ struct GraphTask {

std::mutex mutex;
std::condition_variable not_done;
const Engine::callback_map& function_callbacks;
std::unordered_map<Function*, InputBuffer> not_ready;
std::unordered_map<Function*, int> dependencies;

GraphTask(bool keep_graph)
GraphTask(bool keep_graph, const Engine::callback_map& function_callbacks)
: exception()
, has_error(false)
, outstanding_tasks(0)
, keep_graph(keep_graph)
, has_any_work(false)
, mutex()
, not_done()
, function_callbacks(function_callbacks)
, not_ready()
, dependencies() {}
};
Expand Down Expand Up @@ -120,28 +123,40 @@ auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void
}
}

static variable_list call_pre_hooks(Function& fn, variable_list grad_output) {
static variable_list call_pre_hooks(Function& fn, variable_list inputs) {
for (auto& hook : fn.pre_hooks) {
grad_output = (*hook)(grad_output);
inputs = (*hook)(inputs);
}
return grad_output;
return inputs;
}

static variable_list call_post_hooks(Function& fn, variable_list grad_input, variable_list grad_output) {
static variable_list call_post_hooks(Function& fn, variable_list outputs, variable_list inputs) {
for (auto& hook : fn.post_hooks) {
grad_input = (*hook)(grad_input, grad_output);
outputs = (*hook)(outputs, inputs);
}
return grad_input;
return outputs;
}

static variable_list call_function(FunctionTask& task) {
auto grad_output = call_pre_hooks(*task.fn, InputBuffer::variables(std::move(task.grad)));
auto grad_input = task.fn->apply(grad_output);
return call_post_hooks(*task.fn, std::move(grad_input), std::move(grad_output));
static std::pair<bool, variable_list> call_function(FunctionTask& task) {
auto& fn = *task.fn;
auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs)));

auto& function_callbacks = task.base->function_callbacks;
auto callback_it = function_callbacks.find(&fn);
if (callback_it != function_callbacks.end()) {
auto& callback = callback_it->second;
if (!callback(&fn, inputs)) return std::make_pair(false, variable_list());
}

auto fn_outputs = fn.apply(inputs);
auto outputs = call_post_hooks(fn, std::move(fn_outputs), std::move(inputs));
return std::make_pair(true, std::move(outputs));
}

auto Engine::evaluate_function(FunctionTask& task) -> void {
auto outputs = call_function(task);
auto call_result = call_function(task);
if (!call_result.first) return;
auto outputs = call_result.second;

auto& fn = *task.fn;
if (!task.base->keep_graph) {
Expand Down Expand Up @@ -291,12 +306,12 @@ auto Engine::find_roots(const function_list& input_roots,
}

auto Engine::execute(const function_list& input_roots,
variable_list& inputs,
bool keep_graph) -> void {
static std::once_flag once_flag;
std::call_once(once_flag, &Engine::start_threads, this);
variable_list& inputs,
bool keep_graph,
const callback_map& callbacks) -> void {
std::call_once(start_threads_flag, &Engine::start_threads, this);

GraphTask graph_task(keep_graph);
GraphTask graph_task(keep_graph, callbacks);
std::unique_lock<std::mutex> lock(graph_task.mutex);

// Find the unique roots and backprop into variables.
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/autograd/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ struct Engine {
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;
using function_queue = std::vector<Function*>;
using dependencies_type = std::unordered_map<Function*, int>;
using callback_type = std::function<bool (Function*, variable_list&)>;
using callback_map = std::unordered_map<Function*, callback_type>;


// Given a list of (Function, int) pairs computes the value of the graph
// by following next_function references.
void execute(
const function_list& roots,
variable_list& inputs,
bool keep_graph);
bool keep_graph,
const callback_map& callbacks = callback_map());

protected:
function_queue find_roots(
Expand All @@ -48,6 +52,7 @@ struct Engine {
virtual void thread_main(std::shared_ptr<ReadyQueue> queue);
virtual void thread_on_exception(FunctionTask& task, std::exception& e);

std::once_flag start_threads_flag;
std::vector<std::shared_ptr<ReadyQueue>> ready_queues;
};

Expand Down
114 changes: 108 additions & 6 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/utils/auto_gil.h"

#include <unordered_set>

using namespace torch::autograd;

struct THPEngine {
Expand Down Expand Up @@ -34,16 +36,79 @@ static PythonEngine engine;

PyObject *THPEngineClass = NULL;

// Main backward function
struct CallbackContext {
std::mutex mutex;
std::string error;
THPObjectPtr outputs;
};

void compute_partial_exec_callbacks(const function_list& roots,
const std::vector<Function*> inputs,
Engine::callback_map& map) {
static Engine::callback_type abort_callback(
[](Function* fn, variable_list &vars) { return false; });

std::vector<Function*> queue;
std::unordered_set<Function*> seen; // for the initial DFS
std::unordered_set<Function*> needed; // functions to compute
std::unordered_map<Function*, std::vector<Function*>> rev_graph;

// Reverse the next_fn edges
queue.reserve(roots.size());
for (auto& root : roots) {
auto ptr = root.first.get();
queue.emplace_back(ptr);
seen.insert(ptr);
}
while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
for (auto& next_fn_pair : fn->next_functions) {
auto next_fn = next_fn_pair.first.get();
if (!next_fn) continue;
rev_graph[next_fn].push_back(fn);
if (seen.insert(next_fn).second) {
queue.push_back(next_fn);
}
}
}
auto all_functions = std::move(seen); // this is cheap and improves readability

// Find all functions we need to compute
queue.clear();
for (auto input: inputs) {
auto& rev_edges = rev_graph[input];
if (rev_edges.size() == 0) throw std::runtime_error("unreachable input");
queue.emplace_back(input);
needed.insert(input);
}

while (!queue.empty()) {
auto fn = queue.back(); queue.pop_back();
for (auto rev_next_fn : rev_graph[fn]) {
if (needed.insert(rev_next_fn).second) {
queue.push_back(rev_next_fn);
}
}
}

// Prevent expantion for functions in {all_vertices} \ {needed}
for (auto fn : all_functions) {
if (needed.count(fn) > 0) continue;
map.emplace(fn, abort_callback);
}
}

PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
PyObject *variables = NULL;
PyObject *grad_variables = NULL;
unsigned char keep_graph = 0;
PyObject *inputs = NULL;
unsigned char only_inputs = 0;
const char *accepted_kwargs[] = {"variables", "grad_variables",
"keep_graph", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOb", (char**)accepted_kwargs,
&variables, &grad_variables, &keep_graph))
"keep_graph", "inputs", "only_inputs", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOb|Ob", (char**)accepted_kwargs,
&variables, &grad_variables, &keep_graph, &inputs, &only_inputs))
return NULL;

THPUtils_assert(PyTuple_Check(variables), "variables argument is expected to "
Expand Down Expand Up @@ -78,9 +143,42 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}
}

Engine::callback_map callbacks;
CallbackContext ctx;
if (inputs != NULL) {
THPUtils_assert(PyTuple_Check(inputs), "outputs argument has to be a tuple");
int num_inputs = PyTuple_GET_SIZE(inputs);
ctx.outputs = PyTuple_New(num_inputs);
std::vector<Function*> grad_accumulators;
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Variables, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
auto grad_acc = input_var->cdata->grad_accumulator.lock();
// TODO: maybe just return a zero tensor?
THPUtils_assert(grad_acc, "One of the differentiated Variables appears to not have "
"been used in any computation");
grad_accumulators.push_back(grad_acc.get());
callbacks.emplace(grad_acc.get(), [&ctx, i](Function* _unused, variable_list& grads) {
std::lock_guard<std::mutex> guard(ctx.mutex);
if (grads.size() != 1) {
ctx.error = "expected to get a single gradient, but got ";
ctx.error += std::to_string(grads.size());
}
PyTuple_SET_ITEM(ctx.outputs.get(), i, THPVariable_Wrap(grads[0]));
return false;
});
}
// Disable execution for all unneeded functions
if (only_inputs) {
compute_partial_exec_callbacks(roots, grad_accumulators, callbacks);
}
}

try {
AutoNoGIL no_gil;
engine.execute(roots, grads, keep_graph);
engine.execute(roots, grads, keep_graph, callbacks);
} catch (python_error &e) {
e.restore();
return nullptr;
Expand All @@ -89,7 +187,11 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
return nullptr;
}

Py_RETURN_NONE;
if (ctx.outputs) {
return ctx.outputs.release();
} else {
Py_RETURN_NONE;
}
}

PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
Expand Down

0 comments on commit e5db8f9

Please sign in to comment.