From 657c220de7c515f9dfede94b574a56d5685d7238 Mon Sep 17 00:00:00 2001 From: "Quan (Andy) Gan" Date: Mon, 17 May 2021 19:47:02 +0800 Subject: [PATCH] [Feature] Python interface for adjacency matrix summation and multiplication (#2893) * test commit * fixes * oops * add docs * lint * why does it say I have a trailing whitespace * oh ok * fixes * why there's an invalid argument error * address comments * fix * address comments --- docs/source/api/python/dgl.rst | 2 + python/dgl/backend/backend.py | 85 +++++++++ python/dgl/backend/mxnet/sparse.py | 90 ++++++++- python/dgl/backend/pytorch/sparse.py | 79 +++++++- python/dgl/backend/tensorflow/sparse.py | 66 ++++++- python/dgl/convert.py | 18 +- python/dgl/heterograph_index.py | 131 +++++++++++-- python/dgl/sparse.py | 6 +- python/dgl/transform.py | 241 +++++++++++++++++++++++- src/array/array_op.h | 11 +- src/array/cpu/csr_get_data.cc | 15 +- src/array/cpu/csr_mm.cc | 4 +- src/array/cpu/csr_sum.cc | 4 +- src/array/cuda/csr_get_data.cu | 15 +- src/array/cuda/csr_mm.cu | 12 +- src/array/cuda/csr_sum.cu | 33 +++- src/graph/unit_graph.cc | 6 +- tests/compute/test_csrmm.py | 225 ++++++++++++++++++---- 18 files changed, 928 insertions(+), 115 deletions(-) diff --git a/docs/source/api/python/dgl.rst b/docs/source/api/python/dgl.rst index b8ca7e266ae6..c255b513bd86 100644 --- a/docs/source/api/python/dgl.rst +++ b/docs/source/api/python/dgl.rst @@ -74,6 +74,8 @@ Operators for generating new graphs by manipulating the structure of the existin line_graph khop_graph metapath_reachable_graph + adj_product_graph + adj_sum_graph .. _api-batch: diff --git a/python/dgl/backend/backend.py b/python/dgl/backend/backend.py index f5333592bc7f..257189f21ebe 100644 --- a/python/dgl/backend/backend.py +++ b/python/dgl/backend/backend.py @@ -1574,6 +1574,91 @@ def scatter_add(x, idx, m): """ pass +def csrmm(A, A_weights, B, B_weights, num_vtypes): + """Compute weighted adjacency matrix multiplication. + + Notes + ----- + Both A and B must allow creation of CSR representations, and must be simple graphs + (i.e. having at most one edge between two nodes). + + The output unit graph has no format restriction. + + Parameters + ---------- + A : HeteroGraphIndex + The unit graph as left operand. + A_weights : Tensor + The edge weights of A. Must be a 1D vector. + B : HeteroGraphIndex + The unit graph as right operand. + B_weights : Tensor + The edge weights of B. Must be a 1D vector. + num_vtypes : int + The number of node types of the output graph. Must be either 1 or 2. + + Returns + ------- + HeteroGraphIndex + The output unit graph. + Tensor + The output edge weights. + """ + pass + +def csrsum(gidxs, weights): + """Compute weighted adjacency matrix summation. + + Notes + ----- + All unit graphs must allow creation of CSR representations, and must be simple graphs + (i.e. having at most one edge between two nodes). + + The output unit graph has no format restriction. + + Parameters + ---------- + gidxs : list[HeteroGraphIndex] + The unit graphs. + weights : list[Tensor] + The edge weights of each graph. Must be 1D vectors. + + Returns + ------- + HeteroGraphIndex + The output unit graph. + Tensor + The output edge weights. + """ + pass + +def csrmask(A, A_weights, B): + """Retrieve the values in the weighted adjacency matrix of graph :attr:`A` at the + non-zero positions of graph :attr:`B`'s adjacency matrix. + + In scipy, this is equivalent to ``A[B != 0]``. + + Notes + ----- + Both A and B must allow creation of CSR representations, and must be simple graphs + (i.e. having at most one edge between two nodes). + + Parameters + ---------- + A : HeteroGraphIndex + The unit graph as left operand. + A_weights : Tensor + The edge weights of A. Must be a 1D vector. + B : HeteroGraphIndex + The unit graph as right operand. + + Returns + ------- + Tensor + The output tensor. + """ + pass + ############################################################################### # Other interfaces diff --git a/python/dgl/backend/mxnet/sparse.py b/python/dgl/backend/mxnet/sparse.py index 92e3f998bacb..79a3dec428e5 100644 --- a/python/dgl/backend/mxnet/sparse.py +++ b/python/dgl/backend/mxnet/sparse.py @@ -2,10 +2,13 @@ import numpy as np from mxnet import nd from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add +from ...sparse import _csrmm, _csrsum, _csrmask from ...base import dgl_warning, is_all, ALL from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx +from ...heterograph_index import create_unitgraph_from_csr -__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add'] +__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add', + 'csrmm', 'csrsum', 'csrmask'] def _scatter_nd(index, src, n_rows): @@ -379,3 +382,88 @@ def backward(self, dy): def scatter_add(x, idx, m): scatter_add_op = ScatterAdd(idx, m) return scatter_add_op(x) + + +class CSRMM(mx.autograd.Function): + def __init__(self, gidxA, gidxB, num_vtypes): + super().__init__() + self.gidxA = gidxA + self.gidxB = gidxB + self.num_vtypes = num_vtypes + + def forward(self, A_weights, B_weights): + gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes) + nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr') + # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same + # as the underlying tensors of the created graph gidxC. + self.backward_cache = gidxC + self.save_for_backward(A_weights, B_weights) + nrows = nd.array([nrows], dtype='int64') + ncols = nd.array([ncols], dtype='int64') + return nrows, ncols, C_indptr, C_indices, C_eids, C_weights + + def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): + # Only the last argument is meaningful. + gidxC = self.backward_cache + A_weights, B_weights = self.saved_tensors + dgidxA, dA_weights = _csrmm( + gidxC, dC_weights, self.gidxB.reverse(), B_weights, self.gidxA.number_of_ntypes()) + dgidxB, dB_weights = _csrmm( + self.gidxA.reverse(), A_weights, gidxC, dC_weights, self.gidxB.number_of_ntypes()) + dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA) + dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB) + return dA_weights, dB_weights + +def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): + op = CSRMM(gidxA, gidxB, num_vtypes) + nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(A_weights, B_weights) + gidxC = create_unitgraph_from_csr( + num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids, + ["coo", "csr", "csc"]) + return gidxC, C_weights + +class CSRSum(mx.autograd.Function): + def __init__(self, gidxs): + super().__init__() + self.gidxs = gidxs + + def forward(self, *weights): + gidxC, C_weights = _csrsum(self.gidxs, weights) + nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors( + 0, True, 'csr') + # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same + # as the underlying tensors of the created graph gidxC. + self.backward_cache = gidxC + nrows = nd.array([nrows], dtype='int64') + ncols = nd.array([ncols], dtype='int64') + return nrows, ncols, C_indptr, C_indices, C_eids, C_weights + + def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): + # Only the last argument is meaningful. + gidxC = self.backward_cache + return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs) + +def csrsum(gidxs, weights): + op = CSRSum(gidxs) + nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights) + num_vtypes = gidxs[0].number_of_ntypes() + gidxC = create_unitgraph_from_csr( + num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids, + ["coo", "csr", "csc"]) + return gidxC, C_weights + +class CSRMask(mx.autograd.Function): + def __init__(self, gidxA, gidxB): + super().__init__() + self.gidxA = gidxA + self.gidxB = gidxB + + def forward(self, A_weights): + return _csrmask(self.gidxA, A_weights, self.gidxB) + + def backward(self, dB_weights): + return _csrmask(self.gidxB, dB_weights, self.gidxA) + +def csrmask(gidxA, A_weights, gidxB): + op = CSRMask(gidxA, gidxB) + return op(A_weights) diff --git a/python/dgl/backend/pytorch/sparse.py b/python/dgl/backend/pytorch/sparse.py index 254296c3a970..a00097d1f0d2 100644 --- a/python/dgl/backend/pytorch/sparse.py +++ b/python/dgl/backend/pytorch/sparse.py @@ -2,6 +2,8 @@ from distutils.version import LooseVersion from ...base import is_all, ALL from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add +from ...sparse import _csrmm, _csrsum, _csrmask +from ...heterograph_index import create_unitgraph_from_csr if LooseVersion(th.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import custom_fwd, custom_bwd @@ -24,7 +26,8 @@ def decorate_bwd(*args, **kwargs): return bwd(*args, **kwargs) return decorate_bwd -__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add'] +__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add', + 'csrmm', 'csrsum', 'csrmask'] def _reduce_grad(grad, shape): @@ -303,6 +306,62 @@ def backward(ctx, dy): return dy[idx], None, None +class CSRMM(th.autograd.Function): + @staticmethod + def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes): + gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes) + nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr') + # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same + # as the underlying tensors of the created graph gidxC. + ctx.backward_cache = gidxA, gidxB, gidxC + ctx.save_for_backward(A_weights, B_weights) + return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights + + @staticmethod + def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): + # Only the last argument is meaningful. + gidxA, gidxB, gidxC = ctx.backward_cache + A_weights, B_weights = ctx.saved_tensors + dgidxA, dA_weights = csrmm( + gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes()) + dgidxB, dB_weights = csrmm( + gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes()) + dA_weights = csrmask(dgidxA, dA_weights, gidxA) + dB_weights = csrmask(dgidxB, dB_weights, gidxB) + return None, dA_weights, None, dB_weights, None + + +class CSRSum(th.autograd.Function): + @staticmethod + def forward(ctx, gidxs, *weights): + # PyTorch tensors must be explicit arguments of the forward function + gidxC, C_weights = _csrsum(gidxs, weights) + nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors( + 0, True, 'csr') + # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same + # as the underlying tensors of the created graph gidxC. + ctx.backward_cache = gidxs, gidxC + return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights + + @staticmethod + def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): + # Only the last argument is meaningful. + gidxs, gidxC = ctx.backward_cache + return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs) + + +class CSRMask(th.autograd.Function): + @staticmethod + def forward(ctx, gidxA, A_weights, gidxB): + ctx.backward_cache = gidxA, gidxB + return _csrmask(gidxA, A_weights, gidxB) + + @staticmethod + def backward(ctx, dB_weights): + gidxA, gidxB = ctx.backward_cache + return None, csrmask(gidxB, dB_weights, gidxA), None + + def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) @@ -320,3 +379,21 @@ def segment_reduce(op, x, offsets): def scatter_add(x, idx, m): return ScatterAdd.apply(x, idx, m) + +def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): + nrows, ncols, C_indptr, C_indices, C_eids, C_weights = \ + CSRMM.apply(gidxA, A_weights, gidxB, B_weights, num_vtypes) + gidxC = create_unitgraph_from_csr( + num_vtypes, nrows.item(), ncols.item(), C_indptr, C_indices, C_eids, + ["coo", "csr", "csc"]) + return gidxC, C_weights + +def csrsum(gidxs, weights): + nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(gidxs, *weights) + gidxC = create_unitgraph_from_csr( + gidxs[0].number_of_ntypes(), nrows.item(), ncols.item(), C_indptr, C_indices, C_eids, + ["coo", "csr", "csc"]) + return gidxC, C_weights + +def csrmask(gidxA, A_weights, gidxB): + return CSRMask.apply(gidxA, A_weights, gidxB) diff --git a/python/dgl/backend/tensorflow/sparse.py b/python/dgl/backend/tensorflow/sparse.py index b3c8bc9b2215..e573495601e2 100644 --- a/python/dgl/backend/tensorflow/sparse.py +++ b/python/dgl/backend/tensorflow/sparse.py @@ -3,8 +3,11 @@ from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy from ...base import is_all, ALL from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add +from ...sparse import _csrmm, _csrsum, _csrmask +from ...heterograph_index import create_unitgraph_from_csr -__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add'] +__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add', + 'csrmm', 'csrsum', 'csrmask'] def _scatter_nd(index, src, n_rows): @@ -295,3 +298,64 @@ def scatter_add(x, idx, m): def _lambda(x): return scatter_add_real(x, idx, m) return _lambda(x) + + +def csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes): + gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes) + nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr') + + def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): + # Only the last argument is meaningful. + dgidxA, dA_weights = _csrmm( + gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes()) + dgidxB, dB_weights = _csrmm( + gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes()) + dA_weights = _csrmask(dgidxA, dA_weights, gidxA) + dB_weights = _csrmask(dgidxB, dB_weights, gidxB) + return dA_weights, dB_weights + return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad + +def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): + @tf.custom_gradient + def _lambda(A_weights, B_weights): + return csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes) + nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(A_weights, B_weights) + gidxC = create_unitgraph_from_csr( + num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids, + ["coo", "csr", "csc"]) + return gidxC, C_weights + + +def csrsum_real(gidxs, weights): + gidxC, C_weights = _csrsum(gidxs, weights) + nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr') + + def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): + # Only the last argument is meaningful. + return tuple(_csrmask(gidxC, dC_weights, gidx) for gidx in gidxs) + return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad + +def csrsum(gidxs, weights): + @tf.custom_gradient + def _lambda(*weights): + return csrsum_real(gidxs, weights) + nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights) + num_vtypes = gidxs[0].number_of_ntypes() + gidxC = create_unitgraph_from_csr( + num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids, + ["coo", "csr", "csc"]) + return gidxC, C_weights + + +def csrmask_real(gidxA, A_weights, gidxB): + B_weights = _csrmask(gidxA, A_weights, gidxB) + + def grad(dB_weights): + return _csrmask(gidxB, dB_weights, gidxA) + return B_weights, grad + +def csrmask(gidxA, A_weights, gidxB): + @tf.custom_gradient + def _lambda(A_weights): + return csrmask_real(gidxA, A_weights, gidxB) + return _lambda(A_weights) diff --git a/python/dgl/convert.py b/python/dgl/convert.py index 344e3478ebfe..20bc60bee8a5 100644 --- a/python/dgl/convert.py +++ b/python/dgl/convert.py @@ -335,30 +335,16 @@ def heterograph(data_dict, ' the max ID in the data, but got {} and {}.'.format( dty, num_nodes_dict[dty], vrange - 1)) # Create the graph - - # Sort the ntypes and relation tuples to have a deterministic order for the same set - # of type names. - ntypes = list(sorted(num_nodes_dict.keys())) - relations = list(sorted(node_tensor_dict.keys())) - + metagraph, ntypes, etypes, relations = heterograph_index.create_metagraph_index( + num_nodes_dict.keys(), node_tensor_dict.keys()) num_nodes_per_type = utils.toindex([num_nodes_dict[ntype] for ntype in ntypes], "int64") - ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)} - - meta_edges_src = [] - meta_edges_dst = [] - etypes = [] rel_graphs = [] for srctype, etype, dsttype in relations: - meta_edges_src.append(ntype_dict[srctype]) - meta_edges_dst.append(ntype_dict[dsttype]) - etypes.append(etype) src, dst = node_tensor_dict[(srctype, etype, dsttype)] g = create_from_edges(src, dst, srctype, etype, dsttype, num_nodes_dict[srctype], num_nodes_dict[dsttype]) rel_graphs.append(g) - # metagraph is DGLGraph, currently still using int64 as index dtype - metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True) # create graph index hgidx = heterograph_index.create_heterograph_from_relations( metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type) diff --git a/python/dgl/heterograph_index.py b/python/dgl/heterograph_index.py index 9bf91cf826f8..c90257e532f5 100644 --- a/python/dgl/heterograph_index.py +++ b/python/dgl/heterograph_index.py @@ -8,6 +8,7 @@ from ._ffi.object import register_object, ObjectBase from ._ffi.function import _init_api from .base import DGLError, dgl_warning +from .graph_index import from_coo from . import backend as F from . import utils @@ -649,6 +650,60 @@ def adjacency_matrix(self, etype, transpose, ctx): else: raise Exception("unknown format") + def adjacency_matrix_tensors(self, etype, transpose, fmt): + """Return the adjacency matrix as a triplet of tensors. + + By default, a row of returned adjacency matrix represents the destination + of an edge and the column represents the source. + + When transpose is True, a row represents the source and a column represents + a destination. + + Parameters + ---------- + etype : int + Edge type + transpose : bool + A flag to transpose the returned adjacency matrix. + fmt : str + Indicates the format of returned adjacency matrix. + + Returns + ------- + tuple[int, int, Tensor, Tensor] or tuple[int, int, Tensor, Tensor, Tensor] + The number of rows and columns, followed by the adjacency matrix tensors + whose data type and device are the same as those of the graph. + + If :attr:`fmt` is ``'coo'``, then the triplet will be + the row array and column array of the COO representation. + + If :attr:`fmt` is ``'csr'``, then the triplet will be + the index pointer array (``indptr``), indices array, and data array + of the CSR representation. The data array will contain the edge ID for + each entry of the adjacency matrix. If the data array is empty, then it is + equivalent to a consecutive array from zero to the number of edges minus one. + """ + if not isinstance(transpose, bool): + raise DGLError('Expect bool value for "transpose" arg,' + ' but got %s.' % (type(transpose))) + + rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt) + srctype, dsttype = self.metagraph.find_edge(etype) + nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype) + ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) + nnz = self.number_of_edges(etype) + if fmt == "csr": + indptr = F.from_dgl_nd(rst(0)) + indices = F.from_dgl_nd(rst(1)) + data = F.from_dgl_nd(rst(2)) + return nrows, ncols, indptr, indices, data + elif fmt == 'coo': + idx = F.from_dgl_nd(rst(0)) + row, col = F.reshape(idx, (2, nnz)) + return nrows, ncols, row, col + else: + raise ValueError("unknown format") + def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None): """Return the scipy adjacency matrix representation of this graph. @@ -674,10 +729,6 @@ def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None): scipy.sparse.spmatrix The scipy representation of adjacency matrix. """ - if not isinstance(transpose, bool): - raise DGLError('Expect bool value for "transpose" arg,' - ' but got %s.' % (type(transpose))) - if return_edge_ids is None: dgl_warning( "Adjacency matrix by default currently returns edge IDs." @@ -687,26 +738,30 @@ def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None): FutureWarning) return_edge_ids = True - rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt) - srctype, dsttype = self.metagraph.find_edge(etype) - nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype) - ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) - nnz = self.number_of_edges(etype) - if fmt == "csr": - indptr = utils.toindex(rst(0), self.dtype).tonumpy() - indices = utils.toindex(rst(1), self.dtype).tonumpy() - data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices) + if fmt == 'csr': + nrows, ncols, indptr, indices, data = \ + self.adjacency_matrix_tensors(etype, transpose, fmt) + indptr = F.asnumpy(indptr) + indices = F.asnumpy(indices) + data = F.asnumpy(data) + # Check if edge ID is omitted if return_edge_ids and data.shape[0] == 0: - data = np.arange(nnz) + data = np.arange(self.number_of_edges(etype)) + else: + data = np.ones_like(indices) + return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols)) elif fmt == 'coo': - idx = utils.toindex(rst(0), self.dtype).tonumpy() - row, col = np.reshape(idx, (2, nnz)) - data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row) + nrows, ncols, row, col = \ + self.adjacency_matrix_tensors(etype, transpose, fmt) + row = F.asnumpy(row) + col = F.asnumpy(col) + data = np.arange(self.number_of_edges(etype)) if return_edge_ids \ + else np.ones_like(row) return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols)) else: - raise Exception("unknown format") + raise ValueError("unknown format") def incidence_matrix(self, etype, typestr, ctx): """Return the incidence matrix representation of this graph. @@ -972,6 +1027,46 @@ def induced_edges(self): # Creators ################################################################# +def create_metagraph_index(ntypes, canonical_etypes): + """Return a GraphIndex instance for a metagraph given the node types and canonical + edge types. + + This function will reorder the node types and canonical edge types. + + Parameters + ---------- + ntypes : Iterable[str] + The node types. + canonical_etypes : Iterable[tuple[str, str, str]] + The canonical edge types. + + Returns + ------- + GraphIndex + The index object for metagraph. + list[str] + The reordered node types for each node in the metagraph. + list[str] + The reordered edge types for each edge in the metagraph. + list[tuple[str, str, str]] + The reordered canonical edge types for each edge in the metagraph. + """ + # Sort the ntypes and relation tuples to have a deterministic order for the same set + # of type names. + ntypes = list(sorted(ntypes)) + relations = list(sorted(canonical_etypes)) + ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)} + meta_edges_src = [] + meta_edges_dst = [] + etypes = [] + for srctype, etype, dsttype in relations: + meta_edges_src.append(ntype_dict[srctype]) + meta_edges_dst.append(ntype_dict[dsttype]) + etypes.append(etype) + # metagraph is DGLGraph, currently still using int64 as index dtype + metagraph = from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True) + return metagraph, ntypes, etypes, relations + def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, formats, row_sorted=False, col_sorted=False): """Create a unitgraph graph index from COO format diff --git a/python/dgl/sparse.py b/python/dgl/sparse.py index 3034bb6881a3..96593cf8acb0 100644 --- a/python/dgl/sparse.py +++ b/python/dgl/sparse.py @@ -366,7 +366,7 @@ def _bwd_segment_cmp(feat, arg, m): to_dgl_nd_for_write(out)) return out -def csrmm(A, A_weights, B, B_weights, num_vtypes): +def _csrmm(A, A_weights, B, B_weights, num_vtypes): """Return a graph whose adjacency matrix is the sparse matrix multiplication of those of two given graphs. @@ -397,7 +397,7 @@ def csrmm(A, A_weights, B, B_weights, num_vtypes): A, F.to_dgl_nd(A_weights), B, F.to_dgl_nd(B_weights), num_vtypes) return C, F.from_dgl_nd(C_weights) -def csrsum(As, A_weights): +def _csrsum(As, A_weights): """Return a graph whose adjacency matrix is the sparse matrix summation of the given list of graphs. @@ -421,7 +421,7 @@ def csrsum(As, A_weights): C, C_weights = _CAPI_DGLCSRSum(As, [F.to_dgl_nd(w) for w in A_weights]) return C, F.from_dgl_nd(C_weights) -def csrmask(A, A_weights, B): +def _csrmask(A, A_weights, B): """Return the weights of A at the locations identical to the sparsity pattern of B. diff --git a/python/dgl/transform.py b/python/dgl/transform.py index f343c0dbabcd..37d22b87bf46 100644 --- a/python/dgl/transform.py +++ b/python/dgl/transform.py @@ -9,6 +9,7 @@ from .base import dgl_warning, DGLError from . import convert from .heterograph import DGLHeteroGraph, DGLBlock +from .heterograph_index import create_metagraph_index, create_heterograph_from_relations from .frame import Frame from . import ndarray as nd from . import backend as F @@ -46,7 +47,9 @@ 'metis_partition_assignment', 'partition_graph_with_halo', 'metis_partition', - 'as_heterograph'] + 'as_heterograph', + 'adj_product_graph', + 'adj_sum_graph'] def pairwise_squared_distance(x): @@ -2223,6 +2226,242 @@ def to_simple(g, DGLHeteroGraph.to_simple = utils.alias_func(to_simple) +def _unitgraph_less_than_int32(g): + """Check if a graph with only one edge type has more than 2 ** 31 - 1 + nodes or edges. + """ + num_edges = g.num_edges() + num_nodes = max(g.num_nodes(g.ntypes[0]), g.num_nodes(g.ntypes[-1])) + return max(num_nodes, num_edges) <= (1 << 31) - 1 + +def adj_product_graph(A, B, weight_name, etype='_E'): + r"""Create a weighted graph whose adjacency matrix is the product of + the adjacency matrices of the given two graphs. + + Namely, given two weighted graphs :attr:`A` and :attr:`B`, whose rows + represent source nodes and columns represent destination nodes, this function + returns a new graph whose weighted adjacency matrix is + :math:`\mathrm{adj}(A) \times \mathrm{adj}(B)`. + + The two graphs must be simple graphs, and must have only one edge type. + Moreover, the number of nodes of the destination node type of :attr:`A` must + be the same as the number of nodes of the source node type of :attr:`B`. + + The source node type of the returned graph will be the same as the source + node type of graph :attr:`A`. The destination node type of the returned + graph will be the same as the destination node type of graph :attr:`B`. + If the two node types are the same, the returned graph will be homogeneous. + Otherwise, it will be a bipartite graph. + + Unlike ``scipy``, if an edge in the result graph has zero weight, it will + not be removed from the graph. + + Notes + ----- + This function works on both CPU and GPU. For GPU, the number of nodes and + edges must be less than the maximum of ``int32`` (i.e. ``2 ** 31 - 1``) due + to restriction of cuSPARSE. + + The edge weights returned by this function is differentiable w.r.t. the + input edge weights. + + If the graph format is restricted, both graphs must have CSR available. + + Parameters + ---------- + A : DGLGraph + The graph as left operand. + B : DGLGraph + The graph as right operand. + weight_name : str + The feature name of edge weight of both graphs. + + The corresponding edge feature must be scalar. + etype : str, optional + The edge type of the returned graph. + + Returns + ------- + DGLGraph + The new graph. The edge weight of the returned graph will have the + same feature name as :attr:`weight_name`. + + Examples + -------- + The following shows weighted adjacency matrix multiplication between two + bipartite graphs. You can also perform this between two homogeneous + graphs, or one homogeneous graph and one bipartite graph, as long as the + numbers of nodes of the same type match. + + >>> A = dgl.heterograph({ + ... ('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])}, + ... num_nodes_dict={'A': 3, 'B': 4}) + >>> B = dgl.heterograph({ + ... ('B', 'BA', 'A'): ([0, 3, 2, 1, 3, 3], [1, 2, 0, 2, 1, 0])}, + ... num_nodes_dict={'A': 3, 'B': 4}) + >>> A.edata['w'] = torch.randn(6).requires_grad_() + >>> B.edata['w'] = torch.randn(6).requires_grad_() + + If your graph is a multigraph, you will need to call :func:`dgl.to_simple` + to convert it into a simple graph first. + + >>> A = dgl.to_simple(A) + >>> B = dgl.to_simple(B) + + >>> C = dgl.adj_product_graph(A, B, 'w') + >>> C.edges() + (tensor([0, 0, 1, 2, 2, 2]), tensor([0, 1, 0, 0, 2, 1])) + + >>> C.edata['w'] + tensor([0.6906, 0.2002, 0.0591, 0.3672, 0.1066, 0.1328], + grad_fn=) + + Note that this function is differentiable: + + >>> C.edata['w'].sum().backward() + >>> A.edata['w'].grad + tensor([0.7153, 0.2775, 0.7141, 0.7141, 0.7153, 0.7153]) + + >>> B.edata['w'].grad + tensor([0.4664, 0.0000, 1.5614, 0.3840, 0.0000, 0.0000]) + + If the source node type of the left operand is the same as the destination + node type of the right operand, this function returns a homogeneous graph: + + >>> C.ntypes + ['A'] + + Otherwise, it returns a bipartite graph instead: + + >>> A = dgl.heterograph({ + ... ('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])}, + ... num_nodes_dict={'A': 3, 'B': 4}) + >>> B = dgl.heterograph({ + ... ('B', 'BC', 'C'): ([0, 3, 2, 1, 3, 3], [1, 2, 0, 2, 1, 0])}, + ... num_nodes_dict={'C': 3, 'B': 4}) + >>> A.edata['w'] = torch.randn(6).requires_grad_() + >>> B.edata['w'] = torch.randn(6).requires_grad_() + >>> C = dgl.adj_product_graph(A, B, 'w') + >>> C.ntypes + ['A', 'C'] + """ + srctype, _, _ = A.canonical_etypes[0] + _, _, dsttype = B.canonical_etypes[0] + num_vtypes = 1 if srctype == dsttype else 2 + ntypes = [srctype] if num_vtypes == 1 else [srctype, dsttype] + + if A.device != F.cpu(): + if not (_unitgraph_less_than_int32(A) and _unitgraph_less_than_int32(B)): + raise ValueError( + 'For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.') + + C_gidx, C_weights = F.csrmm( + A._graph, A.edata[weight_name], B._graph, B.edata[weight_name], num_vtypes) + num_nodes_dict = {srctype: A.num_nodes(srctype), dsttype: B.num_nodes(dsttype)} + C_metagraph, ntypes, etypes, _ = \ + create_metagraph_index(ntypes, [(srctype, etype, dsttype)]) + num_nodes_per_type = [num_nodes_dict[ntype] for ntype in ntypes] + C_gidx = create_heterograph_from_relations( + C_metagraph, [C_gidx], utils.toindex(num_nodes_per_type)) + + C = DGLHeteroGraph(C_gidx, ntypes, etypes) + C.edata[weight_name] = C_weights + return C + +def adj_sum_graph(graphs, weight_name): + r"""Create a weighted graph whose adjacency matrix is the sum of the + adjacency matrices of the given graphs, whose rows represent source nodes + and columns represent destination nodes. + + All the graphs must be simple graphs, and must have only one edge type. + They also must have the same metagraph, i.e. have the same source node type + and the same destination node type. Moreover, the number of nodes for every + graph must also be the same. + + The metagraph of the returned graph will be the same as the input graphs. + + Unlike ``scipy``, if an edge in the result graph has zero weight, it will + not be removed from the graph. + + Notes + ----- + This function works on both CPU and GPU. For GPU, the number of nodes and + edges must be less than the maximum of ``int32`` (i.e. ``2 ** 31 - 1``) due + to restriction of cuSPARSE. + + The edge weights returned by this function is differentiable w.r.t. the + input edge weights. + + If the graph format is restricted, both graphs must have CSR available. + + Parameters + ---------- + graphs : list[DGLGraph] + The list of graphs. Must have at least one element. + weight_name : str + The feature name of edge weight of both graphs. + + The corresponding edge feature must be scalar. + + Returns + ------- + DGLGraph + The new graph. The edge weight of the returned graph will have the + same feature name as :attr:`weight_name`. + + Examples + -------- + The following shows weighted adjacency matrix summation between two + bipartite graphs. You can also perform this between homogeneous graphs. + + >>> A = dgl.heterograph( + ... {('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])}, + ... num_nodes_dict={'A': 3, 'B': 4}) + >>> B = dgl.heterograph( + ... {('A', 'AB', 'B'): ([1, 2, 0, 2, 1, 0], [0, 3, 2, 1, 3, 3])}, + ... num_nodes_dict={'A': 3, 'B': 4}) + >>> A.edata['w'] = torch.randn(6).requires_grad_() + >>> B.edata['w'] = torch.randn(6).requires_grad_() + + If your graph is a multigraph, you will need to call :func:`dgl.to_simple` + to convert it into a simple graph first. + + >>> A = dgl.to_simple(A) + >>> B = dgl.to_simple(B) + + >>> C = dgl.adj_sum_graph([A, B], 'w') + >>> C.edges() + (tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2]), + tensor([0, 2, 3, 2, 0, 3, 0, 1, 2, 3])) + + Note that this function is differentiable: + + >>> C.edata['w'].sum().backward() + >>> A.edata['w'].grad + tensor([1., 1., 1., 1., 1., 1.]) + + >>> B.edata['w'].grad + tensor([1., 1., 1., 1., 1., 1.]) + """ + if len(graphs) == 0: + raise ValueError('The list of graphs must not be empty.') + + if graphs[0].device != F.cpu(): + if not all(_unitgraph_less_than_int32(A) for A in graphs): + raise ValueError( + 'For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.') + metagraph = graphs[0]._graph.metagraph + num_nodes = utils.toindex( + [graphs[0]._graph.number_of_nodes(i) for i in range(graphs[0]._graph.number_of_ntypes())]) + weights = [A.edata[weight_name] for A in graphs] + gidxs = [A._graph for A in graphs] + C_gidx, C_weights = F.csrsum(gidxs, weights) + C_gidx = create_heterograph_from_relations(metagraph, [C_gidx], num_nodes) + + C = DGLHeteroGraph(C_gidx, graphs[0].ntypes, graphs[0].etypes) + C.edata[weight_name] = C_weights + return C + def as_heterograph(g, ntype='_U', etype='_E'): # pylint: disable=unused-argument """Convert a DGLGraph to a DGLHeteroGraph with one node and edge type. diff --git a/src/array/array_op.h b/src/array/array_op.h index bfedc7b757ee..40a83e2421bf 100644 --- a/src/array/array_op.h +++ b/src/array/array_op.h @@ -104,12 +104,19 @@ bool CSRIsSorted(CSRMatrix csr); template runtime::NDArray CSRGetData( - CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, + CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids, runtime::NDArray weights, DType filler); +template +runtime::NDArray CSRGetData( + CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, + runtime::NDArray weights, DType filler) { + return CSRGetData(csr, rows, cols, false, weights, filler); +} + template NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { - return CSRGetData(csr, rows, cols, NullArray(rows->dtype), -1); + return CSRGetData(csr, rows, cols, true, NullArray(rows->dtype), -1); } template diff --git a/src/array/cpu/csr_get_data.cc b/src/array/cpu/csr_get_data.cc index 27d8c76c12ee..41365aec1e68 100644 --- a/src/array/cpu/csr_get_data.cc +++ b/src/array/cpu/csr_get_data.cc @@ -39,7 +39,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data, template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) { + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { const int64_t rowlen = rows->shape[0]; const int64_t collen = cols->shape[0]; @@ -56,7 +56,6 @@ NDArray CSRGetData( const IdType* data = CSRHasData(csr)? static_cast(csr.data->data) : nullptr; const int64_t retlen = std::max(rowlen, collen); - bool return_eids = IsNullArray(weights); const DType* weight_data = return_eids ? nullptr : weights.Ptr(); if (return_eids) BUG_IF_FAIL(DLDataTypeTraits::dtype == rows->dtype) << @@ -105,19 +104,19 @@ NDArray CSRGetData( } template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); // For CSRGetData(CSRMatrix, NDArray, NDArray) template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int32_t filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int64_t filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); } // namespace impl } // namespace aten diff --git a/src/array/cpu/csr_mm.cc b/src/array/cpu/csr_mm.cc index a95942735c58..44aec76cffad 100644 --- a/src/array/cpu/csr_mm.cc +++ b/src/array/cpu/csr_mm.cc @@ -127,7 +127,9 @@ std::pair CSRMM( B_indptr, B_indices, B_eids, B_data, C_indptr_data, C_indices_data, C_weights_data, M); - return {CSRMatrix(M, P, C_indptr, C_indices), C_weights}; + return { + CSRMatrix(M, P, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)), + C_weights}; } template std::pair CSRMM( diff --git a/src/array/cpu/csr_sum.cc b/src/array/cpu/csr_sum.cc index 5c7211a0082b..9189da6a7620 100644 --- a/src/array/cpu/csr_sum.cc +++ b/src/array/cpu/csr_sum.cc @@ -124,7 +124,9 @@ std::pair CSRSum( A_indptr, A_indices, A_eids, A_data, C_indptr_data, C_indices_data, C_weights_data, M); - return {CSRMatrix(M, N, C_indptr, C_indices), C_weights}; + return { + CSRMatrix(M, N, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)), + C_weights}; } template std::pair CSRSum( diff --git a/src/array/cuda/csr_get_data.cu b/src/array/cuda/csr_get_data.cu index 9b654bba7672..2db4bc075b57 100644 --- a/src/array/cuda/csr_get_data.cu +++ b/src/array/cuda/csr_get_data.cu @@ -19,7 +19,7 @@ namespace impl { template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) { + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { const int64_t rowlen = rows->shape[0]; const int64_t collen = cols->shape[0]; @@ -37,7 +37,6 @@ NDArray CSRGetData( auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); const int nt = cuda::FindNumThreads(rstlen); const int nb = (rstlen + nt - 1) / nt; - bool return_eids = IsNullArray(weights); if (return_eids) BUG_IF_FAIL(DLDataTypeTraits::dtype == rows->dtype) << "DType does not match row's dtype."; @@ -54,19 +53,19 @@ NDArray CSRGetData( } template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); // For CSRGetData(CSRMatrix, NDArray, NDArray) template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int32_t filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); template NDArray CSRGetData( - CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int64_t filler); + CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); } // namespace impl } // namespace aten diff --git a/src/array/cuda/csr_mm.cu b/src/array/cuda/csr_mm.cu index f5c3341000fd..0a59ca8b50aa 100644 --- a/src/array/cuda/csr_mm.cu +++ b/src/array/cuda/csr_mm.cu @@ -118,7 +118,10 @@ std::pair CusparseSpgemm( CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroySpMat(matB)); CUSPARSE_CALL(cusparseDestroySpMat(matC)); - return {CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns), dC_weights}; + return { + CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns, + NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)), + dC_weights}; } #else // __CUDACC_VER_MAJOR__ != 11 @@ -197,7 +200,9 @@ std::pair CusparseSpgemm( CUSPARSE_CALL(cusparseDestroyMatDescr(matC)); CUSPARSE_CALL(cusparseDestroyMatDescr(matD)); - return {CSRMatrix(m, k, C_indptr, C_indices), C_weights}; + return { + CSRMatrix(m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)), + C_weights}; } #endif // __CUDACC_VER_MAJOR__ == 11 @@ -240,7 +245,8 @@ std::pair CSRMM( if (cast) { CSRMatrix C = result.first; return { - CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64)), + CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64), + AsNumBits(C.data, 64)), result.second}; } else { return result; diff --git a/src/array/cuda/csr_sum.cu b/src/array/cuda/csr_sum.cu index 6e4d436e003e..c1b5e2ee2a7d 100644 --- a/src/array/cuda/csr_sum.cu +++ b/src/array/cuda/csr_sum.cu @@ -48,7 +48,7 @@ std::pair CusparseCsrgeam2( cusparseSetPointerMode(thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST); size_t workspace_size = 0; /* prepare output C */ - IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, ctx); + IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx); IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr(); IdArray dC_columns; NDArray dC_weights; @@ -97,7 +97,9 @@ std::pair CusparseCsrgeam2( CUSPARSE_CALL(cusparseDestroyMatDescr(matA)); CUSPARSE_CALL(cusparseDestroyMatDescr(matB)); CUSPARSE_CALL(cusparseDestroyMatDescr(matC)); - return {CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns), + return { + CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns, + NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true), dC_weights}; } } // namespace cusparse @@ -112,22 +114,31 @@ std::pair CSRSum( // Cast 64 bit indices to 32 bit std::vector newAs; + newAs.reserve(n); bool cast = false; if (As[0].indptr->dtype.bits == 64) { - newAs.reserve(n); for (int i = 0; i < n; ++i) newAs.emplace_back( As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32), AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32)); cast = true; + } else { + for (int i = 0; i < n; ++i) + newAs.push_back(As[i]); + } + + // cuSPARSE csrgeam2 requires the CSR to be sorted. + // TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure how to do it. + for (int i = 0; i < n; ++i) { + if (!newAs[i].sorted) + newAs[i] = CSRSort(newAs[i]); } - const std::vector &As_ref = cast ? newAs : As; // Reorder weights if A[i] has edge IDs std::vector A_weights_reordered(n); for (int i = 0; i < n; ++i) { - if (CSRHasData(As[i])) - A_weights_reordered[i] = IndexSelect(A_weights[i], As[i].data); + if (CSRHasData(newAs[i])) + A_weights_reordered[i] = IndexSelect(A_weights[i], newAs[i].data); else A_weights_reordered[i] = A_weights[i]; } @@ -135,18 +146,20 @@ std::pair CSRSum( // Loop and sum auto result = std::make_pair( CSRMatrix( - As_ref[0].num_rows, As_ref[0].num_cols, - As_ref[0].indptr, As_ref[0].indices), + newAs[0].num_rows, newAs[0].num_cols, + newAs[0].indptr, newAs[0].indices, + NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)), A_weights_reordered[0]); // Weights already reordered so we don't need As[0].data for (int64_t i = 1; i < n; ++i) result = cusparse::CusparseCsrgeam2( - result.first, result.second, As_ref[i], A_weights_reordered[i]); + result.first, result.second, newAs[i], A_weights_reordered[i]); // Cast 32 bit indices back to 64 bit if necessary if (cast) { CSRMatrix C = result.first; return { - CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64)), + CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64), + AsNumBits(C.data, 64), true), result.second}; } else { return result; diff --git a/src/graph/unit_graph.cc b/src/graph/unit_graph.cc index 3e9adbc0ab0b..8865a17d36bb 100644 --- a/src/graph/unit_graph.cc +++ b/src/graph/unit_graph.cc @@ -453,9 +453,9 @@ class UnitGraph::CSR : public BaseHeteroGraph { : BaseHeteroGraph(metagraph) { CHECK(aten::IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indices)); - CHECK(aten::IsValidIdArray(edge_ids)); - CHECK_EQ(indices->shape[0], edge_ids->shape[0]) - << "indices and edge id arrays should have the same length"; + if (aten::IsValidIdArray(edge_ids)) + CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids)) + << "edge id arrays should have the same length as indices if not empty"; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; } diff --git a/tests/compute/test_csrmm.py b/tests/compute/test_csrmm.py index 4dbf1787e76c..904b469263e4 100644 --- a/tests/compute/test_csrmm.py +++ b/tests/compute/test_csrmm.py @@ -1,5 +1,6 @@ import numpy as np import scipy.sparse as ssp +import pytest import dgl from utils import parametrize_dtype import backend as F @@ -11,51 +12,199 @@ def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, et a = ssp.csr_matrix((val, (src, dst)), shape=(M, N)) a.sum_duplicates() a = a.tocoo() + # shuffle edges + perm = np.random.permutation(a.nnz) + row = a.row[perm] + col = a.col[perm] + val = a.data[perm] + a = ssp.csr_matrix((val, (row, col)), shape=(M, N)) + A = dgl.heterograph( - {('A', 'AB', 'B'): ( - F.copy_to(F.tensor(a.row, dtype=idtype), ctx), - F.copy_to(F.tensor(a.col, dtype=idtype), ctx))}, - num_nodes_dict={'A': a.shape[0], 'B': a.shape[1]}) - A.edata['w'] = F.copy_to(F.tensor(a.data, dtype=dtype), ctx) + {(srctype, etype, dsttype): ( + F.copy_to(F.tensor(row, dtype=idtype), ctx), + F.copy_to(F.tensor(col, dtype=idtype), ctx))}, + num_nodes_dict={srctype: a.shape[0], dsttype: a.shape[1]}) + A.edata['w'] = F.copy_to(F.tensor(val, dtype=dtype), ctx) return a, A @parametrize_dtype -def test_csrmm(idtype): - for dtype in [F.float32, F.float64]: - a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') - b, B = _random_simple_graph(idtype, dtype, F.ctx(), 600, 700, 9000, 'B', 'C', 'BC') - C, C_weights = dgl.sparse.csrmm(A._graph, A.edata['w'], B._graph, B.edata['w'], 2) - C_adj = C.adjacency_matrix_scipy(0, True, 'csr') - C_adj.data = F.asnumpy(C_weights) - C_adj = F.tensor(C_adj.todense(), dtype=dtype) - c = F.tensor((a * b).todense(), dtype=dtype) - assert F.allclose(C_adj, c) +@pytest.mark.parametrize('dtype', [F.float32, F.float64]) +def test_csrmm(idtype, dtype): + a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') + b, B = _random_simple_graph(idtype, dtype, F.ctx(), 600, 700, 9000, 'B', 'C', 'BC') + C, C_weights = dgl.sparse._csrmm(A._graph, A.edata['w'], B._graph, B.edata['w'], 2) + C_adj = C.adjacency_matrix_scipy(0, True, 'csr') + C_adj.data = F.asnumpy(C_weights) + C_adj = F.tensor(C_adj.todense(), dtype=dtype) + c = F.tensor((a * b).todense(), dtype=dtype) + assert F.allclose(C_adj, c) + +@parametrize_dtype +@pytest.mark.parametrize('dtype', [F.float32, F.float64]) +@pytest.mark.parametrize('num_vtypes', [1, 2]) +def test_csrmm_backward(idtype, dtype, num_vtypes): + a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') + b, B = _random_simple_graph(idtype, dtype, F.ctx(), 4, 3, 6, 'B', 'A' if num_vtypes == 1 else 'C', 'BA') + A_row, A_col = A.edges(order='eid') + B_row, B_col = B.edges(order='eid') + A_row = F.asnumpy(A_row) + A_col = F.asnumpy(A_col) + B_row = F.asnumpy(B_row) + B_col = F.asnumpy(B_col) + a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype)) + b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype)) + + A.edata['w'] = F.attach_grad(A.edata['w']) + B.edata['w'] = F.attach_grad(B.edata['w']) + + with F.record_grad(): + C = dgl.adj_product_graph(A, B, 'w') + assert len(C.ntypes) == num_vtypes + assert len(C.etypes) == 1 + C_dense = np.zeros((3, 3)) + C_row, C_col = C.edges(order='eid') + C_row = F.asnumpy(C_row) + C_col = F.asnumpy(C_col) + C_dense[C_row, C_col] = F.asnumpy(C.edata['w']) + c_dense = F.matmul(a_dense, b_dense) + assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4) + + F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense)) + a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] + b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col] + A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) + B_spspmm_grad = F.asnumpy(F.grad(B.edata['w'])) + assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) + assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4) + +@parametrize_dtype +@pytest.mark.parametrize('dtype', [F.float32, F.float64]) +def test_csrsum(idtype, dtype): + a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') + b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') + C, C_weights = dgl.sparse._csrsum([A._graph, B._graph], [A.edata['w'], B.edata['w']]) + C_adj = C.adjacency_matrix_scipy(0, True, 'csr') + C_adj.data = F.asnumpy(C_weights) + C_adj = F.tensor(C_adj.todense(), dtype=dtype) + c = F.tensor((a + b).todense(), dtype=dtype) + assert F.allclose(C_adj, c) + +@parametrize_dtype +@pytest.mark.parametrize('dtype', [F.float32, F.float64]) +@pytest.mark.parametrize('nelems', [1, 2]) +def test_csrsum_backward(idtype, dtype, nelems): + a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') + b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') + A_row, A_col = A.edges(order='eid') + B_row, B_col = B.edges(order='eid') + A_row = F.asnumpy(A_row) + A_col = F.asnumpy(A_col) + B_row = F.asnumpy(B_row) + B_col = F.asnumpy(B_col) + a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype)) + b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype)) + + A.edata['w'] = F.attach_grad(A.edata['w']) + B.edata['w'] = F.attach_grad(B.edata['w']) + + with F.record_grad(): + if nelems == 2: + # Test for two element case + C = dgl.adj_sum_graph([A, B], 'w') + assert C.canonical_etypes == A.canonical_etypes + C_dense = np.zeros((3, 4)) + C_row, C_col = C.edges(order='eid') + C_row = F.asnumpy(C_row) + C_col = F.asnumpy(C_col) + C_dense[C_row, C_col] = F.asnumpy(C.edata['w']) + c_dense = a_dense + b_dense + assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4) + + F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense)) + a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] + b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col] + A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) + B_spspmm_grad = F.asnumpy(F.grad(B.edata['w'])) + assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) + assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4) + elif nelems == 1: + # Test for single element case + C = dgl.adj_sum_graph([A], 'w') + assert C.canonical_etypes == A.canonical_etypes + C_dense = np.zeros((3, 4)) + C_row, C_col = C.edges(order='eid') + C_row = F.asnumpy(C_row) + C_col = F.asnumpy(C_col) + C_dense[C_row, C_col] = F.asnumpy(C.edata['w']) + c_dense = a_dense + assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4) + + F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense)) + a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] + A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) + assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) @parametrize_dtype -def test_csrsum(idtype): - for dtype in [F.float32, F.float64]: - a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') - b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') - C, C_weights = dgl.sparse.csrsum([A._graph, B._graph], [A.edata['w'], B.edata['w']]) - C_adj = C.adjacency_matrix_scipy(0, True, 'csr') - C_adj.data = F.asnumpy(C_weights) - C_adj = F.tensor(C_adj.todense(), dtype=dtype) - c = F.tensor((a + b).todense(), dtype=dtype) - assert F.allclose(C_adj, c) +@pytest.mark.parametrize('dtype', [F.float32, F.float64]) +@pytest.mark.parametrize('A_nnz', [9000, 0]) +@pytest.mark.parametrize('B_nnz', [9000, 0]) +def test_csrmask(idtype, dtype, A_nnz, B_nnz): + a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, A_nnz, 'A', 'B', 'AB') + b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, B_nnz, 'A', 'B', 'AB') + C = dgl.sparse._csrmask(A._graph, A.edata['w'], B._graph) + B_row, B_col = B.edges(order='eid') + B_row = F.asnumpy(B_row) + B_col = F.asnumpy(B_col) + c = F.tensor(a.todense()[B_row, B_col], dtype) + assert F.allclose(C, c) @parametrize_dtype -def test_csrmask(idtype): - for dtype in [F.float32, F.float64]: - a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') - b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') - C = dgl.sparse.csrmask(A._graph, A.edata['w'], B._graph) - c = F.tensor(a.tocsr()[b != 0], dtype) - assert F.allclose(C, c) +@pytest.mark.parametrize('dtype', [F.float32, F.float64]) +def test_csrmask_backward(idtype, dtype): + a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') + b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') + A_row, A_col = A.edges(order='eid') + B_row, B_col = B.edges(order='eid') + A_row = F.asnumpy(A_row) + A_col = F.asnumpy(A_col) + B_row = F.asnumpy(B_row) + B_col = F.asnumpy(B_col) + a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype)) + + A.edata['w'] = F.attach_grad(A.edata['w']) + + with F.record_grad(): + # Test for two element case + C1 = F.csrmask(A._graph, A.edata['w'], B._graph) + if dgl.backend.backend_name == 'tensorflow': + import tensorflow as tf + C2 = tf.gather_nd(a_dense, tf.stack([B_row, B_col], 1)) + else: + C2 = a_dense[B_row, B_col] + assert F.allclose(C1, C2, rtol=1e-4, atol=1e-4) + + F.backward(F.reduce_sum(C1) + F.reduce_sum(C2)) + a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] + A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) + assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) + if __name__ == '__main__': - test_csrmm(F.int32) - test_csrmm(F.int64) - test_csrsum(F.int32) - test_csrsum(F.int64) - test_csrmask(F.int32) - test_csrmask(F.int64) + test_csrmm(F.int32, F.float32) + test_csrmm(F.int64, F.float32) + test_csrsum(F.int32, F.float32) + test_csrsum(F.int64, F.float32) + test_csrmask(F.int32, F.float32, 9000, 9000) + test_csrmask(F.int64, F.float32, 9000, 0) + test_csrmask(F.int32, F.float32, 0, 9000) + test_csrmask(F.int64, F.float32, 0, 0) + test_csrmm_backward(F.int32, F.float32, 1) + test_csrmm_backward(F.int64, F.float32, 1) + test_csrmm_backward(F.int32, F.float32, 2) + test_csrmm_backward(F.int64, F.float32, 2) + test_csrsum_backward(F.int32, F.float32, 1) + test_csrsum_backward(F.int64, F.float32, 1) + test_csrsum_backward(F.int32, F.float32, 2) + test_csrsum_backward(F.int64, F.float32, 2) + test_csrmask_backward(F.int32, F.float32) + test_csrmask_backward(F.int64, F.float32)