Skip to content

Commit

Permalink
we need to sync for mxnet. (dmlc#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da authored Jun 12, 2019
1 parent d706298 commit 16ec2a8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,3 +937,12 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
# ----------------
# These are not related to tensors. Some of them are temporary workarounds that
# should be included in DGL in the future.

def sync():
"""Synchronize computation.
In DL frameworks such as MXNet and TensorFlow, the computation in operators
are done asynchronously. This is to synchronize computation and makes sure
that all computation is complete after this function call.
"""
pass
9 changes: 9 additions & 0 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,12 @@ def _reduce_grad(grad, shape):
reduce_idx += 1 # skip batch dim
grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)
return grad.reshape(shape)

def sync():
"""Synchronize computation.
In DL frameworks such as MXNet and TensorFlow, the computation in operators
are done asynchronously. This is to synchronize computation and makes sure
that all computation is complete after this function call.
"""
mx.nd.waitall()
4 changes: 4 additions & 0 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,7 @@ def _reduce_grad(grad, shape):
reduce_idx += 1 # skip batch dim
grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
return grad.view(shape)

def sync():
# Pytorch performs computation synchronously, so no need for synchronization.
pass
6 changes: 6 additions & 0 deletions python/dgl/contrib/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def init_ndata(init, ndata_name, shape, dtype):
init = self._init_manager.deserialize(init)
data = init(shape, dtype, _get_ndata_path(graph_name, ndata_name))
self._graph.ndata[ndata_name] = data
F.sync()
return 0

# RPC command: initialize edge embedding in the server.
Expand All @@ -375,6 +376,7 @@ def init_edata(init, edata_name, shape, dtype):
assert self._graph.number_of_edges() == shape[0]
init = self._init_manager.deserialize(init)
data = init(shape, dtype, _get_edata_path(graph_name, edata_name))
F.sync()
self._graph.edata[edata_name] = data
return 0

Expand Down Expand Up @@ -636,6 +638,10 @@ def _sync_barrier(self, timeout=None):
timeout: int
time out in seconds.
"""
# Before entering the barrier, we need to make sure all computation in the local
# process has completed.
F.sync()

# Here I manually implement multi-processing barrier with RPC.
# It uses busy wait with RPC. Whenever, all_enter is called, there is
# a context switch, so it doesn't burn CPUs so badly.
Expand Down

0 comments on commit 16ec2a8

Please sign in to comment.