Skip to content

Commit

Permalink
[Bug fix] Use shared memory for grad sync when NCCL is not avaliable …
Browse files Browse the repository at this point in the history
…as PyTorch distributed backend. (dmlc#3034)

* Use shared memory for grad sync when NCCL is not avaliable as PyTorch distributed backend.

Fix small bugs and update unitests

* Fix bug

* update test

* update test

* Fix unitest

* Fix unitest

* Fix test

* Fix

* simple update

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
classicsong and Ubuntu authored Jun 24, 2021
1 parent 31f4483 commit 2f7ca41
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 39 deletions.
47 changes: 36 additions & 11 deletions python/dgl/optim/pytorch/sparse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def step(self):
else:
assert not self._device, \
"All gradients must be on the same device"
if self._device:

# distributed backend use nccl
if self._device and \
(not th.distributed.is_initialized() or th.distributed.get_backend() == 'nccl'):
# device is only set if the grads are on a GPU
self._comm_setup()
else:
Expand Down Expand Up @@ -284,9 +287,11 @@ def _shared_step(self):
# The overall buffer cost will be smaller than three times
# the maximum memory requirement for sharing gradients.
buffer_size = 128 if idx_i.shape[0] < 128 else idx_i.shape[0] * 2
idx_shmem = create_shared_mem_array(idx_shmem_name, \
idx_shmem = create_shared_mem_array(
'{}_{}'.format(idx_shmem_name, buffer_size), \
(buffer_size,), idx_dtype)
grad_shmem = create_shared_mem_array(grad_shmem_name, \
grad_shmem = create_shared_mem_array(
'{}_{}'.format(grad_shmem_name, buffer_size), \
(buffer_size, grad_dim), grad_dtype)
self._shared_cache[emb_name][idx_shmem_name] = idx_shmem
self._shared_cache[emb_name][grad_shmem_name] = grad_shmem
Expand Down Expand Up @@ -321,9 +326,11 @@ def _shared_step(self):
if idx_shmem_name not in self._shared_cache[emb_name] or \
self._shared_cache[emb_name][idx_shmem_name].shape[0] < size:
buffer_size = 128 if size < 128 else size * 2
idx_shmem = get_shared_mem_array(idx_shmem_name, \
idx_shmem = get_shared_mem_array(
'{}_{}'.format(idx_shmem_name, buffer_size), \
(buffer_size,), idx_dtype)
grad_shmem = get_shared_mem_array(grad_shmem_name, \
grad_shmem = get_shared_mem_array(
'{}_{}'.format(grad_shmem_name, buffer_size), \
(buffer_size, grad_dim), grad_dtype)
self._shared_cache[emb_name][idx_shmem_name] = idx_shmem
self._shared_cache[emb_name][grad_shmem_name] = grad_shmem
Expand Down Expand Up @@ -424,10 +431,15 @@ def setup(self, params):
emb_name = emb.name
if th.device(emb.emb_tensor.device) == th.device('cpu'):
# if our embedding is on the CPU, our state also has to be
if self._rank <= 0:
if self._rank < 0:
state = th.empty(
emb.weight.shape,
dtype=th.float32,
device=eth.device('cpu')).zero_()
elif self._rank == 0:
state = create_shared_mem_array(emb_name+'_state', \
emb.weight.shape, th.float32).zero_()
if self._rank == 0:

if self._world_size > 1:
emb.store.set(emb_name+'_opt', emb_name)
elif self._rank > 0:
Expand Down Expand Up @@ -538,14 +550,27 @@ def setup(self, params):
emb_name = emb.name
if th.device(emb.emb_tensor.device) == th.device('cpu'):
# if our embedding is on the CPU, our state also has to be
if self._rank <= 0:
if self._rank < 0:
state_step = th.empty(
(emb.weight.shape[0],),
dtype=th.float32,
device=th.device('cpu')).zero_()
state_mem = th.empty(
emb.weight.shape,
dtype=th.float32,
device=th.device('cpu')).zero_()
state_power = th.empty(
emb.weight.shape,
dtype=th.float32,
device=th.device('cpu')).zero_()
elif self._rank == 0:
state_step = create_shared_mem_array(emb_name+'_step', \
(emb.weight.shape[0],), th.float32).zero_()
state_mem = create_shared_mem_array(emb_name+'_mem', \
emb.weight.shape, th.float32).zero_()
state_power = create_shared_mem_array(emb_name+'_power', \
emb.weight.shape, th.float32).zero_()
if self._rank == 0:

if self._world_size > 1:
emb.store.set(emb_name+'_opt', emb_name)
elif self._rank > 0:
Expand Down Expand Up @@ -601,8 +626,8 @@ def update(self, idx, grad, emb):
# only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu
state_nonblock = state_dev != th.device('cpu')
exec_nonblock = exec_dev != th.device('cpu')
state_nonblock = False # state_dev != th.device('cpu')
exec_nonblock = False # exec_dev != th.device('cpu')

# There can be duplicated indices due to sampling.
# Thus unique them here and average the gradient here.
Expand Down
Loading

0 comments on commit 2f7ca41

Please sign in to comment.