forked from scverse/scanpy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve accuracy and speed of _get_meanvar
- Loading branch information
Showing
2 changed files
with
161 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,134 @@ | ||
import numpy as np | ||
from scipy.sparse import issparse | ||
|
||
from scipy import sparse | ||
import numba | ||
|
||
STANDARD_SCALER_FIXED = False | ||
|
||
|
||
def _get_mean_var(X): | ||
# - using sklearn.StandardScaler throws an error related to | ||
# int to long trafo for very large matrices | ||
# - using X.multiply is slower | ||
if not STANDARD_SCALER_FIXED: | ||
mean = X.mean(axis=0) | ||
if issparse(X): | ||
mean_sq = X.multiply(X).mean(axis=0) | ||
mean = mean.A1 | ||
mean_sq = mean_sq.A1 | ||
else: | ||
mean_sq = np.multiply(X, X).mean(axis=0) | ||
# enforece R convention (unbiased estimator) for variance | ||
var = (mean_sq - mean ** 2) * (X.shape[0] / (X.shape[0] - 1)) | ||
if sparse.issparse(X): | ||
mean, var = sparse_mean_variance_axis(X, axis=0) | ||
else: | ||
from sklearn.preprocessing import StandardScaler | ||
if STANDARD_SCALER_FIXED: | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
scaler = StandardScaler(with_mean=False).partial_fit(X) | ||
mean = scaler.mean_ | ||
# enforce R convention (unbiased estimator) | ||
var = scaler.var_ * (X.shape[0] / (X.shape[0] - 1)) | ||
scaler = StandardScaler(with_mean=False).partial_fit(X) | ||
mean, var = scaler.mean_, scaler.var_ | ||
else: | ||
mean = X.mean(axis=0, dtype=np.float64) | ||
var = X.var(axis=0, dtype=np.float64) | ||
# enforce R convention (unbiased estimator) for variance | ||
var = var * (X.shape[0] / (X.shape[0] - 1)) | ||
return mean, var | ||
|
||
|
||
def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int): | ||
""" | ||
This code and internal functions are based on sklearns | ||
`sparsefuncs.mean_variance_axis`. | ||
Modifications: | ||
* allow deciding on the output type, which can increase accuracy when calculating the mean and variance of 32bit floats. | ||
* This doesn't currently implement support for null values, but could. | ||
* Uses numba not cython | ||
""" | ||
assert axis in (0, 1) | ||
if isinstance(mtx, sparse.csr_matrix): | ||
if axis == 0: | ||
return sparse_mean_var_minor_axis( | ||
mtx.data, mtx.indices, mtx.shape[0], mtx.shape[1], np.float64 | ||
) | ||
elif axis == 1: | ||
return sparse_mean_var_major_axis( | ||
mtx.data, | ||
mtx.indices, | ||
mtx.indptr, | ||
mtx.shape[0], | ||
mtx.shape[1], | ||
np.float64, | ||
) | ||
elif isinstance(mtx, sparse.csc_matrix): | ||
if axis == 0: | ||
return sparse_mean_var_major_axis( | ||
mtx.data, | ||
mtx.indices, | ||
mtx.indptr, | ||
mtx.shape[1], | ||
mtx.shape[0], | ||
np.float64, | ||
) | ||
elif axis == 1: | ||
return sparse_mean_var_minor_axis( | ||
mtx.data, mtx.indices, mtx.shape[1], mtx.shape[0], np.float64 | ||
) | ||
else: | ||
raise ValueError( | ||
"This function only works on sparse csr and csc matrices" | ||
) | ||
|
||
|
||
@numba.njit(cache=True) | ||
def sparse_mean_var_minor_axis(data, indices, major_len, minor_len, dtype): | ||
""" | ||
Computes mean and variance for a sparse matrix for the minor axis. | ||
Given arrays for a csr matrix, returns the means and variances for each | ||
column back. | ||
""" | ||
non_zero = indices.shape[0] | ||
|
||
means = np.zeros(minor_len, dtype=dtype) | ||
variances = np.zeros_like(means, dtype=dtype) | ||
|
||
counts = np.zeros(minor_len, dtype=np.int64) | ||
|
||
for i in range(non_zero): | ||
col_ind = indices[i] | ||
means[col_ind] += data[i] | ||
|
||
for i in range(minor_len): | ||
means[i] /= major_len | ||
|
||
for i in range(non_zero): | ||
col_ind = indices[i] | ||
diff = data[i] - means[col_ind] | ||
variances[col_ind] += diff * diff | ||
counts[col_ind] += 1 | ||
|
||
for i in range(minor_len): | ||
variances[i] += (major_len - counts[i]) * means[i] ** 2 | ||
variances[i] /= major_len | ||
|
||
return means, variances | ||
|
||
|
||
@numba.njit(cache=True) | ||
def sparse_mean_var_major_axis( | ||
data, indices, indptr, major_len, minor_len, dtype | ||
): | ||
""" | ||
Computes mean and variance for a sparse array for the major axis. | ||
Given arrays for a csr matrix, returns the means and variances for each | ||
row back. | ||
""" | ||
means = np.zeros(major_len, dtype=dtype) | ||
variances = np.zeros_like(means, dtype=dtype) | ||
|
||
for i in range(major_len): | ||
startptr = indptr[i] | ||
endptr = indptr[i + 1] | ||
counts = endptr - startptr | ||
|
||
for j in range(startptr, endptr): | ||
means[i] += data[j] | ||
means[i] /= minor_len | ||
|
||
for j in range(startptr, endptr): | ||
diff = data[j] - means[i] | ||
variances[i] += diff * diff | ||
|
||
variances[i] += (minor_len - counts) * means[i] ** 2 | ||
variances[i] /= minor_len | ||
|
||
return means, variances |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters