Skip to content

Commit

Permalink
[GraphBolt][CUDA] Refactor Gather operation. (dmlc#7269)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Apr 6, 2024
1 parent 62aca92 commit d4a6f8a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 38 deletions.
14 changes: 14 additions & 0 deletions graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(
*/
torch::Tensor ExclusiveCumSum(torch::Tensor input);

/**
* @brief Computes the gather operation on a given input and index tensor.
*
* @param input The input tensor.
* @param index The index tensor.
* @param dtype The optional output dtype. If not given, inferred from the input
* tensor.
*
* @return The result of the input.gather(0, index).to(dtype) operation.
*/
torch::Tensor Gather(
torch::Tensor input, torch::Tensor index,
torch::optional<torch::ScalarType> dtype = torch::nullopt);

/**
* @brief Select rows from input tensor according to index tensor.
*
Expand Down
36 changes: 36 additions & 0 deletions graphbolt/src/cuda/gather.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gather.cu
* @brief Gather operators implementation on CUDA.
*/
#include <thrust/gather.h>

#include "./common.h"

namespace graphbolt {
namespace ops {

torch::Tensor Gather(
torch::Tensor input, torch::Tensor index,
torch::optional<torch::ScalarType> dtype) {
if (!dtype.has_value()) dtype = input.scalar_type();
auto output = torch::empty(index.sizes(), index.options().dtype(*dtype));
AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "GatherIndexType", ([&] {
AT_DISPATCH_INTEGRAL_TYPES(
input.scalar_type(), "GatherInputType", ([&] {
using input_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(*dtype, "GatherOutputType", ([&] {
using output_t = scalar_t;
THRUST_CALL(
gather, index.data_ptr<index_t>(),
index.data_ptr<index_t>() + index.size(0),
input.data_ptr<input_t>(), output.data_ptr<output_t>());
}));
}));
}));
return output;
}

} // namespace ops
} // namespace graphbolt
41 changes: 3 additions & 38 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -500,44 +500,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}

output_indices = torch::empty(
picked_eids.size(0),
picked_eids.options().dtype(indices.scalar_type()));

// Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
using indices_t = index_t;
THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>());
}));
output_indices = Gather(indices, picked_eids);
}));

auto index_type_per_edge_for_sampled_edges = [&] {
// The code behaves same as:
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The reimplementation is required due to the torch equivalent does
// not work when type_per_edge is on pinned memory
auto types = type_per_edge.value();
auto output = torch::empty(
picked_eids.size(0), picked_eids.options().dtype(types.scalar_type()));
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = index_t;
AT_DISPATCH_INTEGRAL_TYPES(
types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
types.data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
}));
}));
return output;
};

torch::optional<torch::Tensor> output_type_per_edge;
torch::optional<torch::Tensor> edge_offsets;
if (type_per_edge && seed_offsets) {
Expand All @@ -547,7 +512,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// type_per_edge of sampled edges and determine the offsets of different
// sampled etypes and convert to fused hetero indptr representation.
if (fanouts.size() == 1) {
output_type_per_edge = index_type_per_edge_for_sampled_edges();
output_type_per_edge = Gather(*type_per_edge, picked_eids);
torch::Tensor output_in_degree, sliced_output_indptr;
sliced_output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0) - 1);
Expand Down Expand Up @@ -652,7 +617,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
if (type_per_edge)
output_type_per_edge = index_type_per_edge_for_sampled_edges();
output_type_per_edge = Gather(*type_per_edge, picked_eids);
}

torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
Expand Down

0 comments on commit d4a6f8a

Please sign in to comment.