Skip to content

Commit

Permalink
triton kernel fix (still not working)
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackSamorez committed Feb 17, 2024
1 parent ab842da commit 43f1c3a
Showing 1 changed file with 40 additions and 73 deletions.
113 changes: 40 additions & 73 deletions inference_lib/src/aqlm/inference_kernels/triton_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"num_input_groups",
"num_input_groups_next_power_of_2",
"compute_in_fp32",
"has_output_scale",
"has_bias",
],
)
Expand All @@ -42,11 +43,12 @@ def _aqlm_gemv_simple(
num_input_groups: tl.constexpr,
num_input_groups_next_power_of_2: tl.constexpr,
compute_in_fp32: tl.constexpr,
has_output_scale: tl.constexpr,
has_bias: tl.constexpr,
UNUSED: tl.constexpr,
):
# variables ending with "_i" mean "for i-th output unit"
pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
pid = tl.program_id(axis=0) # [0, 1, ... {num_out_groups-1}]

# Stage 1: load input data
input_vec = tl.load(
Expand All @@ -60,6 +62,7 @@ def _aqlm_gemv_simple(
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
# , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
dtype = input_vec.dtype

# Stage 2: load integer codes for the active row
# [in_features // in_group_size, num_codebooks]
Expand Down Expand Up @@ -101,65 +104,28 @@ def _aqlm_gemv_simple(
input_vec = input_vec.to(tl.float32)
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]

output_i = weights_i * input_vec # [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]

if out_group_size == 1:
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
output_i = tl.sum(weights_i * input_vec) * scale
if has_bias:
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
output_i = tl.sum(output_i) # []
else:
output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
if has_bias:
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))


def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
output_i = tl.sum(output_i, axis=1) # [in_features // in_group_size, out_group_size, in_group_size]
output_i = tl.sum(output_i, axis=2) # [in_features // in_group_size, out_group_size]
output_i = tl.sum(output_i, axis=0) # [out_group_size]

if has_output_scale:
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
if has_bias:
output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)

def aqlm_gemv_simple(
input_vec: torch.Tensor,
codes_i16: torch.ShortTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
compute_in_fp32: bool = True,
):
device, dtype = codebooks.device, codebooks.dtype
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
in_features = input_vec.shape[1]
out_features = codes_i16.shape[0] * out_group_size
num_input_groups = codes_i16.shape[1]
assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebooks.shape[1] < 2**32
if out_group_size == 1:
tl.store(output_vec_ptr + pid, output_i.to(dtype))
else:
tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(dtype))

output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
# 1D launch kernel where each block computes output unit
grid = lambda META: (out_features // out_group_size,)
_aqlm_gemv_simple[grid](
input_vec,
output_vec,
codes_i16,
codebooks,
scales,
bias,
in_features,
out_features,
num_codebooks,
codebook_size,
out_group_size,
in_group_size,
num_input_groups,
next_power_of_2(num_input_groups),
compute_in_fp32,
bias is not None,
)

return output_vec
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()


def aqlm_gemm_stupid(
Expand All @@ -176,10 +142,20 @@ def aqlm_gemm_stupid(
out_features = codes_i16.shape[0] * out_group_size
num_input_groups = codes_i16.shape[1]
assert input.ndim == 2
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebooks.shape[1] < 2**32

if scales.shape == (out_features // out_group_size, 1, 1, 1):
has_output_scales = True
elif scales.shape == (1, in_features // in_group_size, 1, 1) and in_group_size == 1:
has_output_scales = False
input *= scales.squeeze()
else:
raise NotImplementedError(f"Can't do Triton AQLM matmul with scales of shape {scales.shape}")

if not input.is_contiguous():
raise Exception("AAAA")

output = torch.empty(input.shape[0], out_features, device=device, dtype=dtype)
for i in range(input.shape[0]):
# 1D launch kernel where each block computes output unit
Expand All @@ -200,6 +176,7 @@ def aqlm_gemm_stupid(
num_input_groups,
next_power_of_2(num_input_groups),
compute_in_fp32,
has_output_scales,
bias is not None,
)

Expand All @@ -217,21 +194,11 @@ def triton_matmul(
input_shape = input.shape
input = input.reshape(-1, input_shape[-1])

if input.shape[0] == 1:
return aqlm_gemv_simple(
input,
codes,
codebooks,
scales,
bias,
compute_in_fp32,
).reshape(input_shape[:-1] + (-1,))
else:
return aqlm_gemm_stupid(
input,
codes,
codebooks,
scales,
bias,
compute_in_fp32,
).reshape(input_shape[:-1] + (-1,))
return aqlm_gemm_stupid(
input,
codes,
codebooks,
scales,
bias,
compute_in_fp32,
).reshape(input_shape[:-1] + (-1,))

0 comments on commit 43f1c3a

Please sign in to comment.