Skip to content

Commit

Permalink
reduce_to_one (pytorch#1571)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1571

reduce_to_one for row-wise sharding in inference
Similar approach to all_to_one but without having the source waiting for target to be ready for potential WAR and WAW dependency violation because in this reduce_to_one implementation we create a new destination tensor.

Reviewed By: xing-liu, jianyuh

Differential Revision: D34263436

fbshipit-source-id: 7b1630b395311cfd6fef124113436f87f51a6fba
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Apr 4, 2023
1 parent 0064f56 commit 595adad
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 2 deletions.
11 changes: 11 additions & 0 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def print_p2p_bandwidth(

def benchmark(
all_to_one_only: bool,
sum_reduce_to_one_only: bool,
num_ads: int,
embedding_dimension: int,
ads_tables: int,
Expand Down Expand Up @@ -306,6 +307,10 @@ def pool_func_with_quantization(
return torch.ops.fbgemm.all_to_one_device(
pooled_ad_embeddings, batch_indices.device
)
elif sum_reduce_to_one_only:
return torch.ops.fbgemm.sum_reduce_to_one(
pooled_ad_embeddings, batch_indices.device
)
else:
return torch.ops.fbgemm.merge_pooled_embeddings(
embedding_results, batch_indices.size(0), batch_indices.device
Expand Down Expand Up @@ -368,6 +373,8 @@ def pool_func_with_quantization(
skip_dequantization,
data_type,
)
if all_to_one_only:
merged = torch.stack(merged)
t, _ = benchmark_torch_function(
pool_func_with_quantization,
(
Expand Down Expand Up @@ -419,6 +426,7 @@ def pool_func_with_quantization(

@click.command()
@click.option("--all-to-one-only", is_flag=True, default=False)
@click.option("--sum-reduce-to-one-only", is_flag=True, default=False)
@click.option("--num_ads", default=1024, type=int)
@click.option("--embedding_dimension", default=300, type=int)
@click.option("--ads_tables", default=100, type=int)
Expand Down Expand Up @@ -446,6 +454,7 @@ def pool_func_with_quantization(
@click.option("--sweep", is_flag=True, default=False)
def main(
all_to_one_only: bool,
sum_reduce_to_one_only: bool,
num_ads: int,
embedding_dimension: int,
ads_tables: int,
Expand Down Expand Up @@ -487,6 +496,7 @@ def handler(signum, frame):
try:
result = benchmark(
all_to_one_only,
sum_reduce_to_one_only,
num_ads,
embedding_dimension,
ads_tables,
Expand All @@ -510,6 +520,7 @@ def handler(signum, frame):

result = benchmark(
all_to_one_only,
sum_reduce_to_one_only,
num_ads,
embedding_dimension,
ads_tables,
Expand Down
185 changes: 185 additions & 0 deletions fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,178 @@ void all_to_one(
AT_CUDA_CHECK(cudaGetLastError());
}

Tensor sum_reduce_to_one(
std::vector<Tensor> input_tensors,
at::Device target_device) {
auto num_gpus = at::cuda::getNumGPUs();
std::vector<at::cuda::CUDAEvent> copy_completion_events(num_gpus);

// Local reduction for tensors residing the same GPU.
// And if there's a tensor already in target device, use it for output tensor.
Tensor output_tensor;
for (const auto i : c10::irange(input_tensors.size())) {
auto& ten = input_tensors[i];
if (!ten.has_storage()) {
continue;
}
TENSOR_ON_CUDA_GPU(ten);
if (ten.device() == target_device && !output_tensor.has_storage()) {
output_tensor = ten;
}
for (auto j = i + 1; j < input_tensors.size(); ++j) {
if (input_tensors[j].has_storage() &&
ten.device() == input_tensors[j].device()) {
ten.add_(input_tensors[j]);
// Replace with a dummy tensor without storage to mark reduced away
input_tensors[j] = Tensor();
}
}
}

// First copy from GPUs that are in 2-hop distance to their intermediate
// GPUs.
static auto intermediate_nodes =
get_intermediate_node(fbgemm_gpu::get_nvlink_matrix());
std::vector<Tensor> copied_tensors(input_tensors.size());
for (const auto i : c10::irange(input_tensors.size())) {
auto& src = input_tensors[i];
if (!src.has_storage()) {
continue;
}
auto intermediate_node =
intermediate_nodes(src.get_device(), target_device.index());
if (intermediate_node == -1) {
continue;
}
auto intermediate_device = at::Device(at::kCUDA, intermediate_node);
Tensor dst = at::empty_like(src, intermediate_device);

// This is a cross-device copy on the src current stream and dst current
// stream.
// Unlike all_to_one case, we don't need to wait for dst ready to worry
// about write-after-write and write-after-read dependencies because we're
// creating a temp tensor, dst.

at::cuda::CUDAGuard device_guard(src.device());
// on source device, launch memcpy.
AT_CUDA_CHECK(cudaMemcpy2DAsync(
dst.data_ptr(),
dst.stride(0) * dst.element_size(),
src.data_ptr(),
src.stride(0) * src.element_size(),
src.size(1) * src.element_size(),
src.size(0),
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream(src.get_device())));
copied_tensors[i] = dst;
}

// Wait for cross-device copies to complete, then reduce
for (const auto device_id : c10::irange(num_gpus)) {
auto intermediate_node =
intermediate_nodes(device_id, target_device.index());
if (intermediate_node == -1) {
continue;
}
auto intermediate_device = at::Device(at::kCUDA, intermediate_node);

auto src_device = at::Device(at::kCUDA, device_id);
// Still on src_device, record stream event
at::cuda::CUDAGuard device_guard(src_device);
at::cuda::CUDAStream copy_stream =
at::cuda::getCurrentCUDAStream(device_id);

auto& src_ready = copy_completion_events[device_id];
src_ready.record(copy_stream);

device_guard.set_device(intermediate_device);
src_ready.block(at::cuda::getCurrentCUDAStream(intermediate_node));

// Find any tensor in the intermediate GPU to reduce to.
Tensor ten_at_intermediate_node;
for (const auto i : c10::irange(input_tensors.size())) {
if (input_tensors[i].has_storage() &&
input_tensors[i].device() == intermediate_device) {
ten_at_intermediate_node = input_tensors[i];
break;
}
}

for (const auto i : c10::irange(copied_tensors.size())) {
auto& ten = copied_tensors[i];
if (!ten.has_storage() || ten.device() != intermediate_device ||
!input_tensors[i].has_storage() ||
input_tensors[i].device() != src_device) {
continue;
}
if (ten_at_intermediate_node.has_storage()) {
ten_at_intermediate_node.add_(ten);
input_tensors[i] = Tensor();
} else {
// No tensor to reduce to, so we just replace input_tensors[i] with
// the version copied to the intermediate GPU.
input_tensors[i] = ten;
}
}
}

// Final hop.
for (const auto i : c10::irange(input_tensors.size())) {
auto& src = input_tensors[i];
if (!src.has_storage() || src.device() == target_device) {
continue;
}

Tensor dst = at::empty_like(src, target_device);

at::cuda::CUDAGuard device_guard(src.device());
AT_CUDA_CHECK(cudaMemcpy2DAsync(
dst.data_ptr(),
dst.stride(0) * dst.element_size(),
src.data_ptr(),
src.stride(0) * src.element_size(),
src.size(1) * src.element_size(),
src.size(0),
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream(src.get_device())));
copied_tensors[i] = dst;
}

// Wait for cross-device copies to complete, then reduce
for (const auto device_id : c10::irange(num_gpus)) {
if (device_id != target_device.index()) {
auto src_device = at::Device(at::kCUDA, device_id);
// Still on src_device, record stream event
at::cuda::CUDAGuard device_guard(src_device);
at::cuda::CUDAStream copy_stream =
at::cuda::getCurrentCUDAStream(device_id);

auto& src_ready = copy_completion_events[device_id];
src_ready.record(copy_stream);

device_guard.set_device(target_device);
src_ready.block(at::cuda::getCurrentCUDAStream(target_device.index()));

for (const auto i : c10::irange(input_tensors.size())) {
auto& src = input_tensors[i];
if (!src.has_storage() || src.device() != src_device) {
continue;
}

if (output_tensor.has_storage()) {
output_tensor.add_(copied_tensors[i]);
} else {
// Very first reduction at the target device is just a shallow copy.
output_tensor = copied_tensors[i];
}
}
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return output_tensor;
}

Tensor cat_dim_2d(
std::vector<Tensor>& tensors,
int64_t uncat_dim_size,
Expand Down Expand Up @@ -415,6 +587,16 @@ std::vector<Tensor> all_to_one_device(
return output_tensors;
}

Tensor sum_reduce_to_one_device(
std::vector<Tensor> input_tensors,
at::Device target_device) {
TORCH_CHECK(input_tensors.size() > 0, "reducing no tensor is undefined");

init_p2p_access();

return sum_reduce_to_one(input_tensors, target_device);
};

} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
Expand All @@ -425,4 +607,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]");
DISPATCH_TO_CUDA("all_to_one_device", fbgemm_gpu::all_to_one_device);
m.def(
"sum_reduce_to_one(Tensor[] input_tensors, Device target_device) -> Tensor");
DISPATCH_TO_CUDA("sum_reduce_to_one", fbgemm_gpu::sum_reduce_to_one_device);
}
27 changes: 25 additions & 2 deletions fbgemm_gpu/test/merge_pooled_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def ref(pooled_ad_embeddings, batch_indices):
@given(
num_inputs=st.integers(min_value=1, max_value=10),
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
non_default_stream=st.booleans(),
r=st.randoms(use_true_random=False),
)
# Can instantiate 8 contexts which takes a long time.
Expand All @@ -105,7 +104,6 @@ def test_all_to_one_device(
self,
num_inputs,
num_gpus,
non_default_stream,
r,
) -> None:
dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
Expand All @@ -131,6 +129,31 @@ def test_merge_pooled_embeddings_cpu_with_different_target_device(self) -> None:
self.assertFalse(output_meta.is_cpu)
self.assertTrue(output_meta.is_meta)

@given(
num_inputs=st.integers(min_value=1, max_value=10),
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
r=st.randoms(use_true_random=False),
)
# Can instantiate 8 contexts which takes a long time.
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
def test_sum_reduce_to_one(
self,
num_inputs,
num_gpus,
r,
) -> None:
dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
with torch.cuda.device(dst_device):
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
cuda_inputs = [
input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
]
cuda_output = torch.ops.fbgemm.sum_reduce_to_one(cuda_inputs, dst_device)
self.assertEqual(cuda_output.device, dst_device)
torch.testing.assert_close(
cuda_output.cpu(), torch.stack(inputs).sum(dim=0)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 595adad

Please sign in to comment.