Skip to content

Commit

Permalink
[FRONTEND] Fix expand_dims and tl.full to handle scalar tensors (trit…
Browse files Browse the repository at this point in the history
…on-lang#2275)

This fixes a few bugs related to scalar tensors:
- `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type
is forbidden')`
- `scalar[None]` fails with `TypeError("'constexpr' object is not
iterable")`
- `scalar[None, None]` fails with `AttributeError("'dtype' object has no
attribute 'shape'")`
- `scalar.shape` returns `[1]` instead of 0-dim `[]`
- Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of
another scalar
  • Loading branch information
peterbell10 authored Sep 12, 2023
1 parent bf4f937 commit ab9da3b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 44 deletions.
91 changes: 59 additions & 32 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device):
# test broadcast
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype):
def test_broadcast(dtype, device):
@triton.jit
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
Expand All @@ -585,41 +585,42 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con
y = numpy_random(N, dtype_str=dtype, rs=rs)
_, y_broadcasted_np = np.broadcast_arrays(x, y)

x_tri = to_triton(x, device='cuda', dst_type=dtype)
y_tri = to_triton(y, device='cuda', dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype)
x_tri = to_triton(x, device=device, dst_type=dtype)
y_tri = to_triton(y, device=device, dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype)

broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()

# ----------
# test slice
# ----------


def test_slice(device):

# ---------------
# test broadcast
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype, device):
@triton.jit
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
offset2 = tl.arange(0, N)
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
y = tl.load(y_ptr + offset2)
_, y_broadcasted = tl.broadcast(x, y)
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)
def slice_kernel(XBLOCK: tl.constexpr):
data = tl.arange(0, XBLOCK)
tl.static_assert(data.shape == [XBLOCK])

M = 32
N = 64
rs = RandomState(17)
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
y = numpy_random(N, dtype_str=dtype, rs=rs)
_, y_broadcasted_np = np.broadcast_arrays(x, y)
t = data[None, :]
tl.static_assert(t.shape == [1, XBLOCK])

x_tri = to_triton(x, device=device, dst_type=dtype)
y_tri = to_triton(y, device=device, dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype)
t = data[None, :, None]
tl.static_assert(t.shape == [1, XBLOCK, 1])

scalar = tl.full([], 1, tl.int32)
tl.static_assert(scalar.shape == [])

t = scalar[None]
tl.static_assert(t.shape == [1])

t = scalar[None, None]
tl.static_assert(t.shape == [1, 1])

slice_kernel[(1,)](XBLOCK=32)

broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()

# ------------------
# test invalid slice
Expand Down Expand Up @@ -669,6 +670,14 @@ def expand_dims_kernel(dummy, N: tl.constexpr):
t = tl.expand_dims(offset1, (3, 1, 2))
tl.static_assert(t.shape == [N, 1, 1, 1])

scalar = tl.sum(offset1)
tl.static_assert(scalar.shape == [])
t = tl.expand_dims(scalar, 0)
tl.static_assert(t.shape == [1])

t = tl.expand_dims(scalar, -1)
tl.static_assert(t.shape == [1])

N = 32
dummy_tensor = torch.empty((), device=device)
expand_dims_kernel[(1,)](dummy_tensor, N)
Expand All @@ -689,6 +698,13 @@ def dim_out_of_range2(dummy, N: tl.constexpr):
t = tl.expand_dims(offset1, 1)
t = tl.expand_dims(offset1, 2)

@triton.jit
def dim_out_of_range3(dummy, N: tl.constexpr):
offset1 = tl.arange(0, 1)
scalar = tl.sum(offset1)

t = tl.expand_dims(scalar, 1)

@triton.jit
def duplicate_dim1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
Expand All @@ -710,6 +726,9 @@ def duplicate_dim2(dummy, N: tl.constexpr):
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
dim_out_of_range2[(1,)](dummy_tensor, N)

with pytest.raises(triton.CompilationError, match="invalid axis 1"):
dim_out_of_range3[(1,)](dummy_tensor, N)

with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
duplicate_dim1[(1,)](dummy_tensor, N)

Expand Down Expand Up @@ -2467,7 +2486,8 @@ def kernel(Z, X, Y,


@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str, device):
@pytest.mark.parametrize("shape", [(), (1,), (128,)])
def test_full(dtype_str, shape, device):
if dtype_str in uint_dtypes and not hasattr(torch, dtype_str):
# PyTorch only has unsigned 8, but not 16, 32, or 64
dtype = getattr(torch, dtype_str[1:]) # uintx -> intx
Expand All @@ -2478,21 +2498,28 @@ def test_full(dtype_str, device):
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
tl.static_assert(a.shape == SHAPE)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)

@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
a = tl.full(SHAPE, val, dtype)
tl.static_assert(a.shape == SHAPE)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)

kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
kernel_static_patched = patch_kernel(kernel_static, {
'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})",
'SHAPE': str(list(shape)),
})
out_static = torch.zeros((128), dtype=dtype, device=device)
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)

kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))})
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_dynamic == 2)


Expand Down
6 changes: 2 additions & 4 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,7 @@ def __init__(self, handle, type: dtype):
# IR handle
self.handle = handle
# Block shape
self.shape = (1, )
if type.is_block():
self.shape = type.shape
self.shape = type.shape if type.is_block() else ()
self.numel = 1
for s in self.shape:
self.numel *= s
Expand Down Expand Up @@ -743,7 +741,7 @@ def __not__(self, _builder=None):

@builtin
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
if isinstance(slices, (slice, constexpr)):
slices = [slices]
ret = self
for dim, sl in enumerate(slices):
Expand Down
26 changes: 18 additions & 8 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,25 +501,31 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te
if isinstance(value, tl.tensor):
assert value.numel.value == 1, "only accepts size-1 tensor"
value = cast(value, dtype, builder)
ret_ty = tl.block_type(value.dtype, shape)
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
else:
# scalar
if dtype is None:
raise ValueError("dtype must be specified when value is not a tensor")
if value == 0:
value = builder.get_null_value(dtype.to_ir(builder))
else:
get_value_fn = getattr(builder, f"get_{dtype.name}")
value = get_value_fn(value)
if dtype is None:
raise ValueError("dtype must be specified when value is not a tensor")
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(value, shape), ret_ty)
value = tl.tensor(value, dtype)

return splat(value, shape, builder)


# ===----------------------------------------------------------------------===//
# Shape Manipulation
# ===----------------------------------------------------------------------===//

def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
assert not value.type.is_block(), "Cannot splat a block tensor"
if len(shape) == 0:
return value
ret_ty = tl.block_type(value.dtype, shape)
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)


def view(input: tl.tensor,
dst_shape: List[int],
Expand All @@ -544,8 +550,12 @@ def reshape(input: tl.tensor,


def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = list(input.type.shape)
dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
dst_shape.insert(axis, 1)

if not input.type.is_block():
return splat(input, shape=dst_shape, builder=builder)

ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)

Expand Down Expand Up @@ -1506,7 +1516,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor:


def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
if max(1, len(x.shape)) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
Expand Down

0 comments on commit ab9da3b

Please sign in to comment.