Skip to content

Commit

Permalink
Add momentum and centered options to RMSProp (pytorch#810)
Browse files Browse the repository at this point in the history
* add momentum and centered options

Add two options :
 - Momentum (like SGD's momentum)
- Centered RMSprop, as in Graves 2013 ( https://arxiv.org/abs/1308.0850 ) : grad is normalized by running estimation of its variance

* somme PEP8

* bug in default

* bug2

* sign mistake

* alloc of momentum & centered only if needed

* add link to docstring

* some pep8 on docstring

* implement __setstate__() for backward compatibilty

* correct grammar mistake

* multiply by lr when adding delta to params

* rename momentum variables

* change __init__ params order
  • Loading branch information
edouardelasalles authored and apaszke committed Mar 9, 2017
1 parent a462edd commit b1c2714
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions torch/optim/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,35 @@
class RMSprop(Optimizer):
"""Implements RMSprop algorithm.
Proposed by G. Hinton in his `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
The centered version first appears in `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-2)
momentum (float, optional): momentum factor (default: 0)
alpha (float, optional): smoothing constant (default: 0.99)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
centered (bool, optional) : if True, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""

def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay)
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
super(RMSprop, self).__init__(params, defaults)

def __setstate__(self, state):
super(RMSprop, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('momentum', 0)
group.setdefault('centered', False)

def step(self, closure=None):
"""Performs a single optimization step.
Expand All @@ -41,6 +55,10 @@ def step(self, closure=None):
if len(state) == 0:
state['step'] = 0
state['square_avg'] = grad.new().resize_as_(grad).zero_()
if group['momentum'] > 0:
state['momentum_buffer'] = grad.new().resize_as_(grad).zero_()
if group['centered']:
state['grad_avg'] = grad.new().resize_as_(grad).zero_()

square_avg = state['square_avg']
alpha = group['alpha']
Expand All @@ -51,7 +69,19 @@ def step(self, closure=None):
grad = grad.add(group['weight_decay'], p.data)

square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
avg = square_avg.sqrt().add_(group['eps'])
p.data.addcdiv_(-group['lr'], grad, avg)

if group['centered']:
grad_avg = state['grad_avg']
grad_avg.mul_(alpha).add_(1 - alpha, grad)
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
else:
avg = square_avg.sqrt().add_(group['eps'])

if group['momentum'] > 0:
buf = state['momentum_buffer']
buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.data.add_(-group['lr'], buf)
else:
p.data.addcdiv_(-group['lr'], grad, avg)

return loss

0 comments on commit b1c2714

Please sign in to comment.