From d4a6f8a03620eb13f35f9bf52551c6f012467498 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Sat, 6 Apr 2024 11:45:58 -0400 Subject: [PATCH] [GraphBolt][CUDA] Refactor `Gather` operation. (#7269) --- graphbolt/include/graphbolt/cuda_ops.h | 14 +++++++++ graphbolt/src/cuda/gather.cu | 36 ++++++++++++++++++++++ graphbolt/src/cuda/neighbor_sampler.cu | 41 ++------------------------ 3 files changed, 53 insertions(+), 38 deletions(-) create mode 100644 graphbolt/src/cuda/gather.cu diff --git a/graphbolt/include/graphbolt/cuda_ops.h b/graphbolt/include/graphbolt/cuda_ops.h index 83493e9f728f..4befacb84698 100644 --- a/graphbolt/include/graphbolt/cuda_ops.h +++ b/graphbolt/include/graphbolt/cuda_ops.h @@ -149,6 +149,20 @@ std::tuple SliceCSCIndptrHetero( */ torch::Tensor ExclusiveCumSum(torch::Tensor input); +/** + * @brief Computes the gather operation on a given input and index tensor. + * + * @param input The input tensor. + * @param index The index tensor. + * @param dtype The optional output dtype. If not given, inferred from the input + * tensor. + * + * @return The result of the input.gather(0, index).to(dtype) operation. + */ +torch::Tensor Gather( + torch::Tensor input, torch::Tensor index, + torch::optional dtype = torch::nullopt); + /** * @brief Select rows from input tensor according to index tensor. * diff --git a/graphbolt/src/cuda/gather.cu b/graphbolt/src/cuda/gather.cu new file mode 100644 index 000000000000..a9b7f35e81e9 --- /dev/null +++ b/graphbolt/src/cuda/gather.cu @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) + * @file cuda/gather.cu + * @brief Gather operators implementation on CUDA. + */ +#include + +#include "./common.h" + +namespace graphbolt { +namespace ops { + +torch::Tensor Gather( + torch::Tensor input, torch::Tensor index, + torch::optional dtype) { + if (!dtype.has_value()) dtype = input.scalar_type(); + auto output = torch::empty(index.sizes(), index.options().dtype(*dtype)); + AT_DISPATCH_INDEX_TYPES( + index.scalar_type(), "GatherIndexType", ([&] { + AT_DISPATCH_INTEGRAL_TYPES( + input.scalar_type(), "GatherInputType", ([&] { + using input_t = scalar_t; + AT_DISPATCH_INTEGRAL_TYPES(*dtype, "GatherOutputType", ([&] { + using output_t = scalar_t; + THRUST_CALL( + gather, index.data_ptr(), + index.data_ptr() + index.size(0), + input.data_ptr(), output.data_ptr()); + })); + })); + })); + return output; +} + +} // namespace ops +} // namespace graphbolt diff --git a/graphbolt/src/cuda/neighbor_sampler.cu b/graphbolt/src/cuda/neighbor_sampler.cu index ff2985d7dafc..4cd4b2820a5c 100644 --- a/graphbolt/src/cuda/neighbor_sampler.cu +++ b/graphbolt/src/cuda/neighbor_sampler.cu @@ -500,44 +500,9 @@ c10::intrusive_ptr SampleNeighbors( } } - output_indices = torch::empty( - picked_eids.size(0), - picked_eids.options().dtype(indices.scalar_type())); - - // Compute: output_indices = indices.gather(0, picked_eids); - AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] { - using indices_t = index_t; - THRUST_CALL( - gather, picked_eids.data_ptr(), - picked_eids.data_ptr() + picked_eids.size(0), - indices.data_ptr(), - output_indices.data_ptr()); - })); + output_indices = Gather(indices, picked_eids); })); - auto index_type_per_edge_for_sampled_edges = [&] { - // The code behaves same as: - // output_type_per_edge = type_per_edge.gather(0, picked_eids); - // The reimplementation is required due to the torch equivalent does - // not work when type_per_edge is on pinned memory - auto types = type_per_edge.value(); - auto output = torch::empty( - picked_eids.size(0), picked_eids.options().dtype(types.scalar_type())); - AT_DISPATCH_INDEX_TYPES( - indptr.scalar_type(), "SampleNeighborsIndptr", ([&] { - using indptr_t = index_t; - AT_DISPATCH_INTEGRAL_TYPES( - types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] { - THRUST_CALL( - gather, picked_eids.data_ptr(), - picked_eids.data_ptr() + picked_eids.size(0), - types.data_ptr(), output.data_ptr()); - })); - })); - return output; - }; - torch::optional output_type_per_edge; torch::optional edge_offsets; if (type_per_edge && seed_offsets) { @@ -547,7 +512,7 @@ c10::intrusive_ptr SampleNeighbors( // type_per_edge of sampled edges and determine the offsets of different // sampled etypes and convert to fused hetero indptr representation. if (fanouts.size() == 1) { - output_type_per_edge = index_type_per_edge_for_sampled_edges(); + output_type_per_edge = Gather(*type_per_edge, picked_eids); torch::Tensor output_in_degree, sliced_output_indptr; sliced_output_indptr = output_indptr.slice(0, 0, output_indptr.size(0) - 1); @@ -652,7 +617,7 @@ c10::intrusive_ptr SampleNeighbors( output_indptr = output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size()); if (type_per_edge) - output_type_per_edge = index_type_per_edge_for_sampled_edges(); + output_type_per_edge = Gather(*type_per_edge, picked_eids); } torch::optional subgraph_reverse_edge_ids = torch::nullopt;