Skip to content

Commit

Permalink
Use parallel version of SGD optimizer
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#4633

X-link: facebookresearch/d2go#406

Profiling yielded the training could benefit from using the parallel version of the SGD optimizer.
Passing `foreach=True` to enable.

Reviewed By: tglik

Differential Revision: D40798214

fbshipit-source-id: aa098d1fbbece0862bc9343df761765b0c3b15da
  • Loading branch information
Francisc Bungiu authored and facebook-github-bot committed Nov 7, 2022
1 parent 2b98c27 commit f755c49
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions detectron2/solver/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from detectron2.config import CfgNode
from detectron2.utils.env import TORCH_VERSION

from .lr_scheduler import LRMultiplier, WarmupParamScheduler

Expand Down Expand Up @@ -126,13 +127,16 @@ def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimiz
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
)
return maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params,
lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
sgd_args = {
"params": params,
"lr": cfg.SOLVER.BASE_LR,
"momentum": cfg.SOLVER.MOMENTUM,
"nesterov": cfg.SOLVER.NESTEROV,
"weight_decay": cfg.SOLVER.WEIGHT_DECAY,
}
if TORCH_VERSION >= (1, 12):
sgd_args["foreach"] = True
return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args))


def get_default_optimizer_params(
Expand Down

0 comments on commit f755c49

Please sign in to comment.