Skip to content

Commit

Permalink
Remove unused pyre-ignore in TBE tests (pytorch#2162)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2162

As title

Reviewed By: q10

Differential Revision: D51600941

fbshipit-source-id: 514d6b52cbf2d25c7ed1a87d12390cdd49e5374e
  • Loading branch information
sryap authored and facebook-github-bot committed Nov 28, 2023
1 parent b40f419 commit 934d881
Showing 1 changed file with 2 additions and 27 deletions.
29 changes: 2 additions & 27 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4370,14 +4370,9 @@ def execute_nbit_forward_( # noqa C901
# Initialize and insert Array index remapping based data structure
index_remappings_array = []
for t in range(T):
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device)
dense_indice_t = (
(dense_indices.view(T, B, L))[t].view(-1)
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
.to(current_device)
(dense_indices.view(T, B, L))[t].view(-1).to(current_device)
)
index_remappings_array_t = torch.tensor(
[-1] * original_E,
Expand Down Expand Up @@ -5143,19 +5138,13 @@ def test_pruning(
index_remappings_array_offsets = torch.empty(
T + 1,
dtype=torch.int64,
# pyre-fixme[6]: For 3rd param expected `Union[None, str, device]` but
# got `Union[int, str]`.
device=current_device,
)
index_remappings_array_offsets[0] = 0
for t in range(T):
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, str]`.
indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device)
dense_indice_t = (
(dense_indices.view(T, B, L))[t].view(-1)
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
.to(current_device)
(dense_indices.view(T, B, L))[t].view(-1).to(current_device)
)
selected_indices = torch.add(indice_t, t * original_E)[:E]
index_remappings_array[selected_indices] = dense_indice_t
Expand All @@ -5174,26 +5163,12 @@ def test_pruning(
index_remappings_array,
index_remappings_array_offsets,
) = (
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
indices.to(current_device),
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
dense_indices.to(current_device),
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
offsets.to(current_device),
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
hash_table.to(current_device),
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
hash_table_offsets.to(current_device),
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
index_remappings_array.to(current_device),
# pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int,
# str]`.
index_remappings_array_offsets.to(current_device),
)

Expand Down

0 comments on commit 934d881

Please sign in to comment.