Skip to content

Commit

Permalink
[BUGFIX] Fix is_multigraph in the construction from scipy coo matrix (d…
Browse files Browse the repository at this point in the history
…mlc#1357)

* fix is_multigraph in from_coo.

* add tests for partition.

* fix.

* Revert "add tests for partition."

This reverts commit cb8c855.

* fix everywhere from_scipy_sparse_matrix is used.
  • Loading branch information
zheng-da authored Mar 22, 2020
1 parent 856de79 commit 10253a5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
8 changes: 6 additions & 2 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,14 +1805,18 @@ def _batcher(lst):
raise DGLError('Not all edges have attribute {}.'.format(attr))
self._edge_frame[attr] = _batcher(attr_dict[attr])

def from_scipy_sparse_matrix(self, spmat):
def from_scipy_sparse_matrix(self, spmat, multigraph=False):
""" Convert from scipy sparse matrix.
Parameters
----------
spmat : scipy sparse matrix
The graph's adjacency matrix
multigraph : bool, optional
Whether the graph would be a multigraph. If the input scipy sparse matrix is CSR,
this argument is ignored.
Examples
--------
>>> from scipy.sparse import coo_matrix
Expand All @@ -1824,7 +1828,7 @@ def from_scipy_sparse_matrix(self, spmat):
>>> g.from_scipy_sparse_matrix(a)
"""
self.clear()
self._graph = graph_index.from_scipy_sparse_matrix(spmat, self.is_readonly)
self._graph = graph_index.from_scipy_sparse_matrix(spmat, multigraph, self.is_readonly)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
Expand Down
10 changes: 7 additions & 3 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,12 +1139,15 @@ def from_networkx(nx_graph, readonly):
dst = utils.toindex(dst)
return from_coo(num_nodes, src, dst, is_multigraph, readonly)

def from_scipy_sparse_matrix(adj, readonly):
def from_scipy_sparse_matrix(adj, multigraph, readonly):
"""Convert from scipy sparse matrix.
Parameters
----------
adj : scipy sparse matrix
multigraph : bool
Whether the graph would be a multigraph. If none, the flag will be determined
by the data.
readonly : bool
True if the returned graph is readonly.
Expand All @@ -1156,8 +1159,9 @@ def from_scipy_sparse_matrix(adj, readonly):
if adj.getformat() != 'csr' or not readonly:
num_nodes = max(adj.shape[0], adj.shape[1])
adj_coo = adj.tocoo()
return from_coo(num_nodes, adj_coo.row, adj_coo.col, False, readonly)
return from_coo(num_nodes, adj_coo.row, adj_coo.col, multigraph, readonly)
else:
# If the input matrix is csr, it's guaranteed to be a simple graph.
return from_csr(adj.indptr, adj.indices, False, "out")

def from_edge_list(elist, is_multigraph, readonly):
Expand Down Expand Up @@ -1298,7 +1302,7 @@ def create_graph_index(graph_data, multigraph, readonly):
return from_edge_list(graph_data, multigraph, readonly)
elif isinstance(graph_data, scipy.sparse.spmatrix):
# scipy format
return from_scipy_sparse_matrix(graph_data, readonly)
return from_scipy_sparse_matrix(graph_data, multigraph, readonly)
else:
# networkx - any format
try:
Expand Down
3 changes: 2 additions & 1 deletion tests/compute/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def create_large_graph_index(num_nodes):
row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))
return from_scipy_sparse_matrix(spm, True)
# It's possible that we generate a multigraph.
return from_scipy_sparse_matrix(spm, True, True)

def get_nodeflow(g, node_ids, num_layers):
batch_size = len(node_ids)
Expand Down
3 changes: 2 additions & 1 deletion tests/graph_index/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def create_large_graph_index(num_nodes):
row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))
return from_scipy_sparse_matrix(spm, True)
# It's possible that we generate a multigraph.
return from_scipy_sparse_matrix(spm, True, True)

def test_node_subgraph_with_halo():
gi = create_large_graph_index(1000)
Expand Down

0 comments on commit 10253a5

Please sign in to comment.