Skip to content

Commit

Permalink
Improve accuracy and speed of _get_meanvar
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Oct 8, 2019
1 parent e46f89b commit 135bed1
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 20 deletions.
145 changes: 125 additions & 20 deletions scanpy/preprocessing/_utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,134 @@
import numpy as np
from scipy.sparse import issparse

from scipy import sparse
import numba

STANDARD_SCALER_FIXED = False


def _get_mean_var(X):
# - using sklearn.StandardScaler throws an error related to
# int to long trafo for very large matrices
# - using X.multiply is slower
if not STANDARD_SCALER_FIXED:
mean = X.mean(axis=0)
if issparse(X):
mean_sq = X.multiply(X).mean(axis=0)
mean = mean.A1
mean_sq = mean_sq.A1
else:
mean_sq = np.multiply(X, X).mean(axis=0)
# enforece R convention (unbiased estimator) for variance
var = (mean_sq - mean ** 2) * (X.shape[0] / (X.shape[0] - 1))
if sparse.issparse(X):
mean, var = sparse_mean_variance_axis(X, axis=0)
else:
from sklearn.preprocessing import StandardScaler
if STANDARD_SCALER_FIXED:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler(with_mean=False).partial_fit(X)
mean = scaler.mean_
# enforce R convention (unbiased estimator)
var = scaler.var_ * (X.shape[0] / (X.shape[0] - 1))
scaler = StandardScaler(with_mean=False).partial_fit(X)
mean, var = scaler.mean_, scaler.var_
else:
mean = X.mean(axis=0, dtype=np.float64)
var = X.var(axis=0, dtype=np.float64)
# enforce R convention (unbiased estimator) for variance
var = var * (X.shape[0] / (X.shape[0] - 1))
return mean, var


def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int):
"""
This code and internal functions are based on sklearns
`sparsefuncs.mean_variance_axis`.
Modifications:
* allow deciding on the output type, which can increase accuracy when calculating the mean and variance of 32bit floats.
* This doesn't currently implement support for null values, but could.
* Uses numba not cython
"""
assert axis in (0, 1)
if isinstance(mtx, sparse.csr_matrix):
if axis == 0:
return sparse_mean_var_minor_axis(
mtx.data, mtx.indices, mtx.shape[0], mtx.shape[1], np.float64
)
elif axis == 1:
return sparse_mean_var_major_axis(
mtx.data,
mtx.indices,
mtx.indptr,
mtx.shape[0],
mtx.shape[1],
np.float64,
)
elif isinstance(mtx, sparse.csc_matrix):
if axis == 0:
return sparse_mean_var_major_axis(
mtx.data,
mtx.indices,
mtx.indptr,
mtx.shape[1],
mtx.shape[0],
np.float64,
)
elif axis == 1:
return sparse_mean_var_minor_axis(
mtx.data, mtx.indices, mtx.shape[1], mtx.shape[0], np.float64
)
else:
raise ValueError(
"This function only works on sparse csr and csc matrices"
)


@numba.njit(cache=True)
def sparse_mean_var_minor_axis(data, indices, major_len, minor_len, dtype):
"""
Computes mean and variance for a sparse matrix for the minor axis.
Given arrays for a csr matrix, returns the means and variances for each
column back.
"""
non_zero = indices.shape[0]

means = np.zeros(minor_len, dtype=dtype)
variances = np.zeros_like(means, dtype=dtype)

counts = np.zeros(minor_len, dtype=np.int64)

for i in range(non_zero):
col_ind = indices[i]
means[col_ind] += data[i]

for i in range(minor_len):
means[i] /= major_len

for i in range(non_zero):
col_ind = indices[i]
diff = data[i] - means[col_ind]
variances[col_ind] += diff * diff
counts[col_ind] += 1

for i in range(minor_len):
variances[i] += (major_len - counts[i]) * means[i] ** 2
variances[i] /= major_len

return means, variances


@numba.njit(cache=True)
def sparse_mean_var_major_axis(
data, indices, indptr, major_len, minor_len, dtype
):
"""
Computes mean and variance for a sparse array for the major axis.
Given arrays for a csr matrix, returns the means and variances for each
row back.
"""
means = np.zeros(major_len, dtype=dtype)
variances = np.zeros_like(means, dtype=dtype)

for i in range(major_len):
startptr = indptr[i]
endptr = indptr[i + 1]
counts = endptr - startptr

for j in range(startptr, endptr):
means[i] += data[j]
means[i] /= minor_len

for j in range(startptr, endptr):
diff = data[j] - means[i]
variances[i] += diff * diff

variances[i] += (minor_len - counts) * means[i] ** 2
variances[i] /= minor_len

return means, variances
36 changes: 36 additions & 0 deletions scanpy/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,42 @@ def test_log1p_chunked():
assert np.allclose(ad3.X, ad.X)


def test_mean_var_sparse():
from sklearn.utils.sparsefuncs import mean_variance_axis

csr64 = sp.random(10000, 1000, format="csr", dtype=np.float64)
csc64 = csr64.tocsc()

# Test that we're equivalent for 64 bit
for mtx in (csr64, csc64):
scm, scv = sc.pp._utils._get_mean_var(mtx)
skm, skv = mean_variance_axis(mtx, 0)
skv *= (mtx.shape[0] / (mtx.shape[0] - 1))

assert np.allclose(scm, skm)
assert np.allclose(scv, skv)

csr32 = csr64.astype(np.float32)
csc32 = csc64.astype(np.float32)

# Test whether ours is more accurate for 32 bit
for mtx32, mtx64 in [(csc32, csc64), (csr32, csr64)]:
scm32, scv32 = sc.pp._utils._get_mean_var(mtx32)
scm64, scv64 = sc.pp._utils._get_mean_var(mtx64)
skm32, skv32 = mean_variance_axis(mtx32, 0)
skm64, skv64 = mean_variance_axis(mtx64, 0)
skv32 *= (mtx.shape[0] / (mtx.shape[0] - 1))
skv64 *= (mtx.shape[0] / (mtx.shape[0] - 1))

m_resid_sc = np.mean(np.abs(scm64 - scm32))
m_resid_sk = np.mean(np.abs(skm64 - skm32))
v_resid_sc = np.mean(np.abs(scv64 - scv32))
v_resid_sk = np.mean(np.abs(skv64 - skv32))

assert m_resid_sc < m_resid_sk
assert v_resid_sc < v_resid_sk


def test_normalize_per_cell():
adata = AnnData(
np.array([[1, 0], [3, 0], [5, 6]]))
Expand Down

0 comments on commit 135bed1

Please sign in to comment.