Skip to content

Commit

Permalink
int8 table batched embedding bag on cpu (pytorch#603)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#603

GPU table batched embedding bag already supports int8 and this diff matches CPU to support it as well.
This requires moving a few quantization ops to fbgemm_gpu

Reviewed By: jianyuh

Differential Revision: D28141945

fbshipit-source-id: 15bad9ebdc3b4bf762d0f543fa9823c8255bdf6c
  • Loading branch information
jspark1105 authored and facebook-github-bot committed May 11, 2021
1 parent bcba0cf commit dc66331
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 47 deletions.
59 changes: 21 additions & 38 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,6 @@

using namespace at;

namespace internal {
// A helper trait to handle that fbgemm doesn't support double precision
template <typename T>
struct double2float {
using type = T;
};

template <>
struct double2float<double> {
using type = float;
};

template <typename T>
struct half2float16 {
using type = T;
};

template <>
struct half2float16<at::Half> {
using type = fbgemm::float16;
};

} // namespace internal

namespace {
void report_error_(
int t,
Expand Down Expand Up @@ -108,7 +84,8 @@ void split_embedding_forward_cpu_kernel(
auto output_stride = output.size(1);

constexpr bool use_fbgemm = std::is_same<weights_t, float>::value ||
std::is_same<weights_t, at::Half>::value;
std::is_same<weights_t, at::Half>::value ||
std::is_same<weights_t, uint8_t>::value;

at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
for (int t = 0; t < T; ++t) {
Expand All @@ -125,8 +102,10 @@ void split_embedding_forward_cpu_kernel(

bool success = true;
if (use_fbgemm) {
using fbgemm_weight_t =
typename ::internal::half2float16<weights_t>::type;
using fbgemm_weight_t = typename std::conditional<
std::is_same<weights_t, at::Half>::value,
fbgemm::float16,
weights_t>::type;
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
fbgemm_weight_t,
/*IndexType=*/int64_t,
Expand All @@ -149,12 +128,10 @@ void split_embedding_forward_cpu_kernel(
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
indice_weights.defined()
? reinterpret_cast<const typename ::internal::double2float<
ind_weights_t>::type*>(
? reinterpret_cast<const float*>(
indice_weights_data + *offsets_begin_ptr)
: nullptr,
reinterpret_cast<
typename ::internal::double2float<output_t>::type*>(
reinterpret_cast<float*>(
output_data + b_begin * output_stride + D_begin));
} else {
output_t output_buf[D];
Expand Down Expand Up @@ -220,7 +197,8 @@ Tensor split_embedding_codegen_forward_cpu(
TORCH_CHECK(B >= 0);

Tensor output;
if (weights.scalar_type() == at::kHalf) {
if (weights.scalar_type() == at::kHalf ||
weights.scalar_type() == ScalarType::Byte) {
output = empty({B, total_D}, weights.options().dtype(at::kFloat));
} else {
output = empty({B, total_D}, weights.options());
Expand All @@ -229,12 +207,17 @@ Tensor split_embedding_codegen_forward_cpu(
// It is assumed that the indice_weights will always be float
TORCH_CHECK(
!indice_weights.defined() || indice_weights.scalar_type() != at::kHalf);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weights.scalar_type(), "split_embedding_cpu_forward", [&]() {
split_embedding_forward_cpu_kernel<
scalar_t,
acc_type<scalar_t, true>,
acc_type<scalar_t, true>>(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::Byte,
weights.scalar_type(),
"split_embedding_cpu_forward",
[&]() {
using output_t = std::conditional<
std::is_same<scalar_t, double>::value,
double,
float>::type;
split_embedding_forward_cpu_kernel<scalar_t, output_t, output_t>(
weights,
weights_offsets,
D_offsets,
Expand Down
22 changes: 15 additions & 7 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,15 +870,23 @@ def _apply_split(
torch.empty(0, device=self.current_device, dtype=dtype),
)
if split.host_size > 0:
setattr(
self,
f"{prefix}_host",
nn.Parameter(
if dtype == torch.uint8:
self.register_buffer(
f"{prefix}_host",
torch.zeros(
split.host_size, device=self.current_device, dtype=dtype
)
),
)
),
)
else:
setattr(
self,
f"{prefix}_host",
nn.Parameter(
torch.zeros(
split.host_size, device=self.current_device, dtype=dtype
)
),
)
else:
self.register_buffer(
f"{prefix}_host",
Expand Down
20 changes: 18 additions & 2 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
MAX_EXAMPLES = 40
Deviceable = TypeVar("Deviceable", torch.nn.EmbeddingBag, Tensor)

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")


def div_round_up(a: int, b: int) -> int:
return int((a + b - 1) // b) * b
Expand Down Expand Up @@ -139,7 +142,9 @@ class SplitTableBatchedEmbeddingsTest(unittest.TestCase):
B=st.integers(min_value=1, max_value=128),
log_E=st.integers(min_value=3, max_value=5),
L=st.integers(min_value=0, max_value=20),
weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]),
weights_precision=st.sampled_from(
[SparseType.INT8, SparseType.FP16, SparseType.FP32]
),
weighted=st.booleans(),
mixed=st.booleans(),
use_cache=st.booleans(),
Expand Down Expand Up @@ -250,6 +255,13 @@ def test_forward(
to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
for (E, D) in zip(Es, Ds)
]
if weights_precision == SparseType.INT8:
for t in range(T):
bs[t].weight.data.copy_(
torch.ops.fb.Fused8BitRowwiseQuantizedToFloat(
torch.ops.fb.FloatToFused8BitRowwiseQuantized(bs[t].weight.data)
)
)

if weights_precision == SparseType.FP16 and not use_cpu:
# NOTE: CPU version of torch.nn.EmbeddingBag doesn't support fp16.
Expand Down Expand Up @@ -306,7 +318,11 @@ def test_forward(
cc = torch.jit.script(cc)

for t in range(T):
cc.split_embedding_weights()[t].data.copy_(bs[t].weight)
cc.split_embedding_weights()[t].data.copy_(
bs[t].weight
if weights_precision != SparseType.INT8
else torch.ops.fb.FloatToFused8BitRowwiseQuantized(bs[t].weight)
)

x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0)
Expand Down

0 comments on commit dc66331

Please sign in to comment.