Skip to content

Commit

Permalink
Make is_sparse a property of MaskedTensor (pytorch#110725)
Browse files Browse the repository at this point in the history
Fixes pytorch#104574

Seeing that MaskedTensor is a prototype, the BC breaking nature of this change seems okay?

Locally tested:
<img width="1372" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/239e61ba-e0b9-4909-8c7a-0ce3869d7375">

Pull Request resolved: pytorch#110725
Approved by: https://github.com/cpuhrsch
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Oct 9, 2023
1 parent 6c8096e commit 2aa0ba3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions torch/masked/maskedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,5 +331,6 @@ def is_sparse_csr(self):
return self.layout == torch.sparse_csr

# Update later to support more sparse layouts
@property
def is_sparse(self):
return self.is_sparse_coo() or self.is_sparse_csr()
6 changes: 3 additions & 3 deletions torch/masked/maskedtensor/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _torch_reduce_all(fn):
def reduce_all(self):
masked_fn = _get_masked_fn(fn)
data = self.get_data()
mask = self.get_mask().values() if self.is_sparse() else self.get_mask()
mask = self.get_mask().values() if self.is_sparse else self.get_mask()
# When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
# element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
# Therefore, this implementation calculates it using the strides.
Expand All @@ -67,7 +67,7 @@ def reduce_all(self):
result_data = torch.sum(idx * stride)

# we simply pass in the values for sparse COO/CSR tensors
elif self.is_sparse():
elif self.is_sparse:
result_data = masked_fn(masked_tensor(data.values(), mask))

else:
Expand All @@ -80,7 +80,7 @@ def reduce_all(self):

def _torch_reduce_dim(fn):
def reduce_dim(self, dim, keepdim=False, dtype=None):
if self.is_sparse():
if self.is_sparse:
msg = (
f"The sparse version of {fn} is not implemented in reductions.\n"
"If you would like this operator to be supported, please file an issue for a feature request at "
Expand Down

0 comments on commit 2aa0ba3

Please sign in to comment.