Skip to content

Commit

Permalink
Use all_gather_into_tensor and reduce_scatter_tensor (NVIDIA#1513)
Browse files Browse the repository at this point in the history
* update distributed_fused_lamb

* update distributed_fused_lamb

* add test

* update apex transformer

* update test_distributed_fused_lamb

* update test_distributed_fused_lamb

* update test_distributed_fused_lamb

* apply suggested changes
  • Loading branch information
Aidyn-A authored Oct 26, 2022
1 parent ee42e2d commit f680fab
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 13 deletions.
77 changes: 71 additions & 6 deletions apex/contrib/optimizers/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import math
import inspect
import torch
import importlib
import amp_C
Expand Down Expand Up @@ -156,8 +157,18 @@ def __init__(self, params,
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None

import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
# Check if collectives have no_copy option
self._reduce_scatter_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
)
self._all_gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
)

if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base

self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
Expand Down Expand Up @@ -647,7 +658,23 @@ def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True,op=_make_nccl_premul_sum((scale,)))
if self._reduce_scatter_no_copy:
works[chunk_id] = torch.distributed.reduce_scatter(
output = self._fp16_g_chunks[block_id][chunk_id],
input_list = self._flat_grads_shards[block_id][chunk_id],
group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op = True,
no_copy = True,
op = _make_nccl_premul_sum((scale,)),
)
else:
works[chunk_id] = torch.distributed.reduce_scatter_tensor(
output = self._fp16_g_chunks[block_id][chunk_id],
input = self._flat_grads_blocks[block_id],
group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op = True,
op = _make_nccl_premul_sum((scale,)),
)

# Reduction across nodes for each rank
if self._num_groups > 1:
Expand All @@ -669,7 +696,21 @@ def _reduce_scatter_and_all_reduce(self, block_id):
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
if self._reduce_scatter_no_copy:
works[chunk_id] = torch.distributed.reduce_scatter(
output = self._fp16_g_chunks[block_id][chunk_id],
input_list = self._flat_grads_shards[block_id][chunk_id],
group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op = True,
no_copy = True,
)
else:
works[chunk_id] = torch.distributed.reduce_scatter_tensor(
output = self._fp16_g_chunks[block_id][chunk_id],
input = self._flat_grads_blocks[block_id],
group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
async_op = True,
)

# Reduction across nodes for each rank
if self._num_groups > 1:
Expand Down Expand Up @@ -826,9 +867,33 @@ def _pipeline_step(self):
if not self._clip_after_ar:
for block in range(self._num_blocks):
for chunk in range(self._num_chunks):
torch.distributed.all_gather(self._new_params2_shards[block][chunk], self._fp16_p_chunks[block][chunk], group=self._ag_pg[0], no_copy=True)
if self._all_gather_no_copy:
torch.distributed.all_gather(
tensor_list = self._new_params2_shards[block][chunk],
tensor = self._fp16_p_chunks[block][chunk],
group = self._ag_pg[0],
no_copy = True,
)
else:
torch.distributed.all_gather_into_tensor(
output_tensor = self._new_params2_blocks[block],
input_tensor = self._fp16_p_chunks[block][chunk],
group = self._ag_pg[0],
)
else:
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
if self._all_gather_no_copy:
torch.distributed.all_gather(
tensor_list = self._new_params_mega_shards,
tensor = self._fp16_p,
group = self._ag_pg[0],
no_copy = True,
)
else:
torch.distributed.all_gather_into_tensor(
output_tensor = self._new_params,
input_tensor = self._fp16_p,
group = self._ag_pg[0],
)

def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0:
Expand Down
123 changes: 123 additions & 0 deletions apex/contrib/test/optimizers/test_distributed_fused_lamb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os
import inspect
import torch
from torch.cuda.amp import GradScaler
from torch.testing._internal import common_utils
from apex.parallel.distributed import flat_dist_call
from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase

def get_init_weights_func():
@torch.no_grad()
def init_weights(m):
if isinstance(m, torch.nn.Linear):
m.weight.fill_(1.0)
return init_weights

class ModelFoo(torch.nn.Module):
def __init__(self):
super(ModelFoo, self).__init__()
self.linear = torch.nn.Linear(128, 128, bias = False)
self.loss = torch.nn.MSELoss()

def forward(self, input_tensor, gt):
y = self.linear(input_tensor)
loss = self.loss(y, gt)
return loss

# A test for distributed fused Lamb optimizer: run several iterations and see if loss decreases
# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use.
# If torch.distributed.get_world_size() == torch.cuda.device_count() it uses only `all_gather`.
# If torch.distributed.get_world_size() < torch.cuda.device_count() it uses both `all_gather` and `reduce_scatter`.
class NcclDistributedFusedLAMB(NcclDistributedTestBase):
@property
def world_size(self) -> int:
return torch.cuda.device_count()

@common_utils.parametrize("no_copy", [False, True])
def test_distributed_fused_lamb(self, no_copy):
if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.reduce_scatter).args:
self.skipTest("does not support no_copy")
if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.all_gather).args:
self.skipTest("does not support no_copy")

