Skip to content

Commit

Permalink
Optimize for training wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackSamorez committed Feb 18, 2024
1 parent 6ba7711 commit 672a665
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
1 change: 1 addition & 0 deletions inference_lib/src/aqlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import aqlm.inference_kernels
from aqlm.inference import QuantizedLinear
from aqlm.inference_kernels import optimize_for_training
6 changes: 4 additions & 2 deletions inference_lib/src/aqlm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn

import aqlm
from aqlm.inference_kernels import get_backward_pass_kernel, get_forward_pass_kernel
from aqlm.utils import get_int_dtype

Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
self.register_parameter("bias", None)

# MATMUL_OP
self.optimize_for_training: bool = aqlm.inference_kernels.kernel_selector._OPTIMIZE_FOR_TRAINING
self.matmul_op = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
Expand All @@ -77,8 +79,8 @@ def prepare_matmul_op(self, input: torch.Tensor):
):
self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous() # TODO: fix this thing

forward_pass_kernel = get_forward_pass_kernel(self.codebooks)
backward_pass_kernel = get_backward_pass_kernel(self.codebooks)
forward_pass_kernel = get_forward_pass_kernel(self.codebooks, self.optimize_for_training)
backward_pass_kernel = get_backward_pass_kernel(self.codebooks, self.optimize_for_training)

class _QuantizedMatmul(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion inference_lib/src/aqlm/inference_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .kernel_selector import get_backward_pass_kernel, get_forward_pass_kernel
from .kernel_selector import get_backward_pass_kernel, get_forward_pass_kernel, optimize_for_training
29 changes: 22 additions & 7 deletions inference_lib/src/aqlm/inference_kernels/kernel_selector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import Callable, Optional

import torch
Expand All @@ -6,44 +7,58 @@

from aqlm.utils import _dequantize_weight, unpack_int_data

_OPTIMIZE_FOR_TRAINING = False


@contextmanager
def optimize_for_training():
global _OPTIMIZE_FOR_TRAINING
_OPTIMIZE_FOR_TRAINING = True
try:
yield
finally:
_OPTIMIZE_FOR_TRAINING = False


def get_forward_pass_kernel(
codebooks: torch.Tensor,
optimize_for_training: bool,
) -> Callable[[torch.Tensor, torch.IntTensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
match (codebooks.device.type, num_codebooks, codebook_size, out_group_size, in_group_size):
case ("cuda", 1, 65536, 1, 8):
match (optimize_for_training, codebooks.device.type, num_codebooks, codebook_size, out_group_size, in_group_size):
case (False, "cuda", 1, 65536, 1, 8):
from .cuda_kernel import CUDA_FOLDER

assert (
codebooks.dtype == torch.float16
), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet"
return torch.ops.aqlm_cuda_kernel.code1x16_matmat
case ("cuda", 2, 256, 1, 8):
case (False, "cuda", 2, 256, 1, 8):
from .cuda_kernel import CUDA_FOLDER

assert (
codebooks.dtype == torch.float16
), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet"
return torch.ops.aqlm_cuda_kernel.code2x8_matmat
case ("cuda", _, _, 1, _):
case (False, "cuda", _, _, 1, _):
from .triton_kernel import triton_matmul

return triton_matmul
case ("cpu", _, 256, 1, _):
case (False, "cpu", _, 256, 1, _):
from .numba_kernel import numba_gemm_lut

return numba_gemm_lut
case _:
case (True, *_):
from .dequantization import dequantize_gemm

return dequantize_gemm


def get_backward_pass_kernel(
codebooks: torch.Tensor,
optimize_for_training: bool,
) -> torch.Tensor:
forward_pass_kernel = get_forward_pass_kernel(codebooks=codebooks)
forward_pass_kernel = get_forward_pass_kernel(codebooks=codebooks, optimize_for_training=optimize_for_training)

def _backward_pass_kernel(
grad_output: torch.Tensor, # [..., in_features]
Expand Down

0 comments on commit 672a665

Please sign in to comment.