Skip to content

Commit

Permalink
[Feature] Add CUDA support for min and max reducer in heterogeneo…
Browse files Browse the repository at this point in the history
…us API for unary message functions (dmlc#3566)

* CUDA support max/min reducer on forward pass

* docstring

* concised UpdateGradMinMax_hetero

* reorganized UpdateGradMinMax_hetero

* CUDA kernels for max/min reducer

* variable name

* lint check

* changed CUDA 2D thread mapping to 1D

* removed legacy cusparse for min/max reducer

* git CI issue

* restarting git CI

* adding namespace std

Co-authored-by: Israt Nisa <[email protected]>
Co-authored-by: Quan (Andy) Gan <[email protected]>
  • Loading branch information
3 people authored Dec 16, 2021
1 parent dd762a1 commit 70a499e
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 68 deletions.
69 changes: 22 additions & 47 deletions src/array/cpu/segment_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
}

/*!
* \brief CPU kernel to update gradients for reduce op max/min
* \param graph The input heterogeneous graph.
* \param op The binary operator, could be `copy_u`, `copy_e'.
* \param list_feat List of the input tensors.
Expand All @@ -117,61 +118,35 @@ void UpdateGradMinMax_hetero(HeteroGraphPtr graph,
const std::string& op,
const std::vector<NDArray>& list_feat,
const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx_ntypes,
const std::vector<NDArray>& list_idx_types,
std::vector<NDArray>* list_out) {
if (op == "copy_lhs") {
std::vector<std::vector<dgl_id_t>> dst_src_ntids(graph->NumVertexTypes(),
if (op == "copy_lhs" || op == "copy_rhs") {
std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(),
std::vector<dgl_id_t>());

for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t dst_id = pair.first; // graph is reversed
const dgl_id_t src_id = pair.second;
dst_src_ntids[dst_id].push_back(src_id); // can have duplicates. Use Hashtable to optimize.
}
std::vector<bool> updated(graph->NumVertexTypes());
for (int dst_id = 0; dst_id < dst_src_ntids.size(); ++dst_id) {
std::fill(updated.begin(), updated.end(), false);
for (int j = 0; j < dst_src_ntids[dst_id].size(); ++j) {
int src_id = dst_src_ntids[dst_id][j];
if (updated[src_id]) continue;
const DType* feat_data = list_feat[dst_id].Ptr<DType>();
const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
DType* out_data = (*list_out)[src_id].Ptr<DType>();
int dim = 1;
for (int i = 1; i < (*list_out)[src_id]->ndim; ++i)
dim *= (*list_out)[src_id]->shape[i];
int n = list_feat[dst_id]->shape[0];
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (int k = 0; k < dim; ++k) {
if (src_id == idx_ntype_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k];
#pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ
}
}
}
updated[src_id] = true;
}
}
} else if (op == "copy_rhs") {
for (dgl_type_t etid = 0; etid < graph->NumEdgeTypes(); ++etid) {
auto pair = graph->meta_graph()->FindEdge(etid);
const dgl_id_t dst_id = pair.first; // graph is reversed
const dgl_id_t src_id = pair.second;
const DType* feat_data = list_feat[dst_id].Ptr<DType>();
const IdType* idx_data = list_idx[dst_id].Ptr<IdType>();
const IdType* idx_ntype_data = list_idx_ntypes[dst_id].Ptr<IdType>();
DType* out_data = (*list_out)[etid].Ptr<DType>();
const dgl_id_t dst_ntype = pair.first; // graph is reversed
const dgl_id_t src_ntype = pair.second;
auto same_src_dst_ntype = std::find(std::begin(src_dst_ntypes[dst_ntype]),
std::end(src_dst_ntypes[dst_ntype]), src_ntype);
// if op is "copy_lhs", relation type with same src and dst node type will be updated once
if (op == "copy_lhs" && same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
continue;
src_dst_ntypes[dst_ntype].push_back(src_ntype);
const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();
const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();
int type = (op == "copy_lhs") ? src_ntype : etype;
DType* out_data = (*list_out)[type].Ptr<DType>();
int dim = 1;
for (int i = 1; i < (*list_out)[etid]->ndim; ++i)
dim *= (*list_out)[etid]->shape[i];
int n = list_feat[dst_id]->shape[0];
for (int i = 1; i < (*list_out)[type]->ndim; ++i)
dim *= (*list_out)[type]->shape[i];
int n = list_feat[dst_ntype]->shape[0];
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (int k = 0; k < dim; ++k) {
if (etid == idx_ntype_data[i * dim + k]) {
if (type == idx_type_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k];
#pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ
Expand Down
17 changes: 12 additions & 5 deletions src/array/cpu/spmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,19 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max
* reducer. \param arge Arg-Min/Max on edges. which refers the source node
* indices correspond to the minimum/maximum values of reduction result on
* reducer.
* \param arge Arg-Min/Max on edges. which refers the source node
* indices correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max
* reducer. \note It uses node parallel strategy, different threads are
* responsible for the computation of different nodes. \note The result will
* contain infinity for zero-degree nodes.
* reducer.
* \param argu_ntype Node type of the arg-Min/Max on source nodes, which refers the
* source node types correspond to the minimum/maximum values of reduction result
* on destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param arge_etype Edge-type of the arg-Min/Max on edges. which refers the source
* node indices correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param src_type Node type of the source nodes of an etype
* \param etype Edge type
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/segment_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Not implemented. Please use CPU version.";
cuda::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
});
}

Expand Down
80 changes: 80 additions & 0 deletions src/array/cuda/segment_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,34 @@ __global__ void ScatterAddKernel(
}
}

/*!
* \brief CUDA kernel to update gradients for reduce op max/min
* \note each WARP (group of 32 threads) is responsible for adding a row in
* feature tensor to a target row in output tensor.
*/

