From 672a66508ca04f7c3802583448b45cc10ae022f1 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Sun, 18 Feb 2024 12:52:28 +0100 Subject: [PATCH] Optimize for training wrapper --- inference_lib/src/aqlm/__init__.py | 1 + inference_lib/src/aqlm/inference.py | 6 ++-- .../src/aqlm/inference_kernels/__init__.py | 2 +- .../aqlm/inference_kernels/kernel_selector.py | 29 ++++++++++++++----- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/inference_lib/src/aqlm/__init__.py b/inference_lib/src/aqlm/__init__.py index cc7b33b6..e4f222b4 100644 --- a/inference_lib/src/aqlm/__init__.py +++ b/inference_lib/src/aqlm/__init__.py @@ -1,2 +1,3 @@ import aqlm.inference_kernels from aqlm.inference import QuantizedLinear +from aqlm.inference_kernels import optimize_for_training diff --git a/inference_lib/src/aqlm/inference.py b/inference_lib/src/aqlm/inference.py index 6ac4cd7f..202e349c 100644 --- a/inference_lib/src/aqlm/inference.py +++ b/inference_lib/src/aqlm/inference.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +import aqlm from aqlm.inference_kernels import get_backward_pass_kernel, get_forward_pass_kernel from aqlm.utils import get_int_dtype @@ -61,6 +62,7 @@ def __init__( self.register_parameter("bias", None) # MATMUL_OP + self.optimize_for_training: bool = aqlm.inference_kernels.kernel_selector._OPTIMIZE_FOR_TRAINING self.matmul_op = None def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -77,8 +79,8 @@ def prepare_matmul_op(self, input: torch.Tensor): ): self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous() # TODO: fix this thing - forward_pass_kernel = get_forward_pass_kernel(self.codebooks) - backward_pass_kernel = get_backward_pass_kernel(self.codebooks) + forward_pass_kernel = get_forward_pass_kernel(self.codebooks, self.optimize_for_training) + backward_pass_kernel = get_backward_pass_kernel(self.codebooks, self.optimize_for_training) class _QuantizedMatmul(torch.autograd.Function): @staticmethod diff --git a/inference_lib/src/aqlm/inference_kernels/__init__.py b/inference_lib/src/aqlm/inference_kernels/__init__.py index 76e3a64e..348bed9c 100644 --- a/inference_lib/src/aqlm/inference_kernels/__init__.py +++ b/inference_lib/src/aqlm/inference_kernels/__init__.py @@ -1 +1 @@ -from .kernel_selector import get_backward_pass_kernel, get_forward_pass_kernel +from .kernel_selector import get_backward_pass_kernel, get_forward_pass_kernel, optimize_for_training diff --git a/inference_lib/src/aqlm/inference_kernels/kernel_selector.py b/inference_lib/src/aqlm/inference_kernels/kernel_selector.py index e726c07e..2994aa3a 100644 --- a/inference_lib/src/aqlm/inference_kernels/kernel_selector.py +++ b/inference_lib/src/aqlm/inference_kernels/kernel_selector.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Callable, Optional import torch @@ -6,35 +7,48 @@ from aqlm.utils import _dequantize_weight, unpack_int_data +_OPTIMIZE_FOR_TRAINING = False + + +@contextmanager +def optimize_for_training(): + global _OPTIMIZE_FOR_TRAINING + _OPTIMIZE_FOR_TRAINING = True + try: + yield + finally: + _OPTIMIZE_FOR_TRAINING = False + def get_forward_pass_kernel( codebooks: torch.Tensor, + optimize_for_training: bool, ) -> Callable[[torch.Tensor, torch.IntTensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape - match (codebooks.device.type, num_codebooks, codebook_size, out_group_size, in_group_size): - case ("cuda", 1, 65536, 1, 8): + match (optimize_for_training, codebooks.device.type, num_codebooks, codebook_size, out_group_size, in_group_size): + case (False, "cuda", 1, 65536, 1, 8): from .cuda_kernel import CUDA_FOLDER assert ( codebooks.dtype == torch.float16 ), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet" return torch.ops.aqlm_cuda_kernel.code1x16_matmat - case ("cuda", 2, 256, 1, 8): + case (False, "cuda", 2, 256, 1, 8): from .cuda_kernel import CUDA_FOLDER assert ( codebooks.dtype == torch.float16 ), f"please load the model with `torch_dtype=torch.float16`, as {codebooks.dtype} is not supported on GPU yet" return torch.ops.aqlm_cuda_kernel.code2x8_matmat - case ("cuda", _, _, 1, _): + case (False, "cuda", _, _, 1, _): from .triton_kernel import triton_matmul return triton_matmul - case ("cpu", _, 256, 1, _): + case (False, "cpu", _, 256, 1, _): from .numba_kernel import numba_gemm_lut return numba_gemm_lut - case _: + case (True, *_): from .dequantization import dequantize_gemm return dequantize_gemm @@ -42,8 +56,9 @@ def get_forward_pass_kernel( def get_backward_pass_kernel( codebooks: torch.Tensor, + optimize_for_training: bool, ) -> torch.Tensor: - forward_pass_kernel = get_forward_pass_kernel(codebooks=codebooks) + forward_pass_kernel = get_forward_pass_kernel(codebooks=codebooks, optimize_for_training=optimize_for_training) def _backward_pass_kernel( grad_output: torch.Tensor, # [..., in_features]