Skip to content

Commit

Permalink
Simplify the gradient clipping code. (pytorch#4896)
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox authored Nov 12, 2021
1 parent f676f94 commit 1de53be
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 10 deletions.
4 changes: 2 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()

if model_ema and i % args.model_ema_steps == 0:
Expand Down
8 changes: 0 additions & 8 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,3 @@ def reduce_across_processes(val):
dist.barrier()
dist.all_reduce(t)
return t


def get_optimizer_params(optimizer):
"""Generator to iterate over all parameters in the optimizer param_groups."""

for group in optimizer.param_groups:
for p in group["params"]:
yield p

0 comments on commit 1de53be

Please sign in to comment.