Skip to content

Commit

Permalink
Prefix values/indices/sparse_mask/nnz with underscore (pytorch#1457)
Browse files Browse the repository at this point in the history
As discussed in pytorch#1441.

I also added some docs giving clear guidance about how to coalescing
in sparse tensors.

Signed-off-by: Edward Z. Yang <[email protected]>
  • Loading branch information
ezyang authored and soumith committed May 3, 2017
1 parent f273377 commit 743e489
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 59 deletions.
57 changes: 42 additions & 15 deletions docs/source/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ efficiently store and process tensors for which the majority of elements
are zeros.

A sparse tensor is represented as a pair of dense tensors: a tensor
which contains the actual values :class:`torch.sparse.values`, and a
tensor which contains the coordinates of those values
:class:`torch.sparse.indices`. A sparse tensor can be constructed
of values and a tensor of indices. A sparse tensor can be constructed
by providing these two tensors, as well as the size of the sparse tensor
(which cannot be inferred from these tensors!)

Expand Down Expand Up @@ -46,34 +44,59 @@ An empty sparse tensor can be constructed by specifying its size:
and values:
[torch.FloatTensor with no dimension]

Sparse tensors can have duplicate entries for an index; such a tensor is
called non-coalesced. Duplicate entries are summed together when
coalescing (or converting to another representation). Some operations
(for example, :func:`torch.FloatTensor.add`) produce duplicate entries;
if you repeatedly perform these operations, you should coalesce your
sparse tensors to prevent them from growing too large.
.. note::

Our sparse tensor format permits *uncoalesced* sparse tensors, where
there may be duplicate coordinates in the indices; in this case,
the interpretation is that the value at that index is the sum of all
duplicate value entries. Uncoalesced tensors permit us to implement
certain operators more efficiently.

For the most part, you shouldn't have to care whether or not a
sparse tensor is coalesced or not, as most operations will work
identically given a coalesced or uncoalesced sparse tensor.
However, there are two cases in which you may need to care.

First, if you repeatedly perform an operation that can produce
duplicate entries (e.g., :func:`torch.sparse.FloatTensor.add`), you
should occasionally coalesce your sparse tensors to prevent
them from growing too large.

Second, some operators will produce different values depending on
whether or not they are coalesced or not (e.g.,
:func:`torch.sparse.FloatTensor._values` and
:func:`torch.sparse.FloatTensor._indices`, as well as
:func:`torch.Tensor._sparse_mask`). These operators are
prefixed by an underscore to indicate that they reveal internal
implementation details and should be used with care, since code
that works with coalesced sparse tensors may not work with
uncoalesced sparse tensors; generally speaking, it is safest
to explicitly coalesce before working with these operators.

For example, suppose that we wanted to implement an operator
by operating directly on :func:`torch.sparse.FloatTensor._values`.
Multiplication by a scalar can be implemented in the obvious way,
as multiplication distributes over addition; however, square root
cannot be implemented directly, since ``sqrt(a + b) != sqrt(a) +
sqrt(b)`` (which is what would be computed if you were given an
uncoalesced tensor.)

.. class:: FloatTensor()

.. method:: add
.. method:: add_
.. method:: clone
.. method:: contiguous
.. method:: dim
.. method:: div
.. method:: div_
.. method:: get_device
.. method:: hspmm
.. method:: indices
.. method:: is_contiguous
.. method:: mm
.. method:: mul
.. method:: mul_
.. method:: nnz
.. method:: resizeAs_
.. method:: size
.. method:: spadd
.. method:: sparse_mask
.. method:: spmm
.. method:: sspaddmm
.. method:: sspmm
Expand All @@ -83,5 +106,9 @@ sparse tensors to prevent them from growing too large.
.. method:: toDense
.. method:: transpose
.. method:: transpose_
.. method:: values
.. method:: zero_
.. method:: coalesce
.. method:: is_coalesced
.. method:: _indices
.. method:: _values
.. method:: _nnz
16 changes: 8 additions & 8 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def safeCoalesce(self, t):
tc = t.coalesce()

value_map = {}
for idx, val in zip(t.indices().t(), t.values()):
for idx, val in zip(t._indices().t(), t._values()):
idx_tup = tuple(idx)
if idx_tup in value_map:
value_map[idx_tup] += val
Expand All @@ -138,16 +138,16 @@ def safeCoalesce(self, t):

new_indices = sorted(list(value_map.keys()))
new_values = [value_map[idx] for idx in new_indices]
if t.values().ndimension() < 2:
new_values = t.values().new(new_values)
if t._values().ndimension() < 2:
new_values = t._values().new(new_values)
else:
new_values = torch.stack(new_values)

new_indices = t.indices().new(new_indices).t()
new_indices = t._indices().new(new_indices).t()
tg = t.new(new_indices, new_values, t.size())

self.assertEqual(tc.indices(), tg.indices())
self.assertEqual(tc.values(), tg.values())
self.assertEqual(tc._indices(), tg._indices())
self.assertEqual(tc._values(), tg._values())

return tg

Expand Down Expand Up @@ -178,8 +178,8 @@ def assertTensorsEqual(a, b):
if x.is_sparse:
x = self.safeCoalesce(x)
y = self.safeCoalesce(y)
assertTensorsEqual(x.indices(), y.indices())
assertTensorsEqual(x.values(), y.values())
assertTensorsEqual(x._indices(), y._indices())
assertTensorsEqual(x._values(), y._values())
else:
assertTensorsEqual(x, y)
elif type(x) == str and type(y) == str:
Expand Down
52 changes: 26 additions & 26 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def assert_uncoalesced(self, x):
# field overwritten to a tensor of ones, coalesce it, and then
# check if any value entries are > 1 (which indicates that the
# original was uncoalesced.)
i = x.indices().clone()
v = x.values().clone().fill_(1)
i = x._indices().clone()
v = x._values().clone().fill_(1)
y = torch.sparse.DoubleTensor(i, v, x.size())
z = self.safeCoalesce(y)
assert (z.values() > 1).sum() > 0
assert (z._values() > 1).sum() > 0

def randn(self, *args, **kwargs):
"""
Expand All @@ -95,21 +95,21 @@ def randn(self, *args, **kwargs):
def test_basic(self):
x, i, v = self._gen_sparse(3, 10, 100)

self.assertEqual(i, x.indices())
self.assertEqual(v, x.values())
self.assertEqual(i, x._indices())
self.assertEqual(v, x._values())

x, i, v = self._gen_sparse(3, 10, [100, 100, 100])
self.assertEqual(i, x.indices())
self.assertEqual(v, x.values())
self.assertEqual(i, x._indices())
self.assertEqual(v, x._values())
self.assertEqual(x.ndimension(), 3)
self.assertEqual(x.coalesce().nnz(), 10)
self.assertEqual(x.coalesce()._nnz(), 10)
for i in range(3):
self.assertEqual(x.size(i), 100)

# Make sure we can access empty indices / values
x = self.SparseTensor()
self.assertEqual(x.indices().numel(), 0)
self.assertEqual(x.values().numel(), 0)
self.assertEqual(x._indices().numel(), 0)
self.assertEqual(x._values().numel(), 0)

def test_to_dense(self):
i = self.IndexTensor([
Expand Down Expand Up @@ -179,8 +179,8 @@ def test_contig(self):
])
exp_v = self.ValueTensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7])
x = self.safeCoalesce(x)
self.assertEqual(exp_i, x.indices())
self.assertEqual(exp_v, x.values())
self.assertEqual(exp_i, x._indices())
self.assertEqual(exp_v, x._values())

i = self.IndexTensor([
[2, 0, 2, 1],
Expand All @@ -197,8 +197,8 @@ def test_contig(self):
exp_v = self.ValueTensor([2, 1, 3, 4])

x = self.safeCoalesce(x)
self.assertEqual(exp_i, x.indices())
self.assertEqual(exp_v, x.values())
self.assertEqual(exp_i, x._indices())
self.assertEqual(exp_v, x._values())

# Duplicate indices
i = self.IndexTensor([
Expand All @@ -216,8 +216,8 @@ def test_contig(self):
exp_v = self.ValueTensor([6, 4])

x = self.safeCoalesce(x)
self.assertEqual(exp_i, x.indices())
self.assertEqual(exp_v, x.values())
self.assertEqual(exp_i, x._indices())
self.assertEqual(exp_v, x._values())

def test_contig_hybrid(self):
i = self.IndexTensor([
Expand All @@ -238,8 +238,8 @@ def test_contig_hybrid(self):
[3, 4], [5, 6], [9, 10], [8, 9], [7, 8],
])
x = self.safeCoalesce(x)
self.assertEqual(exp_i, x.indices())
self.assertEqual(exp_v, x.values())
self.assertEqual(exp_i, x._indices())
self.assertEqual(exp_v, x._values())

i = self.IndexTensor([
[2, 0, 2, 1],
Expand All @@ -256,8 +256,8 @@ def test_contig_hybrid(self):
exp_v = self.ValueTensor([[2, 2, 2], [1, 1, 1], [3, 3, 3], [4, 4, 4]])

x = self.safeCoalesce(x)
self.assertEqual(exp_i, x.indices())
self.assertEqual(exp_v, x.values())
self.assertEqual(exp_i, x._indices())
self.assertEqual(exp_v, x._values())

# Duplicate indices
i = self.IndexTensor([
Expand All @@ -275,8 +275,8 @@ def test_contig_hybrid(self):
exp_v = self.ValueTensor([[6, 4, 5], [4, 3, 4]])

x = self.safeCoalesce(x)
self.assertEqual(exp_i, x.indices())
self.assertEqual(exp_v, x.values())
self.assertEqual(exp_i, x._indices())
self.assertEqual(exp_v, x._values())

def test_transpose(self):
x = self._gen_sparse(4, 20, 5)[0]
Expand Down Expand Up @@ -463,8 +463,8 @@ def _test_basic_ops_shape(self, shape_i, shape_v=None):
self.assertTrue(y.is_coalesced())
self.assertEqual(x1, y)
# check that coalesce is out of place
y.values().add_(1)
self.assertEqual(z.values() + 1, y.values())
y._values().add_(1)
self.assertEqual(z._values() + 1, y._values())

def test_basic_ops(self):
self._test_basic_ops_shape([5, 6])
Expand Down Expand Up @@ -505,7 +505,7 @@ def _test_sparse_mask_fixed(self):
[17, 18, 19, 20],
])
exp_v = self.ValueTensor([7, 14, 14, 3, 20])
res = dense.sparse_mask(x)
res = dense._sparse_mask(x)
expected = self.SparseTensor(i, exp_v, torch.Size([5, 4]))
self.assertEqual(res, expected)

Expand All @@ -531,7 +531,7 @@ def _test_sparse_mask_hybrid_fixed(self):
[[13, 5], [14, 1], [15, 1], [16, 6]],
[[17, 7], [18, 2], [19, 7], [20, 1]],
])
res = dense.sparse_mask(x)
res = dense._sparse_mask(x)
exp_v = self.ValueTensor([[7, 9], [14, 1], [14, 1], [3, 3], [20, 1]])
expected = self.SparseTensor(i, exp_v, torch.Size([5, 4, 2]))
self.assertEqual(res, expected)
Expand Down
8 changes: 4 additions & 4 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def _type(self, new_type=None, async=False):
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
new_type_name = new_type.__module__ + '.' + new_type.__name__
new_values_type_name = new_type_name.replace('.sparse', '')
new_values = self.values().type(new_values_type_name, async)
return new_type(self.indices(), new_values, self.size())
new_values = self._values().type(new_values_type_name, async)
return new_type(self._indices(), new_values, self.size())
if new_type.is_sparse:
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
return new_type(self.size()).copy_(self, async)
Expand Down Expand Up @@ -57,8 +57,8 @@ def _cuda(self, device=None, async=False):
with torch.cuda.device(device):
if self.is_sparse:
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
indices = self.indices().cuda(device, async)
values = self.values().cuda(device, async)
indices = self._indices().cuda(device, async)
values = self._values().cuda(device, async)
return new_type(indices, values, self.size())
else:
new_type = getattr(torch.cuda, self.__class__.__name__)
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/generic/methods/SparseTensor.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)

[[
name: nnz
python_name: _nnz
sparse: yes
return: long
arguments:
Expand All @@ -72,6 +73,7 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)

[[
name: indices
python_name: _indices
sparse: yes
cname: newIndices
return: THIndexTensor*
Expand All @@ -81,6 +83,7 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)

[[
name: values
python_name: _values
sparse: yes
cname: newValues
return: THTensor*
Expand Down Expand Up @@ -414,7 +417,7 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
]]

[[
name: sparse_mask
name: _sparse_mask
cname: sparseMask
return: argument 0
arguments:
Expand Down
8 changes: 4 additions & 4 deletions torch/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def step(self, closure=None):

if p.grad.data.is_sparse:
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad.indices()
grad_values = grad.values()
grad_indices = grad._indices()
grad_values = grad._values()
size = torch.Size([x for x in grad.size()])

def make_sparse(values):
Expand All @@ -75,8 +75,8 @@ def make_sparse(values):
return constructor()
return constructor(grad_indices, values, size)
state['sum'].add_(make_sparse(grad_values.pow(2)))
std = state['sum'].sparse_mask(grad)
std_values = std.values().sqrt_().add_(1e-10)
std = state['sum']._sparse_mask(grad)
std_values = std._values().sqrt_().add_(1e-10)
p.data.add_(-clr, make_sparse(grad_values / std_values))
else:
state['sum'].addcmul_(1, grad, grad)
Expand Down
2 changes: 1 addition & 1 deletion torch/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __ixor__(self, other):

def __str__(self):
return '{} with indices:\n{}and values:\n{}'.format(
self.__class__.__name__, self.indices(), self.values())
self.__class__.__name__, self._indices(), self._values())


class DoubleTensor(_SparseBase, _C.SparseDoubleTensorBase, _TensorBase):
Expand Down

0 comments on commit 743e489

Please sign in to comment.