Skip to content

Commit

Permalink
Add AutoGPU guard and properly reference Python args from BatchNormBa…
Browse files Browse the repository at this point in the history
…ckwardBackward.
  • Loading branch information
gchanan committed Aug 10, 2017
1 parent 50c208a commit 2f624df
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions torch/csrc/autograd/functions/batch_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,31 +221,37 @@ auto BatchNormBackwardBackward::apply(const variable_list& grad_grad_inputs) ->
auto ggW = grad_grad_inputs[1];
auto ggb = grad_grad_inputs[2];

auto gO = grad_output.unpack();
auto input_var = input.unpack();
auto weight_var = weight.unpack();
auto gO_var = grad_output.unpack();

auto input = input_var->data;
AutoGPU guard(input);
AutoGIL gil;

THPObjectPtr input_pvar(THPVariable_Wrap(input_var));
THPObjectPtr weight_pvar(weight_var ? THPVariable_Wrap(weight_var) : Py_None);
THPObjectPtr ggi_pvar(ggI ? THPVariable_Wrap(ggI) : Py_None);
THPObjectPtr ggW_pvar(ggW ? THPVariable_Wrap(ggW) : Py_None);
THPObjectPtr ggb_pvar(ggb ? THPVariable_Wrap(ggb) : Py_None);
THPObjectPtr gO_pvar(THPVariable_Wrap(gO));
THPObjectPtr weight_pvar(THPVariable_Wrap(weight_var));

THPObjectPtr ggi_pvar(THPVariable_Wrap(ggI));
THPObjectPtr ggW_pvar(THPVariable_Wrap(ggW));
THPObjectPtr ggb_pvar(THPVariable_Wrap(ggb));
THPObjectPtr gO_pvar(THPVariable_Wrap(gO_var));
THPObjectPtr eps_py(PyFloat_FromDouble(eps));
THPObjectPtr save_mean_py(createPyObject(save_mean));
THPObjectPtr save_std_py(createPyObject(save_std));
THPObjectPtr running_mean_py(createPyObject(running_mean));
THPObjectPtr running_var_py(createPyObject(running_var));
THPObjectPtr training_py(training ? Py_True : Py_False);

PyObject* args = PyTuple_Pack(12, input_pvar.get(), weight_pvar.get(),
ggi_pvar.get(), ggW_pvar.get(), ggb_pvar.get(),
gO_pvar.get(), eps_py.get(),
save_mean_py.get(), save_std_py.get(),
running_mean_py.get(), running_var_py.get(),
training_py.get());
THPObjectPtr r(PyObject_CallObject(THPBatchNormBackwardBackwardFunction, args));
PyObject *training_pyo = training ? Py_True : Py_False;
Py_INCREF(training_pyo);
THPObjectPtr training_py(training_pyo);

THPObjectPtr args(PyTuple_Pack(12, input_pvar.get(), weight_pvar.get(),
ggi_pvar.get(), ggW_pvar.get(), ggb_pvar.get(),
gO_pvar.get(), eps_py.get(),
save_mean_py.get(), save_std_py.get(),
running_mean_py.get(), running_var_py.get(),
training_py.get()));
THPObjectPtr r(PyObject_CallObject(THPBatchNormBackwardBackwardFunction, args.get()));
if (!r) throw python_error();
if (!PyTuple_Check(r.get())) {
throw std::runtime_error("expected PyTuple return from BatchNormBackwardBackward");
Expand Down

0 comments on commit 2f624df

Please sign in to comment.