forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merging in fused adam optimizer, additional DDP features tested in 18…
….10 (NVIDIA#60) * test passes * notes * Using C++-side flatten and unflatten functions * Adding csrc * Persistent synchronization event so it doesn't need to be created and destroyed each time * Interop with parameter flattening in SSD * Added deterministic option to imagenet main.py * Adding options to split gradient averaging and allreduce in pure fp32 * Fixing allreduce_maybe_retain call * Fixing allreduce_fallback * Also sync active_i_buckets from rank 0 * Making retain_allreduce_buffers compatible with/orthogonal to delay_allreduce=True|False * Correcting syntax error, now all seems to work with SSD * Optional cpp extension build * Add mixed precision adam optimizer (NVIDIA#59) * Add FusedAdam Optimizer to Apex that places all the math into a cuda kernel. * Added fixes to fused_adam to get it to work with network. * wip work on python interface for adam with options * fix dispatch for halfs, add python options to handle optional half gradients and params * cleanup, get rid of grid-stride loop
- Loading branch information
Showing
10 changed files
with
433 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ | |
from . import fp16_utils | ||
from . import parallel | ||
from . import amp | ||
from . import optimizers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .fused_adam import FusedAdam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#include <torch/torch.h> | ||
|
||
// CUDA forward declaration | ||
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode); | ||
|
||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") | ||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") | ||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) | ||
|
||
// C++ interface | ||
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode) { | ||
CHECK_INPUT(p) | ||
if (p_copy.numel() > 0) CHECK_INPUT(p_copy); | ||
CHECK_INPUT(m); | ||
CHECK_INPUT(v); | ||
CHECK_INPUT(g); | ||
int64_t num_elem = p.numel(); | ||
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); | ||
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); | ||
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); | ||
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); | ||
|
||
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("adam", &adam, "Adam optimized CUDA implementation."); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#include "ATen/ATen.h" | ||
#include "ATen/cuda/CUDAContext.h" | ||
#include "ATen/cuda/detail/IndexUtils.cuh" | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <stdio.h> | ||
#include <cmath> | ||
#include "ATen/TensorUtils.h" | ||
#include "ATen/Type.h" | ||
#include "ATen/AccumulateType.h" | ||
#include <THC/THCGeneral.h> | ||
|
||
typedef enum{ | ||
ADAM_MODE_0 =0, // eps under square root | ||
ADAM_MODE_1 =1 // eps outside square root | ||
} adamMode_t; | ||
|
||
template <typename T, typename GRAD_T> | ||
__global__ void adam_cuda_kernel( | ||
T* __restrict__ p, | ||
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed | ||
T* __restrict__ m, | ||
T* __restrict__ v, | ||
const GRAD_T * __restrict__ g, | ||
const float b1, | ||
const float b2, | ||
const float eps, | ||
const float grad_scale, | ||
const float step_size, | ||
const size_t tsize, | ||
adamMode_t mode) { | ||
|
||
//Assuming 2D grids and 2D blocks | ||
const int blockId = gridDim.x * blockIdx.y + blockIdx.x; | ||
const int threadsPerBlock = blockDim.x * blockDim.y; | ||
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; | ||
const int i = (blockId * threadsPerBlock + threadIdInBlock); | ||
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; | ||
|
||
for (int j = i; j < tsize; j+=totThreads) { | ||
T scaled_grad = g[j]/grad_scale; | ||
m[j] = b1*m[j] + (1-b1)*scaled_grad; | ||
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad; | ||
float denom; | ||
if (mode == ADAM_MODE_0) | ||
denom = sqrtf(v[j] + eps); | ||
else // Mode 1 | ||
denom = sqrtf(v[j]) + eps; | ||
p[j] = p[j] - (step_size*m[j]/denom); | ||
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; | ||
} | ||
} | ||
|
||
void fused_adam_cuda( | ||
at::Tensor & p, | ||
at::Tensor & p_copy, | ||
at::Tensor & m, | ||
at::Tensor & v, | ||
at::Tensor & g, | ||
float lr, | ||
float beta1, | ||
float beta2, | ||
float eps, | ||
float grad_scale, | ||
int step, | ||
int mode) { | ||
|
||
//Get tensor size | ||
int tsize = p.numel(); | ||
//Determine #threads and #blocks | ||
const int threadsPerBlock = 512; | ||
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); | ||
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); | ||
//Constants | ||
const float bias_correction1 = 1 - std::pow(beta1, step); | ||
const float bias_correction2 = 1 - std::pow(beta2, step); | ||
const float step_size = lr * std::sqrt(bias_correction2)/bias_correction1; | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
if (g.type().scalarType() == at::ScalarType::Half) { | ||
//all other values should be fp32 for half gradients | ||
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type"); | ||
//dispatch is done on the gradient type | ||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] { | ||
using accscalar_t = at::acc_type<scalar_t, true>; | ||
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( | ||
p.data<accscalar_t>(), | ||
p_copy.numel() ? p_copy.data<scalar_t>() : NULL, | ||
m.data<accscalar_t>(), | ||
v.data<accscalar_t>(), | ||
g.data<scalar_t>(), | ||
beta1, | ||
beta2, | ||
eps, | ||
grad_scale, | ||
step_size, | ||
tsize, | ||
(adamMode_t) mode); | ||
})); | ||
} else { | ||
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] { | ||
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( | ||
p.data<scalar_t>(), | ||
NULL, //don't output p_copy for fp32, it's wasted write | ||
m.data<scalar_t>(), | ||
v.data<scalar_t>(), | ||
g.data<scalar_t>(), | ||
beta1, | ||
beta2, | ||
eps, | ||
grad_scale, | ||
step_size, | ||
tsize, | ||
(adamMode_t) mode); | ||
})); | ||
} | ||
THCudaCheck(cudaGetLastError()); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
import fused_adam_cuda | ||
|
||
class FusedAdam(torch.optim.Adam): | ||
|
||
"""Implements Adam algorithm. | ||
It has been proposed in `Adam: A Method for Stochastic Optimization`_. | ||
Arguments: | ||
params (iterable): iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr (float, optional): learning rate (default: 1e-3) | ||
betas (Tuple[float, float], optional): coefficients used for computing | ||
running averages of gradient and its square (default: (0.9, 0.999)) | ||
eps (float, optional): term added to the denominator to improve | ||
numerical stability (default: 1e-8) | ||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | ||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this | ||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | ||
(default: False) NOT SUPPORTED in FusedAdam! | ||
.. _Adam\: A Method for Stochastic Optimization: | ||
https://arxiv.org/abs/1412.6980 | ||
.. _On the Convergence of Adam and Beyond: | ||
https://openreview.net/forum?id=ryQu7f-RZ | ||
""" | ||
|
||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | ||
weight_decay=0, amsgrad=False, eps_inside_sqrt = False): | ||
if amsgrad: | ||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') | ||
super(FusedAdam, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) | ||
self.eps_mode = 0 if eps_inside_sqrt else 1 | ||
|
||
def step(self, closure=None, grads=None, output_params=None, scale=1.): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
if grads is not None: | ||
assert len(self.param_groups)==1, "mixed precision optimizer works for a single group only" | ||
for group in self.param_groups: | ||
if grads is None: | ||
grads = [None]*len(group['params']) | ||
if output_params is None: | ||
output_params = [None]*len(group['params']) | ||
for p, grad, output_param in zip(group['params'],grads, output_params): | ||
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients | ||
if p.grad is None and grad is None: | ||
continue | ||
if grad is None: | ||
grad = p.grad.data | ||
if grad.is_sparse: | ||
raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead') | ||
|
||
state = self.state[p] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
state['step'] = 0 | ||
# Exponential moving average of gradient values | ||
state['exp_avg'] = torch.zeros_like(p.data) | ||
# Exponential moving average of squared gradient values | ||
state['exp_avg_sq'] = torch.zeros_like(p.data) | ||
|
||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||
beta1, beta2 = group['betas'] | ||
|
||
state['step'] += 1 | ||
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param | ||
fused_adam_cuda.adam(p.data, | ||
out_p, | ||
exp_avg, | ||
exp_avg_sq, | ||
grad, | ||
group['lr'], | ||
beta1, | ||
beta2, | ||
group['eps'], | ||
scale, | ||
state['step'], | ||
self.eps_mode) | ||
return loss | ||
|
Oops, something went wrong.