Skip to content

Commit

Permalink
numba kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov committed Jan 30, 2024
1 parent d7c4561 commit 823db17
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
3 changes: 3 additions & 0 deletions inference_lib/src/aqlm/inference_kernels/kernel_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from aqlm.utils import _dequantize_weight, unpack_int_data

from .numba import numba_gemm_lut
from .triton_kernel import triton_matmul


Expand All @@ -27,6 +28,8 @@ def forward_pass_quantized_linear(
return cuda_gemm_2x8(input, codes, codebooks, scales, bias)
case (True, _, _, _, _):
return triton_matmul(input, codes, codebooks, scales, bias)
case (False, _, 256, 1, _):
return numba_gemm_lut(input, codes, codebooks, scales, bias)
case _:
dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
Expand Down
63 changes: 63 additions & 0 deletions inference_lib/src/aqlm/inference_kernels/numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Optional

import numba
import numpy as np
import torch

COMPILED_KERNELS = {}


def numba_gemm_lut(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
input_shape = input.shape
input = input.reshape(-1, input_shape[-1])

device, dtype = codebooks.device, codebooks.dtype
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
in_features = input.shape[1]
out_features = codes.shape[0] * out_group_size
assert input.ndim == 2
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebook_size == 2**8
assert codes.dtype == torch.int8
assert input.dtype == torch.float32 and codebooks.dtype == torch.float32

kernel_key = (in_group_size, out_features, in_features, num_codebooks)
if kernel_key not in COMPILED_KERNELS:

@numba.njit(nopython=True, parallel=False)
def numba_gemv_lut_(x, codebooks, codes_alt, scales):
lut = x.reshape(-1, in_group_size) @ codebooks.reshape(-1, in_group_size).T
lut = lut.reshape(-1, num_codebooks, codebook_size)

output_vec = np.zeros(out_features, dtype=x.dtype)
for j in range(in_features // in_group_size):
for i in range(out_features):
for c in range(num_codebooks):
output_vec[i] += lut[j, c, codes_alt[j, i, c]]
output_vec *= scales.flatten()
return output_vec

COMPILED_KERNELS[kernel_key] = numba_gemv_lut_
compiled_kernel = COMPILED_KERNELS[kernel_key]

output = torch.zeros(input.shape[0], out_features, device=device, dtype=dtype)
for i in range(input.shape[0]):
output[i] = torch.tensor(
compiled_kernel(
input[i].numpy(),
codebooks.numpy(),
torch.permute(codes, (1, 0, 2)).contiguous().numpy(),
scales.numpy(),
)
)
output *= scales.flatten().unsqueeze(0)
if bias is not None:
output += bias
return output.reshape(input_shape[:-1] + (-1,))

0 comments on commit 823db17

Please sign in to comment.