Skip to content

Commit

Permalink
bug fix for get_data_parallel_src_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
kvareddy committed May 26, 2022
1 parent 739cb43 commit 9ad1944
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions megatron/mpu/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None

# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None



def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
Expand Down Expand Up @@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,

# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
all_data_parallel_group_ranks = []
Expand All @@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GLOBAL_RANKS = ranks

# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
Expand Down Expand Up @@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():

def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
data_parallel_size = get_data_parallel_world_size()
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
return global_rank % num_data_parallel_groups
in the data parallel group."""
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
"Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]


def get_pipeline_model_parallel_first_rank():
Expand Down

0 comments on commit 9ad1944

Please sign in to comment.