diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 20420fc24..949b35e2c 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -580,7 +580,7 @@ def __init__( # TODO: is this limitation a must? if n_microbatches % self.pp_group_size != 0: raise ValueError( - f"Interleaved 1F1B schedule requires the number of microbatches ({self._n_microbatches}) \ + f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \ to be a multiple of the number of pipeline ranks ({self.pp_group_size})." ) @@ -626,6 +626,9 @@ def step_microbatches( # increment warmup_steps by 2 for each hop away warmup_steps = (self.n_local_stages - 1) * self.pp_group_size warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank) + warmup_steps = min( + warmup_steps, self._n_microbatches * self.n_local_stages + ) fwd_bwd_steps = ( self.n_local_stages * self._n_microbatches ) - warmup_steps @@ -648,18 +651,6 @@ def step_microbatches( """ ) - def microbatch_index(step): - # Given the step index, find the corresponding microbatch index. - - # equivalent to a triple nested loop like this ... - # for gpu in range(self.pp_group_size): - # for stage in self.stages: - # for microbatch_within_sequence: - # ... - return (step % self.pp_group_size) + self.pp_group_size * int( - step / (self.pp_group_size * self.n_local_stages) - ) - def forward_stage_local_index(step): return (step // self.pp_group_size) % self.n_local_stages