Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [10/N] (p…
Browse files Browse the repository at this point in the history
…ytorch#3263)

Summary:
Pull Request resolved: pytorch#3263

X-link: facebookresearch/FBGEMM#364

- Add int64_t support for `pruned_hashmap_insert_{{ wdesc }}_cpu` to prevent runtime errors in tests

Reviewed By: spcyppt

Differential Revision: D64705072

fbshipit-source-id: cccc7ea306316e15058f7a31bb481044a4011b00
  • Loading branch information
q10 authored and facebook-github-bot committed Oct 22, 2024
1 parent 2cf3606 commit d32fc6a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,62 +70,66 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
const int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using uidx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<int32_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}
AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu_0", [&] {
using hash_t = index_t;

auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[ht_idx][0] = idx;
hash_table_acc[ht_idx][1] = dense_idx;
break;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu_1", [&] {
using uidx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<hash_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}

// Already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[ht_idx][1] = dense_idx;
break;

auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[ht_idx][0] = static_cast<hash_t>(idx);
hash_table_acc[ht_idx][1] = static_cast<hash_t>(dense_idx);
break;
}

// Already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[ht_idx][1] = static_cast<hash_t>(dense_idx);
break;
}

// Linear probe
slot = (slot + 1) % capacity;
}

// Linear probe
slot = (slot + 1) % capacity;
}
}
}
}
});
});

return;
Expand Down Expand Up @@ -519,14 +523,14 @@ Tensor pruned_array_lookup_cpu(
auto dense_indices = empty_like(indices);

AT_DISPATCH_INDEX_TYPES(index_remappings.scalar_type(), "pruned_array_lookup_cpu_0", [&] {
using hash_t = index_t;
using remap_t = index_t;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu_1", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

const auto index_remappings_acc = index_remappings.data_ptr<hash_t>();
const auto index_remappings_acc = index_remappings.data_ptr<remap_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();

at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
}
}

template <typename index_t, typename hash_t>
template <typename index_t, typename remap_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor32<hash_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<remap_t, 1, at::RestrictPtrTraits>
index_remappings,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
Expand Down Expand Up @@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda(
AT_DISPATCH_INDEX_TYPES(
index_remappings.scalar_type(), "pruned_array_lookup_cuda_0", [&] {
using hash_t = index_t;
using remap_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] {
Expand All @@ -249,7 +249,7 @@ Tensor pruned_array_lookup_cuda(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, index_remappings, hash_t, 1, 32),
func_name, index_remappings, remap_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, index_remappings_offsets, int64_t, 1, 32),
B,
Expand Down

0 comments on commit d32fc6a

Please sign in to comment.