Skip to content

Commit

Permalink
State parameter scheduler (pytorch#2090)
Browse files Browse the repository at this point in the history
* pytorch#1913 Add any parameter scheduling to enlarge the actual scope of ParamScheduler

* code style fixes

* update doc

* fix docstring warnings

* Rename Any* to State* classes
Rm previously introduced OptimizerParam class to ensure BC

* rm useless import

* fix docstring
rename test file

* Introduce BaseParamScheduler class for State* Optimizer* parameters schedulers.
Add examples.

* Naming changes.

* fix flake8 errors

* fix docstring / parametrize tests

* naming changes

* parametrize tests

* fix flake8

* try to remove lines in pytest configs

* Update ignite/handlers/state_param_scheduler.py

Co-authored-by: vfdev <[email protected]>

* LinearState to PwLinearState ( implemented from PiecewiseLinear ParamScheduler class)

* Update ignite/handlers/state_param_scheduler.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/handlers/state_param_scheduler.py

Co-authored-by: vfdev <[email protected]>

* Re-naming : PwLinearStateScheduler to  PiecewiseLinearStateScheduler

* add test for state_dict and docstring examples

* Update ignite/handlers/state_param_scheduler.py

Co-authored-by: Sylvain Desroziers <[email protected]>

* improve docstring / change lambda_fn to lambda_obj for LambdaStateScheduler / add tests

* rm duplicated test

* fix code fmt

* add test LambdaState object must be callable

* add test on asserts

* Apply suggestions from code review

* autopep8 fix

* Apply suggestions from code review

Co-authored-by: vfdev <[email protected]>
Co-authored-by: Sylvain Desroziers <[email protected]>
Co-authored-by: vfdev-5 <[email protected]>
  • Loading branch information
4 people authored Oct 11, 2021
1 parent 89b530e commit 81e13e1
Show file tree
Hide file tree
Showing 6 changed files with 963 additions and 91 deletions.
18 changes: 18 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Complete list of handlers

checkpoint.BaseSaveHandler
param_scheduler.ParamScheduler
state_param_scheduler.StateParamScheduler

.. _param-scheduler-label:

Expand All @@ -43,6 +44,7 @@ Parameter scheduler
:nosignatures:
:toctree: generated

BaseParamScheduler
ConcatScheduler
CosineAnnealingScheduler
CyclicalScheduler
Expand All @@ -53,6 +55,22 @@ Parameter scheduler
PiecewiseLinear
create_lr_scheduler_with_warmup

State Parameter scheduler
-------------------------

.. currentmodule:: ignite.handlers.state_param_scheduler

.. autosummary::
:nosignatures:
:toctree: generated

StateParamScheduler
LambdaStateScheduler
PiecewiseLinearStateScheduler
ExpStateScheduler
StepStateScheduler
MultiStepStateScheduler

More on parameter scheduling
----------------------------

Expand Down
16 changes: 16 additions & 0 deletions ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ignite.handlers.ema_handler import EMAHandler
from ignite.handlers.lr_finder import FastaiLRFinder
from ignite.handlers.param_scheduler import (
BaseParamScheduler,
ConcatScheduler,
CosineAnnealingScheduler,
CyclicalScheduler,
Expand All @@ -17,6 +18,14 @@
PiecewiseLinear,
create_lr_scheduler_with_warmup,
)
from ignite.handlers.state_param_scheduler import (
ExpStateScheduler,
LambdaStateScheduler,
MultiStepStateScheduler,
PiecewiseLinearStateScheduler,
StateParamScheduler,
StepStateScheduler,
)
from ignite.handlers.stores import EpochOutputStore
from ignite.handlers.terminate_on_nan import TerminateOnNan
from ignite.handlers.time_limit import TimeLimit
Expand Down Expand Up @@ -46,6 +55,13 @@
"EMAHandler",
"BasicTimeProfiler",
"HandlersTimeProfiler",
"BaseParamScheduler",
"StateParamScheduler",
"LambdaStateScheduler",
"PiecewiseLinearStateScheduler",
"ExpStateScheduler",
"StepStateScheduler",
"MultiStepStateScheduler",
]


Expand Down
215 changes: 125 additions & 90 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,83 +15,26 @@
from ignite.engine import Engine


class ParamScheduler(metaclass=ABCMeta):
"""An abstract class for updating an optimizer's parameter value during
class BaseParamScheduler(metaclass=ABCMeta):
r"""An abstract class for updating an engine state or optimizer's parameter value during
training.
Args:
optimizer: torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name: name of optimizer's parameter to update.
param_name: name of engine state or optimizer's parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
param_group_index: optimizer's parameters group to use
Note:
Parameter scheduler works independently of the internal state of the attached optimizer.
More precisely, whatever the state of the optimizer (newly created or used by another scheduler) the scheduler
sets defined absolute values.
.. versionadded:: 0.5.0
.. versionadded:: 0.4.5
"""

def __init__(
self,
optimizer: Optimizer,
param_name: str,
save_history: bool = False,
param_group_index: Optional[int] = None,
self, param_name: str, save_history: bool = False,
):

if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
):
raise TypeError(
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "
f"but given {type(optimizer)}"
)

self.optimizer = optimizer
self.param_group_index = param_group_index
self.param_name = param_name
self.event_index = 0
self._save_history = save_history
self._state_attrs = ["event_index", "param_name", "save_history", "param_group_index"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:

value = self.get_param()

if isinstance(value, list):
if len(value) != len(self.optimizer_param_groups):
raise ValueError(
"size of value is different than optimizer_param_groups "
f"{len(value)} != {len(self.optimizer_param_groups)}"
)

for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value[i]
else:
for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value

if name is None:
name = self.param_name

if self.save_history and engine:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(name, []) # type: ignore[attr-defined]
values = [pg[self.param_name] for pg in self.optimizer_param_groups]
engine.state.param_history[name].append(values) # type: ignore[attr-defined]
self.event_index += 1

@property
def optimizer_param_groups(self) -> List[Dict[str, Any]]:
if self.param_group_index is None:
return self.optimizer.param_groups
return [self.optimizer.param_groups[self.param_group_index]]
self._state_attrs = ["event_index", "param_name", "save_history"]

@property
def save_history(self) -> bool:
Expand All @@ -102,11 +45,11 @@ def save_history(self, value: bool) -> None:
self._save_history = value

def state_dict(self) -> Dict[str, Any]:
"""Returns a dictionary containing a whole state of ParamScheduler.
"""Returns a dictionary containing a whole state of BaseParamScheduler.
Returns:
dict:
a dictionary containing a whole state of ParamScheduler
a dictionary containing a whole state of BaseParamScheduler
"""
destination = OrderedDict()
for name in self._state_attrs:
Expand All @@ -118,7 +61,7 @@ def state_dict(self) -> Dict[str, Any]:
return destination

