Skip to content

Commit

Permalink
Add nvFuser support for aten.native_batch_norm_backward (pytorch#84546)
Browse files Browse the repository at this point in the history
Replacing `tensor.reshape(broadcast_mask)` with unsqueezes makes the implementation of `batch_norm_backward` more friendly for PrimTorch+nvFuser.
Pull Request resolved: pytorch#84546
Approved by: https://github.com/Chillee
  • Loading branch information
IvanYashchuk authored and pytorchmergebot committed Sep 6, 2022
1 parent 7243264 commit 6363b1b
Showing 2 changed files with 45 additions and 7 deletions.
29 changes: 29 additions & 0 deletions test/test_prims.py
Original file line number Diff line number Diff line change
@@ -401,6 +401,35 @@ def func(a):
self.assertFalse(node.target == torch.ops.prims.add.default)
self.assertFalse(node.target == torch.ops.aten.add.default)

@dtypes(torch.float32, torch.float16)
def test_batch_norm_backward_nvprims(self, device, dtype):
# This test verifies that the backward pass of batch norm is correctly decomposed into nvprims
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm

samples_iter = sample_inputs_batch_norm(None, device, dtype, requires_grad=True)
sample = next(samples_iter)
grad = torch.randn_like(sample.input)

def func(grad, input, weight, rm, rv, eps, train):
return torch.ops.aten.native_batch_norm_backward.default(
grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True]
)

args = sample.args
kwargs = sample.kwargs
all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']]
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(*all_args)

call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_batch_norm_backward = any(
torch.ops.aten.native_batch_norm_backward.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_batch_norm_backward)

@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
23 changes: 16 additions & 7 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
@@ -1314,6 +1314,13 @@ def cudnn_batch_norm(
)


def _broadcast_batch_norm_backward(x, broadcast_mask):
for axis, mask in enumerate(broadcast_mask):
if mask == 1 and not (axis < x.ndim and x.shape[axis] == broadcast_mask[axis]):
x = x.unsqueeze(axis)
return x


@register_decomposition(aten.native_batch_norm_backward)
def native_batch_norm_backward(
grad_out: Tensor,
@@ -1372,21 +1379,23 @@ def native_batch_norm_backward(
if i != axis:
reduction_axes.append(i)

mean = torch.reshape(mean, broadcast_mask) # type: ignore[arg-type]
mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type]
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type]
dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes)
dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator]

grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator]
grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask)
proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator]

if weight_cast is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type]
grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type]
else:
grad_scale = torch.reshape(invstd * weight_cast, broadcast_mask)
grad_scale = _broadcast_batch_norm_backward(
invstd * weight_cast, broadcast_mask
)

if train:
proj = (input_cast - mean) * proj_scale
proj = (input_cast - mean) * proj_scale # type: ignore[operator]
grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out_cast * grad_scale

0 comments on commit 6363b1b

Please sign in to comment.