Skip to content

Commit

Permalink
Interleaved 1f1b supports loss
Browse files Browse the repository at this point in the history
ghstack-source-id: 719316b63c40ae990d71d131a9856cc5057d3e10
Pull Request resolved: #1063
  • Loading branch information
H-Huang committed Apr 16, 2024
1 parent 00f830c commit 6c78d48
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 31 deletions.
6 changes: 4 additions & 2 deletions pippy/ManualPipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def stage_global_rank(peer_rank):
for chunk_id in range(self.chunks):
self.set_requires_grad[chunk_id] = False
if not self.is_first:
# We assume that we always receive from stage - 1
self.args_recv_info[chunk_id] = tuple(
[
RecvInfo(
Expand All @@ -247,6 +248,7 @@ def stage_global_rank(peer_rank):
# only need the rank that is being sent to
self.act_send_info: Dict[int, List] = {}
for idx in range(len(self.outputs)):
# We assume we always send to stage + 1
if not self.is_last:
self.act_send_info[idx] = [self.stage_index + 1]
else:
Expand All @@ -267,6 +269,8 @@ def _create_grad_recv_info(
) -> Tuple[RecvInfo, ...]:
grad_recv_info: Tuple[RecvInfo, ...] = ()
if not self.is_last:
# Receiving gradients from multiple sources is not supported
# hence we only take the first destination
grad_recv_info = tuple(
[
RecvInfo(
Expand All @@ -277,8 +281,6 @@ def _create_grad_recv_info(
for idx, dst_list in act_send_info.items()
]
)
else:
grad_recv_info = tuple()
return grad_recv_info

def init_p2p_neighbors(self):
Expand Down
71 changes: 43 additions & 28 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _maybe_get_loss(self, mb_index):
return self._internal_losses[mb_index]
elif len(self._internal_losses) != 0 and not valid_index:
raise RuntimeError(
f"Loss of microbatch {mb_index} is not available. "
f"Loss for microbatch {mb_index} is not available. "
f"Available losses for microbatches: {self._internal_losses}"
)
else:
Expand Down Expand Up @@ -653,24 +653,23 @@ def backward_stage_local_index(step):
% self.n_local_stages
)

# Internal loss container
internal_losses = []
fwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int)
bwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int)

# Delay send waits
sends_to_wait: List[dist.Work] = []

# TODO: share across schedules
def maybe_compute_loss(fwd_stage, output, mb_index):
if fwd_stage.is_last and self._loss_fn is not None:
loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
internal_losses.append(loss)
logger.debug(f"Loss of microbatch {mb_index}: {loss}")

for step in range(self.total_steps):
# warmup, forward only
if step < warmup_steps:
logger.debug(f"{forward_stage_local_index(step)=}")

fwd_stage = self._stages[forward_stage_local_index(step)]
mb_index = microbatch_index(step)
# assigns the current microbatch index and updates it for future steps
fwd_stage_mb_index[fwd_stage] = (
mb_index := fwd_stage_mb_index[fwd_stage]
) + 1

logger.debug(
f"{self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}"
)
Expand All @@ -687,13 +686,27 @@ def maybe_compute_loss(fwd_stage, output, mb_index):
works = sorted_batch_isend_irecv(ops)
sends_to_wait.extend(works.values())

maybe_compute_loss(fwd_stage, output, mb_index)
self._maybe_compute_loss(
fwd_stage, output, target_mbs, mb_index
)

# 1f1b
elif warmup_steps <= step < warmup_steps + fwd_bwd_steps:
logger.debug(f"{forward_stage_local_index(step)=}")
logger.debug(f"{backward_stage_local_index(step)=}")

fwd_stage = self._stages[forward_stage_local_index(step)]
bwd_stage = self._stages[backward_stage_local_index(step)]

fwd_stage_mb_index[fwd_stage] = (
fwd_mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1

logger.debug(
f"{self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {mb_index=}"
f"{self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}"
)
with record_function(f"1F1B {step}"):
ops = fwd_stage.get_fwd_recv_ops()
Expand All @@ -702,33 +715,39 @@ def maybe_compute_loss(fwd_stage, output, mb_index):
for work in works.values():
work.wait()

output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index]
# fwd
output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
ops = fwd_stage.get_fwd_send_ops()
self._maybe_compute_loss(
fwd_stage, output, target_mbs, fwd_mb_index
)

maybe_compute_loss(fwd_stage, output, mb_index)

# TODO 1: give loss to backward.
# TODO 2: for us to know which loss to use, we need to know the backward mb index.
bwd_stage.backward_one_chunk()
# bwd
loss = self._maybe_get_loss(bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())

works = sorted_batch_isend_irecv(ops)
sends_to_wait.extend(works.values())

# cooldown
else:
bwd_stage = self._stages[backward_stage_local_index(step)]
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1

logger.debug(
f"{self.rank}: {step=}, {bwd_stage.stage_index=}, {mb_index=}"
f"{self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}"
)
with record_function(f"Cooldown (backward) {step}"):
ops = bwd_stage.get_bwd_recv_ops()
works = sorted_batch_isend_irecv(ops)
for work in works.values():
work.wait()

# TODO 1: give loss to backward.
# TODO 2: for us to know which loss to use, we need to know the backward mb index.
bwd_stage.backward_one_chunk()
loss = self._maybe_get_loss(bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)

ops = bwd_stage.get_bwd_send_ops()
works = sorted_batch_isend_irecv(ops)
Expand All @@ -739,8 +758,4 @@ def maybe_compute_loss(fwd_stage, output, mb_index):
work.wait()

# Return losses if there is a container passed in
if losses is not None:
# Clean external container first
losses.clear()
# Copy internal losses to external container
losses.extend(internal_losses)
self._update_losses(losses)
2 changes: 1 addition & 1 deletion test/test_pipeline_schedule_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def set_up_logging(rank, log_level):
type=str,
nargs="+",
choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b"],
default=["looped_bfs"],
default=["interleaved_1f1b"],
)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--stage_type", type=str, default="manual")
Expand Down

0 comments on commit 6c78d48

Please sign in to comment.