Skip to content

Commit

Permalink
Add tests to verify linearize_cache_indices. (pytorch#924)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#924

Before optimizing linearize_cache_indices, first, create set of unit tests to make it easier to test the changes.

Reviewed By: jianyuh

Differential Revision: D34123789

fbshipit-source-id: 60717d854e7d807aa6bd70a4fea929f2f5672e1e
  • Loading branch information
jasonjk-park authored and facebook-github-bot committed Feb 10, 2022
1 parent a358709 commit e385d02
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3529,6 +3529,78 @@ def test_pickle(self) -> None:
pickled = pickle.dumps(tensor_queue)
unpickled = pickle.loads(pickled)

@unittest.skipIf(*gpu_unavailable)
def test_linearize_cache_indices(self) -> None:
indices = torch.tensor(
[10, 2, 3, 7, 1, 4, 5, 9, 2, 7, 6, 8, 5, 1, 0, 4],
dtype=torch.int,
device="cuda",
)
equal_offsets = torch.tensor([0, 4, 8, 12, 16], dtype=torch.int, device="cuda")
varying_offsets = torch.tensor(
[0, 1, 3, 6, 8, 10, 14, 15, 16], dtype=torch.int, device="cuda"
)

# Testing equal sized tables.
cache_hash_size_cumsum_0 = torch.tensor([0, 12, 24, 36, 48]).cuda()
linear_cache_indices_0 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_0, indices, equal_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_0.cpu(),
torch.tensor(
[10, 2, 3, 7, 13, 16, 17, 21, 26, 31, 30, 32, 41, 37, 36, 40],
dtype=torch.int,
),
)
)

# Testing partially cached tables.
cache_hash_size_cumsum_1 = torch.tensor([0, 12, -1, 24, 36]).cuda()
linear_cache_indices_1 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_1, indices, equal_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_1.cpu(),
torch.tensor(
[10, 2, 3, 7, 13, 16, 17, 21, 36, 36, 36, 36, 29, 25, 24, 28],
dtype=torch.int,
),
)
)

# Testing batched with varying pooling factor.
cache_hash_size_cumsum_2 = torch.tensor([0, 12, -1, 24, 36]).cuda()
linear_cache_indices_2 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_2, indices, varying_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_2.cpu(),
torch.tensor(
[10, 2, 3, 19, 13, 16, 17, 21, 36, 36, 36, 36, 36, 36, 24, 28],
dtype=torch.int,
),
)
)

# Testing when multiple features share the same table.
cache_hash_size_cumsum_3 = torch.tensor([0, 0, 12, 12, 24]).cuda()
linear_cache_indices_3 = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum_3, indices, varying_offsets
)
self.assertTrue(
torch.equal(
linear_cache_indices_3.cpu(),
torch.tensor(
[10, 2, 3, 7, 1, 4, 5, 9, 14, 19, 18, 20, 17, 13, 12, 16],
dtype=torch.int,
),
)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit e385d02

Please sign in to comment.