Skip to content

Commit

Permalink
dev11 fix from Elias
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov committed Feb 6, 2024
1 parent f2ef38b commit 7342655
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 35 deletions.
2 changes: 1 addition & 1 deletion inference_lib/setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = aqlm
version = 1.0.0dev10
version = 1.0.0dev11
author = AQLM paper authors
author_email = [email protected]
description = Efficiently run models quantized with AQLM
Expand Down
2 changes: 1 addition & 1 deletion inference_lib/src/aqlm/inference_kernels/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ __global__ void Code2x8MatVec(
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++)
sh_code[8 * threadIdx.x + (j + lane) % 8] = dec;
sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();

Expand Down
33 changes: 0 additions & 33 deletions inference_lib/src/aqlm/inference_kernels/cuda_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,3 @@
name="codebook_cuda",
sources=[os.path.join(CUDA_FOLDER, "cuda_kernel.cpp"), os.path.join(CUDA_FOLDER, "cuda_kernel.cu")],
)


def cuda_gemm_2x8(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
input_shape = input.shape
input = input.reshape(-1, input_shape[-1])

device, dtype = codebooks.device, codebooks.dtype
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
in_features = input.shape[1]
out_features = codes.shape[0] * out_group_size
assert input.ndim == 2
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebook_size == 2**8
assert num_codebooks == 2
assert codes.dtype == torch.int8
assert input.dtype == torch.float16 and codebooks.dtype == torch.float16

output = torch.zeros(input.shape[0], out_features, device=device, dtype=dtype)
for i in range(input.shape[0]):
CUDA_KERNEL.code2x8_matvec(
codes.squeeze(2), input[i].unsqueeze(-1), output[i].unsqueeze(-1), codebooks.squeeze(0, 2)
)
output *= scales.flatten().unsqueeze(0)
if bias is not None:
output += bias
return output.reshape(input_shape[:-1] + (-1,))

0 comments on commit 7342655

Please sign in to comment.