Skip to content

Commit

Permalink
Pipeline parallelism implementation with periodic full-pipeline syncs
Browse files Browse the repository at this point in the history
Also includes following changes for inter-layer model-parallel implementation:
- Refactoring of model implementations
- Training loop changes to support inter-layer communication using `ring_exchange`
- New groups for inter-layer communication
- Checkpoint changes
- Command line arguments
  • Loading branch information
deepakn94 committed Nov 12, 2020
1 parent 28cd66e commit 7abd3e9
Show file tree
Hide file tree
Showing 48 changed files with 1,412 additions and 612 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, intra-layer-model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.

Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT models with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%.

Expand Down Expand Up @@ -218,7 +218,7 @@ These scripts use the PyTorch distributed launcher for distributed training. As

The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.

Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--model-parallel-size` is 1, which will not implement model parallelism.
Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--intra-layer-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--intra-layer-model-parallel-size` is 1, which will not implement model parallelism.

Other than these minor changes, the distributed training is identical to the training on a single GPU.

Expand All @@ -245,7 +245,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \
--intra-layer-model-parallel-size $MP_SIZE \
--DDP-impl torch
</pre>

Expand All @@ -269,7 +269,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \
--intra-layer-model-parallel-size $MP_SIZE \
--DDP-impl torch

</pre>
Expand Down Expand Up @@ -362,14 +362,14 @@ We provide several command line arguments, detailed in the scripts listed below,
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.

<pre>
MODEL_PARALLEL_SIZE=2
INTRA_LAYER_MODEL_PARALLEL_SIZE=2

VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m

WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \
--intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
Expand Down
2 changes: 1 addition & 1 deletion examples/evaluate_zeroshot_gpt2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--load $CHECKPOINT \
--model-parallel-size 1 \
--intra-layer-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
2 changes: 1 addition & 1 deletion examples/finetune_mnli_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--epochs 5 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \
--intra-layer-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
2 changes: 1 addition & 1 deletion examples/finetune_race_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--epochs 3 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \
--intra-layer-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
2 changes: 1 addition & 1 deletion examples/generate_text.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt

python tools/generate_samples_gpt2.py \
--model-parallel-size 1 \
--intra-layer-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load $CHECKPOINT_PATH \
Expand Down
6 changes: 3 additions & 3 deletions examples/merge_mp_bert.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/bin/bash

MODEL_PARALLEL_SIZE=2
INTRA_LAYER_MODEL_PARALLEL_SIZE=2

VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m

WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
WORLD_SIZE=$INTRA_LAYER_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \
--intra-layer-model-parallel-size $INTRA_LAYER_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
Expand Down
2 changes: 1 addition & 1 deletion examples/pretrain_bert_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--model-parallel-size 1 \
--intra-layer-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
11 changes: 10 additions & 1 deletion megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@
from .initialize import initialize_megatron

