Skip to content

Commit

Permalink
speed improvements, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Raison authored and soumith committed Apr 18, 2017
1 parent 655c225 commit 701e631
Show file tree
Hide file tree
Showing 26 changed files with 979 additions and 733 deletions.
28 changes: 25 additions & 3 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def is_iterable(obj):
class TestCase(unittest.TestCase):
precision = 1e-5

def assertTensorsSlowEqual(self, x, y, prec=None, message=''):
max_err = 0
self.assertEqual(x.size(), y.size())
for index in iter_indices(x):
max_err = max(max_err, abs(x[index] - y[index]))
self.assertLessEqual(max_err, prec, message)

def assertEqual(self, x, y, prec=None, message=''):
if prec is None:
prec = self.precision
Expand All @@ -132,10 +139,19 @@ def assertTensorsEqual(a, b):
if a.numel() > 0:
b = b.type_as(a)
b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
max_err = (a - b).abs().max()
# check that NaNs are in the same locations
nan_mask = a != a
self.assertTrue(torch.equal(nan_mask, b != b))
diff = a - b
diff[nan_mask] = 0
if diff.is_signed():
diff = diff.abs()
max_err = diff.max()
self.assertLessEqual(max_err, prec, message)
self.assertEqual(x.is_sparse, y.is_sparse, message)
if x.is_sparse:
x = x.clone().contiguous()
y = y.clone().contiguous()
assertTensorsEqual(x.indices(), y.indices())
assertTensorsEqual(x.values(), y.values())
else:
Expand Down Expand Up @@ -167,8 +183,14 @@ def assertNotEqual(self, x, y, prec=None, message=''):
self.assertGreater(x.numel(), 0)
y = y.type_as(x)
y = y.cuda(device=x.get_device()) if x.is_cuda else y.cpu()
max_err = (x - y).abs().max()
self.assertGreaterEqual(max_err, prec, message)
nan_mask = x != x
if torch.equal(nan_mask, y != y):
diff = x - y
if diff.is_signed():
diff = diff.abs()
diff[nan_mask] = 0
max_err = diff.max()
self.assertGreaterEqual(max_err, prec, message)
elif type(x) == str and type(y) == str:
super(TestCase, self).assertNotEqual(x, y)
elif is_iterable(x) and is_iterable(y):
Expand Down
3 changes: 1 addition & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,8 +2615,7 @@ def add_test(test):
constructor=lambda: nn.Embedding(4, 3, sparse=True),
input=Variable(torch.randperm(2).repeat(1, 2)),
jacobian_input=False,
fullname='Embedding_sparse',
test_cuda=False,
fullname='Embedding_sparse'
),
dict(
constructor=lambda: nn.FractionalMaxPool2d(
Expand Down
Loading

0 comments on commit 701e631

Please sign in to comment.