diff --git a/python/test/unit/language/test_standard.py b/python/test/unit/language/test_standard.py index 27f78ab99bf7..b3392d4750c4 100644 --- a/python/test/unit/language/test_standard.py +++ b/python/test/unit/language/test_standard.py @@ -28,7 +28,7 @@ def test_maximum_minium(dtype, op, device): @pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) @pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_sort(M, N, descending, dtype_str, device): @triton.jit @@ -55,7 +55,7 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr @pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) -@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_flip(M, N, dtype_str, device): @triton.jit diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 0b6708491531..c1cae30c9e61 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -4,7 +4,7 @@ from . import core from . import math -# constexpr utilities (triton metaprogramming sucks) +# constexpr utilities def _unwrap_if_constexpr(o): @@ -39,7 +39,7 @@ def cdiv(x, div): :param x: the input number :type x: Block :param div: the divisor - :param div: Block + :type div: Block """ return (x + div - 1) // div @@ -129,7 +129,10 @@ def zeros(shape, dtype): @jit def zeros_like(input): """ - Creates a tensor of zeros with the same shape and type as a given tensor. + Returns a tensor of zeros with the same shape and type as a given tensor. + + :param input: input tensor + :type input: Tensor """ return zeros(input.shape, input.dtype) @@ -368,6 +371,16 @@ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core @core._tensor_member_fn @jit def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + """ + Sorts a tensor along a specified dimension using the bitonic merge-sort algorithm. + + :param x: The input tensor to be sorted. + :type x: Tensor + :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported. + :type dim: int, optional + :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order. + :type descending: bool, optional + """ # handle default dimension or check that it is the most minor dim _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") @@ -423,11 +436,13 @@ def flip(x, dim=None): @jit def interleave(a, b): """ - Interleaves the values of two tensors along their last dimension. - - The two tensors must have the same shape. - + Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape. Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor """ c = core.join(a, b)