Skip to content

Commit

Permalink
Merge branch 'dist_optimizer_bugfix' into 'main'
Browse files Browse the repository at this point in the history
Bugfix: Make sure data_end_index is padded when creating new buckets

See merge request ADLR/megatron-lm!1140
  • Loading branch information
deepakn94 committed Feb 24, 2024
2 parents 6d14c7e + a67ffda commit ad53b1e
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions megatron/core/distributed/grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ def __init__(
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).

def _pad_if_needed(data_index: int):
"""Pads data indices if using distributed optimizer (to ensure uniform sharding)."""
def _pad_if_needed(data_index: int) -> int:
"""
Pads data indices if using distributed optimizer (to ensure uniform sharding).
"""
if use_distributed_optimizer:
return (
int(math.ceil(data_index / self.data_parallel_world_size))
Expand All @@ -229,14 +231,22 @@ def _pad_if_needed(data_index: int):
per_bucket_numel_unpadded = []
bucket_id = 0

def _create_new_bucket(data_end_index: int):
def _create_new_bucket(data_end_index: int) -> int:
"""
Create the bucket_id'th bucket with collected bucket_params, starting at
bucket_data_start_index.
"""
nonlocal bucket_data_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(data_end_index - bucket_data_start_index)
data_end_index = _pad_if_needed(data_end_index)
# Update bucket metadata.
self.bucket_indices.append((bucket_data_start_index, data_end_index))
bucket_data_start_index = data_end_index
# Re-set bucket_params and increment bucket_id for next bucket.
bucket_params = set()
bucket_id += 1
# Return the potentially padded data_end_index.
return data_end_index

for param in params[::-1]:
# Iterate through parameters in reverse order to roughly follow backprop order,
Expand All @@ -247,17 +257,22 @@ def _create_new_bucket(data_end_index: int):
data_end_index = data_start_index + this_numel

def _does_param_require_new_bucket(param):
# Split shared embedding parameters into separate bucket if using distributed
# optimizer that makes use of reduce-scatters instead of all-reduces.
# This ensures that the first and last pipeline stage partition optimizer state
# for the shared embedding parameters the same way across DP replicas, allowing
# the DP reduce-scatter to be before the embedding all-reduce.
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return getattr(param, "shared_embedding", False) and self.use_distributed_optimizer

# Create bucket with already collected parameters if current param needs its own bucket.
if _does_param_require_new_bucket(param) and len(bucket_params) > 0:
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current data_start_index.
if use_distributed_optimizer:
# data_start_index should already be padded.
assert data_start_index % self.data_parallel_world_size == 0
_create_new_bucket(data_start_index)

self.param_index_map[param] = (
Expand All @@ -273,12 +288,12 @@ def _does_param_require_new_bucket(param):
bucket_size is not None
and (data_end_index - bucket_data_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
_create_new_bucket(data_end_index)
data_end_index = _create_new_bucket(data_end_index)
data_start_index = data_end_index

# Add remaining params to a new bucket.
if len(bucket_params) > 0:
_create_new_bucket(data_end_index)
data_end_index = _create_new_bucket(data_end_index)

# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
Expand Down

0 comments on commit ad53b1e

Please sign in to comment.