Skip to content

Commit

Permalink
sparse_mean_variance_axis now uses all cores (scverse#3015)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp A <[email protected]>
  • Loading branch information
Intron7 and flying-sheep authored Apr 23, 2024
1 parent ee8505b commit a70582e
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 59 deletions.
1 change: 1 addition & 0 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"scanpy": [""],
"python-igraph": [""],
// "psutil": [""]
"pooch": [""],
},

// Combinations of libraries/python versions can be excluded/included
Expand Down
31 changes: 28 additions & 3 deletions benchmarks/benchmarks/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from typing import TYPE_CHECKING

import scanpy as sc
from scanpy.preprocessing._utils import _get_mean_var

from .utils import pbmc68k_reduced
from .utils import lung93k, pbmc3k, pbmc68k_reduced

if TYPE_CHECKING:
from anndata import AnnData
Expand All @@ -18,9 +19,22 @@
adata: AnnData


def setup():
def setup(*params: str):
"""Setup all tests.
The SparseDenseSuite below defines a parameter,
the other tests none.
"""
global adata
adata = pbmc68k_reduced()

if len(params) == 0 or params[0] == "pbmc68k_reduced":
adata = pbmc68k_reduced()
elif params[0] == "pbmc3k":
adata = pbmc3k()
elif params[0] == "lung93k":
adata = lung93k()
else:
raise ValueError(f"Unknown dataset {params[0]}")


def time_calculate_qc_metrics():
Expand Down Expand Up @@ -99,3 +113,14 @@ def time_scale():

def peakmem_scale():
sc.pp.scale(adata, max_value=10)


class SparseDenseSuite:
params = ["pbmc68k_reduced", "pbmc3k", "lung93k"]
param_names = ["dataset"]

def time_mean_var(self, *_):
_get_mean_var(adata.X)

def peakmem_mean_var(self, *_):
_get_mean_var(adata.X)
37 changes: 31 additions & 6 deletions benchmarks/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
from __future__ import annotations

from functools import cache
from typing import TYPE_CHECKING

import pooch

import scanpy as sc

if TYPE_CHECKING:
from anndata import AnnData

_pbmc68k_reduced: AnnData | None = None

@cache
def _pbmc68k_reduced() -> AnnData:
return sc.datasets.pbmc68k_reduced()


def pbmc68k_reduced() -> AnnData:
return _pbmc68k_reduced().copy()


@cache
def _pbmc3k() -> AnnData:
return sc.datasets.pbmc3k()


def pbmc3k() -> AnnData:
return _pbmc3k().copy()


@cache
def _lung93k() -> AnnData:
path = pooch.retrieve(
url="https://figshare.com/ndownloader/files/45788454",
known_hash="md5:4f28af5ff226052443e7e0b39f3f9212",
)
return sc.read_h5ad(path)


def pbmc68k_reduced():
global _pbmc68k_reduced
if _pbmc68k_reduced is None:
_pbmc68k_reduced = sc.datasets.pbmc68k_reduced()
return _pbmc68k_reduced.copy()
def lung93k() -> AnnData:
return _lung93k().copy()
1 change: 1 addition & 0 deletions docs/release-notes/1.10.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@

```{rubric} Performance
```
* `sparse_mean_variance_axis` now uses all cores for the calculations {pr}`3015` {smaller}`S Dicks`
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ source = [".", "**/site-packages"]
exclude_also = [
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
# https://github.com/numba/numba/issues/4268
"@numba.njit.*",
]

[tool.ruff.format]
Expand Down
103 changes: 53 additions & 50 deletions scanpy/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,78 +65,81 @@ def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int):
if axis == ax_minor:
return sparse_mean_var_major_axis(
mtx.data,
mtx.indices,
mtx.indptr,
major_len=shape[0],
minor_len=shape[1],
dtype=np.float64,
n_threads=numba.get_num_threads(),
)
else:
return sparse_mean_var_minor_axis(mtx.data, mtx.indices, *shape, np.float64)
return sparse_mean_var_minor_axis(
mtx.data,
mtx.indices,
mtx.indptr,
major_len=shape[0],
minor_len=shape[1],
n_threads=numba.get_num_threads(),
)


@numba.njit(cache=True)
def sparse_mean_var_minor_axis(data, indices, major_len, minor_len, dtype):
@numba.njit(cache=True, parallel=True)
def sparse_mean_var_minor_axis(
data, indices, indptr, *, major_len, minor_len, n_threads
):
"""
Computes mean and variance for a sparse matrix for the minor axis.
Given arrays for a csr matrix, returns the means and variances for each
column back.
"""
non_zero = indices.shape[0]

means = np.zeros(minor_len, dtype=dtype)
variances = np.zeros_like(means, dtype=dtype)

counts = np.zeros(minor_len, dtype=np.int64)

for i in range(non_zero):
col_ind = indices[i]
means[col_ind] += data[i]

for i in range(minor_len):
means[i] /= major_len

for i in range(non_zero):
col_ind = indices[i]
diff = data[i] - means[col_ind]
variances[col_ind] += diff * diff
counts[col_ind] += 1

for i in range(minor_len):
variances[i] += (major_len - counts[i]) * means[i] ** 2
variances[i] /= major_len

rows = len(indptr) - 1
sums_minor = np.zeros((n_threads, minor_len))
squared_sums_minor = np.zeros((n_threads, minor_len))
means = np.zeros(minor_len)
variances = np.zeros(minor_len)
for i in numba.prange(n_threads):
for r in range(i, rows, n_threads):
for j in range(indptr[r], indptr[r + 1]):
minor_index = indices[j]
if minor_index >= minor_len:
continue
value = data[j]
sums_minor[i, minor_index] += value
squared_sums_minor[i, minor_index] += value * value
for c in numba.prange(minor_len):
sum_minor = sums_minor[:, c].sum()
means[c] = sum_minor / major_len
variances[c] = (
squared_sums_minor[:, c].sum() / major_len - (sum_minor / major_len) ** 2
)
return means, variances


@numba.njit(cache=True)
def sparse_mean_var_major_axis(data, indices, indptr, *, major_len, minor_len, dtype):
@numba.njit(cache=True, parallel=True)
def sparse_mean_var_major_axis(data, indptr, *, major_len, minor_len, n_threads):
"""
Computes mean and variance for a sparse array for the major axis.
Given arrays for a csr matrix, returns the means and variances for each
row back.
"""
means = np.zeros(major_len, dtype=dtype)
variances = np.zeros_like(means, dtype=dtype)

for i in range(major_len):
startptr = indptr[i]
endptr = indptr[i + 1]
counts = endptr - startptr

for j in range(startptr, endptr):
means[i] += data[j]
means[i] /= minor_len

for j in range(startptr, endptr):
diff = data[j] - means[i]
variances[i] += diff * diff

variances[i] += (minor_len - counts) * means[i] ** 2
variances[i] /= minor_len

rows = len(indptr) - 1
means = np.zeros(major_len)
variances = np.zeros_like(means)

for i in numba.prange(n_threads):
for r in range(i, rows, n_threads):
sum_major = 0.0
squared_sum_minor = 0.0
for j in range(indptr[r], indptr[r + 1]):
value = np.float64(data[j])
sum_major += value
squared_sum_minor += value * value
means[r] = sum_major
variances[r] = squared_sum_minor
for c in numba.prange(major_len):
mean = means[c] / minor_len
means[c] = mean
variances[c] = variances[c] / minor_len - mean * mean
return means, variances


Expand Down

0 comments on commit a70582e

Please sign in to comment.