Skip to content

Commit

Permalink
[Feature] Support direct creation from CSR and CSC (dmlc#3045)
Browse files Browse the repository at this point in the history
* csr and csc creation

* fix

* fix

* fixes to adj transpose

* fine

* raise error if indptr did not match number of nodes

* fix

* huh?

* oh

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
BarclayII and jermainewang authored Jun 25, 2021
1 parent 2f7ca41 commit acd21a6
Show file tree
Hide file tree
Showing 21 changed files with 364 additions and 192 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch/diffpool/model/dgl_layers/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def forward(self, g, h):
assign_tensor = torch.block_diag(*assign_tensor) # size = (sum_N, batch_size * N_a)

h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adjacency_matrix(transpose=False, ctx=device)
adj = g.adjacency_matrix(transpose=True, ctx=device)
adj_new = torch.sparse.mm(adj, assign_tensor)
adj_new = torch.mm(torch.t(assign_tensor), adj_new)

Expand Down
4 changes: 2 additions & 2 deletions python/dgl/backend/mxnet/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(self, gidxA, gidxB, num_vtypes):

def forward(self, A_weights, B_weights):
gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
Expand Down Expand Up @@ -430,7 +430,7 @@ def __init__(self, gidxs):
def forward(self, *weights):
gidxC, C_weights = _csrsum(self.gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, True, 'csr')
0, False, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class CSRMM(th.autograd.Function):
@staticmethod
def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxA, gidxB, gidxC
Expand All @@ -337,7 +337,7 @@ def forward(ctx, gidxs, *weights):
# PyTorch tensors must be explicit arguments of the forward function
gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, True, 'csr')
0, False, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxs, gidxC
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/backend/tensorflow/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _lambda(x):

def csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr')

def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
Expand All @@ -328,7 +328,7 @@ def _lambda(A_weights, B_weights):

def csrsum_real(gidxs, weights):
gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr')

def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
Expand Down
128 changes: 80 additions & 48 deletions python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,18 @@ def graph(data,
DGL calls this format "tuple of node-tensors". The tensors should have the same
data type of int32/int64 and device context (see below the descriptions of
:attr:`idtype` and :attr:`device`).
- ``(iterable[int], iterable[int])``: Similar to the tuple of node-tensors
format, but stores node IDs in two sequences (e.g. list, tuple, numpy.ndarray).
- ``('coo', (Tensor, Tensor))``: Same as ``(Tensor, Tensor)``.
- ``('csr', (Tensor, Tensor, Tensor))``: The three tensors form the CSR representation
of the graph's adjacency matrix. The first one is the row index pointer. The
second one is the column indices. The third one is the edge IDs, which can be empty
to represent consecutive integer IDs starting from 0.
- ``('csc', (Tensor, Tensor, Tensor))``: The three tensors form the CSC representation
of the graph's adjacency matrix. The first one is the column index pointer. The
second one is the row indices. The third one is the edge IDs, which can be empty
to represent consecutive integer IDs starting from 0.
The tensors can be replaced with any iterable of integers (e.g. list, tuple,
numpy.ndarray).
ntype : str, optional
Deprecated. To construct a graph with named node types, use :func:`dgl.heterograph`.
etype : str, optional
Expand Down Expand Up @@ -131,6 +141,14 @@ def graph(data,
>>> g = dgl.graph((src_ids, dst_ids), idtype=torch.int32, device='cuda:0')
Creating a graph with CSR representation:
>>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [])))
Create the same graph with CSR representation and edge IDs.
>>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [0, 1, 2])))
See Also
--------
from_scipy
Expand Down Expand Up @@ -158,16 +176,15 @@ def graph(data,
" Please refer to their API documents for more details.".format(
deprecated_kwargs.keys()))

u, v, urange, vrange = utils.graphdata2tensors(data, idtype)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(data, idtype)
if num_nodes is not None: # override the number of nodes
if num_nodes < max(urange, vrange):
raise DGLError('The num_nodes argument must be larger than the max ID in the data,'
' but got {} and {}.'.format(num_nodes, max(urange, vrange) - 1))
urange, vrange = num_nodes, num_nodes

g = create_from_edges(u, v, '_N', '_E', '_N', urange, vrange,
row_sorted=row_sorted, col_sorted=col_sorted,
validate=False)
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange,
row_sorted=row_sorted, col_sorted=col_sorted)

return g.to(device)

