Skip to content

Commit

Permalink
two ops
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackSamorez committed Mar 5, 2024
1 parent a8e936a commit 596580e
Showing 1 changed file with 63 additions and 48 deletions.
111 changes: 63 additions & 48 deletions inference_lib/src/aqlm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,19 @@ def __init__(
else:
self.register_parameter("bias", None)

# MATMUL_OP
self.optimize_for_training: bool = aqlm.inference_kernels.kernel_selector._OPTIMIZE_FOR_TRAINING
self.matmul_op = None
# MATMUL_OPS
self.gemv_op = None
self.gemm_op = None
self.use_gemv_rule = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.matmul_op is None:
if self.gemv_op is None:
self.prepare_matmul_op(input)

return self.matmul_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)

if self.use_gemv_rule(input):
return self.gemv_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)
else:
return self.gemm_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)

def prepare_matmul_op(self, input: torch.Tensor):
if (
Expand All @@ -78,49 +82,60 @@ def prepare_matmul_op(self, input: torch.Tensor):
):
self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous() # TODO: fix this thing

forward_pass_kernel = get_forward_pass_kernel(self.codebooks, self.optimize_for_training)
backward_pass_kernel = get_backward_pass_kernel(self.codebooks, self.optimize_for_training)

class _QuantizedMatmul(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.Any,
input: torch.Tensor,
codes: torch.IntTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
ctx.save_for_backward(
input,
codes,
codebooks,
scales,
bias,
)
return forward_pass_kernel(
input,
self.gemv_op = _get_autograd_matmul_op(
get_forward_pass_kernel(self.codebooks, False),
get_backward_pass_kernel(self.codebooks, False),
)

self.gemm_op = _get_autograd_matmul_op(
get_forward_pass_kernel(self.codebooks, True),
get_backward_pass_kernel(self.codebooks, True),
)

self.use_gemv_rule = lambda input: sum(input.shape[:-1]) < 100


def _get_autograd_matmul_op(forward_pass_kernel, backward_pass_kernel):
class _QuantizedMatmul(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.Any,
input: torch.Tensor,
codes: torch.IntTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
ctx.save_for_backward(
input,
codes,
codebooks,
scales,
bias,
)
return forward_pass_kernel(
input,
codes,
codebooks,
scales,
bias,
)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
input, codes, codebooks, scales, bias = ctx.saved_tensors
return (
backward_pass_kernel(
grad_output,
codes,
codebooks,
scales,
bias,
)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
input, codes, codebooks, scales, bias = ctx.saved_tensors
return (
backward_pass_kernel(
grad_output,
codes,
codebooks,
scales,
bias,
),
None,
None,
None,
None,
)

self.matmul_op = _QuantizedMatmul
),
None,
None,
None,
None,
)

return _QuantizedMatmul

0 comments on commit 596580e

Please sign in to comment.