Skip to content

Commit

Permalink
Move pipeline parallel functionality into core with associated changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcasper committed Mar 23, 2023
1 parent 0b44909 commit 3c92fa9
Show file tree
Hide file tree
Showing 28 changed files with 776 additions and 332 deletions.
3 changes: 2 additions & 1 deletion examples/detxoify_lm/finetune_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
Expand Down
7 changes: 7 additions & 0 deletions megatron/core/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import enum

class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
49 changes: 41 additions & 8 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,40 @@ def initialize_model_parallel(
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.
pipeline_model_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
Expand Down Expand Up @@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):

def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank


def get_tensor_model_parallel_rank():
Expand All @@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())


def get_pipeline_model_parallel_split_rank():
"""Return pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK


def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
Expand Down
1 change: 1 addition & 0 deletions megatron/core/pipeline_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .schedules import get_forward_backward_func
Loading

0 comments on commit 3c92fa9

Please sign in to comment.