Skip to content

Commit

Permalink
merge_pooled_embedding support merging on dim 0 (#939)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #939

Reviewed By: xing-liu, jspark1105, jasonjk-park

Differential Revision: D34334371

fbshipit-source-id: a0ba36a5ce2b86fe17d8464971ebe80fa5a716c0
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Feb 19, 2022
1 parent 5134475 commit 8e7a826
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
7 changes: 4 additions & 3 deletions fbgemm_gpu/src/merge_pooled_embeddings_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ namespace fbgemm_gpu {

Tensor merge_pooled_embeddings_cpu(
std::vector<Tensor> pooled_embeddings,
int64_t batch_size,
at::Device target_device) {
int64_t /*uncat_dim_size*/,
at::Device target_device,
int64_t cat_dim = 1) {
auto cat_host_0 = [&](const std::vector<Tensor>& ts) {
int64_t n = 0;
for (auto& t : ts) {
Expand All @@ -29,7 +30,7 @@ Tensor merge_pooled_embeddings_cpu(
r = at::empty({n}, ts[0].options());
}
r.resize_(0);
return at::cat_out(r, ts, 1); // concat the tensor list in dim = 1
return at::cat_out(r, ts, cat_dim); // concat the tensor list in dim = 1
};
return cat_host_0(pooled_embeddings);
}
Expand Down
43 changes: 27 additions & 16 deletions fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,27 +303,37 @@ void all_to_one(
AT_CUDA_CHECK(cudaGetLastError());
}

Tensor cat_dim_1(
std::vector<Tensor> tensors,
int batch_size,
at::Device output_device) {
Tensor cat_dim_2d(
std::vector<Tensor>& tensors,
int64_t uncat_dim_size,
at::Device output_device,
int64_t cat_dim = 1) {
// only support 2d tensor concatenation.
TORCH_CHECK(cat_dim >= 0 && cat_dim <= 1);
if (tensors.size() == 0) {
return at::empty({0}, at::TensorOptions().device(output_device));
}
int64_t total_dim_1 = 0;
int64_t total_cat_dim = 0;
std::vector<int64_t> cumulative_dims;
cumulative_dims.push_back(0);
for (const auto& t : tensors) {
TORCH_CHECK(t.dim() == 2);
TORCH_CHECK(t.size(0) == batch_size);
total_dim_1 += t.size(-1);
cumulative_dims.push_back(total_dim_1);
// only support two-dimension tensors.
TORCH_CHECK(t.size(1 - cat_dim) == uncat_dim_size);
total_cat_dim += t.size(cat_dim);
cumulative_dims.push_back(total_cat_dim);
}

auto* prop = at::cuda::getCurrentDeviceProperties();
auto output = at::empty(
{batch_size, total_dim_1},
tensors.front().options().device(output_device));
// default shape for concatenating on dim 1
std::vector<int64_t> output_shape;
if (cat_dim == 0) {
output_shape = {total_cat_dim, uncat_dim_size};
} else {
output_shape = {uncat_dim_size, total_cat_dim};
}
auto output =
at::empty(output_shape, tensors.front().options().device(output_device));
TORCH_CHECK(
output.stride(0) * output.element_size() <=
static_cast<int64_t>(prop->memPitch));
Expand All @@ -332,7 +342,7 @@ Tensor cat_dim_1(

for (const auto i : c10::irange(tensors.size())) {
output_tensors.push_back(
output.slice(1, cumulative_dims[i], cumulative_dims[i + 1]));
output.slice(cat_dim, cumulative_dims[i], cumulative_dims[i + 1]));
}
all_to_one(
tensors, output_tensors, output_device, /* skip_if_same_device */ false);
Expand Down Expand Up @@ -366,13 +376,14 @@ namespace fbgemm_gpu {

Tensor merge_pooled_embeddings(
std::vector<Tensor> pooled_embeddings,
int64_t batch_size,
at::Device target_device) {
int64_t uncat_dim_size,
at::Device target_device,
int64_t cat_dim = 1) {
init_p2p_access();
at::cuda::CUDAGuard g(target_device);

TORCH_CHECK(!pooled_embeddings.empty());
return cat_dim_1(pooled_embeddings, batch_size, target_device);
return cat_dim_2d(pooled_embeddings, uncat_dim_size, target_device, cat_dim);
}

std::vector<Tensor> all_to_one_device(
Expand Down Expand Up @@ -403,7 +414,7 @@ std::vector<Tensor> all_to_one_device(

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"merge_pooled_embeddings(Tensor[] pooled_embeddings, int batch_size, Device target_device) -> Tensor");
"merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor");
DISPATCH_TO_CUDA(
"merge_pooled_embeddings", fbgemm_gpu::merge_pooled_embeddings);
m.def(
Expand Down
12 changes: 8 additions & 4 deletions fbgemm_gpu/test/merge_pooled_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class MergePooledEmbeddingsTest(unittest.TestCase):
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
non_default_stream=st.booleans(),
r=st.randoms(use_true_random=False),
dim=st.integers(min_value=0, max_value=1),
)
# Can instantiate 8 contexts which takes a long time.
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
Expand All @@ -51,6 +52,7 @@ def test_merge(
num_gpus,
non_default_stream,
r,
dim,
) -> None:
dst_device = r.randint(0, num_gpus - 1)
torch.cuda.set_device(dst_device)
Expand All @@ -67,23 +69,25 @@ def test_merge(
streams = [torch.cuda.Stream(device=i) for i in range(num_gpus)]
import contextlib

uncat_size = batch_indices.size(0) if dim == 1 else ad_ds[0]

with contextlib.ExitStack() as stack:
if non_default_stream:
for stream in streams:
stack.enter_context(torch.cuda.stream(stream))
output = torch.ops.fbgemm.merge_pooled_embeddings(
pooled_ad_embeddings, batch_indices.size(0), batch_indices.device
pooled_ad_embeddings, uncat_size, batch_indices.device, dim
)

def ref(pooled_ad_embeddings, batch_indices):
return torch.cat([p.cpu() for p in pooled_ad_embeddings], dim=1)
return torch.cat([p.cpu() for p in pooled_ad_embeddings], dim=dim)

output_ref = ref(pooled_ad_embeddings, batch_indices)

output_cpu = torch.ops.fbgemm.merge_pooled_embeddings(
[pe.cpu() for pe in pooled_ad_embeddings],
batch_indices.size(0),
uncat_size,
batch_indices.cpu().device,
dim,
)
self.assertEqual(output.device, torch.device(f"cuda:{dst_device}"))
torch.testing.assert_allclose(output_ref, output.cpu())
Expand Down

0 comments on commit 8e7a826

Please sign in to comment.