Skip to content

Commit

Permalink
Update doc and test for standard (triton-lang#4093)
Browse files Browse the repository at this point in the history
Improve the python/triton/language/standard.py and
python/test/unit/language/test_standard.py
  • Loading branch information
lancerts authored Jun 7, 2024
1 parent bcf3678 commit 4830b2f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
4 changes: 2 additions & 2 deletions python/test/unit/language/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_maximum_minium(dtype, op, device):
@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32'])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_sort(M, N, descending, dtype_str, device):

@triton.jit
Expand All @@ -55,7 +55,7 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr

@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32'])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_flip(M, N, dtype_str, device):

@triton.jit
Expand Down
29 changes: 22 additions & 7 deletions python/triton/language/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from . import core
from . import math

# constexpr utilities (triton metaprogramming sucks)
# constexpr utilities


def _unwrap_if_constexpr(o):
Expand Down Expand Up @@ -39,7 +39,7 @@ def cdiv(x, div):
:param x: the input number
:type x: Block
:param div: the divisor
:param div: Block
:type div: Block
"""
return (x + div - 1) // div

Expand Down Expand Up @@ -129,7 +129,10 @@ def zeros(shape, dtype):
@jit
def zeros_like(input):
"""
Creates a tensor of zeros with the same shape and type as a given tensor.
Returns a tensor of zeros with the same shape and type as a given tensor.
:param input: input tensor
:type input: Tensor
"""
return zeros(input.shape, input.dtype)

Expand Down Expand Up @@ -368,6 +371,16 @@ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core
@core._tensor_member_fn
@jit
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
"""
Sorts a tensor along a specified dimension using the bitonic merge-sort algorithm.
:param x: The input tensor to be sorted.
:type x: Tensor
:param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
:type dim: int, optional
:param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
:type descending: bool, optional
"""
# handle default dimension or check that it is the most minor dim
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
Expand Down Expand Up @@ -423,11 +436,13 @@ def flip(x, dim=None):
@jit
def interleave(a, b):
"""
Interleaves the values of two tensors along their last dimension.
The two tensors must have the same shape.
Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])`
:param a: The first input tensor.
:type a: Tensor
:param b: The second input tensor.
:type b: Tensor
"""
c = core.join(a, b)

Expand Down

0 comments on commit 4830b2f

Please sign in to comment.