Skip to content

Commit

Permalink
Seryilmaz/fuse norm into scale (NVIDIA#1149)
Browse files Browse the repository at this point in the history
* fuse norm into scale

* add fused norm into dlamb

Co-authored-by: Sukru Eryilmaz <[email protected]>
  • Loading branch information
seryilmaz and Sukru Eryilmaz authored Sep 1, 2021
1 parent 6af09dd commit 4d190db
Show file tree
Hide file tree
Showing 4 changed files with 365 additions and 13 deletions.
42 changes: 29 additions & 13 deletions apex/contrib/optimizers/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, params,
adam_w_mode=True, use_nvlamb=False,
step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(self, params,
self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._L2_grad_norm = None

self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
Expand Down Expand Up @@ -525,7 +525,8 @@ def _pipeline_block_reductions(self, block_id):
# Compute L2 grad norm
self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
if not self._fused_norm:
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)

# Apply clipping & pre-reduction scaling on grads
Expand Down Expand Up @@ -645,19 +646,34 @@ def _pipeline_step(self):
def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale)
if not self._fused_norm:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale, False)[0].float()

self._grads_fp16 = []
if len(self._grads_fp32) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale)
if not self._fused_norm:
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale, False)[0].float()
self._grads_fp32 = []

def _do_overlapped_reduction(self, param_i, param):
Expand Down
9 changes: 9 additions & 0 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale,
at::optional<bool> per_tensor_python);

void multi_tensor_lamb_stage1_cuda(
int chunk_size,
at::Tensor noop_flag,
Expand Down Expand Up @@ -121,6 +128,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
"Computes L2 norm for a list of contiguous tensors and does scaling");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
"Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
Expand Down
Loading

0 comments on commit 4d190db

Please sign in to comment.