Skip to content

Commit

Permalink
Fix a bug in the implementation of dequantization for inference (deep…
Browse files Browse the repository at this point in the history
…speedai#3433)

* bugfix in launch_dequantize()

Get rid of `hid_cnt` and simply set #blocks to output size / #groups

* add a unit test for dequantization

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
  • Loading branch information
3 people authored Sep 14, 2023
1 parent 0c83436 commit 9bf7778
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 5 deletions.
8 changes: 4 additions & 4 deletions csrc/transformer/inference/csrc/dequantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,16 @@ void launch_dequantize(T* output,
{
unsigned threads = 1024;
hidden_dim /= 4;
unsigned hid_cnt = threads / hidden_dim;
unsigned thd_cnt = (hidden_dim - 1) / threads + 1;
hid_cnt = hid_cnt > 0 ? hid_cnt : 1;

unsigned blocks = (output_size + hid_cnt * groups - 1) / (hid_cnt * groups);
assert(output_size % groups == 0);
unsigned blocks = output_size / groups;

dim3 block_dims(threads);
dim3 grid_dims(groups, blocks);

dequantize_kernel<<<grid_dims, block_dims, 0, stream>>>(
output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt);
output, input, qscale, hidden_dim, hidden_dim, thd_cnt);
}

#define INSTANTIATE_DEQUANTIZE_NO_MERGE(T) \
Expand Down
26 changes: 25 additions & 1 deletion csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,27 @@ void ds_release_workspace() { InferenceContext::Instance().release_workspace();

bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); }

template <typename T>
at::Tensor ds_dequantize(at::Tensor& weight, at::Tensor& qscale, int groups)
{
auto options = at::TensorOptions()
.dtype(torch::kFloat16)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto weight16 = at::empty({weight.size(0), weight.size(1)}, options);

launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(0),
weight.size(1),
groups,
InferenceContext::Instance().GetCurrentStream());

return weight16;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_context_int8",
Expand Down Expand Up @@ -1973,7 +1994,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed residual add with " #_name " (CUDA)"); \
m.def("allocate_workspace_" #_name, \
&allocate_workspace<_dtype>, \
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)")
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \
m.def("dequantize_" #_name, \
&ds_dequantize<_dtype>, \
"DeepSpeed dequantize with " #_name " (CUDA)")

DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/compression/test_dequantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Copyright (c) 2023, 2023, Oracle and/or its affiliates.

import os
import torch
from unit.common import DistributedTest
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.accelerator import get_accelerator


class TestDequantization(DistributedTest):

def init(self):
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(get_accelerator().device_name(local_rank))

self.dequantize_func = InferenceBuilder().load().dequantize_fp16

def run_dequantize_test(self, M, N, num_groups):
weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device)
scale = torch.rand(num_groups, 1).to(device=self.device)

weight_deq = (weight.reshape(num_groups, -1) * scale).reshape(M, N).to(torch.float16).contiguous()
weight_deq_backend = self.dequantize_func(weight, scale, num_groups)

assert torch.allclose(weight_deq, weight_deq_backend)

def test_dequantize(self):
self.init()

self.run_dequantize_test(14336, 7168, 32)
self.run_dequantize_test(14336, 1792, 32)
self.run_dequantize_test(768, 768, 32)
self.run_dequantize_test(768, 768, 48)

0 comments on commit 9bf7778

Please sign in to comment.