template <typename IdType, typename DType>
__global__ void UpdateGradMinMaxHeteroKernel(
const DType *feat, const IdType *idx, const IdType *idx_type, DType *out,
int64_t n, int64_t dim, int type) {
unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31;
unsigned int gId = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int warpId = gId >> 5;
unsigned int warp_size = 32;
unsigned int row = warpId;

while (row < n) {
for(unsigned int col = laneId; col < dim; col += warp_size) {
if (type == idx_type[row * dim + col]) {
const int write_row = idx[row * dim + col];
cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
}
}
row += blockDim.x * gridDim.x;
}
}

/*!
* \brief CUDA kernel of backward phase in segment min/max.
* \note each blockthread is responsible for writing a row in the
Expand Down Expand Up @@ -155,6 +183,58 @@ void ScatterAdd(
n, dim);
}

/*!
* \brief CUDA implementation to update gradients for reduce op max/min
* \param graph The input heterogeneous graph.
* \param op The binary operator, could be `copy_u`, `copy_e'.
* \param list_feat List of the input tensors.
* \param list_idx List of the indices tensors.
* \param list_idx_etype List of the node- or edge-type tensors.
* \param list_out List of the output tensors.
*/
template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph,
const std::string& op,
const std::vector<NDArray>& list_feat,
const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx_types,
std::vector<NDArray>* list_out) {
if (op == "copy_lhs" || op == "copy_rhs") {
std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(),
std::vector<dgl_id_t>());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t dst_ntype = pair.first; // graph is reversed
const dgl_id_t src_ntype = pair.second;
auto same_src_dst_ntype = std::find(std::begin(src_dst_ntypes[dst_ntype]),
std::end(src_dst_ntypes[dst_ntype]), src_ntype);
// if op is "copy_lhs", relation type with same src and dst node type will be updated once
if (op == "copy_lhs" && same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
continue;
src_dst_ntypes[dst_ntype].push_back(src_ntype);
const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();
const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();
int type = (op == "copy_lhs") ? src_ntype : etype;
DType* out_data = (*list_out)[type].Ptr<DType>();
int dim = 1;
for (int i = 1; i < (*list_out)[type]->ndim; ++i)
dim *= (*list_out)[type]->shape[i];
int n = list_feat[dst_ntype]->shape[0];
auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int th_per_row = 32;
const int ntx = 128;
const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx);
const dim3 nblks(nbx);
const dim3 nthrs(ntx);
CUDA_KERNEL_CALL((UpdateGradMinMaxHeteroKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream,
feat_data, idx_data, idx_type_data,
out_data, n, dim, type);
}
}
}

/*!
* \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer.
* \note math equation: out[arg[i, k], k] = feat[i, k]
Expand Down
50 changes: 47 additions & 3 deletions src/array/cuda/spmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
std::vector<DType*> trans_out((*vec_out).size(), NULL);

bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) &&
(CUDART_VERSION < 11000) && (reduce == "sum") &&
// legacy cuSPARSE does not care about NNZ, hence the argument "false".
((op == "copy_lhs" && cusparse_available<bits, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
Expand All @@ -542,7 +542,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
trans_out[ntype] = out;
}
}

// Check shape of ufeat for all relation type and compute feature size
int64_t x_length = 1;
for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
Expand All @@ -565,6 +564,30 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
x_length *= ufeat->shape[i];
}
}
// TODO(Israt): Can python do the following initializations while creating the tensors?
if (reduce == "max" || reduce == "min") {
const int64_t dim = bcast.out_len;
std::vector<bool> updated((*vec_out).size(), false);
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
if (reduce == "max")
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
else // min
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
const dgl_type_t dst_id = out_ntids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (op == "copy_lhs") {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
_Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
}
if (op == "copy_rhs") {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
_Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
}
}
}
}

auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
Expand Down Expand Up @@ -606,7 +629,28 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
});
}
// TODO(Israt): Add support for max/min reducer
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else {
LOG(FATAL) << "Not implemented";
}
Expand Down
Loading

0 comments on commit 70a499e

Please sign in to comment.