Skip to content

Commit

Permalink
Add support for the new Function format
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed May 1, 2017
1 parent 702a2e3 commit de9998e
Show file tree
Hide file tree
Showing 15 changed files with 568 additions and 281 deletions.
47 changes: 47 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,53 @@ def backward_engine(engine):

class TestAutograd(TestCase):

def test_function(self):
class MyFunction(Function):

@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2

@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
return (grad_output + grad_output * var2,
grad_output * ctx.scalar + grad_output * var1)

x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
result = MyFunction.apply(x, 2, y)
go = Variable(torch.ones(1), requires_grad=True)
result.sum().backward(go)

self.assertEqual(x.grad.data, y.data + torch.ones(5, 5))
self.assertEqual(y.grad.data, x.data + torch.ones(5, 5) * 2)

self.assertFalse(x.grad.volatile)
self.assertFalse(y.grad.volatile)
self.assertIsNotNone(x.grad.grad_fn)
self.assertIsNotNone(y.grad.grad_fn)

def desc_graph(fn):
result = type(fn).__name__ + '('
next_functions = fn.next_functions
for next_fn, _ in next_functions:
result += desc_graph(next_fn)
result += ', '
if next_functions:
result = result[:-2]
return result + ')'
x_grad_desc = desc_graph(x.grad.grad_fn)
y_grad_desc = desc_graph(y.grad.grad_fn)
self.assertEqual(
x_grad_desc,
'Identity(Add(Error(AccumulateGrad()), Mul(Error(AccumulateGrad()), AccumulateGrad())))')
self.assertEqual(
y_grad_desc,
'Identity(Add(MulConstant(Error(AccumulateGrad())), Mul(Error(AccumulateGrad()), AccumulateGrad())))')

