Skip to content

Commit

Permalink
Improve lxu_cache_lookup microbenchmark to match realistic scenario. (p…
Browse files Browse the repository at this point in the history
…ytorch#937)

Summary:
Pull Request resolved: pytorch#937

Add prefetching before cache lookup to match realisitic scenario (did not change performance that much)

Add cache hit ratio printed as result.

Reviewed By: jspark1105, jianyuh

Differential Revision: D34228009

fbshipit-source-id: 5f6b6c44838412035bc7870b671f967e018c78fb
  • Loading branch information
jasonjk-park authored and facebook-github-bot committed Feb 16, 2022
1 parent f2b8228 commit 5c3c84a
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions fbgemm_gpu/bench/split_embeddings_cache_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,34 +261,45 @@ def lxu_cache_lookup(
cache_load_factor: float,
) -> None:
num_embeddings: int = 1000000
cache_hash_size_cumsum = create_table_offsets(
num_tables, cached_tables_ratio, num_embeddings
embedding_dims: int = 128

embedding_specs = create_embedding_specs(
num_tables, cached_tables_ratio, num_embeddings, embedding_dims
)

tbe: nn.Module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs, cache_load_factor=cache_load_factor
)
tbe.fill_random_weights()

# Imitate execution flow by performing prefetching once.
indices, offsets = create_request(
num_tables, num_embeddings, batch, avg_pooling_factor
)
tbe.prefetch(indices, offsets)

linearized_indices = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum, indices, offsets
tbe.cache_hash_size_cumsum, indices, offsets
)

lxu_cache_state: Tensor = torch.empty(
math.ceil(cache_hash_size_cumsum[-1] * cache_load_factor / ASSOC),
ASSOC,
device="cuda",
dtype=torch.int64,
).fill_(-1)

t_ms = benchmark_same_input(
iters,
lambda linearized_indices, lxu_cache_state: torch.ops.fbgemm.lxu_cache_lookup(
linearized_indices, lxu_cache_state
),
linearized_indices,
lxu_cache_state,
tbe.lxu_cache_state,
)

# Run once again to obtain cache miss ratio.
locations = torch.ops.fbgemm.lxu_cache_lookup(
linearized_indices, tbe.lxu_cache_state
)
num_misses = torch.sum(locations == -1)
logging.info(
f"Across {iters} runs, T: {num_tables}, Cached T: {get_num_cached_tables(num_tables, cached_tables_ratio)}, BS: {batch}, cache_load_factor: {cache_load_factor}, {t_ms * 1.0e3:.0f}us"
f"Across {iters} runs, T: {num_tables}, Cached T: {get_num_cached_tables(num_tables, cached_tables_ratio)}, "
f"BS: {batch}, cache_load_factor: {cache_load_factor}, {t_ms * 1.0e3:.0f}us, "
f"cache miss: {num_misses.item() / locations.numel() * 100}%"
)


Expand Down

0 comments on commit 5c3c84a

Please sign in to comment.