Skip to content

Commit

Permalink
Support broadcast_to on sparse COO tensors (pytorch#71073)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#71073

cc nikitaved pearu cpuhrsch

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D33645744

Pulled By: cpuhrsch

fbshipit-source-id: 4775c9636c4e868022a8c1bbfec93e351d1cf885
(cherry picked from commit 640f21e)
  • Loading branch information
pearu authored and pytorchmergebot committed Jan 19, 2022
1 parent 9b9b878 commit 677fab6
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ Tensor FunctionalInverses::values_inverse(const Tensor& base, const Tensor& muta
return Tensor();
}

Tensor FunctionalInverses::_sparse_broadcast_to_inverse(const Tensor& base, const Tensor& mutated_view, at::IntArrayRef size) {
TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
return Tensor();
}

Tensor FunctionalInverses::crow_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view) {
TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
return Tensor();
Expand Down
76 changes: 76 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,82 @@ Tensor& set_cpu_(Tensor& result) {
return result;
}

Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
TORCH_CHECK(self.is_sparse(), "input must be sparse tensor");
int64_t sparse_extra_ndim = size.size() - self.dim();
int64_t sparse_ndim = size.size() - self.dense_dim();
TORCH_CHECK(sparse_extra_ndim >= 0, "input not broadcastable to size with smaller dimensionality");
Tensor indices = self._indices();
Tensor values = self._values();
auto nnz = values.size(0);

std::vector<int64_t> broadcast_sizes;
std::vector<int64_t> broadcast_dense_sizes;
std::vector<int64_t> broadcast_dims;
std::vector<int64_t> unchanged_dims;
broadcast_sizes.reserve(sparse_ndim);
broadcast_dense_sizes.reserve(self.dense_dim() + 1);
broadcast_dims.reserve(self.sparse_dim());
unchanged_dims.reserve(self.sparse_dim());
int64_t nnz_factor = 1;
int64_t min_broadcast_dim = (sparse_extra_ndim > 0 ? 0: -1);
int64_t max_unchanged_dim = -1;
for (int64_t i=0; i<sparse_extra_ndim; i++) {
auto d = size[i];
nnz_factor *= d;
broadcast_sizes.emplace_back(d);
}
for (int64_t i=0; i<self.sparse_dim(); i++) {
auto d = size[sparse_extra_ndim + i];
if (self.size(i) != d) {
TORCH_CHECK(self.size(i) == 1,
"The expanded size of the tensor (",size[sparse_extra_ndim + i],") ",
"must match the existing size (",self.size(i),")");
nnz_factor *= d;
broadcast_sizes.emplace_back(d);
if (min_broadcast_dim == -1) {
min_broadcast_dim = sparse_extra_ndim + i;
}
broadcast_dims.emplace_back(i);
} else {
unchanged_dims.emplace_back(i);
max_unchanged_dim = sparse_extra_ndim + i;
}
}
// to_broadcast conserves is_coalesced property iff only the last
// sparse dimensions are expaned. Possible expansion of dense
// dimensions can be discarded as it does not affect the is_coalesce
// property.
bool is_coalesced = self.dim()==0 || (self.is_coalesced() && (max_unchanged_dim < min_broadcast_dim || min_broadcast_dim == -1));

broadcast_dense_sizes.emplace_back(nnz);
for (int64_t i=0; i<self.dense_dim(); i++) {
broadcast_dense_sizes.emplace_back(size[sparse_extra_ndim + self.sparse_dim() + i]);
}

std::vector<int64_t> new_indices_size{sparse_ndim, nnz * nnz_factor};
std::vector<int64_t> new_values_size(values.sizes().vec());
new_values_size[0] = new_indices_size[1];

Tensor new_values = values.expand(broadcast_dense_sizes).repeat_interleave(nnz_factor, 0);
Tensor new_indices = at::native::new_empty(indices, new_indices_size);
if (broadcast_sizes.size()>0) {
// ones(broadcast_sizes).nonzero() is equivalent to
// product(map(arange, broadcast_sizes)) but avoids creating
// auxilary arange tensors
Tensor broadcast_indices = at::native::new_ones(indices, broadcast_sizes).nonzero().transpose(0, 1).tile(nnz);
new_indices.narrow(0, 0, sparse_extra_ndim).copy_(broadcast_indices.narrow(0, 0, sparse_extra_ndim));
for (size_t i=0; i<broadcast_dims.size(); i++) {
int64_t j=broadcast_dims[i];
new_indices.select(0, sparse_extra_ndim + j).copy_(broadcast_indices.select(0, sparse_extra_ndim + i));
}
}
for (int64_t j:unchanged_dims) {
new_indices.select(0, sparse_extra_ndim + j).copy_(indices.select(0, j).repeat_interleave(nnz_factor));
}
return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced);
}

Tensor broadcast_to(const Tensor& self, IntArrayRef size) {
return self.expand(size);
}
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,11 @@
- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
variants: function, method

- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
variants: function
dispatch:
SparseCPU, SparseCUDA: sparse_broadcast_to

- func: cat(Tensor[] tensors, int dim=0) -> Tensor
dispatch:
CompositeExplicitAutograd: cat
Expand Down
45 changes: 45 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3359,6 +3359,51 @@ def test_cpu_sparse_dense_mul(self, device):
with self.assertRaisesRegex(NotImplementedError, "CUDA"):
t23 * s

@dtypes(torch.double, torch.cdouble)
def test_full_broadcast_to(self, device, dtype):
def can_broadcast(s0, s1):
s0 = tuple(reversed(s0))
s1 = tuple(reversed(s1))
for i in range(len(s0)):
if s0[i] != 1 and s0[i] != s1[i]:
return False
return True
sizes = (
(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)
)
for s0, s1 in itertools.combinations(sizes, r=2):
t = make_tensor(s0, device, dtype, low=-9, high=9)
for sparse_dims in range(1, len(s0) + 1):
s = t.to_sparse(sparse_dims)
if can_broadcast(s0, s1):
t_res = torch.broadcast_to(t, s1)
s_res = torch._sparse_broadcast_to(s, s1)
torch._validate_sparse_coo_tensor_args(s_res._indices(), s_res._values(), s_res.shape)
if s_res.is_coalesced():
# ensure that is_coalesced is estimated correctly
self.assertEqual(s_res, torch.sparse_coo_tensor(s_res._indices(), s_res._values(), s_res.shape).coalesce())
self.assertEqual(s_res.to_dense(), t_res)
else:
with self.assertRaisesRegex(RuntimeError,
r"The expanded size of the tensor \(\d\) "
r"must match the existing size \(\d\)"):
torch._sparse_broadcast_to(s, s1)

@coalescedonoff
@dtypes(torch.double, torch.cdouble)
def test_sparse_broadcast_to(self, device, dtype, coalesced):
def test(sparse_dims, nnz, with_size, new_size):
x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
y = self.safeToDense(x)
x1 = torch._sparse_broadcast_to(x, new_size)
y1 = y.broadcast_to(new_size)
self.assertEqual(self.safeToDense(x1), y1)

test(4, 6, [7, 3, 1, 3, 0], [7, 3, 4, 3, 0])
test(4, 6, [7, 3, 1, 3, 0], [2, 7, 3, 1, 3, 0])
test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])


class TestSparseOneOff(TestCase):
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
Expand Down

0 comments on commit 677fab6

Please sign in to comment.