forked from Vahe1994/AQLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request Vahe1994#26 from Vahe1994/backward
Integrating with `autograd`, adding static kernel routing.
- Loading branch information
Showing
9 changed files
with
244 additions
and
129 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[metadata] | ||
name = aqlm | ||
version = 1.0.0 | ||
version = 1.0.2 | ||
author = AQLM paper authors | ||
author_email = [email protected] | ||
description = Efficiently run models quantized with AQLM | ||
|
@@ -15,6 +15,8 @@ classifiers = | |
Intended Audience :: Science/Research | ||
License :: OSI Approved :: MIT License | ||
Programming Language :: Python :: 3 | ||
Programming Language :: Python :: 3.8 | ||
Programming Language :: Python :: 3.9 | ||
Programming Language :: Python :: 3.10 | ||
Programming Language :: Python :: 3.11 | ||
Topic :: Scientific/Engineering | ||
|
@@ -29,7 +31,7 @@ package_dir = | |
= src | ||
packages = find: | ||
include_package_data = True | ||
python_requires = >=3.10 | ||
python_requires = >=3.8 | ||
install_requires = | ||
torch>=2.1.1 | ||
transformers>=4.37.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
import aqlm.inference_kernels | ||
from aqlm.inference import QuantizedLinear | ||
from aqlm.inference_kernels import optimize_for_training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .kernel_selector import forward_pass_quantized_linear | ||
from .kernel_selector import get_backward_pass_kernel, get_forward_pass_kernel, optimize_for_training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
21 changes: 21 additions & 0 deletions
21
inference_lib/src/aqlm/inference_kernels/dequantization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from aqlm.utils import _dequantize_weight, unpack_int_data | ||
from torch import nn | ||
|
||
|
||
def dequantize_gemm( | ||
input: torch.Tensor, # [..., in_features] | ||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] | ||
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] | ||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
dequantized_weight = _dequantize_weight( | ||
unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), | ||
codebooks, | ||
scales, | ||
) | ||
return F.linear(input, dequantized_weight, bias) |
122 changes: 85 additions & 37 deletions
122
inference_lib/src/aqlm/inference_kernels/kernel_selector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,94 @@ | ||
from typing import Optional | ||
from contextlib import contextmanager | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from aqlm.utils import _dequantize_weight, unpack_int_data | ||
|
||
_OPTIMIZE_FOR_TRAINING = False | ||
|
||
def forward_pass_quantized_linear( | ||
input: torch.Tensor, | ||
codes: torch.IntTensor, | ||
|
||
@contextmanager | ||
def optimize_for_training(): | ||
"""Use this context manager during model initialization (e.g. `.from_pretrained(...)`) to select inference kernels optimized for larger batch sizes""" | ||
global _OPTIMIZE_FOR_TRAINING | ||
_OPTIMIZE_FOR_TRAINING = True | ||
try: | ||
yield | ||
finally: | ||
_OPTIMIZE_FOR_TRAINING = False | ||
|
||
|
||
def get_forward_pass_kernel( | ||
codebooks: torch.Tensor, | ||
scales: torch.Tensor, | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
optimize_for_training: bool, | ||
) -> Callable[[torch.Tensor, torch.IntTensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: | ||
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape | ||
match (input.is_cuda, num_codebooks, codebook_size, out_group_size, in_group_size): | ||
case (True, 1, 65536, 1, 8): | ||
from .cuda_kernel import CUDA_KERNEL | ||
|
||
assert ( | ||
input.dtype == torch.float16 | ||
), f"please load the model with `torch_dtype=torch.float16`, as {input.dtype} is not supported on GPU yet" | ||
return CUDA_KERNEL.code1x16_matmat(input, codes, codebooks, scales) + (bias if bias is not None else 0) | ||
case (True, 2, 256, 1, 8): | ||
from .cuda_kernel import CUDA_KERNEL | ||
|
||
assert ( | ||
input.dtype == torch.float16 | ||
), f"please load the model with `torch_dtype=torch.float16`, as {input.dtype} is not supported on GPU yet" | ||
return CUDA_KERNEL.code2x8_matmat(input, codes, codebooks, scales) + (bias if bias is not None else 0) | ||
case (True, _, _, _, _): | ||
from .triton_kernel import triton_matmul | ||
|
||
return triton_matmul(input, codes, codebooks, scales, bias) | ||
case (False, _, 256, 1, _): | ||
from .numba_kernel import numba_gemm_lut | ||
|
||
return numba_gemm_lut(input, codes, codebooks, scales, bias) | ||
case _: | ||
dequantized_weight = _dequantize_weight( | ||
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1), | ||
codebooks, | ||
scales, | ||
) | ||
return F.linear(input, dequantized_weight, bias) | ||
|
||
if (optimize_for_training, codebooks.device.type, num_codebooks, codebook_size, out_group_size, in_group_size) == ( | ||
False, | ||
"cuda", | ||
1, | ||
65536, | ||
1, | ||
8, | ||
): | ||
from .cuda_kernel import CUDA_FOLDER | ||
|
||
assert ( | ||
codebooks.dtype == torch.float16 | ||
), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet" | ||
return torch.ops.aqlm_cuda_kernel.code1x16_matmat | ||
elif ( | ||
optimize_for_training, | ||
codebooks.device.type, | ||
num_codebooks, | ||
codebook_size, | ||
out_group_size, | ||
in_group_size, | ||
) == (False, "cuda", 2, 256, 1, 8): | ||
from .cuda_kernel import CUDA_FOLDER | ||
|
||
assert ( | ||
codebooks.dtype == torch.float16 | ||
), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet" | ||
return torch.ops.aqlm_cuda_kernel.code2x8_matmat | ||
elif (optimize_for_training, codebooks.device.type, out_group_size) == (False, "cuda", 1): | ||
from .triton_kernel import triton_matmul | ||
|
||
return triton_matmul | ||
elif (optimize_for_training, codebooks.device.type, codebook_size, out_group_size) == (False, "cpu", 256, 1): | ||
from .numba_kernel import numba_gemm_lut | ||
|
||
return numba_gemm_lut | ||
else: | ||
from .dequantization import dequantize_gemm | ||
|
||
return dequantize_gemm | ||
|
||
|
||
def get_backward_pass_kernel( | ||
codebooks: torch.Tensor, | ||
optimize_for_training: bool, | ||
) -> torch.Tensor: | ||
forward_pass_kernel = get_forward_pass_kernel( | ||
codebooks=codebooks.transpose(2, 3), optimize_for_training=optimize_for_training | ||
) | ||
|
||
def _backward_pass_kernel( | ||
grad_output: torch.Tensor, # [..., in_features] | ||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] | ||
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] | ||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
return forward_pass_kernel( | ||
grad_output.contiguous(), | ||
codes.transpose(0, 1).contiguous(), | ||
codebooks.transpose(2, 3).contiguous(), | ||
scales.transpose(0, 1).transpose(2, 3).contiguous(), | ||
None, | ||
) | ||
|
||
return _backward_pass_kernel |
Oops, something went wrong.