Skip to content

Commit

Permalink
[Fix] Fix to_homo for graph with zero nodes ntype (dmlc#3011)
Browse files Browse the repository at this point in the history
* fix dmlc#2870

* lint

* fix
  • Loading branch information
VoVAllen authored Jun 14, 2021
1 parent 8b64ae5 commit 17141dd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
15 changes: 8 additions & 7 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5985,13 +5985,14 @@ def combine_frames(frames, ids, col_names=None):
schemes = {key: frames[ids[0]].schemes[key] for key in col_names}
for frame_id in ids:
frame = frames[frame_id]
for key, scheme in list(schemes.items()):
if key in frame.schemes:
if frame.schemes[key] != scheme:
raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
(key, frame.schemes[key], scheme))
else:
del schemes[key]
if frame.num_rows != 0:
for key, scheme in list(schemes.items()):
if key in frame.schemes:
if frame.schemes[key] != scheme:
raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
(key, frame.schemes[key], scheme))
else:
del schemes[key]

if len(schemes) == 0:
return None
Expand Down
13 changes: 13 additions & 0 deletions tests/compute/test_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,19 @@ def test_convert(idtype):
assert hg.device == g.device
assert g.number_of_nodes() == 5

@unittest.skipIf(F._default_context_str == 'gpu', reason="Test on cpu is enough")
@parametrize_dtype
def test_to_homo_zero_nodes(idtype):
# Fix gihub issue #2870
g = dgl.heterograph({
('A', 'AB', 'B'): (np.random.randint(0, 200, (1000,)), np.random.randint(0, 200, (1000,))),
('B', 'BA', 'A'): (np.random.randint(0, 200, (1000,)), np.random.randint(0, 200, (1000,))),
}, num_nodes_dict={'A': 200, 'B': 200, 'C': 0}, idtype=idtype)
g.nodes['A'].data['x'] = F.randn((200, 3))
g.nodes['B'].data['x'] = F.randn((200, 3))
gg = dgl.to_homogeneous(g, ['x'])
assert 'x' in gg.ndata

@parametrize_dtype
def test_to_homo2(idtype):
# test the result homogeneous graph has nodes and edges sorted by their types
Expand Down

0 comments on commit 17141dd

Please sign in to comment.