Skip to content

Commit

Permalink
[inductor] Avoid fallback case for custom scan op lowering (pytorch#1…
Browse files Browse the repository at this point in the history
…30936)

We currently can't generate split scans when there are multiple scan
values, so we normally fall back to ATen. However, for the higher order
scan op, we can't fallback so it makes sense to just generate the slower
kernel anyway. This avoids having special shapes where we fail to
codegen.

Pull Request resolved: pytorch#130936
Approved by: https://github.com/lezcano
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jul 18, 2024
1 parent 367213a commit e7f7c5c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
39 changes: 39 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,6 +1910,45 @@ def argmax_combine(a, b):
actual = associative_scan(argmax_combine, (a, idx), 0)
self.assertEqual(expect, actual)

@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
@skip_if_halide # scan ops
def test_custom_scan_would_split(self):
if self.device != "cuda":
raise unittest.SkipTest("associative_scan only supported on GPU")

def combine_linear_recurrence(left, right):
xl, fl = left
xr, fr = right
x = xl * fr + xr
f = fl * fr
return x, f

def eager_scan(x, g):
x, g = x.to(torch.float64), g.to(torch.float64)
x_out = torch.empty_like(x)
g_out = torch.empty_like(g)
x_out[:, 0] = x[:, 0]
g_out[:, 0] = g[:, 0]
for i in range(1, x.shape[1]):
x_out[:, i], g_out[:, i] = combine_linear_recurrence(
(x_out[:, i - 1], g_out[:, i - 1]),
(x[:, i], g[:, i]),
)
return x_out.float(), g_out.float()

@torch.compile
def compiled_scan(x, f):
from torch._higher_order_ops.associative_scan import associative_scan

x, f = associative_scan(combine_linear_recurrence, (x, f), dim=1)
return x, f

x = torch.randn(1, 129, 2, device=self.device)
f = torch.randn(1, 129, 2, device=self.device)
expect = eager_scan(x, f)
actual = compiled_scan(x, f)
self.assertEqual(expect, actual)

def test_embedding_bag_byte_unpack(self):
if self.device != "cpu":
raise unittest.SkipTest(f"No {GPU_TYPE} implementation (it returns empty)")
Expand Down
23 changes: 14 additions & 9 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,9 @@ def create(
axis: int,
combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]],
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
*,
# Whether we have the option to fallback to aten
can_fallback_to_aten: bool = True,
**kwargs,
) -> List[Optional[TensorBox]]:
pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
Expand Down Expand Up @@ -1711,15 +1714,17 @@ def create(
combine_fn=combine_fn,
scan_numel=scan_numel,
)
scan_type = Scan if num_splits <= 1 else SplitScan

if num_splits > 1 and torch.version.hip is not None:
# Fallback for split-scan on ROCm
return [None] * len(dtypes)

if num_splits > 1 and len(dtypes) > 1:
# Fallback for split-scans for multiple inputs
return [None] * len(dtypes)
scan_type = Scan
if num_splits > 1:
supports_split = torch.version.hip is None and len(dtypes) == 1
if not supports_split:
if can_fallback_to_aten:
# Fallback to ATen
return [None] * len(dtypes)
else:
num_splits = 1
else:
scan_type = SplitScan

def reindex(index, scan_index):
assert len(scan_index) == len(scan_ranges)
Expand Down
6 changes: 5 additions & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6192,7 +6192,11 @@ def wrapped_combine_fn(lhs, rhs):
kwargs = _make_scan_inner(input[0], axis=dim, dtype=None)
kwargs["dtypes"] = tuple(x.get_dtype() for x in input)
kwargs["inner_fns"] = tuple(x.make_loader() for x in input)
result = ir.Scan.create(**kwargs, combine_fn=wrapped_combine_fn)
result = ir.Scan.create(
combine_fn=wrapped_combine_fn,
can_fallback_to_aten=False,
**kwargs,
)
if result[0] is None:
raise RuntimeError("Unable to generate code for associative_scan op")
return result
Expand Down

0 comments on commit e7f7c5c

Please sign in to comment.