Skip to content

Commit

Permalink
Speedup scrublet (scverse#3044)
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 authored May 14, 2024
1 parent c26480e commit fb79fed
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 31 deletions.
2 changes: 2 additions & 0 deletions docs/release-notes/1.10.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@

```{rubric} Performance
```

* `sparse_mean_variance_axis` now uses all cores for the calculations {pr}`3015` {smaller}`S Dicks`
* Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer`
11 changes: 6 additions & 5 deletions scanpy/preprocessing/_scrublet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,11 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
pp.normalize_total(ad_obs)

# HVG process needs log'd data.

logged = pp.log1p(ad_obs, copy=True)
pp.highly_variable_genes(logged)
ad_obs = ad_obs[:, logged.var["highly_variable"]].copy()
ad_obs.layers["log1p"] = ad_obs.X.copy()
pp.log1p(ad_obs, layer="log1p")
pp.highly_variable_genes(ad_obs, layer="log1p")
del ad_obs.layers["log1p"]
ad_obs = ad_obs[:, ad_obs.var["highly_variable"]].copy()

# Simulate the doublets based on the raw expressions from the normalised
# and filtered object.
Expand All @@ -214,7 +215,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling,
random_seed=random_state,
)

del ad_obs.layers["raw"]
if log_transform:
pp.log1p(ad_obs)
pp.log1p(ad_sim)
Expand Down
11 changes: 7 additions & 4 deletions scanpy/preprocessing/_scrublet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import numpy as np
from scipy import sparse

from .sparse_utils import sparse_multiply, sparse_var, sparse_zscore
from scanpy.preprocessing._utils import _get_mean_var

from .sparse_utils import sparse_multiply, sparse_zscore

if TYPE_CHECKING:
from ..._utils import AnyRandom
Expand All @@ -20,7 +22,8 @@ def mean_center(self: Scrublet) -> None:


def normalize_variance(self: Scrublet) -> None:
gene_stdevs = np.sqrt(sparse_var(self._counts_obs_norm, axis=0))
_, gene_vars = _get_mean_var(self._counts_obs_norm, axis=0)
gene_stdevs = np.sqrt(gene_vars)
self._counts_obs_norm = sparse_multiply(self._counts_obs_norm.T, 1 / gene_stdevs).T
if self._counts_sim_norm is not None:
self._counts_sim_norm = sparse_multiply(
Expand All @@ -29,8 +32,8 @@ def normalize_variance(self: Scrublet) -> None:


def zscore(self: Scrublet) -> None:
gene_means = self._counts_obs_norm.mean(0)
gene_stdevs = np.sqrt(sparse_var(self._counts_obs_norm, axis=0))
gene_means, gene_vars = _get_mean_var(self._counts_obs_norm, axis=0)
gene_stdevs = np.sqrt(gene_vars)
self._counts_obs_norm = sparse_zscore(
self._counts_obs_norm, gene_mean=gene_means, gene_stdev=gene_stdevs
)
Expand Down
30 changes: 8 additions & 22 deletions scanpy/preprocessing/_scrublet/sparse_utils.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,28 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

import numpy as np
from scipy import sparse

from scanpy.preprocessing._utils import _get_mean_var

from ..._utils import AnyRandom, get_random_state

if TYPE_CHECKING:
from numpy.typing import NDArray


def sparse_var(
E: sparse.csr_matrix | sparse.csc_matrix,
*,
axis: Literal[0, 1],
) -> NDArray[np.float64]:
"""variance across the specified axis"""

mean_gene: NDArray[np.float64] = E.mean(axis=axis).A.squeeze()
tmp: sparse.csc_matrix | sparse.csr_matrix = E.copy()
tmp.data **= 2
return tmp.mean(axis=axis).A.squeeze() - mean_gene**2


def sparse_multiply(
E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64],
a: float | int | NDArray[np.float64],
) -> sparse.csr_matrix | sparse.csc_matrix:
"""multiply each row of E by a scalar"""

nrow = E.shape[0]
w = sparse.lil_matrix((nrow, nrow))
w.setdiag(a)
w = sparse.dia_matrix((a, 0), shape=(nrow, nrow), dtype=a.dtype)
r = w @ E
if isinstance(r, (np.matrix, np.ndarray)):
if isinstance(r, np.ndarray):
return sparse.csc_matrix(r)
return r

Expand All @@ -46,11 +34,9 @@ def sparse_zscore(
gene_stdev: NDArray[np.float64] | None = None,
) -> sparse.csr_matrix | sparse.csc_matrix:
"""z-score normalize each column of E"""

if gene_mean is None:
gene_mean = E.mean(0)
if gene_stdev is None:
gene_stdev = np.sqrt(sparse_var(E, axis=0))
if gene_mean is None or gene_stdev is None:
gene_means, gene_stdevs = _get_mean_var(E, axis=0)
gene_stdevs = np.sqrt(gene_stdevs)
return sparse_multiply(np.asarray((E - gene_mean).T), 1 / gene_stdev).T


Expand Down

0 comments on commit fb79fed

Please sign in to comment.