Skip to content

Commit

Permalink
Merge pull request Vahe1994#26 from Vahe1994/backward
Browse files Browse the repository at this point in the history
Integrating with `autograd`, adding static kernel routing.
  • Loading branch information
BlackSamorez authored Feb 20, 2024
2 parents 163c748 + abbb0d9 commit b0683b2
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 129 deletions.
6 changes: 4 additions & 2 deletions inference_lib/setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = aqlm
version = 1.0.0
version = 1.0.2
author = AQLM paper authors
author_email = [email protected]
description = Efficiently run models quantized with AQLM
Expand All @@ -15,6 +15,8 @@ classifiers =
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Programming Language :: Python :: 3
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering
Expand All @@ -29,7 +31,7 @@ package_dir =
= src
packages = find:
include_package_data = True
python_requires = >=3.10
python_requires = >=3.8
install_requires =
torch>=2.1.1
transformers>=4.37.0
Expand Down
1 change: 1 addition & 0 deletions inference_lib/src/aqlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import aqlm.inference_kernels
from aqlm.inference import QuantizedLinear
from aqlm.inference_kernels import optimize_for_training
73 changes: 67 additions & 6 deletions inference_lib/src/aqlm/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
""" Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search"""
from typing import Optional

import aqlm
import torch
import torch.nn as nn
from aqlm.inference_kernels import forward_pass_quantized_linear
from aqlm.inference_kernels import get_backward_pass_kernel, get_forward_pass_kernel
from aqlm.utils import get_int_dtype


Expand Down Expand Up @@ -35,31 +38,89 @@ def __init__(
# CODES & CODEBOOKS
self.codebooks = nn.Parameter(
torch.empty((num_codebooks, self.codebook_size, out_group_size, in_group_size), **factory_kwargs),
requires_grad=True,
requires_grad=False,
) # [num_codebooks, codebook_size, out_group_size, in_group_size]
self.codes = nn.Parameter(
torch.empty(
(num_out_groups, num_in_groups, num_codebooks), device=device, dtype=get_int_dtype(nbits_per_codebook)
(num_out_groups, num_in_groups, num_codebooks),
device=device,
dtype=get_int_dtype(nbits_per_codebook),
),
requires_grad=False,
) # [num_out_groups, num_in_groups, num_codebooks]

# SCALES
self.scales = nn.Parameter(
torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True
torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=False
) # [num_out_groups, 1, 1, 1]

# BIAS
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs), requires_grad=False)
else:
self.register_parameter("bias", None)

# MATMUL_OP
self.optimize_for_training: bool = aqlm.inference_kernels.kernel_selector._OPTIMIZE_FOR_TRAINING
self.matmul_op = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.matmul_op is None:
self.prepare_matmul_op(input)

return self.matmul_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)

def prepare_matmul_op(self, input: torch.Tensor):
if (
not input.is_cuda
and self.codebook_size == 256
and self.codes.shape[0] == self.out_features // self.out_group_size
):
self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous() # TODO: fix this thing
return forward_pass_quantized_linear(input, self.codes, self.codebooks, self.scales, self.bias)

forward_pass_kernel = get_forward_pass_kernel(self.codebooks, self.optimize_for_training)
backward_pass_kernel = get_backward_pass_kernel(self.codebooks, self.optimize_for_training)