def test_hooks(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
Expand Down
15 changes: 5 additions & 10 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,11 @@ def _zero_grad_parameters(self, module):
def _get_parameters(self, module):
params = []
d_params = []
if hasattr(module, 'weight') and module.weight is not None:
params += [module.weight.data]
if module.weight.grad is None:
module.weight._grad = Variable(module.weight.data.clone().zero_())
d_params += [module.weight.grad.data]
if hasattr(module, 'bias') and module.bias is not None:
params += [module.bias.data]
if module.bias.grad is None:
module.bias._grad = Variable(module.bias.data.clone().zero_())
d_params += [module.bias.grad.data]
for p in module.parameters():
if p.grad is None:
p._grad = Variable(p.data.clone().zero_(), volatile=True)
params.append(p.data)
d_params.append(p.grad.data)
return params, d_params

def test_hooks(self):
Expand Down
31 changes: 31 additions & 0 deletions torch/_six.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2010-2017 Benjamin Peterson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(meta):

def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
return type.__new__(metaclass, 'temporary_class', (), {})
115 changes: 79 additions & 36 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,11 @@
import torch
import torch._C as _C
import torch.utils.hooks as hooks
from torch._six import with_metaclass
from collections import OrderedDict


class Function(_C._FunctionBase):
"""Records operation history and defines formulas for differentiating ops.
Every operation performed on :class:`Variable` s creates a new function
object, that performs the computation, and records that it happened.
The history is retained in the form of a DAG of functions, with edges
denoting data dependencies (``input <- output``). Then, when backward is
called, the graph is processed in the topological ordering, by calling
:func:`backward` methods of each :class:`Function` object, and passing
returned gradients on to next :class:`Function` s.
Normally, the only way users interact with functions is by creating
subclasses and defining new operations. This is a recommended way of
extending torch.autograd.
Since Function logic is a hotspot in most scripts, almost all of it
was moved to our C backend, to ensure that the framework overhead is
minimal.
Each function is meant to be used only once (in the forward pass).
Attributes:
saved_tensors: Tuple of Tensors that were saved in the call to
:func:`forward`.
needs_input_grad: Tuple of booleans of length :attr:`num_inputs`,
indicating whether a given input requires gradient. This can be
used to optimize buffers saved for backward, and ignoring gradient
computation in :func:`~Function.backward`.
num_inputs: Number of inputs given to :func:`forward`.
num_outputs: Number of tensors returned by :func:`forward`.
requires_grad: Boolean indicating whether the :func:`backward` will
ever need to be called.
"""
__call__ = _C._FunctionBase._do_forward
class _ContextMethodMixin(object):

def save_for_backward(self, *tensors):
"""Saves given tensors for a future call to :func:`~Function.backward`.
Expand Down Expand Up @@ -102,6 +70,9 @@ def mark_non_differentiable(self, *args):
"""
self.non_differentiable = args


class _HookMixin(object):

@staticmethod
def _register_hook(backward_hooks, hook):
if backward_hooks is None:
Expand All @@ -110,7 +81,78 @@ def _register_hook(backward_hooks, hook):
backward_hooks[handle.id] = hook
return backward_hooks, handle

def forward(self, *input):

class CFunction(object):
_is_legacy = False

def apply(self, *args, **kwargs):
raise NotImplementedError


class BackwardCFunction(CFunction, _C._FunctionBase, _ContextMethodMixin, _HookMixin):

def apply(self, *args):
return self._forward_cls.backward(self, *args)


class FunctionMeta(type):

def __init__(cls, name, bases, attrs):
for super_cls in cls.mro():
if 'forward' in super_cls.__dict__:
has_static_forward = isinstance(super_cls.__dict__['forward'], staticmethod)
break

# old-style functions
if not has_static_forward:
setattr(cls, '_is_legacy', True)
return super(FunctionMeta, cls).__init__(name, bases, attrs)

backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
setattr(cls, '_backward_cls', backward_fn)

return super(FunctionMeta, cls).__init__(name, bases, attrs)


class Function(with_metaclass(FunctionMeta, _C._FunctionBase, CFunction, _ContextMethodMixin, _HookMixin)):
"""Records operation history and defines formulas for differentiating ops.
Every operation performed on :class:`Variable` s creates a new function
object, that performs the computation, and records that it happened.
The history is retained in the form of a DAG of functions, with edges
denoting data dependencies (``input <- output``). Then, when backward is
called, the graph is processed in the topological ordering, by calling
:func:`backward` methods of each :class:`Function` object, and passing
returned gradients on to next :class:`Function` s.
Normally, the only way users interact with functions is by creating
subclasses and defining new operations. This is a recommended way of
extending torch.autograd.
Since Function logic is a hotspot in most scripts, almost all of it
was moved to our C backend, to ensure that the framework overhead is
minimal.
Each function is meant to be used only once (in the forward pass).
Attributes:
saved_tensors: Tuple of Tensors that were saved in the call to
:func:`forward`.
needs_input_grad: Tuple of booleans of length :attr:`num_inputs`,
indicating whether a given input requires gradient. This can be
used to optimize buffers saved for backward, and ignoring gradient
computation in :func:`~Function.backward`.
num_inputs: Number of inputs given to :func:`forward`.
num_outputs: Number of tensors returned by :func:`forward`.
requires_grad: Boolean indicating whether the :func:`backward` will
ever need to be called.
"""

# only for backward compatibility
__call__ = _C._FunctionBase._do_forward

@staticmethod
def forward(*args, **kwargs):
"""Performs the operation.
This function is to be overriden by all subclasses.
Expand All @@ -119,7 +161,8 @@ def forward(self, *input):
"""
raise NotImplementedError

def backward(self, *grad_output):
@staticmethod
def backward(*grad_outputs):
"""Defines a formula for differentiating the operation.
This function is to be overriden by all subclasses.
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <THPP/THPP.h>
#include <vector>

#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/autograd/function_hook.h"

namespace torch { namespace autograd {
Expand Down Expand Up @@ -70,6 +69,11 @@ struct Function {
return fn && fn->is_executable;
}

inline void set_flags(FunctionFlags&& flags) {
is_executable = flags.is_executable;
next_functions = std::move(flags.next_functions);
}

int num_inputs;
function_list next_functions;
bool is_executable;
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/autograd/functions/batch_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
return std::make_shared<BatchNormBackward>(
f, *this, std::move(save_mean), std::move(save_std),
input->save(),
Variable::save_opt(weight.get()),
Variable::save_opt(bias.get()));
input->save(this),
Variable::save_opt(weight.get(), this),
Variable::save_opt(bias.get(), this));
});
};

auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_list {
auto& input = this->input.unpack();
auto& weight = this->weight.unpack();
auto& bias = this->bias.unpack();
auto input = this->input.unpack_data();
auto weight = this->weight.unpack_data();
auto bias = this->bias.unpack_data();
AutoGPU guard(input->getDevice());

bool use_cudnn = false;
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/autograd/functions/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ auto ConvForward::apply(const variable_list& inputs) -> variable_list {
return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
return std::make_shared<ConvBackward>(
f, *this,
inputs[0]->save(), inputs[1]->save(), Variable::save_opt(inputs[2].get()),
inputs[0]->save(this), inputs[1]->save(this), Variable::save_opt(inputs[2].get(), this),
std::move(columns), std::move(ones), std::move(convolution));
});
};
Expand All @@ -205,9 +205,9 @@ auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list {

AutoGPU guard(input_.data->getDevice());

auto input = input_.unpack()->contiguous();
std::unique_ptr<Tensor> weight(weight_.unpack()->clone_shallow());
std::unique_ptr<Tensor> bias(bias_.unpack() ? bias_.unpack()->clone_shallow() : nullptr);
auto input = input_.unpack_data()->contiguous();
std::unique_ptr<Tensor> weight(weight_.unpack_data()->clone_shallow());
auto bias = bias_.unpack_data();
auto grad_output = grad_outputs[0]->data->contiguous();

int k = input->nDim();
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/autograd/functions/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "batch_normalization.h"
#include "convolution.h"
#include "accumulate_grad.h"
#include "basic_ops.h"
#include "tensor.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/utils/tuple_parser.h"

Expand Down Expand Up @@ -74,6 +76,19 @@ bool THPAutograd_initFunctions(PyObject* _unused)
static PyTypeObject AccumulateGradClass;
addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad");

static PyTypeObject AddClass, AddBackwardClass;
addClass<Add, NoCtor>(module, AddClass, "Add");
addClass<AddBackward, NoCtor>(module, AddBackwardClass, "AddBackward");

static PyTypeObject ErrorClass;
addClass<Error, NoCtor>(module, ErrorClass, "Error");

static PyTypeObject CloneClass;
addClass<Clone, NoCtor>(module, CloneClass, "Clone");

static PyTypeObject IdentityClass;
addClass<Identity, NoCtor>(module, IdentityClass, "Identity");

THPObjectPtr parent = PyImport_ImportModule("torch._C");
if (!parent) return false;
PyModule_AddObject(parent.get(), "_functions", module.release());
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/autograd/functions/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
} else {
auto grad_fn = ctr(std::move(flags));
for (auto& output : outputs) {
result.emplace_back(std::make_shared<Variable>(std::move(output), grad_fn));
if (output) {
result.emplace_back(std::make_shared<Variable>(std::move(output), grad_fn));
} else {
result.emplace_back(nullptr);
}
}
}
return result;
Expand Down
Loading

0 comments on commit de9998e

Please sign in to comment.