Skip to content

Commit

Permalink
fix interleaved 1f1b edge case
Browse files Browse the repository at this point in the history
ghstack-source-id: ad91f0b6a4b1ddcd26b56c770e53d871c51c7423
Pull Request resolved: #1081
  • Loading branch information
H-Huang committed Apr 22, 2024
1 parent 10956e4 commit 3f295b9
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__(
# TODO: is this limitation a must?
if n_microbatches % self.pp_group_size != 0:
raise ValueError(
f"Interleaved 1F1B schedule requires the number of microbatches ({self._n_microbatches}) \
f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \
to be a multiple of the number of pipeline ranks ({self.pp_group_size})."
)

Expand Down Expand Up @@ -626,6 +626,9 @@ def step_microbatches(
# increment warmup_steps by 2 for each hop away
warmup_steps = (self.n_local_stages - 1) * self.pp_group_size
warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank)
warmup_steps = min(
warmup_steps, self._n_microbatches * self.n_local_stages
)
fwd_bwd_steps = (
self.n_local_stages * self._n_microbatches
) - warmup_steps
Expand All @@ -648,18 +651,6 @@ def step_microbatches(
"""
)

def microbatch_index(step):
# Given the step index, find the corresponding microbatch index.

# equivalent to a triple nested loop like this ...
# for gpu in range(self.pp_group_size):
# for stage in self.stages:
# for microbatch_within_sequence:
# ...
return (step % self.pp_group_size) + self.pp_group_size * int(
step / (self.pp_group_size * self.n_local_stages)
)

def forward_stage_local_index(step):
return (step // self.pp_group_size) % self.n_local_stages

Expand Down

0 comments on commit 3f295b9

Please sign in to comment.