Skip to content

Commit

Permalink
parallelize cpu emb fwd across B
Browse files Browse the repository at this point in the history
Summary: Parallelize across batch dim for each table, instead of collapsed parallelization of T * B. This improves locality by multiple cores are constructively accessing the same embedding table.

Reviewed By: jianyuh

Differential Revision: D27567054

fbshipit-source-id: cacb6fce5aadd4d56888b9c84f27e65d518b50e7
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Apr 6, 2021
1 parent 9b623bb commit 1dfb0c3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
eps,
// fbgemm follows caffe2 convention of negative learning rate
-learning_rate);
TORCH_CHECK(success); // TODO more friendly error msg
// TODO: more friendly error msg.
// See report_error_ in embedding_forward_split_cpu.cpp
TORCH_CHECK(success);
}
}); // parallel_for
return;
Expand Down
37 changes: 16 additions & 21 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,19 @@ void split_embedding_forward_cpu_kernel(
constexpr bool use_fbgemm = std::is_same<weights_t, float>::value ||
std::is_same<weights_t, at::Half>::value;

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
int t_end = (tb_end + B - 1) / B;
for (int t = t_begin; t < t_end; ++t) {
const auto D_begin = D_offsets_data[t];
const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
const auto table_begin = weights_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;
for (int t = 0; t < T; ++t) {
const auto D_begin = D_offsets_data[t];
const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
const auto table_begin = weights_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
bool success = true;
if (use_fbgemm) {
using fbgemm_weight_t =
Expand Down Expand Up @@ -201,8 +196,8 @@ void split_embedding_forward_cpu_kernel(
report_error_(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
} // !success
} // for each t
}); // parallel for
}); // parallel for
} // for each t
}

Tensor split_embedding_codegen_forward_cpu(
Expand Down Expand Up @@ -273,8 +268,8 @@ void split_embedding_grad_indice_weights_cpu_kernel(
const auto offsets_data = offsets.accessor<int64_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();

auto weights_data = weights.accessor<weights_t, 1>();
auto grad_output_data = grad_output.accessor<grad_t, 2>();
const auto weights_data = weights.accessor<weights_t, 1>();
const auto grad_output_data = grad_output.accessor<grad_t, 2>();
auto grad_indice_weights_data = grad_indice_weights.accessor<grad_t, 1>();
for (int64_t t = 0; t < T; ++t) {
if (feature_requires_grad.defined() &&
Expand Down

0 comments on commit 1dfb0c3

Please sign in to comment.