Skip to content

Commit

Permalink
[RLlib] Add option to use torch.lr_scheduler classes for learning r…
Browse files Browse the repository at this point in the history
…ate schedules. (ray-project#47453)
  • Loading branch information
simonsays1980 authored Sep 3, 2024
1 parent e1ed103 commit e147e31
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
14 changes: 14 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def __init__(self, algo_class: Optional[type] = None):

# `self.experimental()`
self._torch_grad_scaler_class = None
self._torch_lr_scheduler_classes = None
self._tf_policy_handles_more_than_one_loss = False
self._disable_preprocessor_api = False
self._disable_action_flattening = False
Expand Down Expand Up @@ -3231,6 +3232,9 @@ def experimental(
self,
*,
_torch_grad_scaler_class: Optional[Type] = NotProvided,
_torch_lr_scheduler_classes: Optional[
Union[List[Type], Dict[ModuleID, Type]]
] = NotProvided,
_tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided,
_disable_preprocessor_api: Optional[bool] = NotProvided,
_disable_action_flattening: Optional[bool] = NotProvided,
Expand All @@ -3252,6 +3256,14 @@ def experimental(
and step the given optimizer.
`update()` to update the scaler after an optimizer step (for example to
adjust the scale factor).
_torch_lr_scheduler_classes: A list of `torch.lr_scheduler.LRScheduler`
(see here for more details
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate)
classes or a dictionary mapping module IDs to such a list of respective
scheduler classes. Multiple scheduler classes can be applied in sequence
and will be stepped in the same sequence as defined here. Note, most
learning rate schedulers need arguments to be configured, i.e. you need
to partially initialize the schedulers in the list(s).
_tf_policy_handles_more_than_one_loss: Experimental flag.
If True, TFPolicy will handle more than one loss/optimizer.
Set this to True, if you would like to return more than
Expand Down Expand Up @@ -3295,6 +3307,8 @@ def experimental(
)
if _torch_grad_scaler_class is not NotProvided:
self._torch_grad_scaler_class = _torch_grad_scaler_class
if _torch_lr_scheduler_classes is not NotProvided:
self._torch_lr_scheduler_classes = _torch_lr_scheduler_classes

return self

Expand Down
28 changes: 28 additions & 0 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def __init__(self, **kwargs):
self._grad_scalers = defaultdict(
lambda: self.config._torch_grad_scaler_class()
)
self._lr_schedulers = {}
self._lr_scheduler_classes = None
if self.config._torch_lr_scheduler_classes:
self._lr_scheduler_classes = self.config._torch_lr_scheduler_classes

@OverrideToImplementCustomLogic
@override(Learner)
Expand Down Expand Up @@ -202,6 +206,25 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
for module_id, optimizer_names in self._module_optimizers.items():
for optimizer_name in optimizer_names:
optim = self.get_optimizer(module_id, optimizer_name)
# If we have learning rate schedulers for a module add them, if
# necessary.
if self._lr_scheduler_classes is not None:
if module_id not in self._lr_schedulers:
# Set for each module and optimizer a scheduler.
self._lr_schedulers[module_id] = {optimizer_name: []}
# If the classes are in a dictionary each module might have
# a different set of schedulers.
if isinstance(self._lr_scheduler_classes, dict):
scheduler_classes = self._lr_scheduler_classes[module_id]
# Else, each module has the same learning rate schedulers.
else:
scheduler_classes = self._lr_scheduler_classes
# Initialize and add the schedulers.
for scheduler_class in scheduler_classes:
self._lr_schedulers[module_id][optimizer_name].append(
scheduler_class(optim)
)

# Step through the scaler (unscales gradients, if applicable).
if self._grad_scalers is not None:
scaler = self._grad_scalers[module_id]
Expand Down Expand Up @@ -234,6 +257,11 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
"`False`."
)

# If the module uses learning rate schedulers, step them here.
if module_id in self._lr_schedulers:
for scheduler in self._lr_schedulers[module_id][optimizer_name]:
scheduler.step()

@override(Learner)
def _get_optimizer_state(self) -> StateDict:
return {
Expand Down

0 comments on commit e147e31

Please sign in to comment.