Skip to content

Commit

Permalink
[MPS] LSTM grad_y missing fix (pytorch#96601)
Browse files Browse the repository at this point in the history
Fixes pytorch#96416
Added tests that do not use LSTM output simalarly to the issue

Seems like this fix once again introduces backward incompatibility.
Pull Request resolved: pytorch#96601
Approved by: https://github.com/albanD, https://github.com/kulinseth
  • Loading branch information
alexdremov authored and pytorchmergebot committed Mar 16, 2023
1 parent b249b44 commit 62eb7a2
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 22 deletions.
18 changes: 13 additions & 5 deletions aten/src/ATen/native/mps/operations/RnnOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@
}
}

std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y,
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const c10::optional<Tensor>& grad_y_opt,
const c10::optional<Tensor>& grad_hy_opt,
const c10::optional<Tensor>& grad_cy_opt,
const Tensor& z_state,
Expand All @@ -387,10 +387,20 @@
bool bidirectional,
bool batch_first) {
using namespace mps;
const Tensor& grad_y_r = c10::value_or_else(grad_y_opt, [] { return Tensor(); });
const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] { return Tensor(); });
const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] { return Tensor(); });
auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx[0], input.options());
auto grad_cy = grad_cy_r.defined() ? grad_cy_r : at::zeros_like(hx[1], input.options());
const auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx[0], input.options());
const auto grad_cy = grad_cy_r.defined() ? grad_cy_r : at::zeros_like(hx[1], input.options());

const auto hidden_size = hx[0].sizes()[2];
const auto batch_size = hx[0].sizes()[1];
const auto seq_len = input.sizes()[batch_first ? 1 : 0];
const auto grad_y = grad_y_r.defined() ? grad_y_r
: at::zeros({batch_first ? batch_size : seq_len,
batch_first ? seq_len : batch_size,
hidden_size * (bidirectional ? 2 : 1)},
input.options());

std::vector<Tensor> kernel_weights;
std::vector<Tensor> recurrent_kernel_weights;
Expand Down Expand Up @@ -515,8 +525,6 @@
NSMutableArray<MPSGraphTensor*>* gradStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];

auto hidden_size = hx[0].sizes()[2];

for (int i = num_layers - 1; i >= 0; i--) {
MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor dimension:0 start:i length:1 name:nil];
zState = [mpsGraph squeezeTensor:zState axis:0 name:nil];
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7213,7 +7213,7 @@
MPS: _lstm_mps
autogen: _lstm_mps.out

- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
- func: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
dispatch:
MPS: lstm_mps_backward
autogen: lstm_mps_backward.out
Expand Down
43 changes: 28 additions & 15 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9878,12 +9878,18 @@ def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)

def get_backward_results(rnn, device, inp, hx, cx):
def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
rnn = rnn.to(device)
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)

output, _ = rnn(inp, (hx, cx))
f = 3 * output.sum() + (hx * cx).sum()
output, (hx_out, cx_out) = rnn(inp, (hx, cx))
assert output_grad_presented or states_grad_presented, "At least some outputs must be used"

f = 0
if output_grad_presented:
f = f + 3 * output.sum()
if states_grad_presented:
f = f + (hx_out * cx_out).sum()

param_names, params = zip(*rnn.named_parameters())
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
Expand All @@ -9892,18 +9898,25 @@ def get_backward_results(rnn, device, inp, hx, cx):
return output, param_grads, input_grad, hx_grad, cx_grad

if backward:
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
get_backward_results(rnn, "cpu", input, hx, cx)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
get_backward_results(rnn, device, input, hx, cx)

self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad,
f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
grad_cases = [
dict(output_grad_presented=True, states_grad_presented=True),
dict(output_grad_presented=False, states_grad_presented=True),
dict(output_grad_presented=True, states_grad_presented=False),
]

for grad_case in grad_cases:
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
get_backward_results(rnn, device, input, hx, cx, **grad_case)

self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad,
f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")

LSTM_TEST_CASES = [
dict(), # default
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2572,7 +2572,7 @@
output_differentiability: [True, True, True, False, False, False]
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"

- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
- name: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])



Expand Down

0 comments on commit 62eb7a2

Please sign in to comment.