Skip to content

Commit

Permalink
Re-apply "Move merge_pooled_embeddings op into FBGEMM_GPU." (pytorch#689
Browse files Browse the repository at this point in the history
)

Summary: Pull Request resolved: pytorch#689

Reviewed By: rweyrauch

Differential Revision: D30761093

fbshipit-source-id: 51c3f31fa6ff708969daeeecd9b4bafa7fa84983
  • Loading branch information
jianyuh authored and facebook-github-bot committed Sep 12, 2021
1 parent fdc225d commit 98a7c14
Show file tree
Hide file tree
Showing 5 changed files with 551 additions and 0 deletions.
89 changes: 89 additions & 0 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python3

# pyre-unsafe

import click
import numpy as np
import tabulate
import torch

try:
torch.ops.load_library("fbgemm_gpu_py.so")
except Exception:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings")


@click.command()
@click.option("--num-ads", default=1024, type=int)
@click.option("--embedding-dimension", default=300, type=int)
@click.option("--ads-tables", default=400, type=int)
@click.option("--iters", default=10, type=int)
@click.option("--p2p_bw", is_flag=True, default=False)
@click.option("--dst-device", default=0, type=int)
def main(num_ads, embedding_dimension, ads_tables, iters, p2p_bw, dst_device) -> None:
torch.cuda.set_device(dst_device)
num_gpus = torch.cuda.device_count()
ad_ds = [embedding_dimension * ads_tables for _ in range(num_gpus)]
batch_indices = torch.zeros(num_ads).long().cuda()
pooled_ad_embeddings = [
torch.randn(
num_ads, ad_d, dtype=torch.float16, device=torch.device(f"cuda:{i}")
)
for i, ad_d in enumerate(ad_ds)
]

def benchmark_torch_function(iters: int, f, *args) -> float:
f(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
f(*args)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
if p2p_bw:
print("Pairwise GPU Copy Bandwidth (GB/s)")
p2p_copy_bw = np.zeros((num_gpus, num_gpus))
for i in range(num_gpus):
for j in range(num_gpus):
with torch.cuda.device(i):
t = benchmark_torch_function(
iters,
lambda: pooled_ad_embeddings[i].copy_(pooled_ad_embeddings[j])
if i != j
else pooled_ad_embeddings[i].clone(),
)
p2p_copy_bw[i, j] = pooled_ad_embeddings[i].numel() * 2 / t / 1.0e9
table = tabulate.tabulate(
p2p_copy_bw,
headers=[f"GPU {i}" for i in range(num_gpus)],
tablefmt="fancy_grid",
floatfmt=".0f",
)
print(table)

streams = [torch.cuda.Stream(device=i) for i in range(num_gpus)]
import contextlib

with contextlib.ExitStack() as stack:
for stream in streams:
stack.enter_context(torch.cuda.stream(stream))

t = benchmark_torch_function(
iters,
lambda: torch.ops.fbgemm.merge_pooled_embeddings(
pooled_ad_embeddings, batch_indices
),
)
merged = torch.ops.fbgemm.merge_pooled_embeddings(
pooled_ad_embeddings, batch_indices
)
print(
f"Merge, B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Num GPUs: {num_gpus}, Destination GPU: {dst_device} Output Size: {merged.numel() * 2 / 1.0e6:.2f}MB, BW: {merged.numel() * 2 / t / 1.0e9:.2f}GB/s, t: {t * 1.0e3:.2f}ms"
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions fbgemm_gpu/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def build_extension(self, ext):
os.path.join(cur_dir, "src/sparse_ops_cpu.cpp"),
os.path.join(cur_dir, "src/sparse_ops_gpu.cpp"),
os.path.join(cur_dir, "src/sparse_ops.cu"),
os.path.join(cur_dir, "src/merge_pooled_embeddings_gpu.cpp"),
],
include_dirs=[
cur_dir,
Expand Down
44 changes: 44 additions & 0 deletions fbgemm_gpu/src/merge_pooled_embeddings_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/TensorOptions.h>
#include <torch/library.h>

using namespace at;

namespace at {

at::Tensor merge_pooled_embeddings_cpu(
std::vector<Tensor> ad_pooled_embeddings,
Tensor batch_indices) {
auto cat_host_0 = [&](const std::vector<at::Tensor>& ts) {
int64_t n = 0;
for (auto& t : ts) {
n += t.numel();
}
at::Tensor r;
if (n == 0) {
r = at::empty({n});
} else {
r = at::empty({n}, ts[0].options());
}
return at::cat_out(r, ts, 1); // concat the tensor list in dim = 1
};
return cat_host_0(ad_pooled_embeddings);
}

} // namespace at

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"merge_pooled_embeddings(Tensor[] ad_pooled_embeddings, Tensor batch_indices) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("merge_pooled_embeddings", at::merge_pooled_embeddings_cpu);
}
Loading

0 comments on commit 98a7c14

Please sign in to comment.