Skip to content

Commit

Permalink
Back out "Optimize lxu_cache_lookup." (pytorch#980)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#980

Original commit changeset: d8aa2d28c327

Original Phabricator Diff: D34491095 (pytorch@f1a891c)

Check https://fb.workplace.com/groups/1069285536500339/permalink/4960821917346662/

Reviewed By: chrisxcai

Differential Revision: D34854587

fbshipit-source-id: 59f86a161aeb7836904601c3c7a02243a4f973e6
  • Loading branch information
jianyuh authored and facebook-github-bot committed Mar 14, 2022
1 parent ce37ff6 commit abc7b74
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fbgemm_gpu/src/split_table_batched_embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
"lfu_cache_populate_byte(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state) -> ()");
DISPATCH_TO_CUDA("lfu_cache_populate_byte", lfu_cache_populate_byte_cuda);
m.def(
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index) -> Tensor");
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1) -> Tensor");
DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda);
m.def(
"lxu_cache_flush(Tensor(a!) uvm_weights, Tensor cache_hash_size_cumsum, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, int total_D, Tensor(b!) lxu_cache_state, Tensor(c!) lxu_cache_weights, bool stochastic_rounding) -> ()");
Expand Down Expand Up @@ -153,7 +153,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lfu_cache_populate_byte(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state) -> ()");
DISPATCH_TO_CUDA("lfu_cache_populate_byte", lfu_cache_populate_byte_cuda);
m.def(
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index) -> Tensor");
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1) -> Tensor");
DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda);
m.def(
"lxu_cache_flush(Tensor(a!) uvm_weights, Tensor cache_hash_size_cumsum, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, int total_D, Tensor(b!) lxu_cache_state, Tensor(c!) lxu_cache_weights, bool stochastic_rounding) -> ()");
Expand Down

0 comments on commit abc7b74

Please sign in to comment.