Skip to content

Commit

Permalink
Add unit test coverage for array based pruning method for TBE (pytorc…
Browse files Browse the repository at this point in the history
…h#1016)

Summary:
Pull Request resolved: pytorch#1016

We are missing the unit test coverage for the array based pruning method. This Diffs add some unit test coverages.

Reviewed By: jasonjk-park, houseroad

Differential Revision: D35170062

fbshipit-source-id: e5fc82e66f1d6ff4312bc9e80e01774c75075fe9
  • Loading branch information
jianyuh authored and facebook-github-bot committed Mar 29, 2022
1 parent bde0c5f commit fe32509
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 38 deletions.
4 changes: 2 additions & 2 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2036,12 +2036,12 @@ def pruned_array( # noqa C901
)
index_remappings_offsets = torch.empty(T + 1, dtype=torch.int32, device="cuda")
index_remappings_offsets[0] = 0
dense_indicies = torch.tensor(range(E), dtype=torch.int32, device="cuda")
dense_indices = torch.tensor(range(E), dtype=torch.int32, device="cuda")
for t in range(T):
selected_indices = torch.add(
torch.randperm(original_E, device="cuda"), t * original_E
)[:E]
index_remappings[selected_indices] = dense_indicies
index_remappings[selected_indices] = dense_indices
index_remappings_offsets[t + 1] = index_remappings_offsets[t] + original_E

requests = generate_requests(
Expand Down
127 changes: 91 additions & 36 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3431,11 +3431,12 @@ def test_nbit_forward_uvm_cache(
torch.testing.assert_close(output, output_ref, equal_nan=True)

@given(
T=st.integers(min_value=1, max_value=10),
B=st.integers(min_value=1, max_value=64),
L=st.integers(min_value=0, max_value=64),
T=st.integers(min_value=1, max_value=5),
B=st.integers(min_value=1, max_value=8),
L=st.integers(min_value=0, max_value=8),
use_cpu=st.booleans() if gpu_available else st.just(True),
use_cpu_hashtable=st.booleans(),
use_array_for_index_remapping=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_pruning(
Expand All @@ -3445,38 +3446,42 @@ def test_pruning(
L: int,
use_cpu: bool,
use_cpu_hashtable: bool,
use_array_for_index_remapping: bool,
) -> None:
E = int(1000)
LOAD_FACTOR = 0.8
pruning_ratio = 0.5

capacities = [int(B * L / LOAD_FACTOR) + 1 for _ in range(T)]
original_E = int(E / (1.0 - pruning_ratio))

# Enforce the size of original_E/B/L to get the unique indices
assume(original_E > B * L)

current_device = "cpu" if use_cpu else torch.cuda.current_device()

if use_cpu_hashtable:
assume(use_cpu)
indices = torch.randint(low=0, high=np.iinfo(np.int32).max - 1, size=(T, B, L))

indices = torch.randint(low=0, high=original_E, size=(T, B, L))
for t in range(T):
while (
torch.unique(
indices[t], return_counts=False, return_inverse=False
).numel()
!= indices[t].numel()
):
indices[t] = torch.randint(
low=0, high=np.iinfo(np.int32).max, size=(B, L)
)
indices[t] = torch.randint(low=0, high=original_E, size=(B, L))

indices = indices.view(-1).int()
dense_indices = (
torch.randint(low=0, high=int(1e5), size=(T, B, L)).view(-1).int()
)
dense_indices = torch.randint(low=0, high=E, size=(T, B, L)).view(-1).int()
offsets = torch.tensor([L * b_t for b_t in range(B * T + 1)]).int()

def next_power_of_2(x: int) -> int:
return 1 if x == 0 else 2 ** (x - 1).bit_length()

LOAD_FACTOR = 0.8
capacities = [int(B * L / LOAD_FACTOR) + 1 for _ in range(T)]

# Initialize and insert Hashmap index remapping based data structure
hash_table = torch.empty(
(sum(capacities), 2),
dtype=torch.int32,
)
# initialize
hash_table[:, :] = -1
hash_table_offsets = torch.tensor([0] + np.cumsum(capacities).tolist()).long()

Expand All @@ -3488,35 +3493,85 @@ def next_power_of_2(x: int) -> int:
ht = torch.classes.fb.PrunedMapCPU()
ht.insert(indices, dense_indices, offsets, T)

if not use_cpu:
(indices, dense_indices, offsets, hash_table, hash_table_offsets) = (
indices.cuda(),
dense_indices.cuda(),
offsets.cuda(),
hash_table.cuda(),
hash_table_offsets.cuda(),
# Initialize and insert Array index remapping based data structure
index_remappings_array = torch.tensor(
[-1] * original_E * T, dtype=torch.int32, device=current_device
)
index_remappings_array_offsets = torch.empty(
T + 1, dtype=torch.int64, device=current_device
)
index_remappings_array_offsets[0] = 0
for t in range(T):
indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device)
dense_indice_t = (
(dense_indices.view(T, B, L))[t].view(-1).to(current_device)
)
if use_cpu_hashtable:
dense_indices_ = ht.lookup(indices, offsets)
else:
dense_indices_ = torch.ops.fbgemm.pruned_hashmap_lookup(
indices, offsets, hash_table, hash_table_offsets
selected_indices = torch.add(indice_t, t * original_E)[:E]
index_remappings_array[selected_indices] = dense_indice_t
index_remappings_array_offsets[t + 1] = (
index_remappings_array_offsets[t] + original_E
)

torch.testing.assert_close(dense_indices, dense_indices_)

# now, use a value that does not exist in the original set of indices
# and so should be pruned out.
indices[:] = np.iinfo(np.int32).max
# Move data when using device
if not use_cpu:
(
indices,
dense_indices,
offsets,
hash_table,
hash_table_offsets,
index_remappings_array,
index_remappings_array_offsets,
) = (
indices.to(current_device),
dense_indices.to(current_device),
offsets.to(current_device),
hash_table.to(current_device),
hash_table_offsets.to(current_device),
index_remappings_array.to(current_device),
index_remappings_array_offsets.to(current_device),
)

# Lookup
if use_cpu_hashtable:
dense_indices_ = ht.lookup(indices, offsets)
else:
elif not use_array_for_index_remapping: # hashmap based pruning
dense_indices_ = torch.ops.fbgemm.pruned_hashmap_lookup(
indices, offsets, hash_table, hash_table_offsets
)
else: # array based pruning
dense_indices_ = torch.ops.fbgemm.pruned_array_lookup(
indices,
offsets,
index_remappings_array,
index_remappings_array_offsets,
)

torch.testing.assert_close(dense_indices.clone().fill_(-1), dense_indices_)
# Validate the lookup result
torch.testing.assert_close(dense_indices, dense_indices_)

# For array based pruning, it will be out-of-boundary for arbitrarily
# large indices. We will rely on bound checker to make sure indices
# are within the boundary.
if not use_array_for_index_remapping:
# now, use a value that does not exist in the original set of indices
# and so should be pruned out.
indices[:] = np.iinfo(np.int32).max

if use_cpu_hashtable:
dense_indices_ = ht.lookup(indices, offsets)
elif not use_array_for_index_remapping: # hashmap based pruning
dense_indices_ = torch.ops.fbgemm.pruned_hashmap_lookup(
indices, offsets, hash_table, hash_table_offsets
)
else: # array based pruning
dense_indices_ = torch.ops.fbgemm.pruned_array_lookup(
indices,
offsets,
index_remappings_array,
index_remappings_array_offsets,
)
torch.testing.assert_close(dense_indices.clone().fill_(-1), dense_indices_)

@given(
L=st.integers(min_value=0, max_value=16),
Expand Down

0 comments on commit fe32509

Please sign in to comment.