From fab48f695ab8187c7d2b54c683c8256046ce7011 Mon Sep 17 00:00:00 2001 From: Olexa Bilaniuk <obilaniu@gmail.com> Date: Thu, 26 Jul 2018 02:38:10 -0400 Subject: [PATCH] Add nauka.utils.torch.optim.setLR() utility. Useful to set optimizer LRs without having to rely on torch.optim.lr_scheduler-derived objects. --- src/nauka/utils/torch/optim/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/nauka/utils/torch/optim/__init__.py b/src/nauka/utils/torch/optim/__init__.py index 3061f1e..441b112 100644 --- a/src/nauka/utils/torch/optim/__init__.py +++ b/src/nauka/utils/torch/optim/__init__.py @@ -18,3 +18,14 @@ def fromSpec(params, spec, **kwargs): **kwargs) else: raise NotImplementedError("Optimizer "+spec.name+" not implemented!") + +def setLR(optimizer, lr): + if isinstance(lr, dict): + for paramGroup in optimizer.param_groups: + paramGroup["lr"] = lr[paramGroup] + elif isinstance(lr, list): + for paramGroup, lritem in zip(optimizer.param_groups, lr): + paramGroup["lr"] = lritem + else: + for paramGroup in optimizer.param_groups: + paramGroup["lr"] = lr