Skip to content

Commit

Permalink
[CUDA][Kernel] A bunch of int64 kernels for COO and CSR (dmlc#1883)
Browse files Browse the repository at this point in the history
* COO sort

* COOToCSR

* CSR2COO

* CSRSort; CSRTranspose

* pass all CSR tests

* lint

* remove int32 conversion

* fix tensorflow nn tests

* turn on CI

* fix

* addreess comments
  • Loading branch information
jermainewang authored Jul 30, 2020
1 parent 5b515cf commit f4608c2
Show file tree
Hide file tree
Showing 29 changed files with 703 additions and 121 deletions.
4 changes: 1 addition & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,7 @@ pipeline {
stages {
stage("Unit test") {
steps {
// TODO(minjie): tmp disabled
//unit_test_linux("tensorflow", "gpu")
sh "echo skipped"
unit_test_linux("tensorflow", "gpu")
}
}
}
Expand Down
14 changes: 14 additions & 0 deletions include/dgl/aten/array_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,19 @@ IdArray Add(IdArray lhs, IdArray rhs);
IdArray Sub(IdArray lhs, IdArray rhs);
IdArray Mul(IdArray lhs, IdArray rhs);
IdArray Div(IdArray lhs, IdArray rhs);
IdArray Mod(IdArray lhs, IdArray rhs);

IdArray Add(IdArray lhs, int64_t rhs);
IdArray Sub(IdArray lhs, int64_t rhs);
IdArray Mul(IdArray lhs, int64_t rhs);
IdArray Div(IdArray lhs, int64_t rhs);
IdArray Mod(IdArray lhs, int64_t rhs);

IdArray Add(int64_t lhs, IdArray rhs);
IdArray Sub(int64_t lhs, IdArray rhs);
IdArray Mul(int64_t lhs, IdArray rhs);
IdArray Div(int64_t lhs, IdArray rhs);
IdArray Mod(int64_t lhs, IdArray rhs);

IdArray Neg(IdArray array);

Expand Down Expand Up @@ -304,6 +307,17 @@ IdArray CumSum(IdArray array, bool prepend_zero = false);
*/
IdArray NonZero(NDArray array);

/*!
* \brief Sort the ID vector in ascending order.
*
* It performs both sort and arg_sort (returning the sorted index). The sorted index
* is always in int64.
*
* \param array Input array.
* \return A pair of arrays: sorted values and sorted index to the original position.
*/
std::pair<IdArray, IdArray> Sort(IdArray array);

/*!
* \brief Return a string that prints out some debug information.
*/
Expand Down
4 changes: 4 additions & 0 deletions include/dgl/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,14 +603,18 @@ dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator + (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& array);

dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1,
Expand Down
10 changes: 2 additions & 8 deletions python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ def graph(data,
g = create_from_edges(u, v, ntype, etype, ntype, urange, vrange,
validate, formats=formats)

if device is None:
return utils.to_int32_graph_if_on_gpu(g)
else:
return g.to(device)
return g.to(device)

def bipartite(data,
utype='_U', etype='_E', vtype='_V',
Expand Down Expand Up @@ -300,10 +297,7 @@ def bipartite(data,
u, v, utype, etype, vtype, urange, vrange, validate,
formats=formats)

if device is None:
return utils.to_int32_graph_if_on_gpu(g)
else:
return g.to(device)
return g.to(device)

def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
"""Create a heterograph from graphs representing connections of each relation.
Expand Down
4 changes: 1 addition & 3 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4450,7 +4450,7 @@ def to(self, device, **kwargs): # pylint: disable=invalid-name
device(type='cpu')
"""
if device is None or self.device == device:
return utils.to_int32_graph_if_on_gpu(self)
return self

ret = copy.copy(self)

Expand Down Expand Up @@ -4481,8 +4481,6 @@ def to(self, device, **kwargs): # pylint: disable=invalid-name
for k, num in self._batch_num_edges.items()}
ret._batch_num_edges = new_bne

ret = utils.to_int32_graph_if_on_gpu(ret)

return ret

def cpu(self):
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/tensorflow/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def edge_softmax_real(graph, score, eids=ALL):
"""Edge Softmax function"""
if not is_all(eids):
graph = graph.edge_subgraph(tf.cast(eids, graph.idtype))
graph = graph.edge_subgraph(tf.cast(eids, graph.idtype), preserve_nodes=True)
gidx = graph._graph
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
Expand Down
14 changes: 1 addition & 13 deletions python/dgl/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
# pylint: disable=invalid-name
from __future__ import absolute_import, division

from ..base import DGLError, dgl_warning
from ..base import DGLError
from .. import backend as F
from .internal import to_dgl_context

def prepare_tensor(g, data, name):
"""Convert the data to ID tensor and check its ID type and context.
Expand Down Expand Up @@ -129,14 +128,3 @@ def check_all_same_schema(feat_dict_list, keys, name):
' and feature size, but got\n\t{} {}\nand\n\t{} {}.'.format(
name, k, F.dtype(t1), F.shape(t1)[1:],
F.dtype(t2), F.shape(t2)[1:]))

def to_int32_graph_if_on_gpu(g):
"""Convert to int32 graph if the input graph is on GPU."""
# device_type 2 is an internal code for GPU
if to_dgl_context(g.device).device_type == 2 and g.idtype == F.int64:
dgl_warning('Automatically cast a GPU int64 graph to int32.\n'
' To suppress the warning, call DGLGraph.int() first\n'
' or specify the ``device`` argument when creating the graph.')
return g.int()
else:
return g
7 changes: 7 additions & 0 deletions src/array/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ struct Div {
}
};

struct Mod {
template <typename T>
static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {
return t1 % t2;
}
};

struct GT {
template <typename T>
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
Expand Down
14 changes: 14 additions & 0 deletions src/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,20 @@ IdArray NonZero(NDArray array) {
return ret;
}

std::pair<IdArray, IdArray> Sort(IdArray array) {
if (array.NumElements() == 0) {
IdArray idx = NewIdArray(0, array->ctx, 64);
return std::make_pair(array, idx);
}
std::pair<IdArray, IdArray> ret;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::Sort<XPU, IdType>(array);
});
});
return ret;
}

std::string ToDebugString(NDArray array) {
std::ostringstream oss;
NDArray a = array.CopyTo(DLContext{kDLCPU, 0});
Expand Down
12 changes: 12 additions & 0 deletions src/array/array_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ BINARY_ELEMENT_OP(Add, Add)
BINARY_ELEMENT_OP(Sub, Sub)
BINARY_ELEMENT_OP(Mul, Mul)
BINARY_ELEMENT_OP(Div, Div)
BINARY_ELEMENT_OP(Mod, Mod)
BINARY_ELEMENT_OP(GT, GT)
BINARY_ELEMENT_OP(LT, LT)
BINARY_ELEMENT_OP(GE, GE)
Expand All @@ -81,6 +82,7 @@ BINARY_ELEMENT_OP_L(Add, Add)
BINARY_ELEMENT_OP_L(Sub, Sub)
BINARY_ELEMENT_OP_L(Mul, Mul)
BINARY_ELEMENT_OP_L(Div, Div)
BINARY_ELEMENT_OP_L(Mod, Mod)
BINARY_ELEMENT_OP_L(GT, GT)
BINARY_ELEMENT_OP_L(LT, LT)
BINARY_ELEMENT_OP_L(GE, GE)
Expand All @@ -92,6 +94,7 @@ BINARY_ELEMENT_OP_R(Add, Add)
BINARY_ELEMENT_OP_R(Sub, Sub)
BINARY_ELEMENT_OP_R(Mul, Mul)
BINARY_ELEMENT_OP_R(Div, Div)
BINARY_ELEMENT_OP_R(Mod, Mod)
BINARY_ELEMENT_OP_R(GT, GT)
BINARY_ELEMENT_OP_R(LT, LT)
BINARY_ELEMENT_OP_R(GE, GE)
Expand All @@ -117,6 +120,9 @@ NDArray operator * (const NDArray& lhs, const NDArray& rhs) {
NDArray operator / (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator % (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Mod(lhs, rhs);
}
NDArray operator + (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Add(lhs, rhs);
}
Expand All @@ -129,6 +135,9 @@ NDArray operator * (const NDArray& lhs, int64_t rhs) {
NDArray operator / (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator % (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Mod(lhs, rhs);
}
NDArray operator + (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Add(lhs, rhs);
}
Expand All @@ -141,6 +150,9 @@ NDArray operator * (int64_t lhs, const NDArray& rhs) {
NDArray operator / (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator % (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Mod(lhs, rhs);
}
NDArray operator - (const NDArray& array) {
return dgl::aten::Neg(array);
}
Expand Down
3 changes: 3 additions & 0 deletions src/array/array_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ DType IndexSelect(NDArray array, int64_t index);
template <DLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr);

template <DLDeviceType XPU, typename DType>
std::pair<IdArray, IdArray> Sort(IdArray array);

template <DLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices);

Expand Down
6 changes: 6 additions & 0 deletions src/array/cpu/array_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
Expand All @@ -70,6 +71,7 @@ template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
Expand All @@ -94,6 +96,7 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
Expand All @@ -104,6 +107,7 @@ template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
Expand All @@ -128,6 +132,7 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
Expand All @@ -138,6 +143,7 @@ template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
Expand Down
Loading

0 comments on commit f4608c2

Please sign in to comment.