Skip to content

Commit

Permalink
Fix previous_functions when it contains Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored and soumith committed Feb 17, 2017
1 parent 7117a90 commit dd844f7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
18 changes: 18 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,24 @@ def error():
y._backward_hooks['test'] = error
b.backward(torch.ones(5, 5))

def test_previous_functions(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)

a = x + y
self.assertIsNotNone(a.creator)
previous_functions = a.creator.previous_functions
self.assertEqual(len(previous_functions), 2)
self.assertIs(previous_functions[0][0], x)
self.assertEqual(previous_functions[0][1], 0)
self.assertIs(previous_functions[1][0], y)
self.assertEqual(previous_functions[1][1], 0)

b = a + 5
previous_functions = b.creator.previous_functions
self.assertEqual(len(previous_functions), 1)
self.assertIs(previous_functions[0][0], a.creator)

def test_inplace(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/autograd/python_cpp_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;

PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
{
auto pfw = dynamic_cast<PyFunction*>(cdata.get());
if (pfw) {
if (auto pfw = dynamic_cast<PyFunction*>(cdata.get())) {
PyObject* obj = pfw->obj;
Py_INCREF(obj);
return obj;
}

if (auto var = std::dynamic_pointer_cast<Variable>(cdata)) {
return THPVariable_Wrap(var);
}

auto it = cpp_function_types.find(std::type_index(typeid(*cdata)));
if (it == cpp_function_types.end()) {
return PyErr_Format(PyExc_TypeError,
Expand Down
7 changes: 4 additions & 3 deletions torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,9 +932,10 @@ PyObject *THPFunction_previous_functions(THPFunction *self, void *_unused)
return NULL;
for (int i = 0; i < size; i++) {
THPObjectPtr fn_tuple = PyTuple_New(2);
if (!fn_tuple)
return NULL;
PyTuple_SET_ITEM(fn_tuple.get(), 0, functionToPyObject(prev_fns[i].first));
if (!fn_tuple) return NULL;
PyObject* fn = functionToPyObject(prev_fns[i].first);
if (!fn) return NULL;
PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
PyTuple_SET_ITEM(fn_tuple.get(), 1, PyInt_FromLong(prev_fns[i].second));
PyTuple_SET_ITEM(result.get(), i, fn_tuple.release());
}
Expand Down

0 comments on commit dd844f7

Please sign in to comment.