def print_rank_0(message):
"""If distributed is initialized print only on rank 0."""
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(message, flush=True)
else:
print(message, flush=True)
24 changes: 16 additions & 8 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
args.model_parallel_size = min(args.model_parallel_size, args.world_size)
args.intra_layer_model_parallel_size = min(
args.intra_layer_model_parallel_size, args.world_size)
args.inter_layer_model_parallel_size = min(
args.inter_layer_model_parallel_size,
(args.world_size // args.intra_layer_model_parallel_size))
if args.rank == 0:
print('using world size: {} and model-parallel size: {} '.format(
args.world_size, args.model_parallel_size))
print('using world size: {} and intra-layer-model-parallel size: {} '.format(
args.world_size, args.intra_layer_model_parallel_size))

# Fp16 loss scaling.
args.dynamic_loss_scale = False
Expand Down Expand Up @@ -192,7 +196,7 @@ def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization')

group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout ptobability.')
help='Post attention dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
Expand Down Expand Up @@ -358,10 +362,14 @@ def _add_mixed_precision_args(parser):


def _add_distributed_args(parser):
group = parser.add_argument_group(title='mixed precision')

group.add_argument('--model-parallel-size', type=int, default=1,
help='Size of the model parallel.')
group = parser.add_argument_group(title='distributed')

group.add_argument('--intra-layer-model-parallel-size', type=int, default=1,
help='Degree of intra-layer model parallelism.')
group.add_argument('--inter-layer-model-parallel-size', type=int, default=1,
help='Degree of inter-layer model parallelism.')
group.add_argument('--use-pipelining', action='store_true',
help='Use pipelining to increase throughput of inter-layer model parallelism')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
Expand Down
16 changes: 11 additions & 5 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _compare(arg_name):
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('model_parallel_size')
_compare('intra_layer_model_parallel_size')


def ensure_directory_exists(filename):
Expand All @@ -70,16 +70,22 @@ def ensure_directory_exists(filename):


def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None):
release=False):
"""A unified checkpoint name."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the intra-layer and inter-layer MP rank.
if mpu.get_inter_layer_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_intra_layer_model_parallel_rank()),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None
else mp_rank),
'mp_rank_{:02d}_{:03d}'.format(
mpu.get_intra_layer_model_parallel_rank(),
mpu.get_inter_layer_model_parallel_rank()),
'model_optim_rng.pt')


Expand Down
6 changes: 4 additions & 2 deletions megatron/data/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ def get_samples_mapping_(indexed_dataset,
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group()))

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
6 changes: 4 additions & 2 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_inter_layer_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_intra_layer_model_parallel_group()))

# Load mappings.
start_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion megatron/data/test/test_indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main():
args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1
args.intra_layer_model_parallel_size = 1

if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
Expand Down
24 changes: 20 additions & 4 deletions megatron/fp16/fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import amp_C

from megatron.module import MegatronModule
from megatron import mpu

FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
Expand Down Expand Up @@ -71,7 +72,19 @@ def __init__(self, module):
self.add_module('module', module.half())

def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
convert_inputs = True
convert_outputs = True
if mpu.get_inter_layer_model_parallel_world_size() > 1:
if not mpu.is_inter_layer_first_stage():
convert_inputs = False
if not mpu.is_inter_layer_last_stage():
convert_outputs = False
if convert_inputs:
inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs)
if convert_outputs:
outputs = fp16_to_fp32(outputs)
return outputs

def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
Expand Down Expand Up @@ -214,7 +227,7 @@ def __init__(self,
master_param = param.detach().clone().float()
master_param.requires_grad = True
# Copythe model parallel flag.
master_param.model_parallel = param.model_parallel
master_param.intra_layer_model_parallel = param.intra_layer_model_parallel
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
Expand Down Expand Up @@ -512,7 +525,8 @@ def wrapped_closure():

return retval

def backward(self, loss, update_master_grads=True, retain_graph=False):
def backward(self, output_tensor, update_master_grads=True, retain_graph=False,
output_tensor_grad=None):
"""
:attr:`backward` performs the following conceptual steps:
Expand Down Expand Up @@ -570,7 +584,9 @@ def backward(self, loss, update_master_grads=True, retain_graph=False):
# a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
# Convert output_tensor to float if it's the loss, otherwise stay in half precision.
self.loss_scaler.backward(output_tensor, retain_graph=retain_graph,
output_tensor_grad=output_tensor_grad)
if update_master_grads:
self.update_master_grads()

Expand Down
20 changes: 14 additions & 6 deletions megatron/fp16/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ def scale_gradient(self, module, grad_in, grad_out):
self.loss_scale)
return grad_in

def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
if output_tensor_grad is None:
scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)


class DynamicLossScaler:
Expand Down Expand Up @@ -196,9 +200,13 @@ def scale_gradient(self, module, grad_in, grad_out):
self.loss_scale)
return grad_in

def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
if output_tensor_grad is None:
scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)


##############################################################
Expand Down
Loading

0 comments on commit 7abd3e9

Please sign in to comment.