Skip to content

Commit

Permalink
[Bug fix] [Feature] added option for batching empty data (dmlc#2527)
Browse files Browse the repository at this point in the history
* added option for batching empty data, fixes dmlc#2526

* added option for batching empty data, fixes dmlc#2526

* decreased line lengths

* removed trailing whitespace

* fixed wrong feature name

* now default behavior when all graphs are empty

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
noncomputable and jermainewang authored Jan 14, 2021
1 parent 4d89b54 commit 0778766
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/dgl/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

__all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero']

def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
def batch(graphs, ndata=ALL, edata=ALL, *,
node_attrs=None, edge_attrs=None):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
graph computation.
Expand Down Expand Up @@ -191,9 +192,10 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
# Batch node feature
if ndata is not None:
for ntype_id, ntype in zip(ntype_ids, ntypes):
all_empty = all(g._graph.number_of_nodes(ntype_id) == 0 for g in graphs)
frames = [
g._node_frames[ntype_id] for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0]
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype))
Expand All @@ -202,9 +204,10 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
# Batch edge feature
if edata is not None:
for etype_id, etype in zip(relation_ids, relations):
all_empty = all(g._graph.number_of_edges(etype_id) == 0 for g in graphs)
frames = [
g._edge_frames[etype_id] for g in graphs
if g._graph.number_of_edges(etype_id) > 0]
if g._graph.number_of_edges(etype_id) > 0 or all_empty]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype))
Expand Down
12 changes: 12 additions & 0 deletions tests/compute/test_batched_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ def test_batch_no_edge(idtype):
g3.add_nodes(1) # no edges
g = dgl.batch([g1, g3, g2]) # should not throw an error

@parametrize_dtype
def test_batch_keeps_empty_data(idtype):
g1 = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g1.ndata["nh"] = F.tensor([])
g1.edata["eh"] = F.tensor([])
g2 = dgl.graph(([], [])).astype(idtype).to(F.ctx())
g2.ndata["nh"] = F.tensor([])
g2.edata["eh"] = F.tensor([])
g = dgl.batch([g1, g2])
assert "nh" in g.ndata
assert "eh" in g.edata

def _get_subgraph_batch_info(keys, induced_indices_arr, batch_num_objs):
"""Internal function to compute batch information for subgraphs.
Parameters
Expand Down
12 changes: 12 additions & 0 deletions tests/compute/test_batched_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,18 @@ def test_unbatch2(idtype):
check_graph_equal(g2, gg2)
check_graph_equal(g3, gg3)

@parametrize_dtype
def test_batch_keeps_empty_data(idtype):
g1 = dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
g1.nodes["a"].data["nh"] = F.tensor([])
g1.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g2 = dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
g2.nodes["a"].data["nh"] = F.tensor([])
g2.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g = dgl.batch([g1, g2])
assert "nh" in g.nodes["a"].data
assert "eh" in g.edges[("a", "to", "a")].data

if __name__ == '__main__':
#test_topology('int32')
#test_batching_batched('int32')
Expand Down

0 comments on commit 0778766

Please sign in to comment.