Skip to content

Commit

Permalink
[Feature] Python interface for adjacency matrix summation and multipl…
Browse files Browse the repository at this point in the history
…ication (dmlc#2893)

* test commit

* fixes

* oops

* add docs

* lint

* why does it say I have a trailing whitespace

* oh ok

* fixes

* why there's an invalid argument error

* address comments

* fix

* address comments
  • Loading branch information
BarclayII authored May 17, 2021
1 parent 29fec7d commit 657c220
Show file tree
Hide file tree
Showing 18 changed files with 928 additions and 115 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/python/dgl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ Operators for generating new graphs by manipulating the structure of the existin
line_graph
khop_graph
metapath_reachable_graph
adj_product_graph
adj_sum_graph

.. _api-batch:

Expand Down
85 changes: 85 additions & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,91 @@ def scatter_add(x, idx, m):
"""
pass

def csrmm(A, A_weights, B, B_weights, num_vtypes):
"""Compute weighted adjacency matrix multiplication.
Notes
-----
Both A and B must allow creation of CSR representations, and must be simple graphs
(i.e. having at most one edge between two nodes).
The output unit graph has no format restriction.
Parameters
----------
A : HeteroGraphIndex
The unit graph as left operand.
A_weights : Tensor
The edge weights of A. Must be a 1D vector.
B : HeteroGraphIndex
The unit graph as right operand.
B_weights : Tensor
The edge weights of B. Must be a 1D vector.
num_vtypes : int
The number of node types of the output graph. Must be either 1 or 2.
Returns
-------
HeteroGraphIndex
The output unit graph.
Tensor
The output edge weights.
"""
pass

def csrsum(gidxs, weights):
"""Compute weighted adjacency matrix summation.
Notes
-----
All unit graphs must allow creation of CSR representations, and must be simple graphs
(i.e. having at most one edge between two nodes).
The output unit graph has no format restriction.
Parameters
----------
gidxs : list[HeteroGraphIndex]
The unit graphs.
weights : list[Tensor]
The edge weights of each graph. Must be 1D vectors.
Returns
-------
HeteroGraphIndex
The output unit graph.
Tensor
The output edge weights.
"""
pass

def csrmask(A, A_weights, B):
"""Retrieve the values in the weighted adjacency matrix of graph :attr:`A` at the
non-zero positions of graph :attr:`B`'s adjacency matrix.
In scipy, this is equivalent to ``A[B != 0]``.
Notes
-----
Both A and B must allow creation of CSR representations, and must be simple graphs
(i.e. having at most one edge between two nodes).
Parameters
----------
A : HeteroGraphIndex
The unit graph as left operand.
A_weights : Tensor
The edge weights of A. Must be a 1D vector.
B : HeteroGraphIndex
The unit graph as right operand.
Returns
-------
Tensor
The output tensor.
"""
pass


###############################################################################
# Other interfaces
Expand Down
90 changes: 89 additions & 1 deletion python/dgl/backend/mxnet/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
from ...heterograph_index import create_unitgraph_from_csr

__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add']
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']


def _scatter_nd(index, src, n_rows):
Expand Down Expand Up @@ -379,3 +382,88 @@ def backward(self, dy):
def scatter_add(x, idx, m):
scatter_add_op = ScatterAdd(idx, m)
return scatter_add_op(x)


class CSRMM(mx.autograd.Function):
def __init__(self, gidxA, gidxB, num_vtypes):
super().__init__()
self.gidxA = gidxA
self.gidxB = gidxB
self.num_vtypes = num_vtypes

def forward(self, A_weights, B_weights):
gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
self.save_for_backward(A_weights, B_weights)
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights

def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxC = self.backward_cache
A_weights, B_weights = self.saved_tensors
dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, self.gidxB.reverse(), B_weights, self.gidxA.number_of_ntypes())
dgidxB, dB_weights = _csrmm(
self.gidxA.reverse(), A_weights, gidxC, dC_weights, self.gidxB.number_of_ntypes())
dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB)
return dA_weights, dB_weights

def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
op = CSRMM(gidxA, gidxB, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(A_weights, B_weights)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights

class CSRSum(mx.autograd.Function):
def __init__(self, gidxs):
super().__init__()
self.gidxs = gidxs

def forward(self, *weights):
gidxC, C_weights = _csrsum(self.gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights

def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxC = self.backward_cache
return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs)

def csrsum(gidxs, weights):
op = CSRSum(gidxs)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights)
num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights

class CSRMask(mx.autograd.Function):
def __init__(self, gidxA, gidxB):
super().__init__()
self.gidxA = gidxA
self.gidxB = gidxB

def forward(self, A_weights):
return _csrmask(self.gidxA, A_weights, self.gidxB)

def backward(self, dB_weights):
return _csrmask(self.gidxB, dB_weights, self.gidxA)

def csrmask(gidxA, A_weights, gidxB):
op = CSRMask(gidxA, gidxB)
return op(A_weights)
79 changes: 78 additions & 1 deletion python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr

if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import custom_fwd, custom_bwd
Expand All @@ -24,7 +26,8 @@ def decorate_bwd(*args, **kwargs):
return bwd(*args, **kwargs)
return decorate_bwd

__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add']
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']


def _reduce_grad(grad, shape):
Expand Down Expand Up @@ -303,6 +306,62 @@ def backward(ctx, dy):
return dy[idx], None, None


class CSRMM(th.autograd.Function):
@staticmethod
def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxA, gidxB, gidxC
ctx.save_for_backward(A_weights, B_weights)
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights

@staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxA, gidxB, gidxC = ctx.backward_cache
A_weights, B_weights = ctx.saved_tensors
dgidxA, dA_weights = csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
dgidxB, dB_weights = csrmm(
gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes())
dA_weights = csrmask(dgidxA, dA_weights, gidxA)
dB_weights = csrmask(dgidxB, dB_weights, gidxB)
return None, dA_weights, None, dB_weights, None


class CSRSum(th.autograd.Function):
@staticmethod
def forward(ctx, gidxs, *weights):
# PyTorch tensors must be explicit arguments of the forward function
gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxs, gidxC
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights

@staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxs, gidxC = ctx.backward_cache
return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)


class CSRMask(th.autograd.Function):
@staticmethod
def forward(ctx, gidxA, A_weights, gidxB):
ctx.backward_cache = gidxA, gidxB
return _csrmask(gidxA, A_weights, gidxB)

@staticmethod
def backward(ctx, dB_weights):
gidxA, gidxB = ctx.backward_cache
return None, csrmask(gidxB, dB_weights, gidxA), None


def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)

Expand All @@ -320,3 +379,21 @@ def segment_reduce(op, x, offsets):

def scatter_add(x, idx, m):
return ScatterAdd.apply(x, idx, m)

def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = \
CSRMM.apply(gidxA, A_weights, gidxB, B_weights, num_vtypes)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.item(), ncols.item(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights

def csrsum(gidxs, weights):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(gidxs, *weights)
gidxC = create_unitgraph_from_csr(
gidxs[0].number_of_ntypes(), nrows.item(), ncols.item(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights

def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB)
66 changes: 65 additions & 1 deletion python/dgl/backend/tensorflow/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr

__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add']
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']


def _scatter_nd(index, src, n_rows):
Expand Down Expand Up @@ -295,3 +298,64 @@ def scatter_add(x, idx, m):
def _lambda(x):
return scatter_add_real(x, idx, m)
return _lambda(x)


def csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')

def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
dgidxB, dB_weights = _csrmm(
gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes())
dA_weights = _csrmask(dgidxA, dA_weights, gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, gidxB)
return dA_weights, dB_weights
return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad

def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
@tf.custom_gradient
def _lambda(A_weights, B_weights):
return csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(A_weights, B_weights)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights


def csrsum_real(gidxs, weights):
gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')

def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
return tuple(_csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad

def csrsum(gidxs, weights):
@tf.custom_gradient
def _lambda(*weights):
return csrsum_real(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights)
num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights


def csrmask_real(gidxA, A_weights, gidxB):
B_weights = _csrmask(gidxA, A_weights, gidxB)

def grad(dB_weights):
return _csrmask(gidxB, dB_weights, gidxA)
return B_weights, grad

def csrmask(gidxA, A_weights, gidxB):
@tf.custom_gradient
def _lambda(A_weights):
return csrmask_real(gidxA, A_weights, gidxB)
return _lambda(A_weights)
Loading

0 comments on commit 657c220

Please sign in to comment.