Skip to content

Commit

Permalink
[BUG] Fixing networkx conversion with 0 edges (dmlc#94)
Browse files Browse the repository at this point in the history
* fixing networkx conversion with 0 edges

* fixes as required

* remove obsolete comment
  • Loading branch information
BarclayII authored Oct 22, 2018
1 parent 00add9f commit 70d4758
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
15 changes: 10 additions & 5 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def to_networkx(self):
"""
src, dst, eid = self.edges()
ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
ret.add_nodes_from(range(self.number_of_nodes()))
for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id)
return ret
Expand All @@ -548,16 +549,20 @@ def from_networkx(self, nx_graph):

num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes)
has_edge_id = 'id' in next(iter(nx_graph.edges))

if nx_graph.number_of_edges() == 0:
return

# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1]
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for e, attr in nx_graph.edges.items:
# MultiDiGraph returns a triplet in e while DiGraph returns a pair
for u, v, attr in nx_graph.edges(data=True):
eid = attr['id']
src[eid] = e[0]
dst[eid] = e[1]
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
Expand Down
21 changes: 21 additions & 0 deletions tests/graph_index/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@ def test_nx():
assert 0 in gi.edge_id(0, 1)
assert 1 in gi.edge_id(0, 1)

nxg = nx.DiGraph()
nxg.add_nodes_from(range(3))
gi = create_graph_index(nxg)
assert gi.number_of_nodes() == 3
assert gi.number_of_edges() == 0

gi = create_graph_index()
gi.add_nodes(3)
nxg = gi.to_networkx()
assert len(nxg.nodes) == 3
assert len(nxg.edges) == 0

nxg = nx.DiGraph()
nxg.add_edge(0, 1, id=0)
nxg.add_edge(1, 2, id=1)
gi = create_graph_index(nxg)
assert 0 in gi.edge_id(0, 1)
assert 1 in gi.edge_id(1, 2)
assert gi.number_of_edges() == 2
assert gi.number_of_nodes() == 3

def test_predsucc():
gi = create_graph_index(multigraph=True)

Expand Down

0 comments on commit 70d4758

Please sign in to comment.