Skip to content

Commit

Permalink
Let RRef getValue() synchronize CUDA streams (pytorch#56895)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#56895

PR pytorch#54932 fixes CUDA stream synchronization between RPC-created
OwnerRRef and UserRRef when `to_here()` is invoked. However, there
are two more gaps.

1. RRef value can be accessed on the owner directly through
    `local_value`, which bypasses the fix in pytorch#54932.
2. When RRef is created directly through RRef ctor instead of RPC,
    the OwnerRRef won't be able to correctly record CUDA events.

This PR fixes 1 by letting current streams wait for RRef recorded
CUDA events before returning the value in `RRef::getValue()`.

For 2, more discussions is needed to decide whether we should add
a `devices` argument to RRef ctor, or should RRef ctor inspect the
given values.

Test Plan: Imported from OSS

Reviewed By: lw

Differential Revision: D27992775

Pulled By: mrshenli

fbshipit-source-id: ed0e5bfbf715460208c85e46dd3317deef17f8fe
  • Loading branch information
mrshenli authored and facebook-github-bot committed Apr 26, 2021
1 parent e1a7ec3 commit d1088de
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
17 changes: 16 additions & 1 deletion torch/csrc/distributed/rpc/rref_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <torch/csrc/distributed/rpc/rref_impl.h>

#include <ATen/record_function.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/utils.h>

Expand All @@ -25,6 +27,16 @@ std::string getTypeStr(const c10::TypePtr& type) {
return type->annotation_str();
}
}

void blockCurrentStreams(const std::vector<c10::Event>& events) {
for (const c10::Event& event : events) {
c10::Device device{event.device_type(), event.device_index()};
c10::Stream stream =
c10::impl::getDeviceGuardImpl(device.type())->getStream(device);
event.block(stream);
}
}

} // namespace

namespace torch {
Expand Down Expand Up @@ -239,6 +251,9 @@ const IValue& OwnerRRef::getValue() const {
if (future_->hasError()) {
(void)future_->value(); // Throws the error.
}
// Before accessing the value in this RRef, current CUDA streams must wait
// for pending CUDA operations that create the value.
blockCurrentStreams(events_);
return future_->constValue();
}

Expand Down
43 changes: 43 additions & 0 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5748,6 +5748,49 @@ def test_rref_forward_synchronization3(self):
def test_rref_forward_synchronization4(self):
self._test_rref_forward_synchronization("cuda:1", "cuda:1")

def _test_owner_rref_forward_synchronization(self, local_device, remote_device):
if self.rank == 0:
options = self.rpc_backend_options
options.set_device_map("w0", {local_device: remote_device})
rpc.init_rpc(
"w0",
rank=0,
world_size=1,
rpc_backend_options=options
)

model = rpc.remote(
"w0", torch.nn.Linear, (2048, 20000)
).remote().to(remote_device)
for _ in range(30):
data = torch.rand(2048, 2048).to(local_device)
output = model.rpc_sync().forward(data)
# FIXME: remove this when RRef ctor can record CUDA events
torch.cuda.current_stream(local_device).synchronize()
# to_here() internally calls localValue as the caller is
# the owner of the RRef.
v0 = rpc.RRef(output).remote().sum().to_here().item()
v1 = output.sum().item()
self.assertEqual(v0, v1)

rpc.shutdown()

@skip_if_lt_x_gpu(1)
def test_owner_rref_forward_synchronization1(self):
self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0")

@skip_if_lt_x_gpu(2)
def test_owner_rref_forward_synchronization2(self):
self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1")

@skip_if_lt_x_gpu(2)
def test_owner_rref_forward_synchronization3(self):
self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0")

@skip_if_lt_x_gpu(2)
def test_owner_rref_forward_synchronization4(self):
self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1")

@skip_if_lt_x_gpu(1)
def test_devices_option_mismatch(self):
with self.assertRaisesRegex(
Expand Down

0 comments on commit d1088de

Please sign in to comment.