Skip to content

Commit

Permalink
[Opt Overlap] Implement as_functional_optim and create_functional_opt…
Browse files Browse the repository at this point in the history
…im (pytorch#71604)

Summary:
Pull Request resolved: pytorch#71604

Implement 2 helper functions:
- as_functional_optim which takes in a torch.optim class type and arguments and
  creates the corresponding functional optimizer.
- create_functional_optim which takes in the functional optimizer class type
  and constructs it. Note that as_functional_optim calls into
  create_functional_optim.

  The first will be used in future PRs as described in
  pytorch#67570 to create a functional
  optimizer from a traditional optimizer. The latter is used in
  _OptimizerHookState to create a functional optimizer.

  Both new helper functions are covered by unittests.
ghstack-source-id: 147577170

Test Plan: CI

Reviewed By: cbalioglu

Differential Revision: D33688995

fbshipit-source-id: 8b2daafd1b914efa90877cc4313aa9a428546fc1
(cherry picked from commit 42fdae2)
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Jan 25, 2022
1 parent 5418176 commit f5a71ec
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed as dist
from torch.distributed.optim import create_functional_optim

_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"

Expand All @@ -16,11 +17,10 @@ class _OptimizerHookState(object):
def __init__(
self, functional_optim_cls, *functional_optim_args, **functional_optim_kwargs
):
self.functional_optimizer = functional_optim_cls(
[],
self.functional_optimizer = create_functional_optim(
functional_optim_cls,
*functional_optim_args,
**functional_optim_kwargs,
_allow_empty_param_list=True,
)
self._check_valid_functional_optim()

Expand Down
20 changes: 6 additions & 14 deletions torch/distributed/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import torch
from torch import optim

from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamw import _FunctionalAdamW
Expand All @@ -15,21 +16,12 @@
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_adamax import _FunctionalAdamax
from .utils import (
functional_optim_map,
create_functional_optim,
as_functional_optim,
)

# dict to map a user passed in optimizer_class to a functional
# optimizer class if we have already defined inside the
# distributed.optim package, this is so that we hide the
# functional optimizer to user and still provide the same API.
functional_optim_map = {
optim.Adagrad: _FunctionalAdagrad,
optim.Adam: _FunctionalAdam,
optim.AdamW: _FunctionalAdamW,
optim.SGD: _FunctionalSGD,
optim.Adadelta: _FunctionalAdadelta,
optim.RMSprop: _FunctionalRMSprop,
optim.Rprop: _FunctionalRprop,
optim.Adamax: _FunctionalAdamax,
}

# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
# based on RPC being available.
Expand Down
41 changes: 41 additions & 0 deletions torch/distributed/optim/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Type
from torch import optim
from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamw import _FunctionalAdamW
from .functional_sgd import _FunctionalSGD
from .functional_adadelta import _FunctionalAdadelta
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_adamax import _FunctionalAdamax

# dict to map a user passed in optimizer_class to a functional
# optimizer class if we have already defined inside the
# distributed.optim package, this is so that we hide the
# functional optimizer to user and still provide the same API.
functional_optim_map = {
optim.Adagrad: _FunctionalAdagrad,
optim.Adam: _FunctionalAdam,
optim.AdamW: _FunctionalAdamW,
optim.SGD: _FunctionalSGD,
optim.Adadelta: _FunctionalAdadelta,
optim.RMSprop: _FunctionalRMSprop,
optim.Rprop: _FunctionalRprop,
optim.Adamax: _FunctionalAdamax,
}

def as_functional_optim(optim_cls: Type, *args, **kwargs):
try:
functional_cls = functional_optim_map[optim_cls]
except KeyError:
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")

return create_functional_optim(functional_cls, *args, **kwargs)

def create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
return functional_optim_cls(
[],
*args,
**kwargs,
_allow_empty_param_list=True,
)
15 changes: 9 additions & 6 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
_OptimizerHookState
)

from torch.distributed.optim import (
as_functional_optim,
functional_optim_map,
)

from torch.distributed.distributed_c10d import (
get_world_size,
_get_default_group,
Expand Down Expand Up @@ -76,8 +81,6 @@
sandcastle_skip_if,
)

from torch.distributed.optim import functional_optim_map

from torch.distributed.optim.functional_sgd import _FunctionalSGD
from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
Expand Down Expand Up @@ -3954,12 +3957,13 @@ def _test_ddp_hook_with_optimizer_parity(
# Register hook that runs allreduce + functional optimizer
# step.
allreduce_hook = default.allreduce_hook
mapping = {v: k for k, v in functional_optim_map.items()}
if construct_from_functional:
f_opt = functional_optim_cls(
[],
opt_cls = mapping[functional_optim_cls]
f_opt = as_functional_optim(
opt_cls,
*functional_optim_args,
**functional_optim_kwargs,
_allow_empty_param_list=True
)
opt_hook_state = _OptimizerHookState.from_functional_optim(
f_opt
Expand All @@ -3983,7 +3987,6 @@ def _test_ddp_hook_with_optimizer_parity(
static_graph=static_graph,
)

mapping = {v: k for k, v in functional_optim_map.items()}
optimizer_no_hook = mapping.get(functional_optim_cls)(
ddp_model_with_no_hook.parameters(),
*functional_optim_args,
Expand Down

0 comments on commit f5a71ec

Please sign in to comment.