From d1088de522023a0c84acf392ff915bf4d4c8a475 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Mon, 26 Apr 2021 15:25:51 -0700 Subject: [PATCH] Let RRef getValue() synchronize CUDA streams (#56895) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56895 PR #54932 fixes CUDA stream synchronization between RPC-created OwnerRRef and UserRRef when `to_here()` is invoked. However, there are two more gaps. 1. RRef value can be accessed on the owner directly through `local_value`, which bypasses the fix in #54932. 2. When RRef is created directly through RRef ctor instead of RPC, the OwnerRRef won't be able to correctly record CUDA events. This PR fixes 1 by letting current streams wait for RRef recorded CUDA events before returning the value in `RRef::getValue()`. For 2, more discussions is needed to decide whether we should add a `devices` argument to RRef ctor, or should RRef ctor inspect the given values. Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D27992775 Pulled By: mrshenli fbshipit-source-id: ed0e5bfbf715460208c85e46dd3317deef17f8fe --- torch/csrc/distributed/rpc/rref_impl.cpp | 17 +++++++- .../_internal/distributed/rpc/rpc_test.py | 43 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 65a99ae600242..abc214dc816f9 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -1,10 +1,12 @@ +#include + #include +#include #include #include #include #include #include -#include #include #include @@ -25,6 +27,16 @@ std::string getTypeStr(const c10::TypePtr& type) { return type->annotation_str(); } } + +void blockCurrentStreams(const std::vector& events) { + for (const c10::Event& event : events) { + c10::Device device{event.device_type(), event.device_index()}; + c10::Stream stream = + c10::impl::getDeviceGuardImpl(device.type())->getStream(device); + event.block(stream); + } +} + } // namespace namespace torch { @@ -239,6 +251,9 @@ const IValue& OwnerRRef::getValue() const { if (future_->hasError()) { (void)future_->value(); // Throws the error. } + // Before accessing the value in this RRef, current CUDA streams must wait + // for pending CUDA operations that create the value. + blockCurrentStreams(events_); return future_->constValue(); } diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 2578f20f24a59..03422af059230 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -5748,6 +5748,49 @@ def test_rref_forward_synchronization3(self): def test_rref_forward_synchronization4(self): self._test_rref_forward_synchronization("cuda:1", "cuda:1") + def _test_owner_rref_forward_synchronization(self, local_device, remote_device): + if self.rank == 0: + options = self.rpc_backend_options + options.set_device_map("w0", {local_device: remote_device}) + rpc.init_rpc( + "w0", + rank=0, + world_size=1, + rpc_backend_options=options + ) + + model = rpc.remote( + "w0", torch.nn.Linear, (2048, 20000) + ).remote().to(remote_device) + for _ in range(30): + data = torch.rand(2048, 2048).to(local_device) + output = model.rpc_sync().forward(data) + # FIXME: remove this when RRef ctor can record CUDA events + torch.cuda.current_stream(local_device).synchronize() + # to_here() internally calls localValue as the caller is + # the owner of the RRef. + v0 = rpc.RRef(output).remote().sum().to_here().item() + v1 = output.sum().item() + self.assertEqual(v0, v1) + + rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_owner_rref_forward_synchronization1(self): + self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_owner_rref_forward_synchronization2(self): + self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1") + + @skip_if_lt_x_gpu(2) + def test_owner_rref_forward_synchronization3(self): + self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0") + + @skip_if_lt_x_gpu(2) + def test_owner_rref_forward_synchronization4(self): + self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1") + @skip_if_lt_x_gpu(1) def test_devices_option_mismatch(self): with self.assertRaisesRegex(