def load_state_dict(self, state_dict: Mapping) -> None:
"""Copies parameters from :attr:`state_dict` into this ParamScheduler.
"""Copies parameters from :attr:`state_dict` into this BaseParamScheduler.
Args:
state_dict: a dict containing parameters.
Expand All @@ -140,14 +83,15 @@ def load_state_dict(self, state_dict: Mapping) -> None:

@abstractmethod
def get_param(self) -> Union[List[float], float]:
"""Method to get current optimizer's parameter values
"""Method to get current parameter values
Returns:
list of params, or scalar param
"""
pass

@classmethod
@abstractmethod
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]:
"""Method to simulate scheduled values during `num_events` events.
Expand All @@ -157,29 +101,8 @@ def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[
Returns:
event_index, value
Examples:
.. code-block:: python
lr_values = np.array(LinearCyclicalScheduler.simulate_values(num_events=50, param_name='lr',
start_value=1e-1, end_value=1e-3,
cycle_size=10))
plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
"""
keys_to_remove = ["optimizer", "save_history"]
for key in keys_to_remove:
if key in scheduler_kwargs:
del scheduler_kwargs[key]
values = []
scheduler = cls(optimizer=_get_fake_optimizer(), save_history=False, **scheduler_kwargs)
for i in range(num_events):
scheduler(engine=None)
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
return values
pass

@classmethod
def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any:
Expand Down Expand Up @@ -208,7 +131,7 @@ def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any:
start_value=1e-1, end_value=1e-3, cycle_size=10))
"""
try:
import matplotlib.pylab as plt
import matplotlib.pyplot as plt
except ImportError:
raise RuntimeError(
"This method requires matplotlib to be installed. "
Expand All @@ -223,6 +146,118 @@ def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any:
return ax


class ParamScheduler(BaseParamScheduler):
"""An abstract class for updating an optimizer's parameter value during
training.
Args:
optimizer: torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name: name of optimizer's parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
param_group_index: optimizer's parameters group to use
Note:
Parameter scheduler works independently of the internal state of the attached optimizer.
More precisely, whatever the state of the optimizer (newly created or used by another scheduler) the scheduler
sets defined absolute values.
"""

def __init__(
self,
optimizer: Optimizer,
param_name: str,
save_history: bool = False,
param_group_index: Optional[int] = None,
):
super(ParamScheduler, self).__init__(param_name, save_history)
if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
):
raise TypeError(
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "
f"but given {type(optimizer)}"
)

self.optimizer = optimizer
self.param_group_index = param_group_index
self._state_attrs += ["param_group_index"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:

value = self.get_param()

if isinstance(value, list):
if len(value) != len(self.optimizer_param_groups):
raise ValueError(
"size of value is different than optimizer_param_groups "
f"{len(value)} != {len(self.optimizer_param_groups)}"
)

for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value[i]
else:
for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value

if name is None:
name = self.param_name

if self.save_history and engine:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(name, []) # type: ignore[attr-defined]
values = [pg[self.param_name] for pg in self.optimizer_param_groups]
engine.state.param_history[name].append(values) # type: ignore[attr-defined]
self.event_index += 1

@property
def optimizer_param_groups(self) -> List[Dict[str, Any]]:
if self.param_group_index is None:
return self.optimizer.param_groups
return [self.optimizer.param_groups[self.param_group_index]]

@classmethod
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]:
"""Method to simulate scheduled values during `num_events` events.
Args:
num_events: number of events during the simulation.
scheduler_kwargs: parameter scheduler configuration kwargs.
Returns:
event_index, value
Examples:
.. code-block:: python
lr_values = np.array(LinearCyclicalScheduler.simulate_values(num_events=50, param_name='lr',
start_value=1e-1, end_value=1e-3,
cycle_size=10))
plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
"""
keys_to_remove = ["optimizer", "save_history"]
for key in keys_to_remove:
if key in scheduler_kwargs:
del scheduler_kwargs[key]
values = []
scheduler = cls(optimizer=_get_fake_optimizer(), save_history=False, **scheduler_kwargs)
for i in range(num_events):
scheduler(engine=None)
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
return values


class CyclicalScheduler(ParamScheduler):
"""An abstract class for updating an optimizer's parameter value over a
cycle of some size.
Expand Down
Loading

0 comments on commit 81e13e1

Please sign in to comment.