Skip to content

Commit

Permalink
[BUG] Fix remove_edges crashing with empty edge ID tensors (dmlc#1384)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Mar 22, 2020
1 parent a072140 commit d27b485
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
11 changes: 8 additions & 3 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,15 +922,20 @@ def remove_edges(g, edge_ids):
"Graph has more than one edge type; specify a dict for edge_id instead.")
edge_ids = {g.canonical_etypes[0]: edge_ids}

edge_ids_nd = [None] * len(g.etypes)
edge_ids_nd = [nd.null()] * len(g.etypes)
for key, value in edge_ids.items():
edge_ids_nd[g.get_etype_id(key)] = F.zerocopy_to_dgl_ndarray(value)
new_graph_index, induced_eids_nd = _CAPI_DGLRemoveEdges(g._graph, edge_ids_nd)

new_graph = DGLHeteroGraph(new_graph_index, g.ntypes, g.etypes)
for i, canonical_etype in enumerate(g.canonical_etypes):
new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray(
induced_eids_nd[i].data)
data = induced_eids_nd[i].data
if len(data) == 0:
# Empty means that no edges are removed and edges are not shuffled.
new_graph.edges[canonical_etype].data[EID] = F.arange(
0, g.number_of_edges(canonical_etype))
else:
new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray(data)

return new_graph

Expand Down
10 changes: 10 additions & 0 deletions tests/compute/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,16 @@ def check(g1, etype, g, edges_removed):
check(g2, 'AB', g, [3])
check(g2, 'BA', g, [1])

g3 = dgl.remove_edges(g, {'AA': F.tensor([]), 'AB': F.tensor([3]), 'BA': F.tensor([1])})
check(g3, 'AA', g, [])
check(g3, 'AB', g, [3])
check(g3, 'BA', g, [1])

g4 = dgl.remove_edges(g, {'AB': F.tensor([3])})
check(g4, 'AA', g, [])
check(g4, 'AB', g, [3])
check(g4, 'BA', g, [])

if __name__ == '__main__':
test_line_graph()
test_no_backtracking()
Expand Down

0 comments on commit d27b485

Please sign in to comment.