Skip to content

Commit

Permalink
Privatize step_microbatches (#1113)
Browse files Browse the repository at this point in the history
The chance of this API being used by users is unknown. 
Privatize it for now. May re-open in the future.
  • Loading branch information
kwen2501 authored May 9, 2024
1 parent 4308e2c commit 10dad72
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
18 changes: 9 additions & 9 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _update_losses(self, stages, losses):
self._internal_losses.clear()

@abstractmethod
def step_microbatches(
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
Expand Down Expand Up @@ -226,7 +226,7 @@ class PipelineScheduleSingle(PipelineSchedule):
"""
Base class for single-stage schedules.
Implements the `step` method.
Derived classes should implement `step_microbatches`.
Derived classes should implement `_step_microbatches`.
"""

def __init__(
Expand Down Expand Up @@ -267,7 +267,7 @@ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
targets_split = None

# Run microbatches
self.step_microbatches(args_split, kwargs_split, targets_split, losses)
self._step_microbatches(args_split, kwargs_split, targets_split, losses)

# Return merged results per original format
if self._stage.is_last:
Expand All @@ -277,7 +277,7 @@ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):


class ScheduleGPipe(PipelineScheduleSingle):
def step_microbatches(
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
Expand Down Expand Up @@ -356,7 +356,7 @@ def step_microbatches(


class Schedule1F1B(PipelineScheduleSingle):
def step_microbatches(
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
Expand Down Expand Up @@ -453,7 +453,7 @@ class PipelineScheduleMulti(PipelineSchedule):
"""
Base class for multi-stage schedules.
Implements the `step` method.
Derived classes should implement `step_microbatches`.
Derived classes should implement `_step_microbatches`.
"""

def __init__(
Expand Down Expand Up @@ -504,7 +504,7 @@ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
targets_split = None

# Run microbatches
self.step_microbatches(args_split, kwargs_split, targets_split, losses)
self._step_microbatches(args_split, kwargs_split, targets_split, losses)

# Return merged results per original format
for stage in self._stages:
Expand All @@ -515,7 +515,7 @@ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):


class ScheduleLoopedBFS(PipelineScheduleMulti):
def step_microbatches(
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
Expand Down Expand Up @@ -596,7 +596,7 @@ def __init__(
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank

def step_microbatches(
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
Expand Down
4 changes: 2 additions & 2 deletions test/test_composability.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_manual_pipeline_with_manual_allreduce(self):
n_microbatches=num_microbatches,
)
microbatches = [input1.clone() for _ in range(8)]
pipeline_schedule.step_microbatches(arg_mbs=microbatches)
pipeline_schedule._step_microbatches(arg_mbs=microbatches)
print(f"{self.rank} finished pipeline step")

# all reduce
Expand Down Expand Up @@ -322,7 +322,7 @@ def build_stage(stage_idx, num_stages):
else:
raise RuntimeError(f"unsupported schedule {schedule_name}")

pipeline_schedule.step_microbatches(
pipeline_schedule._step_microbatches(
arg_mbs=input_mb, target_mbs=input_mb
)

Expand Down
20 changes: 10 additions & 10 deletions test/test_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_1f1b(self):
]

schedule = Schedule1F1B(stage, num_microbatches)
schedule.step_microbatches(microbatches)
schedule._step_microbatches(microbatches)
dist.barrier()

@skip_if_lt_x_gpu(4)
Expand Down Expand Up @@ -358,7 +358,7 @@ def test_interleaved_1f1b(self):
stages,
num_microbatches,
)
schedule.step_microbatches(microbatches)
schedule._step_microbatches(microbatches)

# num local pipeline stages == world_size
num_microbatches = 8
Expand All @@ -377,14 +377,14 @@ def test_interleaved_1f1b(self):
stages,
num_microbatches,
)
schedule.step_microbatches(microbatches)
schedule._step_microbatches(microbatches)

# differing microbatch size
num_microbatches = 64
microbatches = [
torch.randn_like(microbatch) for _ in range(num_microbatches)
]
schedule.step_microbatches(microbatches)
schedule._step_microbatches(microbatches)

def test_interleaved_1f1b_negative(self):
device = torch.device("cpu")
Expand Down Expand Up @@ -423,15 +423,15 @@ def test_interleaved_1f1b_negative(self):
microbatches = [
torch.randn_like(microbatch) for _ in range(num_microbatches)
]
schedule.step_microbatches(microbatches)
schedule._step_microbatches(microbatches)

# invalid microbatch values
with self.assertRaises(ValueError):
num_microbatches = 5
microbatches = [
torch.randn_like(microbatch) for _ in range(num_microbatches)
]
schedule.step_microbatches(microbatches)
schedule._step_microbatches(microbatches)

@skip_if_lt_x_gpu(4)
def test_interleaved_1f1b_with_model_sleep(self):
Expand Down Expand Up @@ -500,25 +500,25 @@ def test_check_inputs(self):
# invalid input length
with self.assertRaises(ValueError):
invalid_microbatches = [(i,) for i in range(7)]
schedule.step_microbatches(invalid_microbatches)
schedule._step_microbatches(invalid_microbatches)

# invalid input shapes
with self.assertRaises(ValueError):
invalid_microbatches = [(torch.ones(8, 4, 8))]
schedule.step_microbatches(invalid_microbatches)
schedule._step_microbatches(invalid_microbatches)

# invalid input type
with self.assertRaises(TypeError):
invalid_microbatches = torch.ones(8, 4, 8)
schedule.step_microbatches(invalid_microbatches)
schedule._step_microbatches(invalid_microbatches)

# invalid loss
with self.assertRaises(TypeError):
loss = 1
microbatches = [
torch.randn_like(microbatch) for _ in range(num_microbatches)
]
schedule.step_microbatches(microbatches, loss=loss)
schedule._step_microbatches(microbatches, loss=loss)


class UtilTest(unittest.TestCase):
Expand Down
6 changes: 3 additions & 3 deletions test/test_pipeline_schedule_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,14 @@ def rank_print(msg):
) as _torch_profiler:
with record_function(schedule):
if rank == 0:
my_schedule.step_microbatches(microbatches)
my_schedule._step_microbatches(microbatches)
elif rank == world_size - 1:
losses = []
output = my_schedule.step_microbatches(
output = my_schedule._step_microbatches(
target_mbs=target_mbs, losses=losses
)
else:
my_schedule.step_microbatches()
my_schedule._step_microbatches()
logger.info(f"====== Rank {rank} finished {schedule} ======")


Expand Down

0 comments on commit 10dad72

Please sign in to comment.