forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix a bug in the implementation of dequantization for inference (deep…
…speedai#3433) * bugfix in launch_dequantize() Get rid of `hid_cnt` and simply set #blocks to output size / #groups * add a unit test for dequantization --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Reza Yazdani <[email protected]>
- Loading branch information
1 parent
0c83436
commit 9bf7778
Showing
3 changed files
with
67 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
# Copyright (c) 2023, 2023, Oracle and/or its affiliates. | ||
|
||
import os | ||
import torch | ||
from unit.common import DistributedTest | ||
from deepspeed.ops.op_builder import InferenceBuilder | ||
from deepspeed.accelerator import get_accelerator | ||
|
||
|
||
class TestDequantization(DistributedTest): | ||
|
||
def init(self): | ||
local_rank = int(os.getenv("LOCAL_RANK", "0")) | ||
self.device = torch.device(get_accelerator().device_name(local_rank)) | ||
|
||
self.dequantize_func = InferenceBuilder().load().dequantize_fp16 | ||
|
||
def run_dequantize_test(self, M, N, num_groups): | ||
weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device) | ||
scale = torch.rand(num_groups, 1).to(device=self.device) | ||
|
||
weight_deq = (weight.reshape(num_groups, -1) * scale).reshape(M, N).to(torch.float16).contiguous() | ||
weight_deq_backend = self.dequantize_func(weight, scale, num_groups) | ||
|
||
assert torch.allclose(weight_deq, weight_deq_backend) | ||
|
||
def test_dequantize(self): | ||
self.init() | ||
|
||
self.run_dequantize_test(14336, 7168, 32) | ||
self.run_dequantize_test(14336, 1792, 32) | ||
self.run_dequantize_test(768, 768, 32) | ||
self.run_dequantize_test(768, 768, 48) |