Skip to content

Commit

Permalink
[WIP][Kernel] Set the built-in reduce result of zero-degree nodes to …
Browse files Browse the repository at this point in the history
…0 in C (dmlc#2017)

* test idea

* cuda kernels

* lint and fixes

* lint

* change to another strategy

* use infinity

* fix

Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
BarclayII and yzh119 authored Aug 14, 2020
1 parent de2e608 commit 63e2ba2
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 25 deletions.
15 changes: 15 additions & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,21 @@ def clamp(data, min_val, max_val):
"""
pass

def replace_inf_with_zero(x):
"""Returns a new tensor replacing infinity and negative infinity with zeros.
Parameters
----------
x : Tensor
The input
Returns
-------
Tensor
The result
"""
pass

###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
Expand Down
3 changes: 3 additions & 0 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ def clone(input):
def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val)

def replace_inf_with_zero(x):
return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)

def unique(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
Expand Down
3 changes: 3 additions & 0 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def clone(input):
def clamp(data, min_val, max_val):
return th.clamp(data, min_val, max_val)

def replace_inf_with_zero(x):
return th.masked_fill(x, th.isinf(x), 0)

def unique(input):
if input.dtype == th.bool:
input = input.type(th.int8)
Expand Down
3 changes: 3 additions & 0 deletions python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def clone(input):
def clamp(data, min_val, max_val):
return tf.clip_by_value(data, min_val, max_val)

def replace_inf_with_zero(x):
return tf.where(tf.abs(x) == np.inf, 0, x)

def unique(input):
return tf.unique(input).y

Expand Down
39 changes: 20 additions & 19 deletions python/dgl/ops/spmm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""dgl spmm operator module."""
import sys

from ..base import dgl_warning
from ..backend import gspmm as gspmm_internal, backend_name
from ..backend import gspmm as gspmm_internal
from .. import backend as F

__all__ = ['gspmm']
Expand Down Expand Up @@ -59,35 +58,36 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]
lhs_data = F.reshape(lhs_data, new_lhs_shape)
rhs_data = F.reshape(rhs_data, new_rhs_shape)
# With max and min reducers infinity will be returned for zero degree nodes
ret = gspmm_internal(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op,
lhs_data, rhs_data)

# assign zero features for zero degree nodes.
deg = g.in_degrees()
min_deg = F.as_scalar(F.min(deg, dim=0))
if min_deg == 0:
non_zero_nids = F.nonzero_1d(deg == 0)
if backend_name == 'pytorch':
ret[non_zero_nids] = 0.
else:
dtype = F.dtype(ret)
ctx = F.context(ret)
ret = F.scatter_row(ret, non_zero_nids,
F.zeros((len(non_zero_nids),) + F.shape(ret)[1:], dtype, ctx))
ret = F.replace_inf_with_zero(ret)

# divide in degrees for mean reducer.
if reduce_op == 'mean':
ret_shape = F.shape(ret)
if min_deg == 0:
dgl_warning('Zero-degree nodes encountered in mean reducer. Setting the mean to 0.')
deg = g.in_degrees()
deg = F.astype(F.clamp(deg, 1, g.number_of_edges()), F.dtype(ret))
deg_shape = (ret_shape[0],) + (1,) * (len(ret_shape) - 1)
return ret / F.reshape(deg, deg_shape)
else:
return ret


def _attach_zerodeg_note(docstring, reducer):
note1 = """
The {} function will return zero for nodes with no incoming messages.""".format(reducer)
note2 = """
This is implemented by replacing all {} values to zero.
""".format("infinity" if reducer == "min" else "negative infinity")

docstring = docstring + note1
if reducer in ('min', 'max'):
docstring = docstring + note2
return docstring


def _gen_spmm_func(binary_op, reduce_op):
name = "u_{}_e_{}".format(binary_op, reduce_op)
docstring = """Generalized SpMM function.
Expand Down Expand Up @@ -120,6 +120,7 @@ def _gen_spmm_func(binary_op, reduce_op):
https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
for more details about the NumPy broadcasting semantics.
""".format(binary_op, reduce_op)
docstring = _attach_zerodeg_note(docstring, reduce_op)

def func(g, x, y):
return gspmm(g, binary_op, reduce_op, x, y)
Expand All @@ -139,7 +140,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op):
"copy_u": "source node",
"copy_e": "edge"
}
docstring = lambda binary_op: """Generalized SpMM function. {}
docstring = lambda binary_op: _attach_zerodeg_note("""Generalized SpMM function. {}
Then aggregates the message by {} on destination nodes.
Parameters
Expand All @@ -160,7 +161,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op):
""".format(
binary_str[binary_op],
reduce_op,
x_str[binary_op])
x_str[binary_op]), reduce_op)

def func(g, x):
if binary_op == 'copy_u':
Expand Down
12 changes: 8 additions & 4 deletions src/array/cpu/spmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ void SpMMSumCoo(
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
if (val != 0) {
#pragma omp atomic
out_off[k] += val;
out_off[k] += val;
}
}
}
}
Expand All @@ -123,8 +125,9 @@ void SpMMSumCoo(
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \note it uses node parallel strategy, different threads are responsible
* \note It uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
* \note The result will contain infinity for zero-degree nodes.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsr(
Expand Down Expand Up @@ -194,6 +197,7 @@ void SpMMCmpCsr(
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. To avoid possible data hazard,
* we use atomic operators in the reduction phase.
* \note The result will contain infinity for zero-degree nodes.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCoo(
Expand Down Expand Up @@ -315,7 +319,7 @@ template <typename DType> constexpr bool CopyRhs<DType>::use_rhs;
//////////////////////////////// Reduce operators on CPU ////////////////////////////////
template <typename DType>
struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest();
static constexpr DType zero = -std::numeric_limits<DType>::infinity();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) {
return accum < val;
Expand All @@ -325,7 +329,7 @@ template <typename DType> constexpr DType Max<DType>::zero;

template <typename DType>
struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max();
static constexpr DType zero = std::numeric_limits<DType>::infinity();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) {
return accum > val;
Expand Down
4 changes: 2 additions & 2 deletions src/array/cuda/functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ template <typename Idx,
typename DType,
bool atomic=false>
struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest();
static constexpr DType zero = -std::numeric_limits<DType>::infinity();
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
Expand Down Expand Up @@ -183,7 +183,7 @@ template <typename Idx,
typename DType,
bool atomic=false>
struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max();
static constexpr DType zero = std::numeric_limits<DType>::infinity();
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
Expand Down

0 comments on commit 63e2ba2

Please sign in to comment.