forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Opt Overlap] Implement as_functional_optim and create_functional_opt…
…im (pytorch#71604) Summary: Pull Request resolved: pytorch#71604 Implement 2 helper functions: - as_functional_optim which takes in a torch.optim class type and arguments and creates the corresponding functional optimizer. - create_functional_optim which takes in the functional optimizer class type and constructs it. Note that as_functional_optim calls into create_functional_optim. The first will be used in future PRs as described in pytorch#67570 to create a functional optimizer from a traditional optimizer. The latter is used in _OptimizerHookState to create a functional optimizer. Both new helper functions are covered by unittests. ghstack-source-id: 147577170 Test Plan: CI Reviewed By: cbalioglu Differential Revision: D33688995 fbshipit-source-id: 8b2daafd1b914efa90877cc4313aa9a428546fc1 (cherry picked from commit 42fdae2)
- Loading branch information
1 parent
5418176
commit f5a71ec
Showing
4 changed files
with
59 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import Type | ||
from torch import optim | ||
from .functional_adagrad import _FunctionalAdagrad | ||
from .functional_adam import _FunctionalAdam | ||
from .functional_adamw import _FunctionalAdamW | ||
from .functional_sgd import _FunctionalSGD | ||
from .functional_adadelta import _FunctionalAdadelta | ||
from .functional_rmsprop import _FunctionalRMSprop | ||
from .functional_rprop import _FunctionalRprop | ||
from .functional_adamax import _FunctionalAdamax | ||
|
||
# dict to map a user passed in optimizer_class to a functional | ||
# optimizer class if we have already defined inside the | ||
# distributed.optim package, this is so that we hide the | ||
# functional optimizer to user and still provide the same API. | ||
functional_optim_map = { | ||
optim.Adagrad: _FunctionalAdagrad, | ||
optim.Adam: _FunctionalAdam, | ||
optim.AdamW: _FunctionalAdamW, | ||
optim.SGD: _FunctionalSGD, | ||
optim.Adadelta: _FunctionalAdadelta, | ||
optim.RMSprop: _FunctionalRMSprop, | ||
optim.Rprop: _FunctionalRprop, | ||
optim.Adamax: _FunctionalAdamax, | ||
} | ||
|
||
def as_functional_optim(optim_cls: Type, *args, **kwargs): | ||
try: | ||
functional_cls = functional_optim_map[optim_cls] | ||
except KeyError: | ||
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!") | ||
|
||
return create_functional_optim(functional_cls, *args, **kwargs) | ||
|
||
def create_functional_optim(functional_optim_cls: Type, *args, **kwargs): | ||
return functional_optim_cls( | ||
[], | ||
*args, | ||
**kwargs, | ||
_allow_empty_param_list=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters