diff --git a/Jenkinsfile b/Jenkinsfile index 13b3f661f407..935a149e52e6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -232,9 +232,7 @@ pipeline { stages { stage("Unit test") { steps { - // TODO(minjie): tmp disabled - //unit_test_linux("tensorflow", "gpu") - sh "echo skipped" + unit_test_linux("tensorflow", "gpu") } } } diff --git a/include/dgl/aten/array_ops.h b/include/dgl/aten/array_ops.h index 137d3b280ed7..73604d08a38d 100644 --- a/include/dgl/aten/array_ops.h +++ b/include/dgl/aten/array_ops.h @@ -90,16 +90,19 @@ IdArray Add(IdArray lhs, IdArray rhs); IdArray Sub(IdArray lhs, IdArray rhs); IdArray Mul(IdArray lhs, IdArray rhs); IdArray Div(IdArray lhs, IdArray rhs); +IdArray Mod(IdArray lhs, IdArray rhs); IdArray Add(IdArray lhs, int64_t rhs); IdArray Sub(IdArray lhs, int64_t rhs); IdArray Mul(IdArray lhs, int64_t rhs); IdArray Div(IdArray lhs, int64_t rhs); +IdArray Mod(IdArray lhs, int64_t rhs); IdArray Add(int64_t lhs, IdArray rhs); IdArray Sub(int64_t lhs, IdArray rhs); IdArray Mul(int64_t lhs, IdArray rhs); IdArray Div(int64_t lhs, IdArray rhs); +IdArray Mod(int64_t lhs, IdArray rhs); IdArray Neg(IdArray array); @@ -304,6 +307,17 @@ IdArray CumSum(IdArray array, bool prepend_zero = false); */ IdArray NonZero(NDArray array); +/*! + * \brief Sort the ID vector in ascending order. + * + * It performs both sort and arg_sort (returning the sorted index). The sorted index + * is always in int64. + * + * \param array Input array. + * \return A pair of arrays: sorted values and sorted index to the original position. + */ +std::pair Sort(IdArray array); + /*! * \brief Return a string that prints out some debug information. */ diff --git a/include/dgl/runtime/ndarray.h b/include/dgl/runtime/ndarray.h index 5815b9b2ec0b..53fb8cd8279f 100644 --- a/include/dgl/runtime/ndarray.h +++ b/include/dgl/runtime/ndarray.h @@ -603,14 +603,18 @@ dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2); +dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1, + const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, int64_t rhs); +dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator + (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator - (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator * (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator / (int64_t lhs, const dgl::runtime::NDArray& a2); +dgl::runtime::NDArray operator % (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& array); dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1, diff --git a/python/dgl/convert.py b/python/dgl/convert.py index fc582448b37f..547f0e59aab5 100644 --- a/python/dgl/convert.py +++ b/python/dgl/convert.py @@ -148,10 +148,7 @@ def graph(data, g = create_from_edges(u, v, ntype, etype, ntype, urange, vrange, validate, formats=formats) - if device is None: - return utils.to_int32_graph_if_on_gpu(g) - else: - return g.to(device) + return g.to(device) def bipartite(data, utype='_U', etype='_E', vtype='_V', @@ -300,10 +297,7 @@ def bipartite(data, u, v, utype, etype, vtype, urange, vrange, validate, formats=formats) - if device is None: - return utils.to_int32_graph_if_on_gpu(g) - else: - return g.to(device) + return g.to(device) def hetero_from_relations(rel_graphs, num_nodes_per_type=None): """Create a heterograph from graphs representing connections of each relation. diff --git a/python/dgl/heterograph.py b/python/dgl/heterograph.py index 743994a85ac8..9b40e76b2635 100644 --- a/python/dgl/heterograph.py +++ b/python/dgl/heterograph.py @@ -4450,7 +4450,7 @@ def to(self, device, **kwargs): # pylint: disable=invalid-name device(type='cpu') """ if device is None or self.device == device: - return utils.to_int32_graph_if_on_gpu(self) + return self ret = copy.copy(self) @@ -4481,8 +4481,6 @@ def to(self, device, **kwargs): # pylint: disable=invalid-name for k, num in self._batch_num_edges.items()} ret._batch_num_edges = new_bne - ret = utils.to_int32_graph_if_on_gpu(ret) - return ret def cpu(self): diff --git a/python/dgl/nn/tensorflow/softmax.py b/python/dgl/nn/tensorflow/softmax.py index b6bde6387ca0..c970ecd7c9a2 100644 --- a/python/dgl/nn/tensorflow/softmax.py +++ b/python/dgl/nn/tensorflow/softmax.py @@ -11,7 +11,7 @@ def edge_softmax_real(graph, score, eids=ALL): """Edge Softmax function""" if not is_all(eids): - graph = graph.edge_subgraph(tf.cast(eids, graph.idtype)) + graph = graph.edge_subgraph(tf.cast(eids, graph.idtype), preserve_nodes=True) gidx = graph._graph score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) diff --git a/python/dgl/utils/checks.py b/python/dgl/utils/checks.py index 4fd79a47eb57..60bb98831e52 100644 --- a/python/dgl/utils/checks.py +++ b/python/dgl/utils/checks.py @@ -2,9 +2,8 @@ # pylint: disable=invalid-name from __future__ import absolute_import, division -from ..base import DGLError, dgl_warning +from ..base import DGLError from .. import backend as F -from .internal import to_dgl_context def prepare_tensor(g, data, name): """Convert the data to ID tensor and check its ID type and context. @@ -129,14 +128,3 @@ def check_all_same_schema(feat_dict_list, keys, name): ' and feature size, but got\n\t{} {}\nand\n\t{} {}.'.format( name, k, F.dtype(t1), F.shape(t1)[1:], F.dtype(t2), F.shape(t2)[1:])) - -def to_int32_graph_if_on_gpu(g): - """Convert to int32 graph if the input graph is on GPU.""" - # device_type 2 is an internal code for GPU - if to_dgl_context(g.device).device_type == 2 and g.idtype == F.int64: - dgl_warning('Automatically cast a GPU int64 graph to int32.\n' - ' To suppress the warning, call DGLGraph.int() first\n' - ' or specify the ``device`` argument when creating the graph.') - return g.int() - else: - return g diff --git a/src/array/arith.h b/src/array/arith.h index 3471ea80492f..163857744947 100644 --- a/src/array/arith.h +++ b/src/array/arith.h @@ -46,6 +46,13 @@ struct Div { } }; +struct Mod { + template + static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) { + return t1 % t2; + } +}; + struct GT { template static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) { diff --git a/src/array/array.cc b/src/array/array.cc index f774d9467f8d..03ca8fcfedbc 100644 --- a/src/array/array.cc +++ b/src/array/array.cc @@ -287,6 +287,20 @@ IdArray NonZero(NDArray array) { return ret; } +std::pair Sort(IdArray array) { + if (array.NumElements() == 0) { + IdArray idx = NewIdArray(0, array->ctx, 64); + return std::make_pair(array, idx); + } + std::pair ret; + ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", { + ATEN_ID_TYPE_SWITCH(array->dtype, IdType, { + ret = impl::Sort(array); + }); + }); + return ret; +} + std::string ToDebugString(NDArray array) { std::ostringstream oss; NDArray a = array.CopyTo(DLContext{kDLCPU, 0}); diff --git a/src/array/array_arith.cc b/src/array/array_arith.cc index fd2447fad91d..6dcbb025e620 100644 --- a/src/array/array_arith.cc +++ b/src/array/array_arith.cc @@ -70,6 +70,7 @@ BINARY_ELEMENT_OP(Add, Add) BINARY_ELEMENT_OP(Sub, Sub) BINARY_ELEMENT_OP(Mul, Mul) BINARY_ELEMENT_OP(Div, Div) +BINARY_ELEMENT_OP(Mod, Mod) BINARY_ELEMENT_OP(GT, GT) BINARY_ELEMENT_OP(LT, LT) BINARY_ELEMENT_OP(GE, GE) @@ -81,6 +82,7 @@ BINARY_ELEMENT_OP_L(Add, Add) BINARY_ELEMENT_OP_L(Sub, Sub) BINARY_ELEMENT_OP_L(Mul, Mul) BINARY_ELEMENT_OP_L(Div, Div) +BINARY_ELEMENT_OP_L(Mod, Mod) BINARY_ELEMENT_OP_L(GT, GT) BINARY_ELEMENT_OP_L(LT, LT) BINARY_ELEMENT_OP_L(GE, GE) @@ -92,6 +94,7 @@ BINARY_ELEMENT_OP_R(Add, Add) BINARY_ELEMENT_OP_R(Sub, Sub) BINARY_ELEMENT_OP_R(Mul, Mul) BINARY_ELEMENT_OP_R(Div, Div) +BINARY_ELEMENT_OP_R(Mod, Mod) BINARY_ELEMENT_OP_R(GT, GT) BINARY_ELEMENT_OP_R(LT, LT) BINARY_ELEMENT_OP_R(GE, GE) @@ -117,6 +120,9 @@ NDArray operator * (const NDArray& lhs, const NDArray& rhs) { NDArray operator / (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::Div(lhs, rhs); } +NDArray operator % (const NDArray& lhs, const NDArray& rhs) { + return dgl::aten::Mod(lhs, rhs); +} NDArray operator + (const NDArray& lhs, int64_t rhs) { return dgl::aten::Add(lhs, rhs); } @@ -129,6 +135,9 @@ NDArray operator * (const NDArray& lhs, int64_t rhs) { NDArray operator / (const NDArray& lhs, int64_t rhs) { return dgl::aten::Div(lhs, rhs); } +NDArray operator % (const NDArray& lhs, int64_t rhs) { + return dgl::aten::Mod(lhs, rhs); +} NDArray operator + (int64_t lhs, const NDArray& rhs) { return dgl::aten::Add(lhs, rhs); } @@ -141,6 +150,9 @@ NDArray operator * (int64_t lhs, const NDArray& rhs) { NDArray operator / (int64_t lhs, const NDArray& rhs) { return dgl::aten::Div(lhs, rhs); } +NDArray operator % (int64_t lhs, const NDArray& rhs) { + return dgl::aten::Mod(lhs, rhs); +} NDArray operator - (const NDArray& array) { return dgl::aten::Neg(array); } diff --git a/src/array/array_op.h b/src/array/array_op.h index cdaab999424a..ae02714e361b 100644 --- a/src/array/array_op.h +++ b/src/array/array_op.h @@ -46,6 +46,9 @@ DType IndexSelect(NDArray array, int64_t index); template IdArray NonZero(BoolArray bool_arr); +template +std::pair Sort(IdArray array); + template NDArray Scatter(NDArray array, IdArray indices); diff --git a/src/array/cpu/array_op_impl.cc b/src/array/cpu/array_op_impl.cc index fedf68b5bb6d..37d806c10d71 100644 --- a/src/array/cpu/array_op_impl.cc +++ b/src/array/cpu/array_op_impl.cc @@ -60,6 +60,7 @@ template IdArray BinaryElewise(IdArray lhs, IdArray template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); +template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); @@ -70,6 +71,7 @@ template IdArray BinaryElewise(IdArray lhs, IdArray template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); +template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); @@ -94,6 +96,7 @@ template IdArray BinaryElewise(IdArray lhs, int32_t template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); +template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); @@ -104,6 +107,7 @@ template IdArray BinaryElewise(IdArray lhs, int64_t template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); +template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); @@ -128,6 +132,7 @@ template IdArray BinaryElewise(int32_t lhs, IdArray template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); +template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); @@ -138,6 +143,7 @@ template IdArray BinaryElewise(int64_t lhs, IdArray template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); +template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); diff --git a/src/array/cpu/array_sort.cc b/src/array/cpu/array_sort.cc new file mode 100644 index 000000000000..4b43596996c4 --- /dev/null +++ b/src/array/cpu/array_sort.cc @@ -0,0 +1,184 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file array/cpu/array_sort.cc + * \brief Array sort CPU implementation + */ +#include +#ifdef PARALLEL_ALGORITHMS +#include +#endif +#include +#include + +namespace { + +template +struct PairRef { + PairRef() = delete; + PairRef(const PairRef& other) = default; + PairRef(PairRef&& other) = default; + PairRef(V1 *const r, V2 *const c) + : row(r), col(c) {} + + PairRef& operator=(const PairRef& other) { + *row = *other.row; + *col = *other.col; + return *this; + } + PairRef& operator=(const std::pair& val) { + *row = std::get<0>(val); + *col = std::get<1>(val); + return *this; + } + + operator std::pair() const { + return std::make_pair(*row, *col); + } + + void Swap(const PairRef& other) const { + std::swap(*row, *other.row); + std::swap(*col, *other.col); + } + + V1 *row; + V2 *col; +}; + +using std::swap; +template +void swap(const PairRef& r1, const PairRef& r2) { + r1.Swap(r2); +} + +template +struct PairIterator : public std::iterator, + std::ptrdiff_t, + std::pair, + PairRef> { + PairIterator() = default; + PairIterator(const PairIterator& other) = default; + PairIterator(PairIterator&& other) = default; + PairIterator(V1 *r, V2 *c): row(r), col(c) {} + + PairIterator& operator=(const PairIterator& other) = default; + PairIterator& operator=(PairIterator&& other) = default; + ~PairIterator() = default; + + bool operator==(const PairIterator& other) const { + return row == other.row; + } + + bool operator!=(const PairIterator& other) const { + return row != other.row; + } + + bool operator<(const PairIterator& other) const { + return row < other.row; + } + + bool operator>(const PairIterator& other) const { + return row > other.row; + } + + bool operator<=(const PairIterator& other) const { + return row <= other.row; + } + + bool operator>=(const PairIterator& other) const { + return row >= other.row; + } + + PairIterator& operator+=(const std::ptrdiff_t& movement) { + row += movement; + col += movement; + return *this; + } + + PairIterator& operator-=(const std::ptrdiff_t& movement) { + row -= movement; + col -= movement; + return *this; + } + + PairIterator& operator++() { + return operator+=(1); + } + + PairIterator& operator--() { + return operator-=(1); + } + + PairIterator operator++(int) { + PairIterator ret(*this); + operator++(); + return ret; + } + + PairIterator operator--(int) { + PairIterator ret(*this); + operator--(); + return ret; + } + + PairIterator operator+(const std::ptrdiff_t& movement) const { + PairIterator ret(*this); + ret += movement; + return ret; + } + + PairIterator operator-(const std::ptrdiff_t& movement) const { + PairIterator ret(*this); + ret -= movement; + return ret; + } + + std::ptrdiff_t operator-(const PairIterator& other) const { + return row - other.row; + } + + PairRef operator*() const { + return PairRef(row, col); + } + PairRef operator*() { + return PairRef(row, col); + } + + V1 *row; + V2 *col; +}; + +} // namespace + +namespace dgl { +using runtime::NDArray; +namespace aten { +namespace impl { + +template +std::pair Sort(IdArray array) { + const int64_t nitem = array->shape[0]; + IdArray val = array.Clone(); + IdArray idx = aten::Range(0, nitem, 64, array->ctx); + IdType* val_data = val.Ptr(); + int64_t* idx_data = idx.Ptr(); + typedef std::pair Pair; +#ifdef PARALLEL_ALGORITHMS + __gnu_parallel::sort( +#else + std::sort( +#endif + PairIterator(val_data, idx_data), + PairIterator(val_data, idx_data) + nitem, + [] (const Pair& a, const Pair& b) { + return std::get<0>(a) < std::get<0>(b); + }); + return std::make_pair(val, idx); +} + +template std::pair Sort(IdArray); +template std::pair Sort(IdArray); + +} // namespace impl +} // namespace aten +} // namespace dgl diff --git a/src/array/cuda/array_op_impl.cu b/src/array/cuda/array_op_impl.cu index a34ea46c31e6..4a378d6bcb5c 100644 --- a/src/array/cuda/array_op_impl.cu +++ b/src/array/cuda/array_op_impl.cu @@ -45,6 +45,7 @@ template IdArray BinaryElewise(IdArray lhs, IdArray template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); +template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); @@ -55,6 +56,7 @@ template IdArray BinaryElewise(IdArray lhs, IdArray template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); +template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); @@ -92,6 +94,7 @@ template IdArray BinaryElewise(IdArray lhs, int32_t template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); +template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); template IdArray BinaryElewise(IdArray lhs, int32_t rhs); @@ -102,6 +105,7 @@ template IdArray BinaryElewise(IdArray lhs, int64_t template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); +template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); template IdArray BinaryElewise(IdArray lhs, int64_t rhs); @@ -140,6 +144,7 @@ template IdArray BinaryElewise(int32_t lhs, IdArray template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); +template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); template IdArray BinaryElewise(int32_t lhs, IdArray rhs); @@ -150,6 +155,7 @@ template IdArray BinaryElewise(int64_t lhs, IdArray template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); +template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); template IdArray BinaryElewise(int64_t lhs, IdArray rhs); diff --git a/src/array/cuda/array_sort.cu b/src/array/cuda/array_sort.cu new file mode 100644 index 000000000000..be3a9939b7e5 --- /dev/null +++ b/src/array/cuda/array_sort.cu @@ -0,0 +1,50 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file array/cpu/array_sort.cu + * \brief Array sort GPU implementation + */ +#include +#include +#include "../../runtime/cuda/cuda_common.h" +#include "./utils.h" + +namespace dgl { +using runtime::NDArray; +namespace aten { +namespace impl { + +template +std::pair Sort(IdArray array) { + const auto& ctx = array->ctx; + auto device = runtime::DeviceAPI::Get(ctx); + const int64_t nitems = array->shape[0]; + IdArray orig_idx = Range(0, nitems, 64, ctx); + IdArray sorted_array = NewIdArray(nitems, ctx, array->dtype.bits); + IdArray sorted_idx = NewIdArray(nitems, ctx, 64); + + const IdType* keys_in = array.Ptr(); + const int64_t* values_in = orig_idx.Ptr(); + IdType* keys_out = sorted_array.Ptr(); + int64_t* values_out = sorted_idx.Ptr(); + + // Allocate workspace + size_t workspace_size = 0; + cub::DeviceRadixSort::SortPairs(nullptr, workspace_size, + keys_in, keys_out, values_in, values_out, nitems); + void* workspace = device->AllocWorkspace(ctx, workspace_size); + + // Compute + cub::DeviceRadixSort::SortPairs(workspace, workspace_size, + keys_in, keys_out, values_in, values_out, nitems); + + device->FreeWorkspace(ctx, workspace); + + return std::make_pair(sorted_array, sorted_idx); +} + +template std::pair Sort(IdArray); +template std::pair Sort(IdArray); + +} // namespace impl +} // namespace aten +} // namespace dgl diff --git a/src/array/cuda/coo2csr.cc b/src/array/cuda/coo2csr.cc deleted file mode 100644 index 5e3a1f5129bd..000000000000 --- a/src/array/cuda/coo2csr.cc +++ /dev/null @@ -1,65 +0,0 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file array/cuda/coo2csr.cc - * \brief COO2CSR - */ -#include -#include "../../runtime/cuda/cuda_common.h" - -namespace dgl { - -using runtime::NDArray; - -namespace aten { -namespace impl { - -template -CSRMatrix COOToCSR(COOMatrix coo) { - CHECK(sizeof(IdType) == 4) << "CUDA COOToCSR does not support int64."; - auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); - // allocate cusparse handle if needed - if (!thr_entry->cusparse_handle) { - CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); - } - CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); - - bool row_sorted = coo.row_sorted; - bool col_sorted = coo.col_sorted; - if (!row_sorted) { - // It is possible that the flag is simply not set (default value is false), - // so we still perform a linear scan to check the flag. - std::tie(row_sorted, col_sorted) = COOIsSorted(coo); - } - if (!row_sorted) { - coo = COOSort(coo); - } - - const int64_t nnz = coo.row->shape[0]; - // TODO(minjie): Many of our current implementation assumes that CSR must have - // a data array. This is a temporary workaround. Remove this after: - // - The old immutable graph implementation is deprecated. - // - The old binary reduce kernel is deprecated. - if (!COOHasData(coo)) - coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx); - - NDArray indptr = aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits); - int32_t* indptr_ptr = static_cast(indptr->data); - CUSPARSE_CALL(cusparseXcoo2csr( - thr_entry->cusparse_handle, - coo.row.Ptr(), - nnz, - coo.num_rows, - indptr_ptr, - CUSPARSE_INDEX_BASE_ZERO)); - - return CSRMatrix(coo.num_rows, coo.num_cols, - indptr, coo.col, coo.data, col_sorted); -} - -template CSRMatrix COOToCSR(COOMatrix coo); -template CSRMatrix COOToCSR(COOMatrix coo); - - -} // namespace impl -} // namespace aten -} // namespace dgl diff --git a/src/array/cuda/coo2csr.cu b/src/array/cuda/coo2csr.cu new file mode 100644 index 000000000000..55890d36f7a4 --- /dev/null +++ b/src/array/cuda/coo2csr.cu @@ -0,0 +1,148 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file array/cuda/coo2csr.cc + * \brief COO2CSR + */ +#include +#include "../../runtime/cuda/cuda_common.h" +#include "./utils.h" + +namespace dgl { + +using runtime::NDArray; + +namespace aten { +namespace impl { + +template +CSRMatrix COOToCSR(COOMatrix coo) { + LOG(FATAL) << "Unreachable code."; + return {}; +} + +template <> +CSRMatrix COOToCSR(COOMatrix coo) { + auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); + // allocate cusparse handle if needed + if (!thr_entry->cusparse_handle) { + CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); + } + CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); + + bool row_sorted = coo.row_sorted; + bool col_sorted = coo.col_sorted; + if (!row_sorted) { + // It is possible that the flag is simply not set (default value is false), + // so we still perform a linear scan to check the flag. + std::tie(row_sorted, col_sorted) = COOIsSorted(coo); + } + if (!row_sorted) { + coo = COOSort(coo); + col_sorted = coo.col_sorted; + } + + const int64_t nnz = coo.row->shape[0]; + // TODO(minjie): Many of our current implementation assumes that CSR must have + // a data array. This is a temporary workaround. Remove this after: + // - The old immutable graph implementation is deprecated. + // - The old binary reduce kernel is deprecated. + if (!COOHasData(coo)) + coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx); + + NDArray indptr = aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits); + int32_t* indptr_ptr = static_cast(indptr->data); + CUSPARSE_CALL(cusparseXcoo2csr( + thr_entry->cusparse_handle, + coo.row.Ptr(), + nnz, + coo.num_rows, + indptr_ptr, + CUSPARSE_INDEX_BASE_ZERO)); + + return CSRMatrix(coo.num_rows, coo.num_cols, + indptr, coo.col, coo.data, col_sorted); +} + +/*! + * \brief Search for the insertion positions for needle in the hay. + * + * The hay is a list of sorted elements and the result is the insertion position + * of each needle so that the insertion still gives sorted order. + * + * It essentially perform binary search to find upper bound for each needle + * elements. + * + * For example: + * hay = [0, 0, 1, 2, 2] + * needle = [0, 1, 2, 3] + * then, + * out = [2, 3, 5, 5] + */ +template +__global__ void _SortedSearchKernelUpperBound( + const IdType* hay, int64_t hay_size, + const IdType* needles, int64_t num_needles, + IdType* pos) { + int tx = blockIdx.x * blockDim.x + threadIdx.x; + const int stride_x = gridDim.x * blockDim.x; + while (tx < num_needles) { + const IdType ele = needles[tx]; + // binary search + IdType lo = 0, hi = hay_size; + while (lo < hi) { + IdType mid = (lo + hi) >> 1; + if (hay[mid] <= ele) { + lo = mid + 1; + } else { + hi = mid; + } + } + pos[tx] = lo; + tx += stride_x; + } +} + +template <> +CSRMatrix COOToCSR(COOMatrix coo) { + const auto& ctx = coo.row->ctx; + const auto nbits = coo.row->dtype.bits; + auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); + bool row_sorted = coo.row_sorted; + bool col_sorted = coo.col_sorted; + if (!row_sorted) { + // It is possible that the flag is simply not set (default value is false), + // so we still perform a linear scan to check the flag. + std::tie(row_sorted, col_sorted) = COOIsSorted(coo); + } + if (!row_sorted) { + coo = COOSort(coo); + col_sorted = coo.col_sorted; + } + + const int64_t nnz = coo.row->shape[0]; + // TODO(minjie): Many of our current implementation assumes that CSR must have + // a data array. This is a temporary workaround. Remove this after: + // - The old immutable graph implementation is deprecated. + // - The old binary reduce kernel is deprecated. + if (!COOHasData(coo)) + coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx); + + IdArray rowids = Range(0, coo.num_rows, nbits, ctx); + const int nt = cuda::FindNumThreads(coo.num_rows); + const int nb = (coo.num_rows + nt - 1) / nt; + IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx); + _SortedSearchKernelUpperBound<<stream>>>( + coo.row.Ptr(), nnz, + rowids.Ptr(), coo.num_rows, + indptr.Ptr() + 1); + + return CSRMatrix(coo.num_rows, coo.num_cols, + indptr, coo.col, coo.data, col_sorted); +} + +template CSRMatrix COOToCSR(COOMatrix coo); +template CSRMatrix COOToCSR(COOMatrix coo); + +} // namespace impl +} // namespace aten +} // namespace dgl diff --git a/src/array/cuda/coo_sort.cu b/src/array/cuda/coo_sort.cu index 8de9c5ed8bc6..2a4573f73cee 100644 --- a/src/array/cuda/coo_sort.cu +++ b/src/array/cuda/coo_sort.cu @@ -18,10 +18,14 @@ namespace impl { template void COOSort_(COOMatrix* coo, bool sort_column) { + LOG(FATAL) << "Unreachable codes"; +} + +template <> +void COOSort_(COOMatrix* coo, bool sort_column) { // TODO(minjie): Current implementation is based on cusparse which only supports // int32_t. To support int64_t, we could use the Radix sort algorithm provided // by CUB. - CHECK(sizeof(IdType) == 4) << "CUDA COOSort does not support int64."; auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto device = runtime::DeviceAPI::Get(coo->row->ctx); // allocate cusparse handle if needed @@ -63,7 +67,7 @@ void COOSort_(COOMatrix* coo, bool sort_column) { if (sort_column) { // First create a row indptr array and then call csrsort int32_t* indptr = static_cast( - device->AllocWorkspace(row->ctx, (coo->num_rows + 1) * sizeof(IdType))); + device->AllocWorkspace(row->ctx, (coo->num_rows + 1) * sizeof(int32_t))); CUSPARSE_CALL(cusparseXcoo2csr( thr_entry->cusparse_handle, row_ptr, @@ -101,6 +105,20 @@ void COOSort_(COOMatrix* coo, bool sort_column) { coo->col_sorted = sort_column; } +template <> +void COOSort_(COOMatrix* coo, bool sort_column) { + // Always sort the COO to be both row and column sorted. + IdArray pos = coo->row * coo->num_cols + coo->col; + const auto& sorted = Sort(pos); + coo->row = sorted.first / coo->num_cols; + coo->col = sorted.first % coo->num_cols; + if (aten::COOHasData(*coo)) + coo->data = IndexSelect(coo->data, sorted.second); + else + coo->data = AsNumBits(sorted.second, coo->row->dtype.bits); + coo->row_sorted = coo->col_sorted = true; +} + template void COOSort_(COOMatrix* coo, bool sort_column); template void COOSort_(COOMatrix* coo, bool sort_column); diff --git a/src/array/cuda/csr2coo.cc b/src/array/cuda/csr2coo.cu similarity index 53% rename from src/array/cuda/csr2coo.cc rename to src/array/cuda/csr2coo.cu index 90fdd7f30361..6e2c084523a5 100644 --- a/src/array/cuda/csr2coo.cc +++ b/src/array/cuda/csr2coo.cu @@ -5,6 +5,7 @@ */ #include #include "../../runtime/cuda/cuda_common.h" +#include "./utils.h" namespace dgl { @@ -15,7 +16,12 @@ namespace impl { template COOMatrix CSRToCOO(CSRMatrix csr) { - CHECK(sizeof(IdType) == 4) << "CUDA CSRToCOO does not support int64."; + LOG(FATAL) << "Unreachable codes"; + return {}; +} + +template <> +COOMatrix CSRToCOO(CSRMatrix csr) { auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); // allocate cusparse handle if needed if (!thr_entry->cusparse_handle) { @@ -41,12 +47,72 @@ COOMatrix CSRToCOO(CSRMatrix csr) { true, csr.sorted); } +/*! + * \brief Repeat elements + * \param val Value to repeat + * \param repeats Number of repeats for each value + * \param pos The position of the output buffer to write the value. + * \param out Output buffer. + * \param length Number of values + * + * For example: + * val = [3, 0, 1] + * repeats = [1, 0, 2] + * pos = [0, 1, 1] # write to output buffer position 0, 1, 1 + * then, + * out = [3, 1, 1] + */ +template +__global__ void _RepeatKernel( + const DType* val, const IdType* repeats, const IdType* pos, + DType* out, int64_t length) { + int tx = blockIdx.x * blockDim.x + threadIdx.x; + const int stride_x = gridDim.x * blockDim.x; + while (tx < length) { + IdType off = pos[tx]; + const IdType rep = repeats[tx]; + const DType v = val[tx]; + for (IdType i = 0; i < rep; ++i) { + out[off + i] = v; + } + tx += stride_x; + } +} + +template <> +COOMatrix CSRToCOO(CSRMatrix csr) { + const auto& ctx = csr.indptr->ctx; + const int64_t nnz = csr.indices->shape[0]; + const auto nbits = csr.indptr->dtype.bits; + auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); + IdArray rowids = Range(0, csr.num_rows, nbits, ctx); + IdArray row_nnz = CSRGetRowNNZ(csr, rowids); + IdArray ret_row = NewIdArray(nnz, ctx, nbits); + + const int nt = cuda::FindNumThreads(csr.num_rows); + const int nb = (csr.num_rows + nt - 1) / nt; + _RepeatKernel<<stream>>>( + rowids.Ptr(), row_nnz.Ptr(), + csr.indptr.Ptr(), ret_row.Ptr(), + csr.num_rows); + + return COOMatrix(csr.num_rows, csr.num_cols, + ret_row, csr.indices, csr.data, + true, csr.sorted); +} + template COOMatrix CSRToCOO(CSRMatrix csr); template COOMatrix CSRToCOO(CSRMatrix csr); template COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { - COOMatrix coo = CSRToCOO(csr); + LOG(FATAL) << "Unreachable codes"; + return {}; +} + +template <> +COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { + COOMatrix coo = CSRToCOO(csr); if (aten::IsNullArray(coo.data)) return coo; @@ -85,6 +151,26 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { // The row and column field have already been reordered according // to data, thus the data field will be deprecated. coo.data = aten::NullArray(); + coo.row_sorted = false; + coo.col_sorted = false; + return coo; +} + +template <> +COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { + COOMatrix coo = CSRToCOO(csr); + if (aten::IsNullArray(coo.data)) + return coo; + const auto& sorted = Sort(coo.data); + + coo.row = IndexSelect(coo.row, sorted.second); + coo.col = IndexSelect(coo.col, sorted.second); + + // The row and column field have already been reordered according + // to data, thus the data field will be deprecated. + coo.data = aten::NullArray(); + coo.row_sorted = false; + coo.col_sorted = false; return coo; } diff --git a/src/array/cuda/csr_sort.cu b/src/array/cuda/csr_sort.cu index 18aea3b65e9e..a2f4172916e2 100644 --- a/src/array/cuda/csr_sort.cu +++ b/src/array/cuda/csr_sort.cu @@ -1,9 +1,10 @@ /*! * Copyright (c) 2020 by Contributors * \file array/cuda/csr_sort.cc - * \brief Sort COO index + * \brief Sort CSR index */ #include +#include #include "../../runtime/cuda/cuda_common.h" #include "./utils.h" @@ -56,7 +57,11 @@ template bool CSRIsSorted(CSRMatrix csr); template void CSRSort_(CSRMatrix* csr) { - CHECK(sizeof(IdType) == 4) << "CUDA CSRSort_ does not support int64."; + LOG(FATAL) << "Unreachable codes"; +} + +template <> +void CSRSort_(CSRMatrix* csr) { auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); // allocate cusparse handle if needed @@ -100,6 +105,43 @@ void CSRSort_(CSRMatrix* csr) { device->FreeWorkspace(ctx, workspace); } +template <> +void CSRSort_(CSRMatrix* csr) { + auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); + auto device = runtime::DeviceAPI::Get(csr->indptr->ctx); + + const auto& ctx = csr->indptr->ctx; + const int64_t nnz = csr->indices->shape[0]; + const auto nbits = csr->indptr->dtype.bits; + if (!aten::CSRHasData(*csr)) + csr->data = aten::Range(0, nnz, nbits, ctx); + + IdArray new_indices = csr->indices.Clone(); + IdArray new_data = csr->data.Clone(); + + const int64_t* offsets = csr->indptr.Ptr(); + const int64_t* key_in = csr->indices.Ptr(); + int64_t* key_out = new_indices.Ptr(); + const int64_t* value_in = csr->data.Ptr(); + int64_t* value_out = new_data.Ptr(); + + // Allocate workspace + size_t workspace_size = 0; + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size, + key_in, key_out, value_in, value_out, + nnz, csr->num_rows, offsets, offsets + 1); + void* workspace = device->AllocWorkspace(ctx, workspace_size); + + // Compute + cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size, + key_in, key_out, value_in, value_out, + nnz, csr->num_rows, offsets, offsets + 1); + + csr->sorted = true; + csr->indices = new_indices; + csr->data = new_data; +} + template void CSRSort_(CSRMatrix* csr); template void CSRSort_(CSRMatrix* csr); diff --git a/src/array/cuda/csr_transpose.cc b/src/array/cuda/csr_transpose.cc index 9045bd6a6790..ab23207b615d 100644 --- a/src/array/cuda/csr_transpose.cc +++ b/src/array/cuda/csr_transpose.cc @@ -15,7 +15,12 @@ namespace impl { template CSRMatrix CSRTranspose(CSRMatrix csr) { - CHECK(sizeof(IdType) == 4) << "CUDA CSR2CSC does not support int64."; + LOG(FATAL) << "Unreachable codes"; + return {}; +} + +template <> +CSRMatrix CSRTranspose(CSRMatrix csr) { auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); // allocate cusparse handle if needed if (!thr_entry->cusparse_handle) { @@ -82,6 +87,11 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { false); } +template <> +CSRMatrix CSRTranspose(CSRMatrix csr) { + return COOToCSR(COOTranspose(CSRToCOO(csr, false))); +} + template CSRMatrix CSRTranspose(CSRMatrix csr); template CSRMatrix CSRTranspose(CSRMatrix csr); diff --git a/src/array/cuda/spmat_op_impl_csr.cu b/src/array/cuda/spmat_op_impl_csr.cu index fb281bfa7525..9370b12294d7 100644 --- a/src/array/cuda/spmat_op_impl_csr.cu +++ b/src/array/cuda/spmat_op_impl_csr.cu @@ -375,7 +375,9 @@ __global__ void _SegmentMaskKernel( * of each needle so that the insertion still gives sorted order. * * It essentially perform binary search to find lower bound for each needle - * elements. + * elements. Require the largest elements in the hay is larger than the given + * needle elements. Commonly used in searching for row IDs of a given set of + * coordinates. */ template __global__ void _SortedSearchKernel( @@ -435,7 +437,7 @@ std::vector CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits); const int nt2 = cuda::FindNumThreads(idx->shape[0]); const int nb2 = (idx->shape[0] + nt - 1) / nt; - _SortedSearchKernel<<stream>>>( + _SortedSearchKernel<<stream>>>( csr.indptr.Ptr(), csr.num_rows, idx.Ptr(), idx->shape[0], ret_row.Ptr()); diff --git a/src/graph/heterograph.h b/src/graph/heterograph.h index b5c0ceb9637c..02657c807282 100644 --- a/src/graph/heterograph.h +++ b/src/graph/heterograph.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "./unit_graph.h" #include "shared_mem_manager.h" diff --git a/tests/compute/utils.py b/tests/compute/utils.py index 3e96429b0e12..c685b0ee8149 100644 --- a/tests/compute/utils.py +++ b/tests/compute/utils.py @@ -5,7 +5,7 @@ parametrize_dtype = pytest.mark.parametrize("idtype", [F.int32, F.int64]) else: # only test int32 on GPU because many graph operators are not supported for int64. - parametrize_dtype = pytest.mark.parametrize("idtype", [F.int32]) + parametrize_dtype = pytest.mark.parametrize("idtype", [F.int32, F.int64]) def check_fail(fn, *args, **kwargs): try: diff --git a/tests/cpp/test_aten.cc b/tests/cpp/test_aten.cc index 4d5e7391ee82..c942cfa88cdd 100644 --- a/tests/cpp/test_aten.cc +++ b/tests/cpp/test_aten.cc @@ -111,6 +111,10 @@ void _TestArith(DLContext ctx) { c = c.CopyTo(CPU); for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr(c)[i], 10); + c = (-a) % b; + c = c.CopyTo(CPU); + for (int i = 0; i < N; ++i) + ASSERT_EQ(Ptr(c)[i], 3); const int val = -3; c = aten::Add(a, val); @@ -129,6 +133,11 @@ void _TestArith(DLContext ctx) { c = c.CopyTo(CPU); for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr(c)[i], 3); + c = b % 3; + c = c.CopyTo(CPU); + for (int i = 0; i < N; ++i) + ASSERT_EQ(Ptr(c)[i], 1); + c = aten::Add(val, b); c = c.CopyTo(CPU); for (int i = 0; i < N; ++i) @@ -145,6 +154,10 @@ void _TestArith(DLContext ctx) { c = c.CopyTo(CPU); for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr(c)[i], 0); + c = 3 % b; + c = c.CopyTo(CPU); + for (int i = 0; i < N; ++i) + ASSERT_EQ(Ptr(c)[i], 3); a = aten::Range(0, N, sizeof(IDX)*8, ctx); c = a < 50; @@ -179,7 +192,7 @@ void _TestArith(DLContext ctx) { } -TEST(ArrayTest, TestArith) { +TEST(ArrayTest, Arith) { _TestArith(CPU); _TestArith(CPU); #ifdef DGL_USE_CUDA @@ -1327,17 +1340,17 @@ void _TestLineGraphCOO(DLContext ctx) { * [0, 0, 0, 0, 0, 0]] */ IdArray a_row = - aten::VecToIdArray(std::vector({0, 1, 1, 2, 2, 3}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({0, 1, 1, 2, 2, 3}), sizeof(IdType)*8, ctx); IdArray a_col = - aten::VecToIdArray(std::vector({2, 0, 2, 0, 1, 3}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({2, 0, 2, 0, 1, 3}), sizeof(IdType)*8, ctx); IdArray b_row = - aten::VecToIdArray(std::vector({0, 1, 2, 4}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({0, 1, 2, 4}), sizeof(IdType)*8, ctx); IdArray b_col = - aten::VecToIdArray(std::vector({4, 0, 3, 1}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({4, 0, 3, 1}), sizeof(IdType)*8, ctx); IdArray c_row = - aten::VecToIdArray(std::vector({0, 0, 1, 2, 2, 3, 4, 4}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({0, 0, 1, 2, 2, 3, 4, 4}), sizeof(IdType)*8, ctx); IdArray c_col = - aten::VecToIdArray(std::vector({3, 4, 0, 3, 4, 0, 1, 2}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({3, 4, 0, 3, 4, 0, 1, 2}), sizeof(IdType)*8, ctx); const aten::COOMatrix &coo_a = aten::COOMatrix( 4, @@ -1365,15 +1378,15 @@ void _TestLineGraphCOO(DLContext ctx) { ASSERT_FALSE(l_coo2.col_sorted); IdArray a_data = - aten::VecToIdArray(std::vector({4, 5, 0, 1, 2, 3}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({4, 5, 0, 1, 2, 3}), sizeof(IdType)*8, ctx); b_row = - aten::VecToIdArray(std::vector({4, 5, 0, 2}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({4, 5, 0, 2}), sizeof(IdType)*8, ctx); b_col = - aten::VecToIdArray(std::vector({2, 4, 1, 5}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({2, 4, 1, 5}), sizeof(IdType)*8, ctx); c_row = - aten::VecToIdArray(std::vector({4, 4, 5, 0, 0, 1, 2, 2}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({4, 4, 5, 0, 0, 1, 2, 2}), sizeof(IdType)*8, ctx); c_col = - aten::VecToIdArray(std::vector({1, 2, 4, 1, 2, 4, 5, 0}), sizeof(IdType)*8, CTX); + aten::VecToIdArray(std::vector({1, 2, 4, 1, 2, 4, 5, 0}), sizeof(IdType)*8, ctx); const aten::COOMatrix &coo_ad = aten::COOMatrix( 4, 4, @@ -1403,3 +1416,44 @@ TEST(LineGraphTest, LineGraphCOO) { _TestLineGraphCOO(CPU); _TestLineGraphCOO(CPU); } + +template +void _TestSort(DLContext ctx) { + // case 1 + IdArray a = + aten::VecToIdArray(std::vector({8, 6, 7, 5, 3, 0, 9}), sizeof(IDX)*8, ctx); + IdArray sorted_a = + aten::VecToIdArray(std::vector({0, 3, 5, 6, 7, 8, 9}), sizeof(IDX)*8, ctx); + IdArray sorted_idx = + aten::VecToIdArray(std::vector({5, 4, 3, 1, 2, 0, 6}), 64, ctx); + + IdArray sorted, idx; + std::tie(sorted, idx) = aten::Sort(a); + ASSERT_TRUE(ArrayEQ(sorted, sorted_a)); + ASSERT_TRUE(ArrayEQ(idx, sorted_idx)); + + // case 2: empty array + a = aten::VecToIdArray(std::vector({}), sizeof(IDX)*8, ctx); + sorted_a = aten::VecToIdArray(std::vector({}), sizeof(IDX)*8, ctx); + sorted_idx = aten::VecToIdArray(std::vector({}), 64, ctx); + std::tie(sorted, idx) = aten::Sort(a); + ASSERT_TRUE(ArrayEQ(sorted, sorted_a)); + ASSERT_TRUE(ArrayEQ(idx, sorted_idx)); + + // case 3: array with one element + a = aten::VecToIdArray(std::vector({2}), sizeof(IDX)*8, ctx); + sorted_a = aten::VecToIdArray(std::vector({2}), sizeof(IDX)*8, ctx); + sorted_idx = aten::VecToIdArray(std::vector({0}), 64, ctx); + std::tie(sorted, idx) = aten::Sort(a); + ASSERT_TRUE(ArrayEQ(sorted, sorted_a)); + ASSERT_TRUE(ArrayEQ(idx, sorted_idx)); +} + +TEST(ArrayTest, Sort) { + _TestSort(CPU); + _TestSort(CPU); +#ifdef DGL_USE_CUDA + _TestSort(GPU); + _TestSort(GPU); +#endif +} diff --git a/tests/cpp/test_spmat_coo.cc b/tests/cpp/test_spmat_coo.cc index 45684840a5d6..402fed6a2cd4 100644 --- a/tests/cpp/test_spmat_coo.cc +++ b/tests/cpp/test_spmat_coo.cc @@ -180,6 +180,7 @@ TEST(SpmatTest, COOToCSR) { _TestCOOToCSR(CPU); #ifdef DGL_USE_CUDA _TestCOOToCSR(GPU); + _TestCOOToCSR(GPU); #endif } @@ -265,6 +266,7 @@ TEST(SpmatTest, COOSort) { _TestCOOSort(CPU); #ifdef DGL_USE_CUDA _TestCOOSort(GPU); + _TestCOOSort(GPU); #endif } diff --git a/tests/cpp/test_spmat_csr.cc b/tests/cpp/test_spmat_csr.cc index e4259e93929b..467aa36c433f 100644 --- a/tests/cpp/test_spmat_csr.cc +++ b/tests/cpp/test_spmat_csr.cc @@ -241,6 +241,7 @@ TEST(SpmatTest, CSRGetData) { _TestCSRGetData(CPU); #ifdef DGL_USE_CUDA _TestCSRGetData(GPU); + _TestCSRGetData(GPU); #endif } @@ -287,11 +288,12 @@ void _TestCSRTranspose(DLContext ctx) { ASSERT_TRUE(ArrayEQ(csr_t.data, td)); } -TEST(SpmatTest, TestCSRTranspose) { +TEST(SpmatTest, CSRTranspose) { _TestCSRTranspose(CPU); _TestCSRTranspose(CPU); #ifdef DGL_USE_CUDA _TestCSRTranspose(GPU); + _TestCSRTranspose(GPU); #endif } @@ -335,6 +337,7 @@ TEST(SpmatTest, CSRToCOO) { _TestCSRToCOO(CPU); #if DGL_USE_CUDA _TestCSRToCOO(GPU); + _TestCSRToCOO(GPU); #endif } @@ -441,6 +444,7 @@ TEST(SpmatTest, CSRSliceMatrix) { _TestCSRSliceMatrix(CPU); #ifdef DGL_USE_CUDA _TestCSRSliceMatrix(GPU); + _TestCSRSliceMatrix(GPU); #endif } @@ -457,6 +461,7 @@ TEST(SpmatTest, CSRHasDuplicate) { _TestCSRHasDuplicate(CPU); #ifdef DGL_USE_CUDA _TestCSRHasDuplicate(GPU); + _TestCSRHasDuplicate(GPU); #endif } @@ -480,6 +485,7 @@ TEST(SpmatTest, CSRSort) { _TestCSRSort(CPU); #ifdef DGL_USE_CUDA _TestCSRSort(GPU); + _TestCSRSort(GPU); #endif } diff --git a/tests/tensorflow/test_nn.py b/tests/tensorflow/test_nn.py index 430cb8d2e672..ae4455ecfe80 100644 --- a/tests/tensorflow/test_nn.py +++ b/tests/tensorflow/test_nn.py @@ -247,7 +247,7 @@ def test_partial_edge_softmax(): grad = F.randn((300, 1)) import numpy as np eids = np.random.choice(900, 300, replace=False).astype('int64') - eids = F.zerocopy_from_numpy(eids) + eids = F.tensor(eids) # compute partial edge softmax with tf.GradientTape() as tape: tape.watch(score) @@ -255,7 +255,7 @@ def test_partial_edge_softmax(): grads = tape.gradient(y_1, [score]) grad_1 = grads[0] # compute edge softmax on edge subgraph - subg = g.edge_subgraph(eids) + subg = g.edge_subgraph(eids, preserve_nodes=True) with tf.GradientTape() as tape: tape.watch(score) y_2 = nn.edge_softmax(subg, score) @@ -348,8 +348,8 @@ def test_rgcn(): rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) rgc_basis_low.weight = rgc_basis.weight rgc_basis_low.w_comp = rgc_basis.w_comp - h = tf.constant(np.random.randint(0, I, (100,))) - r = tf.constant(etype) + h = tf.constant(np.random.randint(0, I, (100,))) * 1 + r = tf.constant(etype) * 1 h_new = rgc_basis(g, h, r) h_new_low = rgc_basis_low(g, h, r) assert list(h_new.shape) == [100, O] diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index deb4c5187e20..3780fd721899 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -5,7 +5,7 @@ parametrize_dtype = pytest.mark.parametrize("idtype", [F.int32, F.int64]) else: # only test int32 on GPU because many graph operators are not supported for int64. - parametrize_dtype = pytest.mark.parametrize("idtype", [F.int32]) + parametrize_dtype = pytest.mark.parametrize("idtype", [F.int32, F.int64]) from .checks import * from .graph_cases import get_cases