Skip to content

Commit

Permalink
[inductor] Fix unsoundness with negative-valued indexing expressions (p…
Browse files Browse the repository at this point in the history
…ytorch#131761)

This fixes a few instances where we assumed indexing expressions were
non-negative. This is not valid when we have more complicated
expressions involving masking e.g. pointwise cat.

Pull Request resolved: pytorch#131761
Approved by: https://github.com/ezyang
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jul 31, 2024
1 parent e74ba1b commit 260c991
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 39 deletions.
4 changes: 2 additions & 2 deletions test/inductor/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def test_indexing_simplification(self):
self.assertEqual(FloorDiv(i0 * 4, 2), i0 * 2)

# Nested modular indexing is correctly simplified
var_ranges = {sympy.Symbol("i1"): 13, sympy.Symbol("i2"): 121}
var_ranges = {i1: 13, i2: 121}
expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28)
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28)
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
var_ranges = {sympy.Symbol("i2"): 784}
var_ranges = {i2: 784}
expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4)
expected = FloorDiv(ModularIndexing(i2, 1, 28), 7)
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected)
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def forward(self, x, y):
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
)
out_feature = 30

for binary_fn, input_shape, bias, dtype in options:
metrics.reset()
# addmm(mm) + (linear+add)
Expand Down
12 changes: 12 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11172,6 +11172,18 @@ def forward():

self.common(forward, ())

def test_flip_cat(self):
def forward(unsqueeze, unsqueeze_1):
cat_1 = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1)
view = torch.ops.aten.view.default(cat_1, [4])
slice_5 = torch.ops.aten.slice.Tensor(view, 0, 0, 3)
rev_1 = torch.ops.aten.flip.default(slice_5, [0])
return (rev_1,)

a = torch.randn(2, 1, requires_grad=True)
b = torch.randn(2, 1, requires_grad=True)
self.common(forward, (a, b))


@dataclasses.dataclass
class TestFailure:
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def run(*ex, **kwargs):
"test_issue102546_dynamic_shapes": TestFailure(("cpu",)),
"test_repeat_as_strided_dynamic_shapes": TestFailure(("cpu",)),
"test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)),
"test_flip_cat_dynamic_shapes": TestFailure(("cpu",)),
#
# Failed to find for loop/triton kernel:
#
Expand Down
30 changes: 14 additions & 16 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def gen_common_triton_imports():


block_offsets = {
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True)
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}

block_sizes = {
symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, nonzero=True)
symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}

Expand Down Expand Up @@ -383,24 +383,22 @@ def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"

# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right). If you are trying to hit this, maybe try something like
# torch.arange(n, device="cuda") - 1 and then do a modulus on it
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
quot, div = expr.args
quot_s = self._print(quot)
div_s = self._print(div)
if quot.is_nonnegative and div.is_nonnegative:
return f"{self.paren(quot_s)} % {self.paren(div_s)}"
return f"triton_helpers.remainder_integer({quot_s}, {div_s})"

# TODO: This is wrong, see
# https://github.com/triton-lang/triton/issues/955
# But for Sympy expressions, things will /mostly/ work out because we
# don't usually deal with negative numbers in the division
def _print_FloorDiv(self, expr):
assert expr.is_integer
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
quot, div = expr.args
quot_s = self._print(quot)
div_s = self._print(div)
if quot.is_nonnegative and div.is_nonnegative:
return f"({self.paren(quot_s)} // {self.paren(div_s)})"
return f"triton_helpers.div_floor_integer({quot_s}, {div_s})"

# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
# precision algorithm, which we would need to replicate here
Expand Down
17 changes: 17 additions & 0 deletions torch/_inductor/runtime/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ def promote_to_tensor(x):
return x + tl.zeros((1,), tl.int1)


@triton.jit
def div_floor_integer(a, b):
# NOTE: a // b is C division, but we want floor division
# Based on c10::div_floor_integer
quot = a // b
remainder = a % b
fixed = tl.where(remainder != 0, quot - 1, quot)
return tl.where((a < 0) != (b < 0), fixed, quot)


@triton.jit
def remainder_integer(a, b):
# NOTE: a % b matches C division, not floor division
remainder = a % b
return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder)


@triton.jit
def is_floating(x):
return promote_to_tensor(x).dtype.is_floating()
Expand Down
66 changes: 45 additions & 21 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Expand All @@ -21,10 +22,11 @@
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy
from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges

from .runtime.runtime_utils import is_power_of_2
from .utils import (
has_free_symbols,
sympy_index_symbol,
sympy_index_symbol_with_prefix,
sympy_subs,
Expand Down Expand Up @@ -119,8 +121,43 @@ def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr:
expr = join_dimensions(self.simplify(expr))
original_expr = expr

var_to_range = dict(self.shape_env.var_to_range)
var_to_range.update(
{
k: ValueRanges(
0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity()
)
for k, v in var_ranges.items()
}
)
for var in expr.free_symbols:
if var not in var_to_range:
var_to_range[var] = ValueRanges(0, IntInfinity())

var_to_range_tuple = cast(
Tuple[Tuple[sympy.Symbol, ValueRanges[sympy.Expr]]],
tuple(var_to_range.items()),
)

axioms = []
for var, upper_bound in var_ranges.items():
axioms.append(0 <= var)
axioms.append(var < upper_bound)
axioms = tuple(axioms) + self.shape_env.get_axioms()

def statically_known(expr):
evaluated = self.shape_env._maybe_evaluate_static(
expr,
axioms=axioms,
var_to_range=var_to_range_tuple,
)
return bool(evaluated)

def remove_zero_terms(base, divisor):
"""Symbols smaller than the divisor are zero"""
if not statically_known(base >= 0):
return base

for v in base.free_symbols:
if v in var_ranges:
# var smaller than divisor can be removed
Expand All @@ -130,7 +167,7 @@ def remove_zero_terms(base, divisor):
if m and v not in m[rest].free_symbols:
gcd = sympy.gcd(m[rest], divisor)
if gcd == divisor:
if self.statically_known_leq(var_ranges[v], divisor):
if statically_known(v < divisor):
base = m[rest]
return base

Expand All @@ -139,25 +176,12 @@ def visit_indexing_div(base, divisor):

def visit_modular_indexing(base, divisor, modulus):
base = remove_zero_terms(base, divisor)
base_pos = True
if isinstance(base, ModularIndexing):
# for modular indexing, biggest values from the ranges don't necessarily result in
# the biggest result, the biggest result is modulus - 1
base_s = base.args[2] - 1
elif not base.has(ModularIndexing):
# actual iteration range is to size-1
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
base_lowest = sympy_subs(base, iter_ranges_zero)
if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type]
# can't replace with indexing div if base can be negative
base_pos = True
else:
base_pos = False
iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
base_s = sympy_subs(base, iter_ranges)
else:
base_s = base
if self.statically_known_lt(base_s, modulus * divisor) and base_pos:

can_remove_mod = statically_known(base >= 0) and statically_known(
base < modulus * divisor
)

if can_remove_mod:
return FloorDiv(base, divisor)
return ModularIndexing(base, divisor, modulus)

Expand Down

0 comments on commit 260c991

Please sign in to comment.