Skip to content

Commit

Permalink
TorchDynamo: Add convolution binary+unary fusion for cpu in inference…
Browse files Browse the repository at this point in the history
… mode (pytorch#88412)

This PR is about enabling the fusion of **conv+binary+relu**, which will improve the vision model's performance.

Pull Request resolved: pytorch#88412
Approved by: https://github.com/jgong5, https://github.com/jansel
  • Loading branch information
XiaobingSuper authored and pytorchmergebot committed Nov 14, 2022
1 parent cb4842c commit 072920c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
15 changes: 13 additions & 2 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,7 @@ def __init__(
dilation,
groups,
bias,
has_relu,
**kwargs,
):
super(M, self).__init__()
Expand All @@ -1471,16 +1472,18 @@ def __init__(
)
)
self.binary_fn = binary_fn
self.relu = torch.nn.ReLU() if has_relu else torch.nn.Identity()

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
return self.binary_fn(x1, x2)
return self.relu(self.binary_fn(x1, x2))

test_memory_format = [torch.contiguous_format, torch.channels_last]
options = itertools.product(
binary_list,
[True, False],
[True, False],
[1, 3],
[1, 2],
[1, 4],
Expand All @@ -1489,6 +1492,7 @@ def forward(self, x):

for (
binary_fn,
has_relu,
bias,
kernel_size,
dilation,
Expand All @@ -1499,7 +1503,14 @@ def forward(self, x):
iC = 3 * groups
x_shape = (1, iC, 112, 112)
mod = M(
binary_fn, iC, oC, dilation, groups, bias, kernel_size=kernel_size
binary_fn,
iC,
oC,
dilation,
groups,
bias,
has_relu,
kernel_size=kernel_size,
).eval()
mod = mod.to(memory_format=memory_format)
# TODO: add bf16 test
Expand Down
23 changes: 20 additions & 3 deletions torch/_inductor/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ def _update_module_params(self, conv, binary_op_name):
self.unary_scalars = []
self.unary_algorithm = None

def _update_unary_params(self, unary):
self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[
unary.__class__
](unary)

def _conv_forward(self, input, other, weight, bias):
if self.padding_mode != "zeros":
return torch.ops.mkldnn._convolution_pointwise(
Expand Down Expand Up @@ -226,9 +231,9 @@ def _update_module_params(self, conv, binary_op_name):
self.unary_algorithm = None

def _update_unary_params(self, unary):
self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
unary
)
self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[
unary.__class__
](unary)

def _conv_forward(self, input, other, weight, bias):
if self.padding_mode != "zeros":
Expand Down Expand Up @@ -344,6 +349,13 @@ def fused_conv_binary_inplace_eval(conv: nn.Module, binary_op_name: str):
)


def fused_binary_unary_eval(conv_binary: nn.Module, unary: nn.Module):
assert not (conv_binary.training), "Fusion only for eval!"
# reuse origin conv module, and just update its' unary attr.
conv_binary._update_unary_params(unary)
return conv_binary


def is_bfloat16_module(m):
weight_is_bf16 = m.weight.dtype == torch.bfloat16
bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16
Expand Down Expand Up @@ -430,6 +442,9 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
gm = fuse_unary(gm)
gm = fuse_binary_inplace(gm)
gm = fuse_binary(gm)
# why re-run fuse_unary? we want to enable conv+binary+unary fusion,
# such as conv+add+relu for vision model.
gm = fuse_unary(gm)

return gm

Expand Down Expand Up @@ -741,6 +756,8 @@ def rand_like(x, **kwargs):
computation_op_unary_op_fusion_map = {
nn.Conv2d: fused_conv_unary_eval,
nn.Linear: fused_linear_unary_eval,
ConvBinary2d: fused_binary_unary_eval,
ConvBinaryInplace2d: fused_binary_unary_eval,
}


Expand Down

0 comments on commit 072920c

Please sign in to comment.