assert torch.distributed.is_initialized()
gpu_count = torch.distributed.get_world_size()

init_scale = 100
lr = torch.tensor(0.1).cuda()
grad_scaler = GradScaler(init_scale=init_scale, growth_interval=1000)

model = ModelFoo()
model = model.cuda().half()
model.apply(get_init_weights_func())

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

full_ar = gpu_count == torch.cuda.device_count()

# Aidyn-A: not sure what parameters are the best for testing purposes,
# setting up whatever I think appropriate.
optimizer = DistributedFusedLAMB(
optimizer_grouped_parameters,
lr=0.1,
betas=(0.9, 0.9),
eps=1e-6,
max_grad_norm=1.0,
overlap_reductions=False,
dwu_group_size=gpu_count,
dwu_num_blocks=1,
dwu_num_chunks=1,
dwu_num_rs_pg=1,
dwu_num_ar_pg=1,
dwu_num_ag_pg=1,
use_nvlamb=False,
clip_after_ar=False,
fused_norm=True,
fuse_scale=True,
full_ar=full_ar,
set_param_views_to_flat_buffer=False,
e5m2_allgather=False,
)
optimizer.set_global_scale(init_scale)

optimizer._reduce_scatter_no_copy = no_copy
optimizer._all_gather_no_copy = no_copy

flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )

x = torch.randn(4096, 128, dtype=torch.float16).cuda()
y = torch.randn(4096, 128, dtype=torch.float16).cuda()

losses = []
for _ in range(10):
loss = model(x, y)
optimizer._lazy_init_stage1()
grad_scaler.scale(loss).backward()
optimizer._lazy_init_stage2()
optimizer._lr = lr
optimizer.complete_reductions()
optimizer.set_global_scale(grad_scaler._get_scale_async())
grad_scaler.step(optimizer)
grad_scaler.update()
optimizer.zero_grad(set_to_none=True)

losses.append(loss.item())

self.assertTrue(losses == sorted(losses, reverse=True))

common_utils.instantiate_parametrized_tests(NcclDistributedFusedLAMB)

class NcclDistributedFusedLAMB_partial_ar(NcclDistributedFusedLAMB):
@property
def world_size(self) -> int:
return max(torch.cuda.device_count()-1, 1)

if __name__ == "__main__":
common_utils.run_tests()

15 changes: 11 additions & 4 deletions apex/transformer/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
from apex.transformer.tensor_parallel.utils import VocabUtility
from apex.transformer.log_util import get_transformer_logger

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base

_logger = get_transformer_logger(__name__)

Expand All @@ -65,7 +73,6 @@
"partition_stride": 1,
}


def param_is_not_tensor_parallel_duplicate(param: torch.Tensor) -> bool:
return (
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
Expand Down Expand Up @@ -302,7 +309,7 @@ def forward(
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(all_gather_buffer, input, group=get_tensor_model_parallel_group())
torch.distributed.all_gather_into_tensor(all_gather_buffer, input, group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
Expand All @@ -327,7 +334,7 @@ def backward(ctx, grad_output):
device=torch.cuda.current_device(),
requires_grad=False,
)
handle = torch.distributed._all_gather_base(
handle = torch.distributed.all_gather_into_tensor(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(),
Expand Down Expand Up @@ -355,7 +362,7 @@ def backward(ctx, grad_output):
if ctx.sequence_parallel_enabled:
assert not ctx.async_grad_allreduce
sub_grad_input = torch.empty(input.shape, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False)
handle = torch.distributed._reduce_scatter_base(
handle = torch.distributed.reduce_scatter_tensor(
sub_grad_input,
grad_input,
group=get_tensor_model_parallel_group(),
Expand Down
12 changes: 10 additions & 2 deletions apex/transformer/tensor_parallel/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base

def _reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
Expand Down Expand Up @@ -103,7 +111,7 @@ def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
shape[0] *= world_size

output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(
torch.distributed.all_gather_into_tensor(
output,
input_.contiguous(),
group=get_tensor_model_parallel_group()
Expand All @@ -122,7 +130,7 @@ def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor:
assert shape[0] % world_size == 0
shape[0] //= world_size
output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(
torch.distributed.reduce_scatter_tensor(
output,
input_.contiguous(),
group=get_tensor_model_parallel_group()
Expand Down
8 changes: 7 additions & 1 deletion apex/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

from apex.transformer import parallel_state

# `all_gather_into_tensor` is new placeholders for `_all_gather_base`.
# It requires the most recent version of PyTorch.
# The following 4 lines are for backward comparability with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base

def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
Expand Down Expand Up @@ -40,7 +46,7 @@ def gather_split_1d_tensor(tensor):
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(
torch.distributed.all_gather_into_tensor(
gathered,
tensor,
group=parallel_state.get_tensor_model_parallel_group()
Expand Down

0 comments on commit f680fab

Please sign in to comment.