Expand Down Expand Up @@ -226,8 +243,18 @@ def heterograph(data_dict,
this format "tuple of node-tensors". The tensors should have the same data type,
which must be either int32 or int64. They should also have the same device context
(see below the descriptions of :attr:`idtype` and :attr:`device`).
- ``(iterable[int], iterable[int])``: Similar to the tuple of node-tensors
format, but stores node IDs in two sequences (e.g. list, tuple, numpy.ndarray).
- ``('coo', (Tensor, Tensor))``: Same as ``(Tensor, Tensor)``.
- ``('csr', (Tensor, Tensor, Tensor))``: The three tensors form the CSR representation
of the graph's adjacency matrix. The first one is the row index pointer. The
second one is the column indices. The third one is the edge IDs, which can be empty
(i.e. with 0 elements) to represent consecutive integer IDs starting from 0.
- ``('csc', (Tensor, Tensor, Tensor))``: The three tensors form the CSC representation
of the graph's adjacency matrix. The first one is the column index pointer. The
second one is the row indices. The third one is the edge IDs, which can be empty
to represent consecutive integer IDs starting from 0.
The tensors can be replaced with any iterable of integers (e.g. list, tuple,
numpy.ndarray).
num_nodes_dict : dict[str, int], optional
The number of nodes for some node types, which is a dictionary mapping a node type
:math:`T` to the number of :math:`T`-typed nodes. If not given for a node type
Expand Down Expand Up @@ -320,8 +347,9 @@ def heterograph(data_dict,
raise DGLError("dgl.heterograph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead.")
is_bipartite = (sty != dty)
u, v, urange, vrange = utils.graphdata2tensors(data, idtype, bipartite=is_bipartite)
node_tensor_dict[(sty, ety, dty)] = (u, v)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=is_bipartite)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer:
num_nodes_dict[sty] = max(num_nodes_dict[sty], urange)
num_nodes_dict[dty] = max(num_nodes_dict[dty], vrange)
Expand All @@ -340,8 +368,8 @@ def heterograph(data_dict,
num_nodes_per_type = utils.toindex([num_nodes_dict[ntype] for ntype in ntypes], "int64")
rel_graphs = []
for srctype, etype, dsttype in relations:
src, dst = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(src, dst, srctype, etype, dsttype,
sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(sparse_fmt, arrays, srctype, etype, dsttype,
num_nodes_dict[srctype], num_nodes_dict[dsttype])
rel_graphs.append(g)

Expand All @@ -368,8 +396,18 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
this format "tuple of node-tensors". The tensors should have the same data type,
which must be either int32 or int64. They should also have the same device context
(see below the descriptions of :attr:`idtype` and :attr:`device`).
- ``(iterable[int], iterable[int])``: Similar to the tuple of node-tensors
format, but stores node IDs in two sequences (e.g. list, tuple, numpy.ndarray).
- ``('coo', (Tensor, Tensor))``: Same as ``(Tensor, Tensor)``.
- ``('csr', (Tensor, Tensor, Tensor))``: The three tensors form the CSR representation
of the graph's adjacency matrix. The first one is the row index pointer. The
second one is the column indices. The third one is the edge IDs, which can be empty
to represent consecutive integer IDs starting from 0.
- ``('csc', (Tensor, Tensor, Tensor))``: The three tensors form the CSC representation
of the graph's adjacency matrix. The first one is the column index pointer. The
second one is the row indices. The third one is the edge IDs, which can be empty
to represent consecutive integer IDs starting from 0.
The tensors can be replaced with any iterable of integers (e.g. list, tuple,
numpy.ndarray).
If you would like to create a MFG with a single source node type, a single destination
node type, and a single edge type, then you can pass in the graph data directly
Expand Down Expand Up @@ -489,8 +527,9 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
# Convert all data to node tensors first
node_tensor_dict = {}
for (sty, ety, dty), data in data_dict.items():
u, v, urange, vrange = utils.graphdata2tensors(data, idtype, bipartite=True)
node_tensor_dict[(sty, ety, dty)] = (u, v)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=True)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer:
num_src_nodes[sty] = max(num_src_nodes[sty], urange)
num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange)
Expand Down Expand Up @@ -525,8 +564,8 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
meta_edges_src.append(srctype_dict[srctype])
meta_edges_dst.append(dsttype_dict[dsttype])
etypes.append(etype)
src, dst = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(src, dst, 'SRC/' + srctype, etype, 'DST/' + dsttype,
sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(sparse_fmt, arrays, 'SRC/' + srctype, etype, 'DST/' + dsttype,
num_src_nodes[srctype], num_dst_nodes[dsttype])
rel_graphs.append(g)

Expand Down Expand Up @@ -1041,8 +1080,8 @@ def from_scipy(sp_mat,
raise DGLError('Expect the number of rows to be the same as the number of columns for '
'sp_mat, got {:d} and {:d}.'.format(num_rows, num_cols))

u, v, urange, vrange = utils.graphdata2tensors(sp_mat, idtype)
g = create_from_edges(u, v, '_N', '_E', '_N', urange, vrange)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(sp_mat, idtype)
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange)
if eweight_name is not None:
g.edata[eweight_name] = F.tensor(sp_mat.data)
return g.to(device)
Expand Down Expand Up @@ -1135,9 +1174,8 @@ def bipartite_from_scipy(sp_mat,
heterograph
bipartite_from_networkx
"""
# Sanity check
u, v, urange, vrange = utils.graphdata2tensors(sp_mat, idtype, bipartite=True)
g = create_from_edges(u, v, utype, etype, vtype, urange, vrange)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(sp_mat, idtype, bipartite=True)
g = create_from_edges(sparse_fmt, arrays, utype, etype, vtype, urange, vrange)
if eweight_name is not None:
g.edata[eweight_name] = F.tensor(sp_mat.data)
return g.to(device)
Expand Down Expand Up @@ -1255,10 +1293,10 @@ def from_networkx(nx_graph,
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()

u, v, urange, vrange = utils.graphdata2tensors(
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
nx_graph, idtype, edge_id_attr_name=edge_id_attr_name)

g = create_from_edges(u, v, '_N', '_E', '_N', urange, vrange)
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange)

# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
Expand Down Expand Up @@ -1450,12 +1488,12 @@ def bipartite_from_networkx(nx_graph,
bottom_map = {n : i for i, n in enumerate(bottom_nodes)}

# Get the node tensors and the number of nodes
u, v, urange, vrange = utils.graphdata2tensors(
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
nx_graph, idtype, bipartite=True,
edge_id_attr_name=edge_id_attr_name,
top_map=top_map, bottom_map=bottom_map)

g = create_from_edges(u, v, utype, etype, vtype, urange, vrange)
g = create_from_edges(sparse_fmt, arrays, utype, etype, vtype, urange, vrange)

# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
Expand Down Expand Up @@ -1586,10 +1624,9 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
# Internal APIs
############################################################

def create_from_edges(u, v,
def create_from_edges(sparse_fmt, arrays,
utype, etype, vtype,
urange, vrange,
validate=True,
row_sorted=False,
col_sorted=False):
"""Internal function to create a graph from incident nodes with types.
Expand All @@ -1598,10 +1635,10 @@ def create_from_edges(u, v,
Parameters
----------
u : Tensor
Source node IDs.
v : Tensor
Dest node IDs.
sparse_fmt : str
The sparse adjacency matrix format.
arrays : tuple[Tensor]
The sparse adjacency matrix arrays.
utype : str
Source node type name.
etype : str
Expand All @@ -1614,8 +1651,6 @@ def create_from_edges(u, v,
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
row_sorted : bool, optional
Whether or not the rows of the COO are in ascending order.
col_sorted : bool, optional
Expand All @@ -1627,24 +1662,21 @@ def create_from_edges(u, v,
-------
DGLHeteroGraph
"""
if validate:
if urange is not None and len(u) > 0 and \
urange <= F.as_scalar(F.max(u, dim=0)):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
urange, F.as_scalar(F.max(u, dim=0))))
if vrange is not None and len(v) > 0 and \
vrange <= F.as_scalar(F.max(v, dim=0)):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
vrange, F.as_scalar(F.max(v, dim=0))))

if utype == vtype:
num_ntypes = 1
else:
num_ntypes = 2

hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'],
row_sorted, col_sorted)
if sparse_fmt == 'coo':
u, v = arrays
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'],
row_sorted, col_sorted)
else: # 'csr' or 'csc'
indptr, indices, eids = arrays
hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, urange, vrange, indptr, indices, eids, ['coo', 'csr', 'csc'],
sparse_fmt == 'csc')
if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
Expand Down
6 changes: 3 additions & 3 deletions python/dgl/data/fraud.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .utils import save_graphs, load_graphs, _get_dgl_url
from ..convert import heterograph
from ..utils import graphdata2tensors
from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F

Expand Down Expand Up @@ -106,8 +105,9 @@ def process(self):

graph_data = {}
for relation in self.relations[self.name]:
u, v, _, _ = graphdata2tensors(data[relation])
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (u, v)
adj = data[relation].tocoo()
row, col = adj.row, adj.col
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data)

g.ndata['feature'] = F.tensor(node_features)
Expand Down
Loading

0 comments on commit acd21a6

Please sign in to comment.