diff --git a/megatron/core/distributed/grad_buffer.py b/megatron/core/distributed/grad_buffer.py index 949bc9468c..0b41d3a2bc 100644 --- a/megatron/core/distributed/grad_buffer.py +++ b/megatron/core/distributed/grad_buffer.py @@ -210,8 +210,10 @@ def __init__( self.param_to_bucket = {} # Param -> bucket mapping. self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer). - def _pad_if_needed(data_index: int): - """Pads data indices if using distributed optimizer (to ensure uniform sharding).""" + def _pad_if_needed(data_index: int) -> int: + """ + Pads data indices if using distributed optimizer (to ensure uniform sharding). + """ if use_distributed_optimizer: return ( int(math.ceil(data_index / self.data_parallel_world_size)) @@ -229,14 +231,22 @@ def _pad_if_needed(data_index: int): per_bucket_numel_unpadded = [] bucket_id = 0 - def _create_new_bucket(data_end_index: int): + def _create_new_bucket(data_end_index: int) -> int: + """ + Create the bucket_id'th bucket with collected bucket_params, starting at + bucket_data_start_index. + """ nonlocal bucket_data_start_index, bucket_params, bucket_id per_bucket_numel_unpadded.append(data_end_index - bucket_data_start_index) data_end_index = _pad_if_needed(data_end_index) + # Update bucket metadata. self.bucket_indices.append((bucket_data_start_index, data_end_index)) bucket_data_start_index = data_end_index + # Re-set bucket_params and increment bucket_id for next bucket. bucket_params = set() bucket_id += 1 + # Return the potentially padded data_end_index. + return data_end_index for param in params[::-1]: # Iterate through parameters in reverse order to roughly follow backprop order, @@ -247,17 +257,22 @@ def _create_new_bucket(data_end_index: int): data_end_index = data_start_index + this_numel def _does_param_require_new_bucket(param): - # Split shared embedding parameters into separate bucket if using distributed - # optimizer that makes use of reduce-scatters instead of all-reduces. - # This ensures that the first and last pipeline stage partition optimizer state - # for the shared embedding parameters the same way across DP replicas, allowing - # the DP reduce-scatter to be before the embedding all-reduce. + """ + Split shared embedding parameters into separate bucket if using distributed + optimizer that makes use of reduce-scatters instead of all-reduces. + This ensures that the first and last pipeline stage partition optimizer state + for the shared embedding parameters the same way across DP replicas, allowing + the DP reduce-scatter to be before the embedding all-reduce. + """ return getattr(param, "shared_embedding", False) and self.use_distributed_optimizer # Create bucket with already collected parameters if current param needs its own bucket. if _does_param_require_new_bucket(param) and len(bucket_params) > 0: # We are creating a bucket for the already accumulated parameters, whose params # end at the current data_start_index. + if use_distributed_optimizer: + # data_start_index should already be padded. + assert data_start_index % self.data_parallel_world_size == 0 _create_new_bucket(data_start_index) self.param_index_map[param] = ( @@ -273,12 +288,12 @@ def _does_param_require_new_bucket(param): bucket_size is not None and (data_end_index - bucket_data_start_index) >= bucket_size ) or _does_param_require_new_bucket(param): - _create_new_bucket(data_end_index) + data_end_index = _create_new_bucket(data_end_index) data_start_index = data_end_index # Add remaining params to a new bucket. if len(bucket_params) > 0: - _create_new_bucket(data_end_index) + data_end_index = _create_new_bucket(data_end_index) # Next, create underlying storage for buffer (with numel elements that includes # padding as necessary).