Skip to content

Commit

Permalink
[Performance][Optimizer] Enable using UVA and FP16 with SparseAdam Op…
Browse files Browse the repository at this point in the history
…timizer (dmlc#3885)

* Add uva by default to embedding

* More updates

* Update optimizer

* Add new uva functions

* Expose new pinned memory function

* Add unit tests

* Update formatting

* Fix unit test

* Handle auto UVA case when training is on CPU

* Allow per-embedding decisions for whether to use UVA

* Address spares_optim.py comments

* Remove unused templates

* Update unit test

* Use dgl allocate memory for pinning

* allow automatically unpin

* workaround for d2h copy with a different dtype

* fix linting

* update error message

* update copyright

Co-authored-by: Xin Yao <[email protected]>
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
3 people authored Jun 24, 2022
1 parent 548c85f commit 020f024
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 56 deletions.
139 changes: 105 additions & 34 deletions python/dgl/optim/pytorch/sparse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from abc import abstractmethod
import torch as th

from ...utils import get_shared_mem_array, create_shared_mem_array
from ...utils import get_shared_mem_array, create_shared_mem_array, \
pin_memory_inplace, gather_pinned_tensor_rows, \
scatter_pinned_tensor_rows
from ...nn.pytorch import NodeEmbedding
from ...cuda import nccl
from ...partition import NDArrayPartition
Expand Down Expand Up @@ -434,7 +436,7 @@ def setup(self, params):
state = th.empty(
emb.weight.shape,
dtype=th.float32,
device=eth.device('cpu')).zero_()
device=th.device('cpu')).zero_()
elif self._rank == 0:
state = create_shared_mem_array(emb_name+'_state', \
emb.weight.shape, th.float32).zero_()
Expand Down Expand Up @@ -519,6 +521,16 @@ class SparseAdam(SparseGradOptimizer):
eps : float, Optional
The term added to the denominator to improve numerical stability
Default: 1e-8
use_uva : bool, Optional
Whether to use pinned memory for storing 'mem' and 'power' parameters,
when the embedding is stored on the CPU. This will improve training
speed, but will require locking a large number of virtual memory pages.
For embeddings which are stored in GPU memory, this setting will have
no effect.
Default: True if the gradients are generated on the GPU, and False
if the gradients are on the CPU.
dtype : torch.dtype, Optional
The type to store optimizer state with. Default: th.float32.
Examples
--------
Expand All @@ -534,66 +546,89 @@ class SparseAdam(SparseGradOptimizer):
... loss.backward()
... optimizer.step()
'''
def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08):
def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08, \
use_uva=None, dtype=th.float32):
super(SparseAdam, self).__init__(params, lr)
self._lr = lr
self._beta1 = betas[0]
self._beta2 = betas[1]
self._eps = eps
self._use_uva = use_uva
self._nd_handle = {}
self._is_using_uva = {}
assert dtype in [th.float16, th.float32], \
"Unsupported dtype {}. Valid choices are th.float32 " \
"and th.float32".format(dtype)
self._dtype = dtype

def _setup_uva(self, name, mem, power):
self._is_using_uva[name] = True
mem_nd = pin_memory_inplace(mem)
power_nd = pin_memory_inplace(power)
self._nd_handle[name] = [mem_nd, power_nd]

def setup(self, params):
# We need to register a state sum for each embedding in the kvstore.
for emb in params:
assert isinstance(emb, NodeEmbedding), \
'SparseAdam only supports dgl.nn.NodeEmbedding'
emb_name = emb.name
self._is_using_uva[emb_name] = self._use_uva
if th.device(emb.emb_tensor.device) == th.device('cpu'):
# if our embedding is on the CPU, our state also has to be
if self._rank < 0:
state_step = th.empty(
(emb.weight.shape[0],),
dtype=th.float32,
dtype=th.int32,
device=th.device('cpu')).zero_()
state_mem = th.empty(
emb.weight.shape,
dtype=th.float32,
dtype=self._dtype,
device=th.device('cpu')).zero_()
state_power = th.empty(
emb.weight.shape,
dtype=th.float32,
dtype=self._dtype,
device=th.device('cpu')).zero_()
elif self._rank == 0:
state_step = create_shared_mem_array(emb_name+'_step', \
(emb.weight.shape[0],), th.float32).zero_()
(emb.weight.shape[0],), th.int32).zero_()
state_mem = create_shared_mem_array(emb_name+'_mem', \
emb.weight.shape, th.float32).zero_()
emb.weight.shape, self._dtype).zero_()
state_power = create_shared_mem_array(emb_name+'_power', \
emb.weight.shape, th.float32).zero_()
emb.weight.shape, self._dtype).zero_()

if self._world_size > 1:
emb.store.set(emb_name+'_opt', emb_name)
elif self._rank > 0:
# receive
emb.store.wait([emb_name+'_opt'])
state_step = get_shared_mem_array(emb_name+'_step', \
(emb.weight.shape[0],), th.float32)
(emb.weight.shape[0],), th.int32)
state_mem = get_shared_mem_array(emb_name+'_mem', \
emb.weight.shape, th.float32)
emb.weight.shape, self._dtype)
state_power = get_shared_mem_array(emb_name+'_power', \
emb.weight.shape, th.float32)
emb.weight.shape, self._dtype)

if self._is_using_uva[emb_name]:
# if use_uva has been explicitly set to true, otherwise
# wait until first step to decide
self._setup_uva(emb_name, state_mem, state_power)
else:
# make sure we don't use UVA when data is on the GPU
self._is_using_uva[emb_name] = False

# distributed state on on gpu
state_step = th.empty(
[emb.emb_tensor.shape[0]],
dtype=th.float32,
dtype=th.int32,
device=emb.emb_tensor.device).zero_()
state_mem = th.empty(
emb.emb_tensor.shape,
dtype=th.float32,
dtype=self._dtype,
device=emb.emb_tensor.device).zero_()
state_power = th.empty(
emb.emb_tensor.shape,
dtype=th.float32,
dtype=self._dtype,
device=emb.emb_tensor.device).zero_()
state = (state_step, state_mem, state_power)
emb.set_optm_state(state)
Expand All @@ -613,20 +648,34 @@ def update(self, idx, grad, emb):
Sparse embedding to update.
"""
with th.no_grad():
beta1 = self._beta1
beta2 = self._beta2
eps = self._eps

clr = self._lr
state_step, state_mem, state_power = emb.optm_state
exec_dtype = grad.dtype
exec_dev = grad.device
state_dev = state_step.device

# whether or not we need to transfer data from the GPU to the CPU
# while updating the weights
is_d2h = state_dev.type == 'cpu' and exec_dev.type == 'cuda'

# only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu
state_block = state_dev == th.device('cpu') and exec_dev != state_dev
state_block = is_d2h

if self._is_using_uva[emb.name] is None and is_d2h:
# we should use UVA going forward
self._setup_uva(emb.name, state_mem, state_power)
elif self._is_using_uva[emb.name] is None:
# we shouldn't use UVA going forward
self._is_using_uva[emb.name] = False

use_uva = self._is_using_uva[emb.name]

beta1 = self._beta1
beta2 = self._beta2
eps = self._eps

clr = self._lr
# There can be duplicated indices due to sampling.
# Thus unique them here and average the gradient here.
grad_indices, inverse, cnt = th.unique(idx,
Expand All @@ -635,8 +684,16 @@ def update(self, idx, grad, emb):
state_idx = grad_indices.to(state_dev)
state_step[state_idx] += 1
state_step = state_step[state_idx].to(exec_dev)
orig_mem = state_mem[state_idx].to(exec_dev)
orig_power = state_power[state_idx].to(exec_dev)

if use_uva:
orig_mem = gather_pinned_tensor_rows(state_mem, grad_indices)
orig_power = gather_pinned_tensor_rows(state_power, grad_indices)
else:
orig_mem = state_mem[state_idx].to(exec_dev)
orig_power = state_power[state_idx].to(exec_dev)
# convert to exec dtype
orig_mem = orig_mem.to(dtype=exec_dtype)
orig_power = orig_power.to(dtype=exec_dtype)

grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev)
grad_values.index_add_(0, inverse, grad)
Expand All @@ -647,12 +704,23 @@ def update(self, idx, grad, emb):

update_mem = beta1 * orig_mem + (1.-beta1) * grad_mem
update_power = beta2 * orig_power + (1.-beta2) * grad_power
update_mem_dst = update_mem.to(state_dev, non_blocking=True)
update_power_dst = update_power.to(state_dev, non_blocking=True)
if state_block:
# use events to try and overlap CPU and GPU as much as possible
update_event = th.cuda.Event()
update_event.record()

if use_uva:
scatter_pinned_tensor_rows(state_mem, \
grad_indices, \
update_mem.to(dtype=self._dtype))
scatter_pinned_tensor_rows(state_power, \
grad_indices, \
update_power.to(dtype=self._dtype))
else:
update_mem_dst = update_mem.to(dtype=self._dtype).to(
state_dev, non_blocking=True)
update_power_dst = update_power.to(dtype=self._dtype).to(
state_dev, non_blocking=True)
if state_block:
# use events to try and overlap CPU and GPU as much as possible
update_event = th.cuda.Event()
update_event.record()

update_mem_corr = update_mem / (1. - th.pow(th.tensor(beta1, device=exec_dev),
state_step)).unsqueeze(1)
Expand All @@ -664,11 +732,14 @@ def update(self, idx, grad, emb):
if state_block:
std_event = th.cuda.Event()
std_event.record()
# wait for our transfers from exec_dev to state_dev to finish
# before we can use them
update_event.wait()
state_mem[state_idx] = update_mem_dst
state_power[state_idx] = update_power_dst

if not use_uva:
if state_block:
# wait for our transfers from exec_dev to state_dev to finish
# before we can use them
update_event.wait()
state_mem[state_idx] = update_mem_dst
state_power[state_idx] = update_power_dst

if state_block:
# wait for the transfer of std_values to finish before we
Expand Down
17 changes: 17 additions & 0 deletions python/dgl/utils/pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,21 @@ def gather_pinned_tensor_rows(tensor, rows):
"""
return F.from_dgl_nd(_CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows)))

def scatter_pinned_tensor_rows(dest, rows, source):
"""Directly scatter rows from a GPU tensor given an indices array on CUDA devices,
to a pinned tensor on the CPU.
Parameters
----------
dest : Tensor
The tensor on the CPU to scatter rows to. Must be in pinned memory.
rows : Tensor
The rows to scatter. Must be a CUDA tensor with unique entries.
source : Tensor
The tensor on the GPU to scatter rows from.
"""
_CAPI_DGLIndexScatterGPUToCPU(F.to_dgl_nd(dest), F.to_dgl_nd(rows),
F.to_dgl_nd(source))


_init_api("dgl.ndarray.uvm", __name__)
43 changes: 41 additions & 2 deletions src/array/cuda/array_index_select.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2021 by Contributors
* \file array/cpu/array_index_select.cuh
* Copyright (c) 2021-2022 by Contributors
* \file array/cuda/array_index_select.cuh
* \brief Array index select GPU kernel implementation
*/

Expand Down Expand Up @@ -50,6 +50,45 @@ __global__ void IndexSelectMultiKernel(
}
}

template <typename DType, typename IdType>
__global__ void IndexScatterSingleKernel(const DType* array,
const IdType* index,
const int64_t length,
const int64_t arr_len,
DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
assert(index[tx] >= 0 && index[tx] < arr_len);
out[index[tx]] = array[tx];
tx += stride_x;
}
}

template <typename DType, typename IdType>
__global__ void IndexScatterMultiKernel(
const DType* const array,
const int64_t num_feat,
const IdType* const index,
const int64_t length,
const int64_t arr_len,
DType* const out) {
int64_t in_row = blockIdx.x*blockDim.y+threadIdx.y;

const int64_t stride = blockDim.y*gridDim.x;

while (in_row < length) {
int64_t col = threadIdx.x;
const int64_t out_row = index[in_row];
assert(out_row >= 0 && out_row < arr_len);
while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col];
col += blockDim.x;
}
in_row += stride;
}
}

} // namespace impl
} // namespace aten
} // namespace dgl
Expand Down
Loading

0 comments on commit 020f024

Please sign in to comment.