From ab9da3b2b86931b3dd7d5b15516722b3e3c6352e Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Tue, 12 Sep 2023 04:59:13 +0100 Subject: [PATCH] [FRONTEND] Fix expand_dims and tl.full to handle scalar tensors (#2275) This fixes a few bugs related to scalar tensors: - `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type is forbidden')` - `scalar[None]` fails with `TypeError("'constexpr' object is not iterable")` - `scalar[None, None]` fails with `AttributeError("'dtype' object has no attribute 'shape'")` - `scalar.shape` returns `[1]` instead of 0-dim `[]` - Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of another scalar --- python/test/unit/language/test_core.py | 91 +++++++++++++++++--------- python/triton/language/core.py | 6 +- python/triton/language/semantic.py | 26 +++++--- 3 files changed, 79 insertions(+), 44 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f962facb4696..5bde5f1ddd94 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -568,7 +568,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # test broadcast # --------------- @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) -def test_broadcast(dtype): +def test_broadcast(dtype, device): @triton.jit def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): offset1 = tl.arange(0, M) @@ -585,41 +585,42 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con y = numpy_random(N, dtype_str=dtype, rs=rs) _, y_broadcasted_np = np.broadcast_arrays(x, y) - x_tri = to_triton(x, device='cuda', dst_type=dtype) - y_tri = to_triton(y, device='cuda', dst_type=dtype) - y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() +# ---------- +# test slice +# ---------- + + +def test_slice(device): -# --------------- -# test broadcast -# --------------- -@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) -def test_broadcast(dtype, device): @triton.jit - def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): - offset1 = tl.arange(0, M) - offset2 = tl.arange(0, N) - x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) - y = tl.load(y_ptr + offset2) - _, y_broadcasted = tl.broadcast(x, y) - tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) - M = 32 - N = 64 - rs = RandomState(17) - x = numpy_random((M, N), dtype_str=dtype, rs=rs) - y = numpy_random(N, dtype_str=dtype, rs=rs) - _, y_broadcasted_np = np.broadcast_arrays(x, y) + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) - x_tri = to_triton(x, device=device, dst_type=dtype) - y_tri = to_triton(y, device=device, dst_type=dtype) - y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1,)](XBLOCK=32) - broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) - assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() # ------------------ # test invalid slice @@ -669,6 +670,14 @@ def expand_dims_kernel(dummy, N: tl.constexpr): t = tl.expand_dims(offset1, (3, 1, 2)) tl.static_assert(t.shape == [N, 1, 1, 1]) + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + N = 32 dummy_tensor = torch.empty((), device=device) expand_dims_kernel[(1,)](dummy_tensor, N) @@ -689,6 +698,13 @@ def dim_out_of_range2(dummy, N: tl.constexpr): t = tl.expand_dims(offset1, 1) t = tl.expand_dims(offset1, 2) + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + @triton.jit def duplicate_dim1(dummy, N: tl.constexpr): offset1 = tl.arange(0, N) @@ -710,6 +726,9 @@ def duplicate_dim2(dummy, N: tl.constexpr): with pytest.raises(triton.CompilationError, match="invalid axis 2"): dim_out_of_range2[(1,)](dummy_tensor, N) + with pytest.raises(triton.CompilationError, match="invalid axis 1"): + dim_out_of_range3[(1,)](dummy_tensor, N) + with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): duplicate_dim1[(1,)](dummy_tensor, N) @@ -2467,7 +2486,8 @@ def kernel(Z, X, Y, @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) -def test_full(dtype_str, device): +@pytest.mark.parametrize("shape", [(), (1,), (128,)]) +def test_full(dtype_str, shape, device): if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): # PyTorch only has unsigned 8, but not 16, 32, or 64 dtype = getattr(torch, dtype_str[1:]) # uintx -> intx @@ -2478,21 +2498,28 @@ def test_full(dtype_str, device): @triton.jit def kernel_static(out): a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) out_ptr = out + tl.arange(0, 128)[:] tl.store(out_ptr, a) @triton.jit def kernel_dynamic(out, val, dtype: tl.constexpr): - a = tl.full((128,), val, dtype) + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) out_ptr = out + tl.arange(0, 128)[:] tl.store(out_ptr, a) - kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"}) + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) out_static = torch.zeros((128), dtype=dtype, device=device) kernel_static_patched[(1,)](out_static) - out_dynamic = torch.zeros((128), dtype=dtype, device=device) - kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_dynamic == 2) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3caf8ee3fea5..3b9205f9d02a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -531,9 +531,7 @@ def __init__(self, handle, type: dtype): # IR handle self.handle = handle # Block shape - self.shape = (1, ) - if type.is_block(): - self.shape = type.shape + self.shape = type.shape if type.is_block() else () self.numel = 1 for s in self.shape: self.numel *= s @@ -743,7 +741,7 @@ def __not__(self, _builder=None): @builtin def __getitem__(self, slices, _builder=None): - if isinstance(slices, slice): + if isinstance(slices, (slice, constexpr)): slices = [slices] ret = self for dim, sl in enumerate(slices): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 1df2376da406..49597e4cb7af 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -501,25 +501,31 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te if isinstance(value, tl.tensor): assert value.numel.value == 1, "only accepts size-1 tensor" value = cast(value, dtype, builder) - ret_ty = tl.block_type(value.dtype, shape) - return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) else: # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") if value == 0: value = builder.get_null_value(dtype.to_ir(builder)) else: get_value_fn = getattr(builder, f"get_{dtype.name}") value = get_value_fn(value) - if dtype is None: - raise ValueError("dtype must be specified when value is not a tensor") - ret_ty = tl.block_type(dtype, shape) - return tl.tensor(builder.create_splat(value, shape), ret_ty) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) # ===----------------------------------------------------------------------===// # Shape Manipulation # ===----------------------------------------------------------------------===// +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + def view(input: tl.tensor, dst_shape: List[int], @@ -544,8 +550,12 @@ def reshape(input: tl.tensor, def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - dst_shape = list(input.type.shape) + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + ret_ty = tl.block_type(input.type.scalar, dst_shape) return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) @@ -1506,7 +1516,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor: def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: - if len(x.shape) != len(values): + if max(1, len(x.shape)) != len(values): raise ValueError("Shape of input to multiple_of does not match the length of values") x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) return x