class _QuantizedMatmul(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.Any,
input: torch.Tensor,
codes: torch.IntTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
ctx.save_for_backward(
input,
codes,
codebooks,
scales,
bias,
)
return forward_pass_kernel(
input,
codes,
codebooks,
scales,
bias,
)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
input, codes, codebooks, scales, bias = ctx.saved_tensors
return (
backward_pass_kernel(
grad_output,
codes,
codebooks,
scales,
bias,
),
None,
None,
None,
None,
)

self.matmul_op = _QuantizedMatmul
2 changes: 1 addition & 1 deletion inference_lib/src/aqlm/inference_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .kernel_selector import forward_pass_quantized_linear
from .kernel_selector import get_backward_pass_kernel, get_forward_pass_kernel, optimize_for_training
31 changes: 22 additions & 9 deletions inference_lib/src/aqlm/inference_kernels/cuda_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>

void code1x16_matvec_cuda(
const void* A,
Expand All @@ -25,6 +26,7 @@ void code1x16_matvec(
torch::Tensor& C,
const torch::Tensor& codebook
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code1x16_matvec_cuda(
Expand All @@ -41,7 +43,8 @@ torch::Tensor code1x16_matmat(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales
const torch::Tensor& scales,
const std::optional<torch::Tensor>& bias
) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
Expand All @@ -63,10 +66,14 @@ torch::Tensor code1x16_matmat(
);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}

auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.view(output_sizes);
auto output = flat_output.reshape(output_sizes).clone();
return output;
}

Expand All @@ -76,6 +83,7 @@ void code2x8_matvec(
torch::Tensor& C,
const torch::Tensor& codebook
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code2x8_matvec_cuda(
Expand All @@ -92,7 +100,8 @@ torch::Tensor code2x8_matmat(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales
const torch::Tensor& scales,
const std::optional<torch::Tensor>& bias
) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
Expand All @@ -114,16 +123,20 @@ torch::Tensor code2x8_matmat(
);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}

auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.view(output_sizes);
auto output = flat_output.reshape(output_sizes).clone();
return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("code1x16_matvec", &code1x16_matvec, "1x16 (2bit) codebook matrix-vector product.");
m.def("code1x16_matmat", &code1x16_matmat, "1x16 (2bit) codebook matrix-matrix product.");
m.def("code2x8_matvec", &code2x8_matvec, "2x8 (2bit) codebook matrix-vector product.");
m.def("code2x8_matmat", &code2x8_matmat, "2x8 (2bit) codebook matrix-matrix product.");
TORCH_LIBRARY(aqlm_cuda_kernel, m) {
m.def("code1x16_matvec", code1x16_matvec);
m.def("code1x16_matmat", code1x16_matmat);
m.def("code2x8_matvec", code2x8_matvec);
m.def("code2x8_matmat", code2x8_matmat);
}
3 changes: 2 additions & 1 deletion inference_lib/src/aqlm/inference_kernels/cuda_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch.utils.cpp_extension import load

CUDA_FOLDER = os.path.dirname(os.path.abspath(__file__))
CUDA_KERNEL = load(
load(
name="codebook_cuda",
sources=[os.path.join(CUDA_FOLDER, "cuda_kernel.cpp"), os.path.join(CUDA_FOLDER, "cuda_kernel.cu")],
is_python_module=False,
)
21 changes: 21 additions & 0 deletions inference_lib/src/aqlm/inference_kernels/dequantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Optional

import torch
import torch.nn.functional as F
from aqlm.utils import _dequantize_weight, unpack_int_data
from torch import nn


def dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
122 changes: 85 additions & 37 deletions inference_lib/src/aqlm/inference_kernels/kernel_selector.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,94 @@
from typing import Optional
from contextlib import contextmanager
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from aqlm.utils import _dequantize_weight, unpack_int_data

_OPTIMIZE_FOR_TRAINING = False

def forward_pass_quantized_linear(
input: torch.Tensor,
codes: torch.IntTensor,

@contextmanager
def optimize_for_training():
"""Use this context manager during model initialization (e.g. `.from_pretrained(...)`) to select inference kernels optimized for larger batch sizes"""
global _OPTIMIZE_FOR_TRAINING
_OPTIMIZE_FOR_TRAINING = True
try:
yield
finally:
_OPTIMIZE_FOR_TRAINING = False


def get_forward_pass_kernel(
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
optimize_for_training: bool,
) -> Callable[[torch.Tensor, torch.IntTensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
match (input.is_cuda, num_codebooks, codebook_size, out_group_size, in_group_size):
case (True, 1, 65536, 1, 8):
from .cuda_kernel import CUDA_KERNEL

assert (
input.dtype == torch.float16
), f"please load the model with `torch_dtype=torch.float16`, as {input.dtype} is not supported on GPU yet"
return CUDA_KERNEL.code1x16_matmat(input, codes, codebooks, scales) + (bias if bias is not None else 0)
case (True, 2, 256, 1, 8):
from .cuda_kernel import CUDA_KERNEL

assert (
input.dtype == torch.float16
), f"please load the model with `torch_dtype=torch.float16`, as {input.dtype} is not supported on GPU yet"
return CUDA_KERNEL.code2x8_matmat(input, codes, codebooks, scales) + (bias if bias is not None else 0)
case (True, _, _, _, _):
from .triton_kernel import triton_matmul

return triton_matmul(input, codes, codebooks, scales, bias)
case (False, _, 256, 1, _):
from .numba_kernel import numba_gemm_lut

return numba_gemm_lut(input, codes, codebooks, scales, bias)
case _:
dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)

if (optimize_for_training, codebooks.device.type, num_codebooks, codebook_size, out_group_size, in_group_size) == (
False,
"cuda",
1,
65536,
1,
8,
):
from .cuda_kernel import CUDA_FOLDER

assert (
codebooks.dtype == torch.float16
), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet"
return torch.ops.aqlm_cuda_kernel.code1x16_matmat
elif (
optimize_for_training,
codebooks.device.type,
num_codebooks,
codebook_size,
out_group_size,
in_group_size,
) == (False, "cuda", 2, 256, 1, 8):
from .cuda_kernel import CUDA_FOLDER

assert (
codebooks.dtype == torch.float16
), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet"
return torch.ops.aqlm_cuda_kernel.code2x8_matmat
elif (optimize_for_training, codebooks.device.type, out_group_size) == (False, "cuda", 1):
from .triton_kernel import triton_matmul

return triton_matmul
elif (optimize_for_training, codebooks.device.type, codebook_size, out_group_size) == (False, "cpu", 256, 1):
from .numba_kernel import numba_gemm_lut

return numba_gemm_lut
else:
from .dequantization import dequantize_gemm

return dequantize_gemm


def get_backward_pass_kernel(
codebooks: torch.Tensor,
optimize_for_training: bool,
) -> torch.Tensor:
forward_pass_kernel = get_forward_pass_kernel(
codebooks=codebooks.transpose(2, 3), optimize_for_training=optimize_for_training
)

def _backward_pass_kernel(
grad_output: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return forward_pass_kernel(
grad_output.contiguous(),
codes.transpose(0, 1).contiguous(),
codebooks.transpose(2, 3).contiguous(),
scales.transpose(0, 1).transpose(2, 3).contiguous(),
None,
)

return _backward_pass_kernel
Loading

0 comments on commit b0683b2

Please sign in to comment.