Skip to content

Commit

Permalink
Sparse Adam optimizer for sparse gradients (#3137)
Browse files Browse the repository at this point in the history
* sparse adam

* Favor dense addition over sparse_mask
  • Loading branch information
ssnl authored and soumith committed Nov 6, 2017
1 parent c2626f6 commit f76d6c0
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 25 deletions.
25 changes: 19 additions & 6 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ def eval():

self.assertLessEqual(params.data.dist(solution), initial_dist)

def _test_rosenbrock_sparse(self, constructor):
def _test_rosenbrock_sparse(self, constructor, sparse_only=False):
params_t = torch.Tensor([1.5, 1.5])

params = Variable(torch.Tensor([1.5, 1.5]), requires_grad=True)
params_c = Variable(torch.Tensor([1.5, 1.5]), requires_grad=True)
params = Variable(params_t, requires_grad=True)
optimizer = constructor([params])
optimizer_c = constructor([params_c])
if not sparse_only:
params_c = Variable(params_t.clone(), requires_grad=True)
optimizer_c = constructor([params_c])

solution = torch.Tensor([1, 1])
initial_dist = params.data.dist(solution)
Expand Down Expand Up @@ -99,8 +100,9 @@ def eval(params, sparse_grad, w):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, True, w))
optimizer_c.step(functools.partial(eval, params_c, False, w))
self.assertEqual(params.data, params_c.data)
if not sparse_only:
optimizer_c.step(functools.partial(eval, params_c, False, w))
self.assertEqual(params.data, params_c.data)

self.assertLessEqual(params.data.dist(solution), initial_dist)

Expand Down Expand Up @@ -229,6 +231,11 @@ def test_sgd(self):
lr=1e-3)
)

def test_sgd_sparse(self):
self._test_rosenbrock_sparse(
lambda params: optim.SGD(params, lr=5e-3)
)

def test_adam(self):
self._test_rosenbrock(
lambda params: optim.Adam(params, lr=1e-2),
Expand All @@ -247,6 +254,12 @@ def test_adam(self):
lr=1e-3)
)

def test_sparse_adam(self):
self._test_rosenbrock_sparse(
lambda params: optim.SparseAdam(params, lr=4e-2),
True
)

def test_adadelta(self):
self._test_rosenbrock(
lambda params: optim.Adadelta(params),
Expand Down
1 change: 1 addition & 0 deletions torch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .adadelta import Adadelta
from .adagrad import Adagrad
from .adam import Adam
from .sparse_adam import SparseAdam
from .adamax import Adamax
from .asgd import ASGD
from .sgd import SGD
Expand Down
8 changes: 6 additions & 2 deletions torch/optim/adadelta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from .optimizer import Optimizer


Expand Down Expand Up @@ -40,13 +42,15 @@ def step(self, closure=None):
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adadelta does not support sparse gradients')
state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = grad.new().resize_as_(grad).zero_()
state['acc_delta'] = grad.new().resize_as_(grad).zero_()
state['square_avg'] = torch.zeros_like(p.data)
state['acc_delta'] = torch.zeros_like(p.data)

square_avg, acc_delta = state['square_avg'], state['acc_delta']
rho, eps = group['rho'], group['eps']
Expand Down
13 changes: 6 additions & 7 deletions torch/optim/adagrad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch

from .optimizer import Optimizer


Expand Down Expand Up @@ -28,7 +27,7 @@ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0):
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['sum'] = p.data.new().resize_as_(p.data).zero_()
state['sum'] = torch.zeros_like(p.data)

def share_memory(self):
for group in self.param_groups:
Expand Down Expand Up @@ -59,21 +58,21 @@ def step(self, closure=None):

if group['weight_decay'] != 0:
if p.grad.data.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients ")
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad = grad.add(group['weight_decay'], p.data)

clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay'])

if p.grad.data.is_sparse:
if grad.is_sparse:
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = torch.Size([x for x in grad.size()])
size = grad.size()

