Skip to content

Commit

Permalink
Fix pytorch#1447: sparse_mask doesn't make sense with uncoalesced ten…
Browse files Browse the repository at this point in the history
…sors (pytorch#1458)

* Make sparseMask error if mask is uncoalesced.

Fixes pytorch#1447.

Signed-off-by: Edward Z. Yang <[email protected]>

* Add test for sparse adagrad.

Previously, the sparse codepath was not exercised at all; this commit
adds a very simple test case "sparse Rosenbrock"; the idea is to do
Rosenbrock but then knock out one of the dimensions so that the
tensor is sparse.

Signed-off-by: Edward Z. Yang <[email protected]>
  • Loading branch information
ezyang authored and soumith committed May 3, 2017
1 parent 4ec0435 commit 80c0a87
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 10 deletions.
46 changes: 46 additions & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.optim as optim
import torch.legacy.optim as old_optim
from torch.autograd import Variable
from torch import sparse

from common import TestCase, run_tests

Expand Down Expand Up @@ -58,6 +59,46 @@ def eval():

self.assertLessEqual(params.data.dist(solution), initial_dist)

def _test_rosenbrock_sparse(self, constructor):
params_t = torch.Tensor([1.5, 1.5])

params = Variable(torch.Tensor([1.5, 1.5]), requires_grad=True)
params_c = Variable(torch.Tensor([1.5, 1.5]), requires_grad=True)
optimizer = constructor([params])
optimizer_c = constructor([params_c])

solution = torch.Tensor([1, 1])
initial_dist = params.data.dist(solution)

def eval(params, sparse_grad, w):
optimizer.zero_grad()
loss = rosenbrock(params)
loss.backward()
# params.grad.data now has the computed gradient. Turn
# it into a sparse tensor
params.grad.data.copy_(drosenbrock(params.data))
# This is really goofy
if w:
i = torch.LongTensor([[0]])
v = torch.DoubleTensor([params.grad.data[0]])
else:
i = torch.LongTensor([[1]])
v = torch.DoubleTensor([params.grad.data[1]])
x = sparse.DoubleTensor(i, v, torch.Size([2]))
if sparse_grad:
params.grad.data = x
else:
params.grad.data = x.to_dense()
return loss

for i in range(2000):
w = torch.rand(1)[0] > 0.5
optimizer.step(functools.partial(eval, params, True, w))
optimizer_c.step(functools.partial(eval, params_c, False, w))
self.assertEqual(params.data, params_c.data)

self.assertLessEqual(params.data.dist(solution), initial_dist)

def _test_basic_cases_template(self, weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
Expand Down Expand Up @@ -236,6 +277,11 @@ def test_adagrad(self):
lr=1e-1)
)

def test_adagrad_sparse(self):
self._test_rosenbrock_sparse(
lambda params: optim.Adagrad(params, lr=1e-1)
)

def test_adamax(self):
self._test_rosenbrock(
lambda params: optim.Adamax(params, lr=1e-1),
Expand Down
23 changes: 13 additions & 10 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,19 +492,19 @@ def _test_sparse_mask_shape(self, shape_i, shape_v=None):

def _test_sparse_mask_fixed(self):
i = self.IndexTensor([
[1, 3, 3, 0, 4],
[2, 1, 1, 2, 3],
[1, 3, 0, 4],
[2, 1, 2, 3],
])
v = self.ValueTensor([1, 2, 3, 4, 5])
x = self.SparseTensor(i, v, torch.Size([5, 4]))
v = self.ValueTensor([1, 2, 3, 4])
x = self.SparseTensor(i, v, torch.Size([5, 4])).coalesce()
dense = self.ValueTensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
])
exp_v = self.ValueTensor([7, 14, 14, 3, 20])
exp_v = self.ValueTensor([7, 14, 3, 20])
res = dense._sparse_mask(x)
expected = self.SparseTensor(i, exp_v, torch.Size([5, 4]))
self.assertEqual(res, expected)
Expand All @@ -519,11 +519,14 @@ def test_sparse_mask(self):

def _test_sparse_mask_hybrid_fixed(self):
i = self.IndexTensor([
[1, 3, 3, 0, 4],
[2, 1, 1, 2, 3],
[1, 3, 0, 4],
[2, 1, 2, 3],
])
v = self.ValueTensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
x = self.SparseTensor(i, v, torch.Size([5, 4, 2]))
v = self.ValueTensor([[1, 2], [2, 3], [3, 4], [4, 5]])
# TODO: This is also testing that, if coalesce is a no-op,
# the indices don't get permuted. I don't know if we actually
# want to give this invariant.
x = self.SparseTensor(i, v, torch.Size([5, 4, 2])).coalesce()
dense = self.ValueTensor([
[[1, 3], [2, 2], [3, 3], [4, 2]],
[[5, 7], [6, 7], [7, 9], [8, 9]],
Expand All @@ -532,7 +535,7 @@ def _test_sparse_mask_hybrid_fixed(self):
[[17, 7], [18, 2], [19, 7], [20, 1]],
])
res = dense._sparse_mask(x)
exp_v = self.ValueTensor([[7, 9], [14, 1], [14, 1], [3, 3], [20, 1]])
exp_v = self.ValueTensor([[7, 9], [14, 1], [3, 3], [20, 1]])
expected = self.SparseTensor(i, exp_v, torch.Size([5, 4, 2]))
self.assertEqual(res, expected)

Expand Down
1 change: 1 addition & 0 deletions torch/lib/THCS/generic/THCSTensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ int THCSTensor_(checkGPU)(THCState *state, unsigned int nSparseTensors, unsigned
}

void THCTensor_(sparseMask)(THCState *state, THCSTensor *r_, THCTensor *t, THCSTensor *mask) {
THArgCheck(mask->coalesced, 2, "mask is uncoalesced");
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 2, 3, r_, mask, t));
if(!THCSTensor_(isSameSizeAsDense)(state, mask, t)) {
THError("sparseMask operands have incompatible sizes");
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THS/generic/THSTensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ THSTensor *THSTensor_(newCoalesce)(THSTensor *self) {
}

void THTensor_(sparseMask)(THSTensor *r_, THTensor *t, THSTensor *mask) {
THArgCheck(mask->coalesced, 2, "mask is uncoalesced");
THSTensor_(resizeAs)(r_, mask);
if (mask->nnz == 0) {
THSTensor_(zero)(r_);
Expand Down

0 comments on commit 80c0a87

Please sign in to comment.