Skip to content

Commit

Permalink
Add dtype <-> SparseType conversion util function (pytorch#1057)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1057

As title

Reviewed By: geyyer

Differential Revision: D35532366

fbshipit-source-id: 73891dd0eadcb0c79d6d0a06d7e0da911bd2519a
  • Loading branch information
jianyuh authored and facebook-github-bot committed Apr 26, 2022
1 parent aa1eefd commit 9948726
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
29 changes: 29 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import enum
from typing import Dict

import torch


@enum.unique
class EmbOptimType(enum.Enum):
Expand Down Expand Up @@ -71,6 +73,33 @@ def as_int(self) -> int:
SparseType.BF16.value: 5,
}[self.value]

@staticmethod
def from_dtype(dtype: torch.dtype) -> "SparseType":
if dtype == torch.float32:
return SparseType("fp32")
elif dtype == torch.float16:
return SparseType("fp16")
elif dtype == torch.int8 or dtype == torch.uint8:
return SparseType("int8")
elif dtype == torch.quint4x2:
return SparseType("int4")
elif dtype == torch.quint2x4:
return SparseType("int2")
elif dtype == torch.bfloat16:
return SparseType("bf16")
else:
raise ValueError(f"Unsupported sparse dtype: {dtype}")

def as_dtype(self) -> torch.dtype:
return {
SparseType.FP32.value: torch.float32,
SparseType.FP16.value: torch.float16,
SparseType.INT8.value: torch.uint8,
SparseType.INT4.value: torch.quint4x2,
SparseType.INT2.value: torch.quint2x4,
SparseType.BF16.value: torch.bfloat16,
}[self.value]

def bit_rate(self) -> int:
return {
SparseType.FP32.value: 32,
Expand Down
10 changes: 3 additions & 7 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,13 @@ def __init__( # noqa C901
cacheable=True,
precision=weights_precision,
)
table_embedding_dtype = torch.float32
if weights_precision == SparseType.FP16:
table_embedding_dtype = torch.float16
elif weights_precision == SparseType.INT8:
table_embedding_dtype = torch.uint8
table_embedding_dtype = weights_precision.as_dtype()

self._apply_split(
weight_split,
prefix="weights",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param but
# got `Type[typing.Union[torch.float16, torch.float32, torch.uint8]]`.
# pyre-fixme[6]: For 3rd param expected `Type[Type[_dtype]]` but got
# `Type[_dtype]`.
dtype=table_embedding_dtype,
enforce_hbm=enforce_hbm,
)
Expand Down
11 changes: 11 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/embedding_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ at::ScalarType getScalarType(SparseType dtype) {
return at::kByte;
case SparseType::BF16:
return at::kBFloat16;
case SparseType::INT4:
return at::kQUInt4x2;
case SparseType::INT2:
return at::kQUInt2x4;
default:
return at::ScalarType::Undefined;
}
Expand All @@ -60,9 +64,16 @@ SparseType getSparseType(at::ScalarType dtype) {
case at::kHalf:
return SparseType::FP16;
case at::kByte:
case at::kChar:
case at::kQUInt8:
case at::kQInt8:
return SparseType::INT8;
case at::kBFloat16:
return SparseType::BF16;
case at::kQUInt4x2:
return SparseType::INT4;
case at::kQUInt2x4:
return SparseType::INT2;
default:
return SparseType::INVALID;
}
Expand Down

0 comments on commit 9948726

Please sign in to comment.