Skip to content

Commit

Permalink
Add support for FP32 master weights in FusedAdam optimizer (NVIDIA#1623)
Browse files Browse the repository at this point in the history
and use float32 in the math kernel when parameters are either float16 or bfloat16

* Cherry pick changes to ConvScaleBiasReLU fusion

* Fix testbench

* Add missing conv_cscale_cbias_relu_forward

* Fix bug in setOperationGraph

* Remove manual cuDNN heuristics knobs

* Use torch.testing.assert_close for tensor comparison

* Return at::Tensor instead of vector, add debug msg

* Start making changes

* Changes

* Update

* Probably all necessary changes, need to test compilation

* Fix bug

* Fix bug

* Fix bug

* Change implementation to separately maintain master weights

* Update test

* Fix test

* Update test

* Fix potential issue with gradient unscaling

* Write out unscaled gradients

* Debugging test

* Add static casts

* Test

* Test

* Revert test

* Add debugging prints

* Fix bug

* Make m and v FP32

* Fix compilation bug

* m and v

* Revert test

* Remove debug prints

* Remove print

* Remove assert

* Cleanup

* Cleanup

* Update test

* Remove float conversions for m and v

* Fix typo

---------

Co-authored-by: Jaemin Choi <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: root <[email protected]>
  • Loading branch information
4 people authored Mar 22, 2023
1 parent d8643ef commit 6952004
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 23 deletions.
23 changes: 17 additions & 6 deletions apex/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,19 @@ class FusedAdam(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,
weight_decay=0., amsgrad=False, set_grad_none=True,
capturable=False):
capturable=False, use_master=False):

if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
if use_master and not capturable:
raise RuntimeError('FusedAdam should be capturable to utilize master weights')
# If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
self.use_master = use_master
if self.use_master:
self.master_params = [p.data.clone().detach().float() for p in params]
super(FusedAdam, self).__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
Expand All @@ -91,6 +96,7 @@ def __init__(self, params, lr=1e-3, bias_correction=True,
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_adam = amp_C.multi_tensor_adam
self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master
else:
raise RuntimeError('apex.optimizers.FusedAdam requires cuda extensions')

Expand Down Expand Up @@ -133,8 +139,9 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
g_16, p_16, m_16, v_16 = [], [], [], []
g_bf, p_bf, m_bf, v_bf = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
p_16_master = []

for p in group['params']:
for pi, p in enumerate(group['params']):
if p.grad is None:
continue
if p.grad.data.is_sparse:
Expand All @@ -144,11 +151,13 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p.data).float()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data).float()

if p.dtype == torch.float16:
if self.use_master:
p_16_master.append(self.master_params[pi])
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
Expand Down Expand Up @@ -186,9 +195,11 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
inv_scale = torch.ones((1,), device=device)

if len(g_16) > 0:
multi_tensor_applier(self.multi_tensor_adam_capturable,
multi_tensor_applier(self.multi_tensor_adam_capturable_master if self.use_master
else self.multi_tensor_adam_capturable,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
[g_16, p_16, m_16, v_16, p_16_master] if self.use_master
else [g_16, p_16, m_16, v_16],
group['lr'],
beta1,
beta2,
Expand Down
16 changes: 16 additions & 0 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ void multi_tensor_adam_capturable_cuda(
const float weight_decay,
at::Tensor inv_scale);

void multi_tensor_adam_capturable_master_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr,
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
const int mode,
const int bias_correction,
const float weight_decay,
at::Tensor inv_scale);

void multi_tensor_adagrad_cuda(
int chunk_size,
at::Tensor noop_flag,
Expand Down Expand Up @@ -178,6 +192,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support and LR scheduling");
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support, LR scheduling and FP32 master weights");
m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
Expand Down
187 changes: 171 additions & 16 deletions csrc/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ typedef enum{

using MATH_T = float;

template<typename T>
template<typename T, typename FULL_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
Expand Down Expand Up @@ -54,10 +54,10 @@ struct AdamFunctor
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;

T* m = (T*)tl.addresses[2][tensor_loc];
FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;

T* v = (T*)tl.addresses[3][tensor_loc];
FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;

n -= chunk_idx*chunk_size;
Expand Down Expand Up @@ -126,7 +126,7 @@ struct AdamFunctor
}
};

template<typename T>
template<typename T, typename FULL_T>
struct AdamCapturableFunctor
{
__device__ __forceinline__ void operator()(
Expand Down Expand Up @@ -166,10 +166,10 @@ struct AdamCapturableFunctor
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;

T* m = (T*)tl.addresses[2][tensor_loc];
FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;

T* v = (T*)tl.addresses[3][tensor_loc];
FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;

n -= chunk_idx*chunk_size;
Expand All @@ -189,11 +189,10 @@ struct AdamCapturableFunctor
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
g[i] = g[i] * (*inv_scale);
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
r_p[ii] = static_cast<MATH_T>(p[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
Expand Down Expand Up @@ -230,9 +229,127 @@ struct AdamCapturableFunctor
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
g[i] = static_cast<T>(r_g[ii]);
p[i] = static_cast<T>(r_p[ii]);
m[i] = static_cast<T>(r_m[ii]);
v[i] = static_cast<T>(r_v[ii]);
}
}
}
}
};

template<typename T, typename FULL_T>
struct AdamCapturableMasterFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<5>& tl,
const float beta1,
const float beta2,
const int* step,
const int bias_correction,
const float epsilon,
const float* lr,
adamMode_t mode,
const float decay,
const float* inv_scale)
{
if(*noop_gmem == 1)
return;

float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (bias_correction == 1) {
beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, *step);
}

int tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;

int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;

T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;

FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;

FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;

FULL_T* p_master = (FULL_T*)tl.addresses[4][tensor_loc];
p_master += chunk_idx*chunk_size;

n -= chunk_idx*chunk_size;

// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
r_p[ii] = static_cast<MATH_T>(p_master[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if(mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (*lr * update);
}
else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (*lr * update);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
g[i] = static_cast<T>(r_g[ii]);
p[i] = static_cast<T>(r_p[ii]);
p_master[i] = static_cast<FULL_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
Expand Down Expand Up @@ -269,7 +386,7 @@ void multi_tensor_adam_cuda(
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
AdamFunctor<scalar_t_0, float>(),
beta1,
beta2,
bias_correction1,
Expand Down Expand Up @@ -306,7 +423,45 @@ void multi_tensor_adam_capturable_cuda(
chunk_size,
noop_flag,
tensor_lists,
AdamCapturableFunctor<scalar_t_0>(),
AdamCapturableFunctor<scalar_t_0, float>(),
beta1,
beta2,
step.data_ptr<int>(),
bias_correction,
epsilon,
lr.data_ptr<float>(),
(adamMode_t) mode,
weight_decay,
inv_scale.data_ptr<float>()); )

AT_CUDA_CHECK(cudaGetLastError());

}

void multi_tensor_adam_capturable_master_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr,
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
const int mode,
const int bias_correction,
const float weight_decay,
at::Tensor inv_scale)
{
using namespace at;

DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamCapturableMasterFunctor<scalar_t_0, float>(),
beta1,
beta2,
step.data_ptr<int>(),
Expand Down
Loading

0 comments on commit 6952004

Please sign in to comment.