Skip to content

Commit

Permalink
Fixed parameter scheduler bug with CosineAnnealingWarmRestarts (pyt…
Browse files Browse the repository at this point in the history
…orch#2938)

* remove codecov

* RankProcessFirst

* annotations

* from class to contextlib

* from class to contextlib and test

* del test file

* uniq folder for test

* refactor tests + new assert_test

* add to __all__, remove idist import

* Apply suggestions from code review

* Apply suggestions from code review

* Update tests/ignite/distributed/utils/test_native.py

* Added local arg and renamed function

* add proxy class

* annotation

* test, proxy class

* add optim

* name change

* test upd/ setter

* class fix

* Fixed mypy issues

* test upd

* Fixed failing test_lr_scheduler

---------

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
AlexanderChaptykov and vfdev-5 authored May 23, 2023
1 parent a99ea7f commit e9e5b45
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 7 deletions.
68 changes: 63 additions & 5 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

# https://github.com/pytorch/ignite/issues/2773
Expand Down Expand Up @@ -792,6 +792,61 @@ def simulate_values( # type: ignore[override]
return output


class _CosineAnnealingWarmRestarts:
def __init__(self, lr_scheduler: CosineAnnealingWarmRestarts):
self._lr_scheduler = lr_scheduler

@property
def last_epoch(self) -> int:
return self._lr_scheduler.last_epoch

@last_epoch.setter
def last_epoch(self, value: int) -> None:
self._lr_scheduler.last_epoch = value

@property
def optimizer(self) -> torch.optim.Optimizer:
return self._lr_scheduler.optimizer

def get_lr(self, epoch: Optional[int] = None) -> List[float]:
# TODO: Remove this workaround when pytorch has fixed wrong type hints:
# https://github.com/pytorch/pytorch/pull/102067
# Replace below T_mult -> self._lr_scheduler.T_mult
# Replace below eta_min -> self._lr_scheduler.eta_min
T_mult = cast(int, self._lr_scheduler.T_mult)
eta_min = cast(float, self._lr_scheduler.eta_min)

if epoch is None and self.last_epoch < 0:
epoch = 0
if epoch is None:
epoch = self.last_epoch + 1
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur + 1
if self._lr_scheduler.T_cur >= self._lr_scheduler.T_i:
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur - self._lr_scheduler.T_i
self._lr_scheduler.T_i = self._lr_scheduler.T_i * T_mult
else:
if epoch < 0:
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
if epoch >= self._lr_scheduler.T_0:
if T_mult == 1:
self._lr_scheduler.T_cur = epoch % self._lr_scheduler.T_0
else:
n = int(math.log((epoch / self._lr_scheduler.T_0 * (T_mult - 1) + 1), T_mult))
self._lr_scheduler.T_cur = epoch - self._lr_scheduler.T_0 * (T_mult**n - 1) / (T_mult - 1)
self._lr_scheduler.T_i = self._lr_scheduler.T_0 * T_mult**n
else:
self._lr_scheduler.T_i = self._lr_scheduler.T_0
self._lr_scheduler.T_cur = epoch

self.last_epoch = math.floor(epoch)

return [
eta_min
+ (base_lr - eta_min) * (1 + math.cos(math.pi * self._lr_scheduler.T_cur / self._lr_scheduler.T_i)) / 2
for base_lr in self._lr_scheduler.base_lrs
]


class LRScheduler(ParamScheduler):
"""A wrapper class to call `torch.optim.lr_scheduler` objects as `ignite` handlers.
Expand Down Expand Up @@ -853,7 +908,10 @@ def __init__(
f"but given {type(lr_scheduler)}"
)

self.lr_scheduler = lr_scheduler
self.lr_scheduler: Union[PyTorchLRScheduler, _CosineAnnealingWarmRestarts] = lr_scheduler
if isinstance(lr_scheduler, CosineAnnealingWarmRestarts):
self.lr_scheduler = _CosineAnnealingWarmRestarts(lr_scheduler)

super(LRScheduler, self).__init__(
optimizer=self.lr_scheduler.optimizer,
param_name="lr",
Expand All @@ -863,7 +921,7 @@ def __init__(
warnings.warn(
"Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
"instead of Events.ITERATION_STARTED to make sure to use "
"the first lr value from the optimizer, otherwise it is will be skipped"
"the first lr value from the optimizer, otherwise it will be skipped"
)
self.lr_scheduler.last_epoch += 1

Expand All @@ -876,9 +934,9 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None
def get_param(self) -> Union[float, List[float]]:
"""Method to get current optimizer's parameter value"""
# Emulate context manager for pytorch>=1.4
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[attr-defined]
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[union-attr]
lr_list = cast(List[float], self.lr_scheduler.get_lr())
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[attr-defined]
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[union-attr]
if len(lr_list) == 1:
return lr_list[0]
else:
Expand Down
46 changes: 44 additions & 2 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest
import torch
from torch.optim.lr_scheduler import ExponentialLR, StepLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR, StepLR

from ignite.engine import Engine, Events
from ignite.handlers.param_scheduler import (
Expand Down Expand Up @@ -650,7 +650,7 @@ def test_lr_scheduler(torch_lr_scheduler_cls, kwargs):
state_dict1 = scheduler1.state_dict()

torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it is will be skipped"):
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it will be skipped"):
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
state_dict2 = scheduler2.state_dict()

Expand Down Expand Up @@ -1362,3 +1362,45 @@ def test_reduce_lr_on_plateau_scheduler_asserts():
with pytest.raises(ValueError, match=r"Length of argument metric_values should be equal to num_events."):
metric_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
ReduceLROnPlateauScheduler.simulate_values(5, metric_values, 0.01)


@pytest.mark.parametrize("warmup_end_value", [0.23, None])
@pytest.mark.parametrize("T_0", [1, 12])
@pytest.mark.parametrize("T_mult", [1, 3])
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
lr = 0.2
steps = 200
warm_steps = 50
warm_start = 0.023

def get_optim():
t1 = torch.zeros([1], requires_grad=True)
return torch.optim.SGD([t1], lr=lr)

def get_cos_shed():
return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, verbose=False)

optimizer = get_optim()
scheduler = get_cos_shed()
cosine_lrs = []
for i in range(steps):
cosine_lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()

optimizer = get_optim()
scheduler = create_lr_scheduler_with_warmup(
get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
)

warm_lrs = []
real_warm_steps = warm_steps if warmup_end_value is not None else (warm_steps - 1)
for epoch in range(real_warm_steps + steps):
scheduler(None)
warm_lrs.append(optimizer.param_groups[0]["lr"])

if warmup_end_value is not None:
np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[real_warm_steps:] == cosine_lrs
else:
np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[real_warm_steps:] == cosine_lrs

0 comments on commit e9e5b45

Please sign in to comment.