Skip to content

Commit

Permalink
Merging in fused adam optimizer, additional DDP features tested in 18…
Browse files Browse the repository at this point in the history
….10 (NVIDIA#60)

* test passes

* notes

* Using C++-side flatten and unflatten functions

* Adding csrc

* Persistent synchronization event so it doesn't need to be created and destroyed each time

* Interop with parameter flattening in SSD

* Added deterministic option to imagenet main.py

* Adding options to split gradient averaging and allreduce in pure fp32

* Fixing allreduce_maybe_retain call

* Fixing allreduce_fallback

* Also sync active_i_buckets from rank 0

* Making retain_allreduce_buffers compatible with/orthogonal to delay_allreduce=True|False

* Correcting syntax error, now all seems to work with SSD

* Optional cpp extension build

* Add mixed precision adam optimizer (NVIDIA#59)

* Add FusedAdam Optimizer to Apex that places all the math into a cuda kernel.

* Added fixes to fused_adam to get it to work with network.

* wip work on python interface for adam with options

* fix dispatch for halfs, add python options to handle optional half gradients and params

* cleanup, get rid of grid-stride loop
  • Loading branch information
mcarilli authored Oct 29, 2018
1 parent 81eef1e commit e0bc5d6
Show file tree
Hide file tree
Showing 10 changed files with 433 additions and 38 deletions.
1 change: 1 addition & 0 deletions apex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from . import fp16_utils
from . import parallel
from . import amp
from . import optimizers
1 change: 1 addition & 0 deletions apex/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fused_adam import FusedAdam
28 changes: 28 additions & 0 deletions apex/optimizers/csrc/fused_adam_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <torch/torch.h>

// CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode);

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// C++ interface
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode) {
CHECK_INPUT(p)
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");

fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation.");
}
119 changes: 119 additions & 0 deletions apex/optimizers/csrc/fused_adam_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>

typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;

template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode) {

//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;

for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
p[j] = p[j] - (step_size*m[j]/denom);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}

void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode) {

//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
const float step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

if (g.type().scalarType() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode);
}));
} else {
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode);
}));
}
THCudaCheck(cudaGetLastError());

}
90 changes: 90 additions & 0 deletions apex/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import fused_adam_cuda

class FusedAdam(torch.optim.Adam):

"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, eps_inside_sqrt = False):
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
super(FusedAdam, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
self.eps_mode = 0 if eps_inside_sqrt else 1

def step(self, closure=None, grads=None, output_params=None, scale=1.):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
if grads is not None:
assert len(self.param_groups)==1, "mixed precision optimizer works for a single group only"
for group in self.param_groups:
if grads is None:
grads = [None]*len(group['params'])
if output_params is None:
output_params = [None]*len(group['params'])
for p, grad, output_param in zip(group['params'],grads, output_params):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
scale,
state['step'],
self.eps_mode)
return loss

Loading

0 comments on commit e0bc5d6

Please sign in to comment.