Skip to content

Commit

Permalink
uvm_to_device expose device as interface (pytorch#3030)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3030

X-link: facebookresearch/FBGEMM#128

sometimes we have device directly. dont need to create a dummy tensor

Reviewed By: bixue2010

Differential Revision: D61740444

fbshipit-source-id: 67e8dc6b95eb03bd8ef86edf61fe662f9b832430
  • Loading branch information
Bin Wen authored and facebook-github-bot committed Aug 24, 2024
1 parent d545293 commit 95a406c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/include/fbgemm_gpu/cumem_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Tensor uvm_to_cpu(const Tensor& self);
/// @return A new tensor that shares the same device and UVM storage with
/// `prototype`.
Tensor uvm_to_device(const Tensor& self, const Tensor& prototype);
Tensor uvm_to_device_d(const Tensor& self, const at::Device& device);

/// @ingroup cumem-utils
///
Expand Down
8 changes: 6 additions & 2 deletions fbgemm_gpu/src/memory_utils/memory_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,19 @@ Tensor uvm_to_cpu(const Tensor& t) {
.set_(std::move(storage), t.storage_offset(), t.sizes(), t.strides());
}

Tensor uvm_to_device(const Tensor& t, const Tensor& prototype) {
Tensor uvm_to_device(const Tensor& self, const Tensor& prototype) {
auto device = prototype.device();
return uvm_to_device_d(self, device);
}

Tensor uvm_to_device_d(const Tensor& t, const at::Device& device) {
TORCH_CHECK(is_uvm_tensor(t));
// Don't copy the storage - just keep a reference to the original storage
auto* tcontext =
t.storage().data_ptr().cast_context<CUDAManagedIndirectContext>(
&CUDAManagedIndirectContext::release);
TORCH_CHECK(tcontext != nullptr)

auto device = prototype.device();
auto* ocontext =
tcontext->storage_.data_ptr().cast_context<CUDAManagedContext>(
&CUDAManagedContext::release);
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/src/memory_utils/memory_utils_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"uvm_to_device(Tensor self, Tensor prototype) -> Tensor",
TORCH_FN(uvm_to_device));
m.def(
"uvm_to_device_d(Tensor self, Device device) -> Tensor",
TORCH_FN(uvm_to_device_d));

m.def(
"cuda_mem_advise(Tensor t, int advice) -> ()",
Expand Down

0 comments on commit 95a406c

Please sign in to comment.