Skip to content

Commit

Permalink
use fbgemm cpu embedding spmdm jit kernel (pytorch#550)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#550

Use FBGEMM JIT'ed kernel except when weight is in double

Reviewed By: jiyuanzFB

Differential Revision: D27039292

fbshipit-source-id: 9fe4ec69ab6c4690c2a512375665647c44c6357e
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 17, 2021
1 parent 4ab1eeb commit a06cb38
Showing 1 changed file with 100 additions and 27 deletions.
127 changes: 100 additions & 27 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,37 @@
* LICENSE file in the root directory of this source tree.
*/
#include "codegen/embedding_forward_split_cpu.h"
#include "fbgemm/FbgemmEmbedding.h"
#include "fbgemm/Types.h"

#include <ATen/AccumulateType.h>

using namespace at;

namespace internal {
// A helper trait to handle that fbgemm doesn't support double precision
template <typename T>
struct double2float {
using type = T;
};

template <>
struct double2float<double> {
using type = float;
};

template <typename T>
struct half2float16 {
using type = T;
};

template <>
struct half2float16<at::Half> {
using type = fbgemm::float16;
};

} // namespace internal

template <typename weights_t, typename ind_weights_t, typename output_t>
void split_embedding_forward_cpu_kernel(
Tensor weights,
Expand All @@ -27,19 +53,30 @@ void split_embedding_forward_cpu_kernel(
int64_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

offsets.contiguous();
indices.contiguous();
weights.contiguous();
if (indice_weights.defined()) {
indice_weights.contiguous();
}

const auto D_offsets_data = D_offsets.accessor<int, 1>();
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.data_ptr<int64_t>();
const auto indices_data = indices.data_ptr<int64_t>();

const auto weights_data = weights.accessor<weights_t, 1>();
const auto weights_data = weights.data_ptr<weights_t>();
// If indice_weights not defined, then this accessor won't be used.
// The else condition is just to make compiler happy
const auto indice_weights_data = indice_weights.defined()
? indice_weights.accessor<ind_weights_t, 1>()
: TensorAccessor<ind_weights_t, 1>(nullptr, nullptr, nullptr);
? indice_weights.data_ptr<ind_weights_t>()
: nullptr;

auto output_data = output.accessor<output_t, 2>();
auto output_data = output.data_ptr<output_t>();
auto output_stride = output.size(1);

constexpr bool use_fbgemm = std::is_same<weights_t, float>::value ||
std::is_same<weights_t, at::Half>::value;

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
Expand All @@ -51,28 +88,64 @@ void split_embedding_forward_cpu_kernel(

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;
for (int b = b_begin; b < b_end; ++b) {
const auto pool_begin = offsets_data[t * B + b];
const auto pool_end = offsets_data[t * B + b + 1];
const auto L = pool_end - pool_begin;
const double scale_factor =
// NOTE: MEAN pooling will not work with indice_weights!
(pooling_mode == MEAN && !indice_weights.defined() && L > 0)
? 1.0 / L
: 1.0;
for (auto p = pool_begin; p < pool_end; ++p) {
const int64_t embedding_begin = table_begin + indices_data[p] * D;
for (int64_t d = 0; d < D; ++d) {
output_data[b][D_begin + d] += scale_factor *
(indice_weights.defined()
? static_cast<output_t>(
weights_data[embedding_begin + d]) *
static_cast<output_t>(indice_weights_data[p])
: static_cast<output_t>(
weights_data[embedding_begin + d]));

if (use_fbgemm) {
using fbgemm_weight_t =
typename ::internal::half2float16<weights_t>::type;
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithOutputStride<
fbgemm_weight_t,
/*IndexType=*/int64_t,
/*OffsetType=*/int64_t>(
D,
indice_weights.defined(),
pooling_mode == MEAN,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
output_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
kernel(
b_end - b_begin,
offsets_data[t * B + b_end] - *offsets_begin_ptr,
// TODO: this ellides array out of bound checking.
// Should pass hash_size_cumsum to do this.
/*data_size=*/std::numeric_limits<int64_t>::max(),
reinterpret_cast<const fbgemm_weight_t*>(
weights_data + table_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
indice_weights.defined()
? reinterpret_cast<const typename ::internal::double2float<
ind_weights_t>::type*>(
indice_weights_data + *offsets_begin_ptr)
: nullptr,
reinterpret_cast<
typename ::internal::double2float<output_t>::type*>(
output_data + b_begin * output_stride + D_begin));
} else {
for (int b = b_begin; b < b_end; ++b) {
const auto pool_begin = offsets_data[t * B + b];
const auto pool_end = offsets_data[t * B + b + 1];
const auto L = pool_end - pool_begin;
const double scale_factor =
// NOTE: MEAN pooling will not work with indice_weights!
(pooling_mode == MEAN && !indice_weights.defined() && L > 0)
? 1.0 / L
: 1.0;
for (auto p = pool_begin; p < pool_end; ++p) {
const int64_t embedding_begin = table_begin + indices_data[p] * D;
for (int64_t d = 0; d < D; ++d) {
output_data[b * output_stride + D_begin + d] += scale_factor *
(indice_weights.defined()
? static_cast<output_t>(
weights_data[embedding_begin + d]) *
static_cast<output_t>(indice_weights_data[p])
: static_cast<output_t>(
weights_data[embedding_begin + d]));
}
}
}
} // for each b
} // for each b
}
} // for each t
}); // parallel for
}
Expand Down

0 comments on commit a06cb38

Please sign in to comment.