Skip to content

Commit

Permalink
zipf w/o replacement faster in split_table_batched_embeddings_benchma…
Browse files Browse the repository at this point in the history
…rk.py (pytorch#1170)

Summary:
Pull Request resolved: pytorch#1170

This was too slow benchmarking TBE annoying

Reviewed By: jianyuh

Differential Revision: D37339105

fbshipit-source-id: 2d1050e5cad2c639f443ccc9305942cca3ae44a9
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Jun 23, 2022
1 parent 243010f commit 8e9b829
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 13 deletions.
16 changes: 3 additions & 13 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,9 @@ def generate_requests(
# oversample and then remove duplicates to obtain sampling without
# replacement
all_indices = (np.random.zipf(a=alpha, size=(iters, T, B, 3 * L)) - 1) % E
for index_tuple in itertools.product(range(iters), range(T), range(B)):
# sample without replacement from
# https://stats.stackexchange.com/questions/20590/how-do-i-sample-without-replacement-using-a-sampling-with-replacement-function
r = set()
for x in all_indices[index_tuple]:
if x not in r:
r.add(x)
if len(r) == L:
break
assert (len(r)) == L, "too skewed distribution (alpha too big)"
all_indices[index_tuple][:L] = list(r)
# shuffle indices so we don't have unintended spatial locality
all_indices = torch.as_tensor(all_indices[:, :, :, :L])
all_indices = torch.ops.fbgemm.bottom_unique_k_per_row(
torch.as_tensor(all_indices), L
)
rng = default_rng()
permutation = torch.as_tensor(
rng.choice(E, size=all_indices.max().item() + 1, replace=False)
Expand Down
39 changes: 39 additions & 0 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2325,6 +2325,39 @@ Tensor index_select_dim0(
c10::optional<int64_t> /*consecutive_range_length*/) {
return at::index_select(input, 0, indices);
}

Tensor bottom_unique_k_per_row(const Tensor& input, const int64_t k) {
auto num_cols = input.size(-1);
Tensor input_reshaped = input.reshape({-1, num_cols});
auto input_accessor = input_reshaped.accessor<int64_t, 2>();

// Create output tensor
int num_rows = input_reshaped.size(0);
Tensor output = at::empty({num_rows, k}, input.options());
auto output_accessor = output.accessor<int64_t, 2>();

for (auto i : c10::irange(input_reshaped.size(0))) {
std::set<int64_t> s;
for (auto j : c10::irange(num_cols)) {
s.insert(input_accessor[i][j]);
if (s.size() == static_cast<size_t>(k)) {
break;
}
}
TORCH_CHECK(
s.size() == static_cast<size_t>(k),
"too skewed distribution (alpha too big)")
int j = 0;
for (int64_t x : s) {
output_accessor[i][j] = x;
++j;
}
}

auto output_shape = input.sizes().vec();
output_shape[output_shape.size() - 1] = k;
return output.reshape(output_shape);
}
} // namespace

} // namespace fbgemm_gpu
Expand Down Expand Up @@ -2393,6 +2426,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"index_select_dim0(Tensor input, Tensor indices, int? consecutive_range_start=0, int? consecutive_range_length=0) -> Tensor");
m.def(
"jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]");
// This is an one-off op to be used in bench_utils.py for zipf generation w/o
// replacement Along dim=-1, find smallest unique k. If the number of unique
// elements is less than k, errors out.
m.def("bottom_unique_k_per_row(Tensor input, int k) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
Expand Down Expand Up @@ -2454,4 +2491,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
fbgemm_gpu::permute_sequence_embeddings_cpu);
DISPATCH_TO_CPU("pack_segments", fbgemm_gpu::pack_segments_cpu);
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
DISPATCH_TO_CPU(
"bottom_unique_k_per_row", fbgemm_gpu::bottom_unique_k_per_row);
}
33 changes: 33 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import random
import unittest
from itertools import accumulate
Expand Down Expand Up @@ -1513,6 +1514,38 @@ def test_index_select_dim0(

torch.autograd.gradcheck(torch.ops.fbgemm.index_select_dim0, gradcheck_args)

# pyre-ignore [56]
@given(
T=st.integers(1, 5),
B=st.integers(1, 5),
L=st.integers(1, 5),
)
@settings(max_examples=20, deadline=None)
def test_bottom_unique_k_per_row(
self,
T: int,
B: int,
L: int,
) -> None:
E = 1000000
all_indices = (np.random.zipf(a=1.15, size=(T, B, 3 * L)) - 1) % E
all_indices_deduped = torch.ops.fbgemm.bottom_unique_k_per_row(
torch.as_tensor(all_indices), L
)
for index_tuple in itertools.product(range(T), range(B)):
# sample without replacement from
# https://stats.stackexchange.com/questions/20590/how-do-i-sample-without-replacement-using-a-sampling-with-replacement-function
r = set()
for x in all_indices[index_tuple]:
if x not in r:
r.add(x)
if len(r) == L:
break
assert (len(r)) == L, "too skewed distribution (alpha too big)"
all_indices[index_tuple][:L] = sorted(r)
all_indices_deduped_ref = torch.as_tensor(all_indices[:, :, :L])
torch.testing.assert_close(all_indices_deduped, all_indices_deduped_ref)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8e9b829

Please sign in to comment.