Skip to content

Commit

Permalink
Partially back out "[fbgemm_gpu] Add support for int64_t indices and …
Browse files Browse the repository at this point in the history
…offsets in TBE inference [7C/N]" (pytorch#3257)

Summary:
Pull Request resolved: pytorch#3257
X-link: facebookresearch/FBGEMM#358

Original commit changeset: 270834722e8b
Original Phabricator Diff: D63778645

The original diff D63778645 contained two features:

(a) extending the remap index kernels from supporting int32-only to supporting both int32 and int64
(b) update the frontend code to construct int64 remap index arrays

Some downstream code has picked up D63778645 and have generated models with int64 remapping indices, which fail 3 downstream unit tests.

Since (b) is the problematic feature that is breaking downstream, only (b) has been reverted in this diff, as (a) is now needed for those unit tests to pass.

Reviewed By: jianyuh

Differential Revision: D64618221

fbshipit-source-id: 6c4838dbfaf301f1204d469d7f4bf7cfe4926b2e
  • Loading branch information
q10 authored and facebook-github-bot committed Oct 19, 2024
1 parent 4ba523c commit f728c94
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<int64_t, 2>();
auto hash_table_acc = hash_table.accessor<int32_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,22 +397,21 @@ def max_ty_D(ty: SparseType) -> int:
self.assign_embedding_weights(weight_lists)

# Handle index remapping for embedding pruning.
# All buffers are int64 in order to support both int32 and int64 indices.
self.register_buffer(
"index_remappings_array_offsets",
torch.empty(0, device=self.current_device, dtype=torch.int64),
)
self.register_buffer(
"index_remappings_array",
torch.empty(0, device=self.current_device, dtype=torch.int64),
torch.empty(0, device=self.current_device, dtype=torch.int32),
)
self.register_buffer(
"index_remapping_hash_table_offsets",
torch.empty(0, device=self.current_device, dtype=torch.int64),
)
self.register_buffer(
"index_remapping_hash_table",
torch.empty(0, device=self.current_device, dtype=torch.int64),
torch.empty(0, device=self.current_device, dtype=torch.int32),
)
self.register_buffer(
"original_rows_per_table",
Expand Down Expand Up @@ -1529,11 +1528,11 @@ def set_index_remappings_array(
index_remappings_filter_nones.append(mapping)
if len(index_remappings_filter_nones) == 0:
self.index_remappings_array = torch.empty(
0, dtype=torch.int64, device=self.current_device
0, dtype=torch.int32, device=self.current_device
)
else:
self.index_remappings_array = torch.cat(index_remappings_filter_nones).to(
dtype=torch.int64, device=self.current_device
self.current_device
)

def set_index_remappings(
Expand All @@ -1556,7 +1555,7 @@ def set_index_remappings(
]
hash_table = torch.empty(
(sum(capacities), 2),
dtype=torch.int64,
dtype=torch.int32,
)
hash_table[:, :] = -1
hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long()
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def test_pruning(
# Initialize and insert Hashmap index remapping based data structure
hash_table = torch.empty(
(sum(capacities), 2),
dtype=torch.int64,
dtype=torch.int32,
)
hash_table[:, :] = -1
hash_table_offsets = torch.tensor([0] + np.cumsum(capacities).tolist()).long()
Expand All @@ -486,7 +486,7 @@ def test_pruning(
# Initialize and insert Array index remapping based data structure
index_remappings_array = torch.tensor(
[-1] * original_E * T,
dtype=torch.int64,
dtype=torch.int32,
device=current_device,
)
index_remappings_array_offsets = torch.empty(
Expand All @@ -498,7 +498,7 @@ def test_pruning(
for t in range(T):
indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device)
dense_indice_t = (
(dense_indices.view(T, B, L))[t].long().view(-1).to(current_device)
(dense_indices.view(T, B, L))[t].view(-1).to(current_device)
)
selected_indices = torch.add(indice_t, t * original_E)[:E]
index_remappings_array[selected_indices] = dense_indice_t
Expand Down

0 comments on commit f728c94

Please sign in to comment.