def make_sparse(values):
constructor = type(p.grad.data)
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor()
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
state['sum'].add_(make_sparse(grad_values.pow(2)))
std = state['sum']._sparse_mask(grad)
Expand Down
8 changes: 6 additions & 2 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import torch
from .optimizer import Optimizer


Expand Down Expand Up @@ -43,15 +44,18 @@ def step(self, closure=None):
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
state['exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
Expand Down
6 changes: 4 additions & 2 deletions torch/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ def step(self, closure=None):
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adamax does not support sparse gradients')
state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
state['exp_inf'] = grad.new().resize_as_(grad).zero_()
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_inf'] = torch.zeros_like(p.data)

exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
beta1, beta2 = group['betas']
Expand Down
5 changes: 4 additions & 1 deletion torch/optim/asgd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import torch
from .optimizer import Optimizer


Expand Down Expand Up @@ -42,14 +43,16 @@ def step(self, closure=None):
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('ASGD does not support sparse gradients')
state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
state['eta'] = group['lr']
state['mu'] = 1
state['ax'] = grad.new().resize_as_(grad).zero_()
state['ax'] = torch.zeros_like(p.data)

state['step'] += 1

Expand Down
9 changes: 6 additions & 3 deletions torch/optim/rmsprop.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from .optimizer import Optimizer


Expand Down Expand Up @@ -50,16 +51,18 @@ def step(self, closure=None):
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients')
state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = grad.new().resize_as_(grad).zero_()
state['square_avg'] = torch.zeros_like(p.data)
if group['momentum'] > 0:
state['momentum_buffer'] = grad.new().resize_as_(grad).zero_()
state['momentum_buffer'] = torch.zeros_like(p.data)
if group['centered']:
state['grad_avg'] = grad.new().resize_as_(grad).zero_()
state['grad_avg'] = torch.zeros_like(p.data)

square_avg = state['square_avg']
alpha = group['alpha']
Expand Down
5 changes: 4 additions & 1 deletion torch/optim/rprop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import torch
from .optimizer import Optimizer


Expand Down Expand Up @@ -36,12 +37,14 @@ def step(self, closure=None):
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Rprop does not support sparse gradients')
state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
state['prev'] = grad.new().resize_as_(grad).zero_()
state['prev'] = torch.zeros_like(p.data)
state['step_size'] = grad.new().resize_as_(grad).fill_(group['lr'])

etaminus, etaplus = group['etas']
Expand Down
3 changes: 2 additions & 1 deletion torch/optim/sgd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from .optimizer import Optimizer, required


Expand Down Expand Up @@ -86,7 +87,7 @@ def step(self, closure=None):
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = p.data.new().resize_as_(p.data).zero_()
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state['momentum_buffer']
Expand Down
95 changes: 95 additions & 0 deletions torch/optim/sparse_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import math
import torch
from .optimizer import Optimizer


class SparseAdam(Optimizer):
"""Implements lazy version of Adam algorithm suitable for sparse tensors.
In this variant, only moments that show up in the gradient get updated, and
only those portions of the gradient get applied to the parameters.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
defaults = dict(lr=lr, betas=betas, eps=eps)
super(SparseAdam, self).__init__(params, defaults)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if not grad.is_sparse:
raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

state['step'] += 1

grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()

def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

# Decay the first and second moment running average coefficient
# old <- b * old + (1 - b) * new
# <==> old += (1 - b) * (new - old)
old_exp_avg_values = exp_avg._sparse_mask(grad)._values()
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
exp_avg.add_(make_sparse(exp_avg_update_values))
old_exp_avg_sq_values = exp_avg_sq._sparse_mask(grad)._values()
exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))

# Dense addition again is intended, avoiding another _sparse_mask
numer = exp_avg_update_values.add_(old_exp_avg_values)
denom = exp_avg_sq_update_values.add_(old_exp_avg_sq_values).sqrt_().add_(group['eps'])
del exp_avg_update_values, exp_avg_sq_update_values

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

p.data.add_(make_sparse(-step_size * numer.div_(denom)))

return loss

0 comments on commit f76d6c0

Please sign in to comment.