Skip to content

Commit

Permalink
Fix segfault in autograd: (pytorch#1644)
Browse files Browse the repository at this point in the history
* Fix segfault in autograd:

1) Every "output" variable must have a grad_fn or grad_accumulator
2) compute_partial_exec_callbacks uses Python errors

* assertRaisesRegexp was renamed assertRaisesRegex in 3.2

* Use HANDLE_TH_ERRORS macro
  • Loading branch information
colesbury authored and soumith committed May 24, 2017
1 parent 3d38e4f commit e1d257b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
4 changes: 4 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def assertObjectIn(self, obj, iterable):
return
raise AssertionError("object not found in iterable")

if sys.version_info < (3, 2):
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp


def download_file(url, path, binary=True):
if sys.version_info < (3,):
Expand Down
18 changes: 14 additions & 4 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,12 @@ def backward(ctx, grad_output):
grad_output * ctx.scalar + grad_output * t1)

x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(graph_desc(x.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
self.assertEqual(graph_desc(y.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')

def test_accumulate_grad(self):
import sys

grad_output = Variable(torch.ones(5, 5))
for start_volatile, end_volatile in product((True, False), repeat=2):
go1 = grad_output.data if start_volatile else grad_output
Expand Down Expand Up @@ -248,6 +244,20 @@ def hook(*grads):
self.assertFalse(hook_called[0])
self.assertIsNone(x.grad)

def test_grad_badcalls(self):
x = Variable(torch.ones(1))
y = x ** 2
with self.assertRaisesRegex(RuntimeError, 'does not require grad'):
torch.autograd.grad(x, y)
with self.assertRaisesRegex(RuntimeError, 'not have been used in the graph'):
torch.autograd.grad(y, x)

x = Variable(torch.ones(1), requires_grad=True)
y = x ** 2
torch.autograd.grad(y, x) # this should succeed now
with self.assertRaisesRegex(RuntimeError, 'unreachable'):
torch.autograd.grad(x, y)

def test_hooks(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void compute_partial_exec_callbacks(const function_list& roots,

PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
PyObject *variables = NULL;
PyObject *grad_variables = NULL;
unsigned char keep_graph = 0;
Expand Down Expand Up @@ -137,6 +138,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
THPUtils_assert(!variable->is_volatile,
"element %d of variables tuple is volatile", i);
auto grad_fn = variable->grad_fn ? variable->grad_fn : variable->get_grad_accumulator();
THPUtils_assert(grad_fn, "element %d of variables tuple does not require grad", i);
int output_nr = variable->grad_fn ? variable->output_nr : 0;
roots[i] = std::make_pair<>(std::move(grad_fn), output_nr);

Expand Down Expand Up @@ -201,16 +203,14 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
} catch (python_error &e) {
e.restore();
return nullptr;
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}

if (ctx.outputs) {
return ctx.outputs.release();
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}

PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
Expand Down

0 comments on commit e1d257b

Please sign in to comment.