Skip to content

Commit

Permalink
[DDP] Support step_param for AdamW (pytorch#63382)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#63382

Per title
ghstack-source-id: 135966156

Test Plan: CI

Reviewed By: SciPioneer

Differential Revision: D30255446

fbshipit-source-id: e6ffbf339db0bc5b4702d02b74a462309df07c75
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Aug 18, 2021
1 parent cd5e9dc commit 5b8862a
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 4 deletions.
18 changes: 17 additions & 1 deletion test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@
if not IS_WINDOWS:
from torch.distributed.optim.functional_sgd import _FunctionalSGD
from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
_SUPPORTED_OPTIM_MAPPING = {
_FunctionalSGD: torch.optim.SGD,
_FunctionalAdam: torch.optim.Adam
_FunctionalAdam: torch.optim.Adam,
_FunctionalAdamW: torch.optim.AdamW,
}

if TEST_WITH_TSAN:
Expand Down Expand Up @@ -1737,6 +1739,20 @@ def test_hook_then_sgd_nccl_grad_as_bucket_view(self):
gradient_as_bucket_view=True
)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_hook_then_adamw_nccl(self):
adamw_lr = 1e-2
adamw_betas = (0.9, 0.99)
adamw_eps = 1e-6
self._test_hook_then_optimizer(
_FunctionalAdamW,
adamw_lr,
betas=adamw_betas,
eps=adamw_eps,
gradient_as_bucket_view=True
)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_hook_then_adam_nccl(self):
Expand Down
13 changes: 11 additions & 2 deletions test/test_functional_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.optim import SGD, Adam, AdamW
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS

if not IS_WINDOWS:
from torch.distributed.optim.functional_sgd import _FunctionalSGD
from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
_SUPPORTED_OPTIM_MAPPING = {
SGD: _FunctionalSGD,
Adam: _FunctionalAdam
Adam: _FunctionalAdam,
AdamW: _FunctionalAdamW,
}


Expand Down Expand Up @@ -102,6 +104,13 @@ def test_functional_optim_parity_sgd(self):
def test_functional_optim_parity_adam(self):
self._test_functional_optim_parity(Adam, 1e-2, betas=(0.9, 0.999), eps=1e-6)

@unittest.skipIf(
IS_WINDOWS,
"Functional optimizer not support on windows, see https://github.com/pytorch/pytorch/issues/62137",
)
def test_functional_optim_parity_adam_w(self):
self._test_functional_optim_parity(AdamW, 1e-2, betas=(0.9, 0.999), eps=1e-6)


if __name__ == "__main__":
run_tests()
49 changes: 49 additions & 0 deletions torch/distributed/optim/functional_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,55 @@ def __init__(
# param group as it's not a common use case.
self.param_group = {"params": params}

def step_param(self, param: Tensor, grad: Optional[Tensor]):
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[int] = []
if grad is not None:
params_with_grad.append(param)
grads.append(grad)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
if self.amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

state = self.state[param]

exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])

if self.amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])

# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'].item())
with torch.no_grad():
F.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
beta1=self.defaults['beta1'],
beta2=self.defaults['beta2'],
lr=self.defaults['lr'],
weight_decay=self.defaults['weight_decay'],
eps=self.defaults['eps'])

def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group['params']
params_with_grad = []
Expand Down
30 changes: 29 additions & 1 deletion torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@
import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
from torch.distributed.optim.functional_sgd import _FunctionalSGD
from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
_SUPPORTED_OPTIM_MAPPING = {
_FunctionalSGD: torch.optim.SGD,
_FunctionalAdam: torch.optim.Adam
_FunctionalAdam: torch.optim.Adam,
_FunctionalAdamW: torch.optim.AdamW,
}

from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -3999,6 +4001,32 @@ def _test_ddp_hook_with_optimizer_parity(
)
dist.barrier()

@sandcastle_skip_if(
BACKEND != "nccl" and BACKEND != "gloo",
"Only Nccl & Gloo backend support DistributedDataParallel",
)
@sandcastle_skip_if(
IS_WINDOWS,
"FunctionalAdam not yet supported with Windows, see https://github.com/pytorch/pytorch/issues/62137"
)
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_hook_with_optimizer_parity_adamw(self):
for grad_as_bucket_view, static_graph in itertools.product(
[True, False], [True, False]
):
adamw_lr = 1e-2
adamw_betas = (0.9, 0.99)
adamw_eps = 1e-6
self._test_ddp_hook_with_optimizer_parity(
grad_as_bucket_view,
static_graph,
_FunctionalAdamW,
adamw_lr,
betas=adamw_betas,
eps=adamw_eps,
)

@sandcastle_skip_if(
BACKEND != "nccl" and BACKEND != "gloo",
"Only Nccl & Gloo backend support DistributedDataParallel",
Expand Down

0 comments on commit 5b8862a

Please sign in to comment.