Skip to content

Commit

Permalink
BUG: tighten condition for bsr eliminate_zeros fast path (scipy#9690)
Browse files Browse the repository at this point in the history
* BUG: tighten condition for bsr eliminate_zeros fast path

* move test as requested by CJ
  • Loading branch information
ExpHP authored and perimosocordiae committed Jan 21, 2019
1 parent 9053491 commit bc9d407
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
7 changes: 4 additions & 3 deletions scipy/sparse/bsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,16 +542,17 @@ def transpose(self, axes=None, copy=False):

def eliminate_zeros(self):
"""Remove zero elements in-place."""

if not self.nnz:
return # nothing to do

R,C = self.blocksize
M,N = self.shape

mask = (self.data != 0).reshape(-1,R*C).sum(axis=1) # nonzero blocks

nonzero_blocks = mask.nonzero()[0]

if len(nonzero_blocks) == 0:
return # nothing to do

self.data[:len(nonzero_blocks)] = self.data[nonzero_blocks]

# modifies self.indptr and self.indices *in place*
Expand Down
24 changes: 24 additions & 0 deletions scipy/sparse/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4239,6 +4239,30 @@ def test_eliminate_zeros(self):
assert_array_equal(asp.nnz, 3*4)
assert_array_equal(asp.todense(),bsp.todense())

# github issue #9687
def test_eliminate_zeros_all_zero(self):
np.random.seed(0)
m = bsr_matrix(np.random.random((12, 12)), blocksize=(2, 3))

# eliminate some blocks, but not all
m.data[m.data <= 0.9] = 0
m.eliminate_zeros()
assert_equal(m.nnz, 66)
assert_array_equal(m.data.shape, (11, 2, 3))

# eliminate all remaining blocks
m.data[m.data <= 1.0] = 0
m.eliminate_zeros()
assert_equal(m.nnz, 0)
assert_array_equal(m.data.shape, (0, 2, 3))
assert_array_equal(m.todense(), np.zeros((12,12)))

# test fast path
m.eliminate_zeros()
assert_equal(m.nnz, 0)
assert_array_equal(m.data.shape, (0, 2, 3))
assert_array_equal(m.todense(), np.zeros((12,12)))

def test_bsr_matvec(self):
A = bsr_matrix(arange(2*3*4*5).reshape(2*4,3*5), blocksize=(4,5))
x = arange(A.shape[1]).reshape(-1,1)
Expand Down

0 comments on commit bc9d407

Please sign in to comment.