Skip to content

Commit

Permalink
[Bug] Fix khop graph (dmlc#1433)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update
  • Loading branch information
mufeili authored Apr 9, 2020
1 parent 88c3448 commit bd1e48a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
14 changes: 13 additions & 1 deletion python/dgl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,18 @@ def khop_graph(g, k):
Examples
--------
Below gives an easy example:
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1], [1, 2])
>>> g_2 = dgl.transform.khop_graph(g, 2)
>>> print(g_2.edges())
(tensor([0]), tensor([2]))
A more complicated example:
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
Expand All @@ -234,7 +246,7 @@ def khop_graph(g, k):
edata_schemes={})
"""
n = g.number_of_nodes()
adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k
adj_k = g.adjacency_matrix_scipy(transpose=True, return_edge_ids=False) ** k
adj_k = adj_k.tocoo()
multiplicity = adj_k.data
row = np.repeat(adj_k.row, multiplicity)
Expand Down
33 changes: 20 additions & 13 deletions tests/compute/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,27 @@ def _test(in_readonly, out_readonly):
def test_khop_graph():
N = 20
feat = F.randn((N, 5))
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for k in range(4):
g_k = dgl.khop_graph(g, k)
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop graph to do message passing for one time.
g_k.ndata['h'] = feat
g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_1 = g_k.ndata.pop('h')
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)

def _test(g):
for k in range(4):
g_k = dgl.khop_graph(g, k)
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop graph to do message passing for one time.
g_k.ndata['h'] = feat
g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_1 = g_k.ndata.pop('h')
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)

# Test for random undirected graphs
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
_test(g)
# Test for random directed graphs
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3, directed=True))
_test(g)

def test_khop_adj():
N = 20
Expand Down

0 comments on commit bd1e48a

Please sign in to comment.