Skip to content

Commit

Permalink
Fix bugs, rename differentiate to grad, make it more flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed May 1, 2017
1 parent 87164f5 commit 5c74534
Show file tree
Hide file tree
Showing 14 changed files with 339 additions and 100 deletions.
2 changes: 1 addition & 1 deletion docs/source/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Automatic differentiation package - torch.autograd

.. autofunction:: backward

.. autofunction:: differentiate
.. autofunction:: grad

Variable
--------
Expand Down
92 changes: 84 additions & 8 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,34 @@ def backward(ctx, grad_output):
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(graph_desc(x.grad.grad_fn), 'Identity(Error())')
self.assertEqual(graph_desc(y.grad.grad_fn), 'Identity(Error())')
self.assertEqual(graph_desc(x.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
self.assertEqual(graph_desc(y.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')

def test_accumulate_grad(self):
import sys

grad_output = Variable(torch.ones(5, 5))
for start_volatile, end_volatile in product((True, False), repeat=2):
go1 = grad_output.data if start_volatile else grad_output
go2 = grad_output.data if end_volatile else grad_output

x = Variable(torch.randn(5, 5), requires_grad=True)
y = x + 2
y.backward(go1, retain_variables=True)
x_grad = x.grad
x_grad_clone = x.grad.data.clone()

del x
y.backward(go2)

# That's the only case when we can accumulate in-place
if start_volatile and end_volatile:
expected_grad = x_grad_clone * 2
else:
expected_grad = x_grad_clone
self.assertEqual(x_grad.data, expected_grad)

def test_hessian_vector(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
Expand All @@ -140,7 +166,7 @@ def test_hessian_vector(self):
self.assertEqual(x.grad.data, x_grad + x_hv)
self.assertEqual(y.grad.data, y_grad + y_hv)

def test_differentiate(self):
def test_grad(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
y = Variable(torch.randn(2, 2), requires_grad=True)
z = x ** 2 + y * x + y ** 2
Expand All @@ -152,9 +178,9 @@ def test_differentiate(self):
self.assertEqual(y.grad.data, y_grad)

grad_sum = 2 * x.grad + y.grad
x_hv = torch.autograd.differentiate(
x_hv = torch.autograd.grad(
outputs=[grad_sum], grad_outputs=[torch.ones(2, 2)],
inputs=[x], only_inputs=True, retain_variables=True)
inputs=[x], create_graph=True, only_inputs=True)
expected_x_hv = torch.ones(2, 2) * 5
expected_y_hv = torch.ones(2, 2) * 4

Expand All @@ -163,14 +189,64 @@ def test_differentiate(self):
self.assertEqual(y.grad.data, y_grad)

grad_sum = 2 * x.grad + y.grad
x_hv = torch.autograd.differentiate(
outputs=[grad_sum], grad_outputs=[torch.ones(2, 2)],
inputs=[x], only_inputs=False)
x_hv = torch.autograd.grad(
outputs=grad_sum, inputs=x,
grad_outputs=torch.ones(2, 2),
only_inputs=False)

self.assertEqual(x_hv, expected_x_hv)
self.assertEqual(x.grad.data, x_grad)
self.assertEqual(y.grad.data, y_grad + expected_y_hv)

def test_grad_nonleaf(self):
x_init = Variable(torch.randn(2, 2), requires_grad=True)
x = x_init
y = Variable(torch.randn(2, 2), requires_grad=True)
grad_output = torch.ones(2, 2)

def fn(x):
return x ** 2 + y * x + y ** 2

for i in range(5):
grad_x, = torch.autograd.grad(
fn(x), x, grad_outputs=grad_output, create_graph=True)

grad_x_expected = 2 * x.data + y.data
self.assertIsNone(y.grad)
self.assertIsNone(x.grad)
self.assertEqual(grad_x, grad_x_expected)

x = x + 0.05 * grad_x

val_init = fn(x_init).data.sum()
val_final = fn(x).data.sum()
self.assertGreater(val_final, val_init)

x.backward(grad_output)
self.assertIsNotNone(y.grad)
self.assertIsNotNone(x_init.grad)

def test_grad_nonleaf_many_outputs(self):
# This checks an edge case for function callbacks
# We want to capture two grads of a function, but can only
# register a single callback.
x = Variable(torch.randn(4, 2), requires_grad=True)
a, b = x.chunk(2)

def hook(*grads):
hook_called[0] = True
hook_called = [False]
x.register_hook(hook)

go = torch.randn(2, 2)
grad_a, grad_b = torch.autograd.grad(
(a + 2 * b), [a, b], grad_outputs=go, create_graph=True)

self.assertEqual(grad_a, go)
self.assertEqual(grad_b, go * 2)
self.assertFalse(hook_called[0])
self.assertIsNone(x.grad)

def test_hooks(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
Expand Down
118 changes: 92 additions & 26 deletions torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
:class:`.Variable` objects.
"""
import torch
import warnings

from .variable import Variable
from .function import Function, NestedIOFunction
Expand All @@ -14,7 +15,35 @@
__all__ = ['Variable', 'Function', 'StochasticFunction', 'backward']


def backward(variables, grad_variables, retain_variables=False):
def _make_grads(outputs, grads, user_create_graph):
if user_create_graph is not None:
create_graph = user_create_graph
else:
create_graph = any(isinstance(grad, Variable) and not grad.volatile
for grad in grads)

new_grads = []
for out, grad in zip(outputs, grads):
if isinstance(grad, Variable):
new_grads.append(grad)
elif torch.is_tensor(grad):
new_grads.append(Variable(grad, volatile=not create_graph))
elif grad is None:
if out.requires_grad:
if out.numel() != 1:
raise RuntimeError("grad can be implicitly created only for scalar outputs")
data = out.data
new_grads.append(
Variable(data.new().resize_as_(data).fill_(1), volatile=not create_graph))
else:
new_grads.append(None)
else:
raise TypeError("gradients can be either Tensors, Variables or None, but got " +
type(grad).__name__)
return tuple(new_grads), create_graph


def backward(variables, grad_variables=None, retain_graph=None, create_graph=None, retain_variables=None):
"""Computes the sum of gradients of given variables w.r.t. graph leaves.
The graph is differentiated using the chain rule. If any of ``variables``
Expand All @@ -30,22 +59,41 @@ def backward(variables, grad_variables, retain_variables=False):
Arguments:
variables (sequence of Variable): Variables of which the derivative will be
computed.
grad_variables (sequence of Tensor): Gradients w.r.t. each element of
corresponding variables. Required only for non-scalar variables that
require gradient.
retain_variables (bool): If ``True``, buffers necessary for computing
gradients won't be freed after use. It is only necessary to
specify ``True`` if you want to differentiate some subgraph multiple
times.
grad_variables (sequence of (Tensor, Variable or None)): Gradients w.r.t.
each element of corresponding variables. Any tensors will be
automatically converted to Variables that are volatile unless
``create_graph`` is True. None values can be specified for scalar
Variables or ones that don't require grad. If a None value would
be acceptable for all grad_variables, then this argument is optional.
retain_graph (bool, optional): If False, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this option to True
is not needed and often can be worked around in a much more efficient
way. Defaults to the value of ``create_graph``.
create_graph (bool, optional): If true, graph of the derivative will
be constructed, allowing to compute higher order derivative products.
Defaults to False, unless ``grad_variables`` contains at least one
non-volatile Variable.
"""
grad_variables = tuple(var if isinstance(var, Variable) or var is None
else Variable(var, volatile=True)
for var in grad_variables)
variables = tuple(variables)

if grad_variables is None:
grad_variables = (None,) * variables
grad_variables, create_graph = _make_grads(variables, list(grad_variables), create_graph)

if retain_variables is not None:
if retain_graph is not None:
raise ValueError("only one of retain_graph and retain_variables can be specified")
retain_graph = retain_variables
warnings.warn("retain_variables option is deprecated and will be removed in 0.3. "
"Use retain_graph instead.")
elif retain_graph is None:
retain_graph = create_graph

Variable._execution_engine.run_backward(
tuple(variables), grad_variables, retain_variables)
variables, grad_variables, retain_graph)


def differentiate(outputs, grad_outputs, inputs, only_inputs=True, retain_variables=True):
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=None, only_inputs=True):
"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
``grad_outputs`` should be a sequence of length matching ``output``
Expand All @@ -61,23 +109,41 @@ def differentiate(outputs, grad_outputs, inputs, only_inputs=True, retain_variab
Arguments:
outputs (sequence of Variable): outputs of the differentiated function.
grad_outputs (sequence of Tensor or Variable): Gradients w.r.t each output.
The jacobian will be multiplied by these vectors from the left.
inputs (sequence of Variable): Inputs w.r.t. which the gradient will be
returned (and not accumulated into ``.grad``).
grad_outputs (sequence of Tensor or Variable): Gradients w.r.t. each output.
Any tensors will be automatically converted to Variables that are
volatile unless ``create_graph`` is True. None values can be
specified for scalar Variables or ones that don't require grad.
If a None value would be acceptable for all grad_variables, then
this argument is optional.
retain_graph (bool, optional): If False, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this option to True
is not needed and often can be worked around in a much more efficient
way. Defaults to the value of ``create_graph``.
create_graph (bool, optional): If True, graph of the derivative will
be constructed, allowing to compute higher order derivative products.
Defaults to False, unless ``grad_variables`` contains at least one
non-volatile Variable.
only_inputs (bool, optional): If True, gradient w.r.t. leaves that are
part of the graph, but are not in ``inputs`` won't be computed and
accumulated.
retain_variables (bool, optional): If True, buffers necessary for
computing the gradients won't be freed after use. It is only
necessary to specify True if you want to differentiate any subgraph
again.
part of the graph, but don't appear in ``inputs`` won't be computed
and accumulated. Defaults to True.
"""
grad_outputs = tuple(var if isinstance(var, Variable) or var is None
else Variable(var, volatile=True)
for var in grad_outputs)

outputs = (outputs,) if isinstance(outputs, Variable) else tuple(outputs)
inputs = (inputs,) if isinstance(inputs, Variable) else tuple(inputs)
if grad_outputs is None:
grad_outputs = (None,) * len(outputs)
elif isinstance(grad_outputs, Variable) or torch.is_tensor(grad_outputs):
grad_outputs = (grad_outputs,)

grad_outputs, create_graph = _make_grads(outputs, grad_outputs, create_graph)
if retain_graph is None:
retain_graph = create_graph

return Variable._execution_engine.run_backward(
tuple(outputs), grad_outputs, retain_variables,
tuple(inputs), only_inputs)
outputs, grad_outputs, retain_graph,
inputs, only_inputs)


assert torch._C._autograd_init()
27 changes: 15 additions & 12 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,6 @@ def backward(*grad_outputs):
raise NotImplementedError


class Error(_C._FunctionBase):

def apply(self, *args, **kwargs):
raise RuntimeError("trying to differentiate twice a function that was marked"
"with @once_differentiable")


def once_differentiable(fn):
from .variable import Variable

Expand All @@ -187,19 +180,29 @@ def wrapper(ctx, *args):
tensor_args = [arg.data if isinstance(arg, Variable) else arg
for arg in args]
outputs = fn(ctx, *tensor_args)
# XXX: this is only an approximation of these flags - there's no way
# to figure out if fn didn't use ctx.saved_variables and as a result
# some Variables might require grad, even if no args do.
# Unfortunately, this leads to unexpected error messages ("no nodes
# require computing gradients"), but I don't have a better idea.
# These functions would raise an error in backward anyway.
volatile = any(arg.volatile if isinstance(arg, Variable) else False
for arg in args)
requires_grad = any(arg.requires_grad if isinstance(arg, Variable) else False
for arg in args)
if volatile:
def err_fn(*args):
return args
kwargs = {'volatile': True}
else:
err_fn = Error()
err_fn.requires_grad = requires_grad
kwargs = {'requires_grad': requires_grad, '_grad_fn': err_fn}
err_fn = torch._C._functions.DelayedError(
b"trying to differentiate twice a function that was marked"
b"with @once_differentiable")
kwargs = {'requires_grad': requires_grad}
if not isinstance(outputs, tuple):
return Variable(outputs, **kwargs) if outputs is not None else None
return tuple([Variable(o, **kwargs) if o is not None else None
var = Variable(outputs, **kwargs) if outputs is not None else None
return err_fn(var)
return err_fn(*[Variable(o, **kwargs) if o is not None else None
for o in outputs])
return wrapper

Expand Down
Loading

0 comments on commit 5c74534

Please sign in to comment.