Skip to content

Commit

Permalink
Fix accelerate init (coqui-ai#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol authored Jul 22, 2023
1 parent c5a6783 commit bce8b12
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,32 +620,28 @@ def init_accelerate(model, optimizer, training_dataloader, scheduler, grad_accum
_precision = "bf16"
accelerator = Accelerator(gradient_accumulation_steps=grad_accum_steps, mixed_precision=_precision)
if isinstance(model, torch.nn.Module):
model = accelerator.prepare(model)
model = accelerator.prepare_model(model)

if isinstance(optimizer, torch.optim.Optimizer):
optimizer = accelerator.prepare(optimizer)
elif isinstance(optimizer, dict):
if isinstance(optimizer, dict):
for key, optim in optimizer.items():
optimizer[key] = accelerator.prepare(optim)
optimizer[key] = accelerator.prepare_optimizer(optim)
elif isinstance(optimizer, list):
for i, optim in enumerate(optimizer):
optimizer[i] = accelerator.prepare(optim)
optimizer[i] = accelerator.prepare_optimizer(optim)
elif optimizer is not None:
raise ValueError("Optimizer must be a dict, list or torch.optim.Optimizer")
optimizer = accelerator.prepare_optimizer(optimizer)

if isinstance(training_dataloader, torch.utils.data.DataLoader):
training_dataloader = accelerator.prepare(training_dataloader)
training_dataloader = accelerator.prepare_data_loader(training_dataloader)

if isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler): # pylint:disable=protected-access
scheduler = accelerator.prepare(scheduler)
elif isinstance(scheduler, dict):
if isinstance(scheduler, dict):
for key, sched in scheduler.items():
scheduler[key] = accelerator.prepare(sched)
scheduler[key] = accelerator.prepare_scheduler(sched)
elif isinstance(scheduler, list):
for i, sched in enumerate(scheduler):
scheduler[i] = accelerator.prepare(sched)
scheduler[i] = accelerator.prepare_scheduler(sched)
elif scheduler is not None:
raise ValueError("Scheduler must be a dict, list or torch.optim.lr_scheduler._LRScheduler")
scheduler = accelerator.prepare_scheduler(scheduler)

return model, optimizer, training_dataloader, scheduler, accelerator

Expand Down

0 comments on commit bce8b12

Please sign in to comment.