Skip to content

Commit

Permalink
[inductor] Fix can_merge check for expr=q0*q1 (pytorch#129806)
Browse files Browse the repository at this point in the history
Fixes pytorch#111884

In the minimised reproducer, we have a loop with the index expression `-q0*q1`
for which in the merge tester we get:
```
expr1 = - 0 * (_merge_tester * 16) = 0
expr2 = - _merge_tester * 0 = 0
```
so it decides we can merge the dimensions and `q0` is set to `0`, meaning `-q0*q1` is always zero!

Here I change the test so we have at least one case where no zeros are
substituted so we can catch this situation. In the normal strided case we get
e.g.
```
expr = 16 * q0 + q1
expr1 = 16 * _merge_tester2 + (16 * _merge_tester1)
expr2 = 16 * (_merge_tester2 + _merge_tester1)
```
which are still equivalent expressions.

Pull Request resolved: pytorch#129806
Approved by: https://github.com/lezcano
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jul 2, 2024
1 parent 37e3c60 commit dc75ec2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
22 changes: 22 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10831,6 +10831,28 @@ def forward(float_1, view_1):

self.common(forward, (a, b))

def test_mul_index_expr(self):
# Minified repro from https://github.com/pytorch/pytorch/issues/111884
def forward():
iota = torch.ops.prims.iota.default(
16,
start=0,
step=1,
dtype=torch.int64,
device=self.device,
requires_grad=False,
)
unsqueeze = torch.ops.aten.unsqueeze.default(iota, -1)
mul = torch.ops.aten.mul.Tensor(unsqueeze, iota)
unsqueeze = iota = None
neg = torch.ops.aten.neg.default(mul)
mul = None
div = torch.ops.aten.div.Tensor(neg, 16)
neg = None
return (div,)

self.common(forward, ())


@dataclasses.dataclass
class TestFailure:
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def run(*ex, **kwargs):
"test_uint_dynamic_shapes": TestFailure(("cpu",)),
"test_issue102546_dynamic_shapes": TestFailure(("cpu",)),
"test_repeat_as_strided_dynamic_shapes": TestFailure(("cpu",)),
"test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)),
#
# Failed to find for loop/triton kernel:
#
Expand Down
9 changes: 6 additions & 3 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ def can_merge_dims(a, b):
# approximate test passed, try sound version
va = index_vars[a]
vb = index_vars[b]
v = sympy_index_symbol("_merge_tester")
expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
m1 = sympy_index_symbol("_merge_tester1")
m2 = sympy_index_symbol("_merge_tester2")
# NOTE: can't sub vb=0 here in case va * vb appears in the expression,
# in which case both expr1 and expr2 would be zero!
expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2})
expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)})
if self.simplify(expr1) == self.simplify(expr2):
continue
return False
Expand Down

0 comments on commit dc75ec2

Please sign in to comment.