Skip to content

Commit

Permalink
fix batched heterograph serializations (dmlc#1794)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Jul 13, 2020
1 parent 200340a commit c13903b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/dgl/batched_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,14 @@ def to(self, ctx, **kwargs): # pylint: disable=invalid-name
batch_num_nodes=self._batch_num_nodes,
batch_num_edges=self._batch_num_edges)

def __getstate__(self):
state = super().__getstate__()
return state, self._batch_size, self._batch_num_nodes, self._batch_num_edges

def __setstate__(self, state):
state, self._batch_size, self._batch_num_nodes, self._batch_num_edges = state
super().__setstate__(state)

def unbatch_hetero(graph):
"""Return the list of heterographs in this batch.
Expand Down
35 changes: 35 additions & 0 deletions tests/compute/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def _assert_is_identical_batchedgraph(bg1, bg2):
assert bg1.batch_num_nodes == bg2.batch_num_nodes
assert bg1.batch_num_edges == bg2.batch_num_edges

def _assert_is_identical_batchedhetero(bg1, bg2):
_assert_is_identical_hetero(bg1, bg2)
for ntype in bg1.ntypes:
assert bg1.batch_num_nodes(ntype) == bg2.batch_num_nodes(ntype)
for canonical_etype in bg1.canonical_etypes:
assert bg1.batch_num_edges(canonical_etype) == bg2.batch_num_edges(canonical_etype)

def _assert_is_identical_index(i1, i2):
assert i1.slice_data() == i2.slice_data()
assert F.array_equal(i1.tousertensor(), i2.tousertensor())
Expand Down Expand Up @@ -258,6 +265,33 @@ def test_pickling_heterograph():
new_g = _reconstruct_pickle(g)
_assert_is_identical_hetero(g, new_g)

def test_pickling_batched_heterograph():
# copied from test_heterograph.create_test_heterograph()
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
wishes_nx = nx.DiGraph()
wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
wishes_nx.add_edge('u0', 'g1', id=0)
wishes_nx.add_edge('u2', 'g0', id=1)

follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
g2 = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])

g.nodes['user'].data['u_h'] = F.randn((3, 4))
g.nodes['game'].data['g_h'] = F.randn((2, 5))
g.edges['plays'].data['p_h'] = F.randn((4, 6))
g2.nodes['user'].data['u_h'] = F.randn((3, 4))
g2.nodes['game'].data['g_h'] = F.randn((2, 5))
g2.edges['plays'].data['p_h'] = F.randn((4, 6))

bg = dgl.batch_hetero([g, g2])
new_bg = _reconstruct_pickle(bg)
_assert_is_identical_batchedhetero(bg, new_bg)

@unittest.skipIf(dgl.backend.backend_name != "pytorch", reason="Only test for pytorch format file")
def test_pickling_heterograph_index_compatibility():
plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
Expand Down Expand Up @@ -287,3 +321,4 @@ def test_pickling_heterograph_index_compatibility():
test_pickling_nodeflow()
test_pickling_batched_graph()
test_pickling_heterograph()
test_pickling_batched_heterograph()

0 comments on commit c13903b

Please sign in to comment.