Skip to content

Commit

Permalink
[Bug Fix] Fix several sparse optimizer bugs (dmlc#2596)
Browse files Browse the repository at this point in the history
* Fix pytorch TCP kvstore bug

* lint

* Fix

* upd

* Fix lint

* Fix

* trigger

* fix

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
classicsong and Ubuntu authored Feb 5, 2021
1 parent 2f71bc5 commit 23afe91
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 115 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch/rgcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='30,30' --batch-siz

OGBN-MAG without node-feats 42.79
```
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='30,30' --batch-size 1024 --n-hidden 128 --lr 0.01 --num-worker 4 --eval-batch-size 8 --low-mem --gpu 0,1,2,3 --dropout 0.7 --use-self-loop --n-bases 2 --n-epochs 3 --dgl-sparse --sparse-lr 0.0
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='30,30' --batch-size 1024 --n-hidden 128 --lr 0.01 --num-worker 4 --eval-batch-size 8 --low-mem --gpu 0,1,2,3 --dropout 0.7 --use-self-loop --n-bases 2 --n-epochs 3 --dgl-sparse --sparse-lr 0.08
```

Test-bd: P2-8xlarge
Expand Down
3 changes: 2 additions & 1 deletion examples/pytorch/rgcn/entity_classify_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
validation_time = 0
test_time = 0
last_val_acc = 0.0
do_test = False
if n_gpus > 1 and n_cpus - args.num_workers > 0:
th.set_num_threads(n_cpus-args.num_workers)
for epoch in range(args.n_epochs):
Expand Down Expand Up @@ -405,7 +406,7 @@ def collect_eval():
vend = time.time()
validation_time += (vend - vstart)

if (epoch + 1) > (args.n_epochs / 2) and do_test:
if epoch > 0 and do_test:
tstart = time.time()
if (queue is not None) or (proc_id == 0):
test_logits, test_seeds = evaluate(model, embed_layer, test_loader, node_feats)
Expand Down
19 changes: 9 additions & 10 deletions python/dgl/nn/pytorch/sparse_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,15 @@ def __init__(self, num_embeddings, embedding_dim, name,
emb = create_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32)
if init_func is not None:
emb = init_func(emb)
if rank == 0:
if world_size > 1:
# for multi-gpu training, setup a TCPStore for
# embeding status synchronization across GPU processes
if _STORE is None:
_STORE = th.distributed.TCPStore(
host_name, port, world_size, True, timedelta(seconds=30))
for _ in range(1, world_size):
# send embs
_STORE.set(name, name)
if rank == 0: # the master gpu process
# for multi-gpu training, setup a TCPStore for
# embeding status synchronization across GPU processes
if _STORE is None:
_STORE = th.distributed.TCPStore(
host_name, port, world_size, True, timedelta(seconds=30))
for _ in range(1, world_size):
# send embs
_STORE.set(name, name)
elif rank > 0:
# receive
if _STORE is None:
Expand Down
251 changes: 148 additions & 103 deletions python/dgl/optim/pytorch/sparse_optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Node embedding optimizers"""
import abc
from abc import abstractmethod
import gc
import torch as th

from ...utils import get_shared_mem_array, create_shared_mem_array
Expand All @@ -25,6 +26,34 @@ def __init__(self, params, lr):
self._world_size = None
self._shared_cache = {}
self._clean_grad = False
self._opt_meta = {}

for emb in params:
assert isinstance(emb, NodeEmbedding), \
'DGL SparseOptimizer only supports dgl.nn.NodeEmbedding'

if self._rank is None:
self._rank = emb.rank
self._world_size = emb.world_size
else:
assert self._rank == emb.rank, \
'MultiGPU rank for each embedding should be same.'
assert self._world_size == emb.world_size, \
'MultiGPU world_size for each embedding should be same.'

emb_name = emb.name
if self._rank == 0: # the master gpu process
opt_meta = create_shared_mem_array(emb_name+'_opt_meta', \
(self._world_size, self._world_size), th.int32).zero_()
if self._rank == 0:
emb.store.set(emb_name+'_opt_meta', emb_name)
self._opt_meta[emb_name] = opt_meta
elif self._rank > 0:
# receive
emb.store.wait([emb_name+'_opt_meta'])
opt_meta = get_shared_mem_array(emb_name+'_opt_meta', \
(self._world_size, self._world_size), th.int32)
self._opt_meta[emb_name] = opt_meta

def step(self):
''' The step function.
Expand All @@ -36,92 +65,127 @@ def step(self):
# We cache shared memory buffers in shared_emb.
shared_emb = {emb.name: ([], []) for emb in self._params}

# hold released shared memory to let other process to munmap it first
# unless it will crash the training
shmem_ptr_holder = []

# Go through all sparse embeddings
for emb in self._params: # pylint: disable=too-many-nested-blocks
num_embeddings = emb.num_embeddings
emb_name = emb.name

# Each gpu process takes the resposibility of update a range of sparse embedding,
# thus we can parallel the gradient update.
range_size = (num_embeddings + self._world_size - 1) // self._world_size \
if self._world_size > 0 else 0
for idx, data in emb._trace:
grad = data.grad.data
device = grad.device
idx_dtype = idx.dtype
grad_dtype = grad.dtype
grad_dim = grad.shape[1]

if self._world_size > 0:
if emb_name not in self._shared_cache:
self._shared_cache[emb_name] = {}

for i in range(self._world_size):
start = i * range_size
end = (i + 1) * range_size \
if (i + 1) * range_size < num_embeddings \
else num_embeddings
if i == 0:
mask = idx < end
elif i + 1 == self._world_size:
mask = idx >= start
else:
mask = th.logical_and((idx >= start), (idx < end))
idx_i = idx[mask]
grad_i = grad[mask]

if i == self._rank:
shared_emb[emb_name][0].append(idx_i)
shared_emb[emb_name][1].append(grad_i)
else:
# currently nccl does not support Alltoallv operation
# we need to use CPU shared memory to share gradient
# across processes
idx_i = idx_i.to(th.device('cpu'))
grad_i = grad_i.to(th.device('cpu'))
idx_shmem_name = 'idx_{}_{}_{}'.format(emb_name, self._rank, i)
grad_shmem_name = 'grad_{}_{}_{}'.format(emb_name, self._rank, i)

if idx_shmem_name not in self._shared_cache[emb_name] or \
self._shared_cache[emb_name][idx_shmem_name].shape[0] \
< idx_i.shape[0]:
# in case idx_i.shape[0] is 0
idx_shmem = create_shared_mem_array(idx_shmem_name, \
(idx_i.shape[0] * 2 + 2,), idx_dtype)
grad_shmem = create_shared_mem_array(grad_shmem_name, \
(idx_i.shape[0] * 2 + 2, grad_dim), grad_dtype)
self._shared_cache[emb_name][idx_shmem_name] = idx_shmem
self._shared_cache[emb_name][grad_shmem_name] = grad_shmem

self._shared_cache[emb_name][idx_shmem_name][:idx_i.shape[0]] \
= idx_i
self._shared_cache[emb_name][grad_shmem_name][:idx_i.shape[0]] \
= grad_i
emb.store.set(idx_shmem_name, str(idx_i.shape[0]))

# gather gradients from all other processes
for i in range(self._world_size):
if i != self._rank:
idx_shmem_name = 'idx_{}_{}_{}'.format(emb_name, i, self._rank)
grad_shmem_name = 'grad_{}_{}_{}'.format(emb_name, i, self._rank)
size = int(emb.store.get(idx_shmem_name))
if idx_shmem_name not in self._shared_cache[emb_name] or \
self._shared_cache[emb_name][idx_shmem_name].shape[0] < size:
idx_shmem = get_shared_mem_array(idx_shmem_name, \
(size * 2 + 2,), idx_dtype)
grad_shmem = get_shared_mem_array(grad_shmem_name, \
(size * 2 + 2, grad_dim), grad_dtype)
self._shared_cache[emb_name][idx_shmem_name] = idx_shmem
self._shared_cache[emb_name][grad_shmem_name] = grad_shmem
idx_i = self._shared_cache[emb_name][idx_shmem_name][:size]
grad_i = self._shared_cache[emb_name][grad_shmem_name][:size]
shared_emb[emb_name][0].append(idx_i.to(device,
non_blocking=True))
shared_emb[emb_name][1].append(grad_i.to(device,
non_blocking=True))
else:
shared_emb[emb_name][0].append(idx)
shared_emb[emb_name][1].append(grad)
# we need to combine gradients from multiple forward paths
idx = []
grad = []
for i, data in emb._trace:
idx.append(i)
grad.append(data.grad.data)
idx = th.cat(idx, dim=0)
grad = th.cat(grad, dim=0)

device = grad.device
idx_dtype = idx.dtype
grad_dtype = grad.dtype
grad_dim = grad.shape[1]
if self._world_size > 1:
if emb_name not in self._shared_cache:
self._shared_cache[emb_name] = {}

# Each training process takes the resposibility of updating a range
# of node embeddings, thus we can parallel the gradient update.
# The overall progress includes:
# 1. In each training process:
# 1.a Deciding which process a node embedding belongs to according
# to the formula: process_id = node_idx mod num_of_process(N)
# 1.b Split the node index tensor and gradient tensor into N parts
# according to step 1.
# 1.c Write each node index sub-tensor and gradient sub-tensor into
# different DGL shared memory buffers.
# 2. Cross training process synchronization
# 3. In each traning process:
# 3.a Collect node index sub-tensors and gradient sub-tensors
# 3.b Do gradient update
# 4. Done
idx_split = th.remainder(idx, self._world_size).long()
for i in range(self._world_size):
mask = idx_split == i
idx_i = idx[mask]
grad_i = grad[mask]

if i == self._rank:
shared_emb[emb_name][0].append(idx_i)
shared_emb[emb_name][1].append(grad_i)
else:
# currently nccl does not support Alltoallv operation
# we need to use CPU shared memory to share gradient
# across processes
idx_i = idx_i.to(th.device('cpu'))
grad_i = grad_i.to(th.device('cpu'))
idx_shmem_name = 'idx_{}_{}_{}'.format(emb_name, self._rank, i)
grad_shmem_name = 'grad_{}_{}_{}'.format(emb_name, self._rank, i)

# Create shared memory to hold temporary index and gradient tensor for
# cross-process send and recv.
if idx_shmem_name not in self._shared_cache[emb_name] or \
self._shared_cache[emb_name][idx_shmem_name].shape[0] \
< idx_i.shape[0]:

if idx_shmem_name in self._shared_cache[emb_name]:
shmem_ptr_holder.append(
self._shared_cache[emb_name][idx_shmem_name])
shmem_ptr_holder.append(
self._shared_cache[emb_name][grad_shmem_name])

# in case idx_i.shape[0] is 0
idx_shmem = create_shared_mem_array(idx_shmem_name, \
(idx_i.shape[0] * 2 + 2,), idx_dtype)
grad_shmem = create_shared_mem_array(grad_shmem_name, \
(idx_i.shape[0] * 2 + 2, grad_dim), grad_dtype)
self._shared_cache[emb_name][idx_shmem_name] = idx_shmem
self._shared_cache[emb_name][grad_shmem_name] = grad_shmem

# Fill shared memory with temporal index tensor and gradient tensor
self._shared_cache[emb_name][idx_shmem_name][:idx_i.shape[0]] \
= idx_i
self._shared_cache[emb_name][grad_shmem_name][:idx_i.shape[0]] \
= grad_i
self._opt_meta[emb_name][self._rank][i] = idx_i.shape[0]
else:
shared_emb[emb_name][0].append(idx)
shared_emb[emb_name][1].append(grad)

# make sure the idx shape is passed to each process through opt_meta
if self._world_size > 1:
th.distributed.barrier()
for emb in self._params: # pylint: disable=too-many-nested-blocks
emb_name = emb.name
if self._world_size > 1:
# gather gradients from all other processes
for i in range(self._world_size):
if i != self._rank:
idx_shmem_name = 'idx_{}_{}_{}'.format(emb_name, i, self._rank)
grad_shmem_name = 'grad_{}_{}_{}'.format(emb_name, i, self._rank)
size = self._opt_meta[emb_name][i][self._rank]

# Retrive shared memory holding the temporal index and gradient
# tensor that is sent to current training process
if idx_shmem_name not in self._shared_cache[emb_name] or \
self._shared_cache[emb_name][idx_shmem_name].shape[0] < size:
idx_shmem = get_shared_mem_array(idx_shmem_name, \
(size * 2 + 2,), idx_dtype)
grad_shmem = get_shared_mem_array(grad_shmem_name, \
(size * 2 + 2, grad_dim), grad_dtype)
self._shared_cache[emb_name][idx_shmem_name] = idx_shmem
self._shared_cache[emb_name][grad_shmem_name] = grad_shmem
# make sure shared memory are released in child process first
# This will not be called frequently
# TODO(xiangsx) Provide API to mumap shared memory directly
gc.collect()
idx_i = self._shared_cache[emb_name][idx_shmem_name][:size]
grad_i = self._shared_cache[emb_name][grad_shmem_name][:size]
shared_emb[emb_name][0].append(idx_i.to(device,
non_blocking=True))
shared_emb[emb_name][1].append(grad_i.to(device,
non_blocking=True))

if self._clean_grad:
# clean gradient track
Expand Down Expand Up @@ -205,21 +269,12 @@ def __init__(self, params, lr, eps=1e-10):
assert isinstance(emb, NodeEmbedding), \
'SparseAdagrad only supports dgl.nn.NodeEmbedding'

if self._rank is None:
self._rank = emb.rank
self._world_size = emb.world_size
else:
assert self._rank == emb.rank, \
'MultiGPU rank for each embedding should be same.'
assert self._world_size == emb.world_size, \
'MultiGPU world_size for each embedding should be same.'
if self._rank <= 0:
emb_name = emb.name
state = create_shared_mem_array(emb_name+'_state', \
emb.emb_tensor.shape, th.float32).zero_()
if self._rank == 0:
for _ in range(1, world_size):
# send embs
if self._world_size > 1:
emb.store.set(emb_name+'_opt', emb_name)
elif self._rank > 0:
# receive
Expand Down Expand Up @@ -318,14 +373,6 @@ def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08):
assert isinstance(emb, NodeEmbedding), \
'SparseAdam only supports dgl.nn.NodeEmbedding'

if self._rank is None:
self._rank = emb.rank
self._world_size = emb.world_size
else:
assert self._rank == emb.rank, \
'MultiGPU rank for each embedding should be same.'
assert self._world_size == emb.world_size, \
'MultiGPU world_size for each embedding should be same.'
if self._rank <= 0:
emb_name = emb.name
state_step = create_shared_mem_array(emb_name+'_step', \
Expand All @@ -335,10 +382,8 @@ def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08):
state_power = create_shared_mem_array(emb_name+'_power', \
emb.emb_tensor.shape, th.float32).zero_()
if self._rank == 0:
state = (state_step, state_mem, state_power)
emb_name = emb.name
for _ in range(1, self._world_size):
# send embs
if self._world_size > 1:
emb.store.set(emb_name+'_opt', emb_name)
elif self._rank > 0:
# receive
Expand Down

0 comments on commit 23afe91

Please sign in to comment.