Skip to content

Commit

Permalink
[Bugfix] Fix number of nodes mismatch in to_homo() conversion (dmlc#874)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Sep 22, 2019
1 parent bf8bb58 commit f9c0217
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,11 @@ def to_homo(G):
eids = []
ntype_ids = []
nids = []
total_num_nodes = 0

for ntype_id, ntype in enumerate(G.ntypes):
num_nodes = G.number_of_nodes(ntype)
total_num_nodes += num_nodes
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu()))
nids.append(F.arange(0, num_nodes))

Expand All @@ -445,7 +447,7 @@ def to_homo(G):
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges))

retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), card=total_num_nodes)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)
Expand Down
5 changes: 5 additions & 0 deletions tests/compute/test_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,11 @@ def test_convert():
assert hg.number_of_edges(('user', 'watches', 'movie')) == 1
assert len(hg.etypes) == 2

# hetero_to_homo test case 2
hg = dgl.bipartite([(0, 0), (1, 1)], card=(2, 3))
g = dgl.to_homo(hg)
assert g.number_of_nodes() == 5

def test_subgraph():
g = create_test_heterograph()
x = F.randn((3, 5))
Expand Down

0 comments on commit f9c0217

Please sign in to comment.