Skip to content

Commit

Permalink
follow PR comment, add for-loop training of PipiParallism on training…
Browse files Browse the repository at this point in the history
…_step
  • Loading branch information
tree-park committed Oct 27, 2022
1 parent 57a080a commit 703ed7c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 29 deletions.
2 changes: 2 additions & 0 deletions oslo/transformers/oslo_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _values(*args):
"parallel_size": _type(int),
"params": {
"memory_computation_balance": _type(float),
"num_micro_batches": _type(int)
},
},
"expert_parallelism": {
Expand Down Expand Up @@ -190,6 +191,7 @@ class OsloTrainerConfig(Config):
"parallel_size": _type(int),
"params": {
"memory_computation_balance": _type(float),
"num_micro_batches": _type(int)
},
},
"expert_parallelism": {
Expand Down
55 changes: 33 additions & 22 deletions oslo/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
default_data_collator,
)
from oslo.torch.utils.checkpoint.activation_checkpointing import ActivationCheckpointing
from oslo.torch.nn.parallel.data_parallel._fsdp.sharded_grad_scaler import ShardedGradScaler
from oslo.transformers.training_args import TrainingArguments
from oslo.transformers.trainer_utils import OptimizerNames, log_dist

Expand Down Expand Up @@ -148,7 +149,6 @@ def __init__(
self.parallel_context = None
self.model_wrappers = []

self.do_grad_scaling = False # TODO FP16, BF16
self.label_smoother = None # TODO label_smooㅇther

if args.oslo_config:
Expand Down Expand Up @@ -191,7 +191,10 @@ def __init__(
"train_dataset does not implement __len__, max_steps has to be specified"
)

# TODO Grade Scaler
self.do_grad_scaling = False
if args.fp16 or args.bf16:
self.do_grad_scaling = True
self.scaler = ShardedGradScaler()
# TODO Label Smoother

self.state = TrainerState(
Expand Down Expand Up @@ -412,15 +415,14 @@ def train(
# TODO Gradient Clipping
# Optimizer step
optimizer_was_run = True
# TODO do_grad_scaling
# if self.do_grad_scaling:
# scale_before = self.scaler.get_scale()
# self.scaler.step(self.optimizer)
# self.scaler.update()
# scale_after = self.scaler.get_scale()
# optimizer_was_run = scale_before <= scale_after
# else:
self.optimizer.step()
if self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step()

if optimizer_was_run:
self.lr_scheduler.step()
Expand Down Expand Up @@ -473,20 +475,28 @@ def training_step(
# log_dist(f"Before self._prepare_inputs: \n{inputs}", rank=-1)
inputs = self._prepare_inputs(inputs) # TODO Check
# log_dist(f"After self._prepare_inputs: \n{inputs}", rank=-1)
if self.args.oslo_config.pipeline_parallelism:
pp_loss = torch.tensor(0.0).to(self.args.device)
num_micro_batches = self.args.oslo_config.pipeline_parallelism["param"]["num_micro_batches"] if "num_micro_batches" in self.args.oslo_config.pipeline_parallelism["param"] else 1
for idx, out in enumerate(model(**inputs)):
loss = out.loss
loss = loss / num_micro_batches
loss.backward()
pp_loss += loss.detach().item()
return pp_loss
else:
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

if self.args.gradient_accumulation_steps > 1:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps

if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
loss.backward()
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
loss.backward()

return loss.detach()
return loss.detach()

def _wrap_model(self, model_wrappers: List, training: bool = True):
if not training:
Expand Down Expand Up @@ -722,6 +732,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
else:
labels = None
# log_dist(f"**inputs: {inputs}", rank=-1)

outputs = model(**inputs)
# # TODO: Save past state if it exists
# # HF-TODO: this needs to be fixed and made cleaner later.
Expand Down
23 changes: 16 additions & 7 deletions oslo/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ class TrainingArguments:
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
`"comet_ml"`, `"mlflow"`, `"tensorboard"` and `"wandb"`. Use `"all"` to report to all integrations
installed, `"none"` for no integrations.
bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture. This is an experimental API and it may change.
fp16 (`bool`, *optional*, defaults to `False`):
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
"""

output_dir: str = field(
Expand Down Expand Up @@ -185,13 +191,6 @@ class TrainingArguments:
default=500, metadata={"help": "Save checkpoint every X updates steps."}
)

# log_level: Optional[str] = field(
# default="passive",
# metadata={
# "help": "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug', 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the application set the level. Defaults to 'passive'.",
# "choices": trainer_log_levels.keys(),
# },
# )
seed: int = field(
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
Expand Down Expand Up @@ -239,6 +238,16 @@ class TrainingArguments:
"help": "The list of integrations to report the results and logs to."
},
)
bf16: bool = field(
default=False,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA architecture. This is an experimental API and it may change."
},
)
fp16: bool = field(
default=False,
metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
)

def __post_init__(self):
# TODO set log level
Expand Down

0 comments on commit 703ed7c

Please sign in to comment.