Skip to content

Commit

Permalink
[Perf] lazily create msg_index. (dmlc#563)
Browse files Browse the repository at this point in the history
* lazily create msg_index.

* update test.
  • Loading branch information
zheng-da authored May 24, 2019
1 parent de54891 commit 14af840
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
20 changes: 14 additions & 6 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ def __init__(self,
self._edge_frame = edge_frame
# message indicator:
# if self._msg_index[eid] == 1, then edge eid has message
self._msg_index = utils.zero_index(size=self.number_of_edges())
self._msg_index = None
# message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame
Expand All @@ -921,6 +921,14 @@ def __init__(self,
self._apply_node_func = None
self._apply_edge_func = None

def _get_msg_index(self):
if self._msg_index is None:
self._msg_index = utils.zero_index(size=self.number_of_edges())
return self._msg_index

def _set_msg_index(self, index):
self._msg_index = index

def add_nodes(self, num, data=None):
"""Add multiple new nodes.
Expand Down Expand Up @@ -1026,7 +1034,8 @@ def add_edge(self, u, v, data=None):
else:
self._edge_frame.append(data)
# resize msg_index and msg_frame
self._msg_index = self._msg_index.append_zeros(1)
if self._msg_index is not None:
self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1)

def add_edges(self, u, v, data=None):
Expand Down Expand Up @@ -1086,7 +1095,8 @@ def add_edges(self, u, v, data=None):
else:
self._edge_frame.append(data)
# initialize feature placeholder for messages
self._msg_index = self._msg_index.append_zeros(num)
if self._msg_index is not None:
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)

def clear(self):
Expand All @@ -1111,7 +1121,7 @@ def clear(self):
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_index = utils.zero_index(0)
self._msg_index = None
self._msg_frame.clear()

def clear_cache(self):
Expand Down Expand Up @@ -1218,7 +1228,6 @@ def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())

# copy attributes
Expand Down Expand Up @@ -1285,7 +1294,6 @@ def from_scipy_sparse_matrix(self, spmat):
self._graph.from_scipy_sparse_matrix(spmat)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())

def node_attr_schemes(self):
Expand Down
8 changes: 4 additions & 4 deletions python/dgl/runtime/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def schedule_send(graph, u, v, eid, message_func):
msg = _gen_send(graph, var_nf, var_nf, var_ef, var_u, var_v, var_eid, message_func)
ir.WRITE_ROW_(var_mf, var_eid, msg)
# set message indicator to 1
graph._msg_index = graph._msg_index.set_items(eid, 1)
graph._set_msg_index(graph._get_msg_index().set_items(eid, 1))

def schedule_recv(graph,
recv_nodes,
Expand All @@ -80,7 +80,7 @@ def schedule_recv(graph,
"""
src, dst, eid = graph._graph.in_edges(recv_nodes)
if len(eid) > 0:
nonzero_idx = graph._msg_index.get_items(eid).nonzero()
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx)
src = src.get_items(nonzero_idx)
dst = dst.get_items(nonzero_idx)
Expand All @@ -107,8 +107,8 @@ def schedule_recv(graph,
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
# set message indicator to 0
graph._msg_index = graph._msg_index.set_items(eid, 0)
if not graph._msg_index.has_nonzero():
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._get_msg_index().has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf'))

def schedule_snr(graph,
Expand Down
22 changes: 11 additions & 11 deletions tests/compute/test_multi_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _fmsg(edges):
eid = g.edge_ids([0, 0, 0, 0, 0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 9, 9, 9, 9, 9])
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)

def test_multi_recv():
# basic recv test
Expand All @@ -80,20 +80,20 @@ def test_multi_recv():
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)

u = [0]
v = [1, 2, 3]
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)

h1 = g.ndata['h']

Expand All @@ -104,19 +104,19 @@ def test_multi_recv():
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [4, 5, 6]
v = [9]
g.recv(v)
eid = g.edge_ids(u, v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [0]
v = [1, 2, 3]
g.recv(v)
eid = g.edge_ids(u, v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)

h2 = g.ndata['h']
assert F.allclose(h1, h2)
Expand Down Expand Up @@ -250,7 +250,7 @@ def _apply(nodes):
'h2': F.randn((2, D))})
g.send()
expected = F.ones((g.number_of_edges(),), dtype=F.int64)
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)

# add more edges
g.add_edges([0, 2], [2, 0], {'h1': F.randn((2, D))})
Expand Down Expand Up @@ -281,10 +281,10 @@ def test_recv_no_send():
g.send((1, 2), message_func)
expected = F.zeros((2,), dtype=F.int64)
expected[1] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(2, reduce_func)
expected[1] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)

def test_send_recv_after_conversion():
# test send and recv after converting from a graph with edges
Expand Down

0 comments on commit 14af840

Please sign in to comment.