Skip to content

Commit

Permalink
Refactor linearize_cache_indices test (pytorch#2389)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2389

As title

Reviewed By: q10

Differential Revision: D54528393

fbshipit-source-id: 8017f52c71bd6430f643c2af20d6d054c7125630
  • Loading branch information
sryap authored and facebook-github-bot committed Mar 14, 2024
1 parent 4608c6e commit 234d13d
Showing 1 changed file with 52 additions and 89 deletions.
141 changes: 52 additions & 89 deletions fbgemm_gpu/test/tbe/cache/linearize_cache_indices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,37 @@

@optests.generate_opcheck_tests(fast=True)
class LinearizeCacheIndicesTest(unittest.TestCase):
def execute_linearize_cache_indices_ref(
self,
hash_size_cumsum: torch.Tensor,
indices: torch.Tensor,
offsets: torch.Tensor,
) -> torch.Tensor:
T = hash_size_cumsum.numel() - 1
B = (offsets.numel() - 1) // T
# Move offsets to CPU
offsets_ = offsets.cpu().tolist()
# Sentinel value
max_offset = hash_size_cumsum[-1].to(indices.dtype)
# Output
linear_cache_indices = indices.detach().clone()
for t in range(T):
hash_size_offset = hash_size_cumsum[t]
# Get slicing indices
indices_start = offsets_[t * B]
indices_end = offsets_[(t + 1) * B]
if hash_size_offset >= 0:
# Add hash size offset if the table is on cache
linear_cache_indices[indices_start:indices_end] += hash_size_offset
else:
# Set indices of the table that is not on cache to max_offset
linear_cache_indices[indices_start:indices_end] = max_offset
# Overwrite pruned indices with max_offset
pruned_pos = (indices < 0).nonzero(as_tuple=True)
if len(pruned_pos) > 0:
linear_cache_indices[pruned_pos] = max_offset
return linear_cache_indices

@unittest.skipIf(*gpu_unavailable)
def test_linearize_cache_indices(self) -> None:
indices = torch.tensor(
Expand All @@ -46,95 +77,27 @@ def test_linearize_cache_indices(self) -> None:
[0, 1, 3, 6, 8, 10, 14, 15, 16], dtype=torch.int, device="cuda"
)

# Testing equal sized tables.
cache_hash_size_cumsum_0 = torch.tensor([0, 12, 24, 36, 48]).cuda()
linear_cache_indices_0 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_0, indices, equal_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_0.cpu(),
torch.tensor(
[10, 2, 3, 7, 13, 16, 17, 21, 26, 31, 30, 32, 41, 37, 36, 40],
dtype=torch.int,
),
)
)

# Testing partially cached tables.
cache_hash_size_cumsum_1 = torch.tensor([0, 12, -1, 24, 36]).cuda()
linear_cache_indices_1 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_1, indices, equal_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_1.cpu(),
torch.tensor(
[10, 2, 3, 7, 13, 16, 17, 21, 36, 36, 36, 36, 29, 25, 24, 28],
dtype=torch.int,
),
)
)

# Testing batched with varying pooling factor.
cache_hash_size_cumsum_2 = torch.tensor([0, 12, -1, 24, 36]).cuda()
linear_cache_indices_2 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_2, indices, varying_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_2.cpu(),
torch.tensor(
[10, 2, 3, 19, 13, 16, 17, 21, 36, 36, 36, 36, 36, 36, 24, 28],
dtype=torch.int,
),
)
)

# Testing when multiple features share the same table.
cache_hash_size_cumsum_3 = torch.tensor([0, 0, 12, 12, 24]).cuda()
linear_cache_indices_3 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_3, indices, varying_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_3.cpu(),
torch.tensor(
[10, 2, 3, 7, 1, 4, 5, 9, 14, 19, 18, 20, 17, 13, 12, 16],
dtype=torch.int,
),
)
)

# Testing equal sized tables + pruned indices
cache_hash_size_cumsum_4 = torch.tensor([0, 12, 24, 36, 48]).cuda()
linear_cache_indices_4 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_4, pruned_indices, equal_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_4.cpu(),
torch.tensor(
[10, 48, 3, 7, 13, 16, 48, 21, 26, 48, 30, 32, 41, 37, 48, 40],
dtype=torch.int,
),
)
)

# Testing batched with varying pooling factor + pruned indices
cache_hash_size_cumsum_5 = torch.tensor([0, 12, -1, 24, 36]).cuda()
linear_cache_indices_5 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_5, pruned_indices, varying_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_5.cpu(),
torch.tensor(
[10, 36, 3, 19, 13, 16, 36, 21, 36, 36, 36, 36, 36, 36, 36, 28],
dtype=torch.int,
),
)
)
test_args = [
# Testing equal sized tables
([0, 12, 24, 36, 48], indices, equal_offsets),
# Testing partially cached tables
([0, 12, -1, 24, 36], indices, equal_offsets),
# Testing batched with varying pooling factor
([0, 12, -1, 24, 36], indices, varying_offsets),
# Testing when multiple features share the same table
([0, 0, 12, 12, 24], indices, varying_offsets),
# Testing equal sized tables + pruned indices
([0, 12, 24, 36, 48], pruned_indices, equal_offsets),
# Testing batched with varying pooling factor + pruned indices
([0, 12, -1, 24, 36], pruned_indices, varying_offsets),
]

for hash_size_cumsum_list, indices, offsets in test_args:
hash_size_cumsum = torch.tensor(hash_size_cumsum_list).cuda()
args = (hash_size_cumsum, indices, offsets)
output_test = torch.ops.fbgemm.linearize_cache_indices(*args)
output_ref = self.execute_linearize_cache_indices_ref(*args)
self.assertTrue(torch.equal(output_test, output_ref))

@unittest.skipIf(*gpu_unavailable)
def test_linearize_cache_indices_from_row_idx(self) -> None:
Expand Down

0 comments on commit 234d13d

Please sign in to comment.