From 4b757d85b015f028e7a283f509cbb77d82476754 Mon Sep 17 00:00:00 2001 From: Philipp A Date: Fri, 22 Mar 2024 17:22:57 +0100 Subject: [PATCH] (feat): pre-processing functions for `dask` with sparse chunks (#2856) * (chore): add dask sparse chunks creation * (feat): add dask summation * (refactor): `materialize_as_ndarray` needs to operate on indidiual dask arrays * (feat): `filter_genes` and `filter_cells` * (feat): normalization * (fix) `lop1p` tests working * (refactor): clean up writing test * (refactor): us `da.chunk.sum` * (fix): remove `Client` * (refactor): remove unnecessary `count_nonzero` * (fix): change expected fail on sparse normalization * (fix): update comment * (feat): `_get_mean_var` dask * (feat): clean up tests for what should/should not work * (refactor): `_compat.sum` to `_utils.elem_sum` * (chore): add `elem_sum` test * (refactor): `elem_sum` -> `axis_sum` * (feat): add `scale` support * (fix): maintain dtype * (chore): add back condition * (fix): use `sum` when needed * (chore): release notes * (fx): don't use `mean_func` name twice * (chore): revert sparse-chunks-in-dask * (chore): type hint * (chore): check `test_compare_to_upstream` * (chore): remove comment * (chore): allow passing `dtype` arg in `axis_sum` * (fix): revert fixture changes * (refactor): cleaner with `array_type` conversion before if-then * (chore): clarify hvg support * (chore): handle array types better * (chore): clean up `materialize_as_ndarray` * (chore): fix typing/dispatch problem in 3.9 * (chore): `list` type -> `Callable` * (feat): `row_divide` for better division handling * (fix): use `tuple` for `ARRAY_TYPEXXXX` * (refactor): `mean_func` -> `axis_mean` + types * (chore): remove unnecessary aggregation * (fix): raise `ValueError` for summing over more than one axis * (fix): grammar * (fix): better type hints * (revert): use old `test_normalize_total` siince we have `csr` * (revert): extraneous diff * (fix): try `Union` * (chore): add column division ability * (chore): add scale test * (fix): duplicate in release note * (refactor): guard clause + comments * (chore): add `out` check for `dask` * (chore): add `divisor` type hints * (fix): remove some erroneous diffs * (chore): `axis_{sum,mean}` type hint fixes * (refactor): generalize to scaling * (chore): remove erroneous comment * (chore): remove non-public API * (fix): import from `sc._utils` * (fix): `inidices` -> `indices` * (fix): remove erroneous `axis_sum` calls * (fix): return statements for `axis_scale` * (refactor): return out of `axis_sum` if `X._meta` is `np.ndarray` * (core): comment fix * (fix): use `normalize_total` in HVG test for true reproducibility * (refactor): separate out `out` test for dask * (fix): correct chunking/rechunking behavior * (chore): add guard clause for `sparse` `out != X != None` in scaling * (fix): guard clause condition * (fix): try finishing `|` typing for 3.9 * (fix): call `register` to allow unions? * (fix): clarify warning * (feat): test for `max_value`/`zero_center` combos * (fix): allow settings of `X` in `scale_array` * (chore): add tests for `normalize` correctness * (fix): refactor for pure dask in `median` * (refactor): add clarifying condition * (chore): skip warning computations + tests * (fix): actually skip computation in `normalize_total` condition * (fix): actually skip in `filter_genes` + tests * (fix): use all-in-one median implemetation * (refactor): remove erreous dask warnings * (chore): add note about `exclude_highly_expressed` * (feat): `axis_scale` -> `axis_mul_or_truediv` * (feat): `allow_divide_by_zero` * (chore): add notes + type hints * Have hvg compute earlier and only once * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (refactor): make codecov better by removing dead code/refactoring * (fix): `np.clip` in dask does not take min/max as `kwargs` * Update docs/release-notes/1.11.0.md Co-authored-by: Isaac Virshup * (chore): move release note * (chore): remove erroneous comment --------- Co-authored-by: ilan-gold Co-authored-by: Isaac Virshup Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/release-notes/1.10.0.md | 1 + scanpy/_utils/__init__.py | 223 +++++++++++++++++- scanpy/preprocessing/_distributed.py | 4 +- .../preprocessing/_highly_variable_genes.py | 6 +- scanpy/preprocessing/_normalization.py | 66 ++++-- scanpy/preprocessing/_simple.py | 77 ++++-- scanpy/preprocessing/_utils.py | 22 +- scanpy/testing/_pytest/params.py | 16 +- scanpy/tests/test_highly_variable_genes.py | 24 +- scanpy/tests/test_normalization.py | 49 +++- scanpy/tests/test_pca.py | 8 +- scanpy/tests/test_preprocessing.py | 121 +++++++++- scanpy/tests/test_utils.py | 129 +++++++++- 13 files changed, 651 insertions(+), 95 deletions(-) diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index fab39ca94f..52d6be2d9c 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -27,6 +27,7 @@ * `scanpy.pp.calculate_qc_metrics` now allows `qc_vars` to be passed as a string {pr}`2859` {smaller}`N Teyssier` * {func}`scanpy.tl.leiden` and {func}`scanpy.tl.louvain` now store clustering parameters in the key provided by the `key_added` parameter instead of always writing to (or overwriting) a default key {pr}`2864` {smaller}`J Fan` * {func}`scanpy.pp.scale` now clips `np.ndarray` also at `- max_value` for zero-centering {pr}`2913` {smaller}`S Dicks` +* Support sparse chunks in dask {func}`~scanpy.pp.scale`, {func}`~scanpy.pp.normalize_total` and {func}`~scanpy.pp.highly_variable_genes` (`seurat` and `cell-ranger` tested) {pr}`2856` {smaller}`ilan-gold` ```{rubric} Docs ``` diff --git a/scanpy/_utils/__init__.py b/scanpy/_utils/__init__.py index a2254ad0cd..5acbc3482f 100644 --- a/scanpy/_utils/__init__.py +++ b/scanpy/_utils/__init__.py @@ -15,9 +15,18 @@ from contextlib import contextmanager from enum import Enum from functools import partial, singledispatch, wraps +from operator import mul, truediv from textwrap import dedent from types import MethodType, ModuleType -from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + TypeVar, + Union, + overload, +) from weakref import WeakSet import numpy as np @@ -542,6 +551,218 @@ def _elem_mul_dask(x: DaskArray, y: DaskArray) -> DaskArray: return da.map_blocks(elem_mul, x, y) +Scaling_T = TypeVar("Scaling_T", DaskArray, np.ndarray) + + +def broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T: + divisor = np.ravel(divisor) + if axis: + return divisor[None, :] + return divisor[:, None] + + +def check_op(op): + if op not in {truediv, mul}: + raise ValueError(f"{op} not one of truediv or mul") + + +@singledispatch +def axis_mul_or_truediv( + X: sparse.spmatrix, + scaling_array, + axis: Literal[0, 1], + op: Callable[[Any, Any], Any], + *, + allow_divide_by_zero: bool = True, + out: sparse.spmatrix | None = None, +) -> sparse.spmatrix: + check_op(op) + if out is not None: + if X.data is not out.data: + raise ValueError( + "`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling." + ) + if not allow_divide_by_zero and op is truediv: + scaling_array = scaling_array.copy() + (scaling_array == 0) + + row_scale = axis == 0 + column_scale = axis == 1 + if row_scale: + + def new_data_op(x): + return op(x.data, np.repeat(scaling_array, np.diff(x.indptr))) + + elif column_scale: + + def new_data_op(x): + return op(x.data, scaling_array.take(x.indices, mode="clip")) + + if X.format == "csr": + indices = X.indices + indptr = X.indptr + if out is not None: + X.data = new_data_op(X) + return X + return sparse.csr_matrix( + (new_data_op(X), indices.copy(), indptr.copy()), shape=X.shape + ) + transposed = X.T + return axis_mul_or_truediv( + transposed, + scaling_array, + op=op, + axis=1 - axis, + out=transposed, + allow_divide_by_zero=allow_divide_by_zero, + ).T + + +@axis_mul_or_truediv.register(np.ndarray) +def _( + X: np.ndarray, + scaling_array: np.ndarray, + axis: Literal[0, 1], + op: Callable[[Any, Any], Any], + *, + allow_divide_by_zero: bool = True, + out: np.ndarray | None = None, +) -> np.ndarray: + check_op(op) + scaling_array = broadcast_axis(scaling_array, axis) + if op is mul: + return np.multiply(X, scaling_array, out=out) + if not allow_divide_by_zero: + scaling_array = scaling_array.copy() + (scaling_array == 0) + return np.true_divide(X, scaling_array, out=out) + + +def make_axis_chunks( + X: DaskArray, axis: Literal[0, 1], pad=True +) -> tuple[tuple[int], tuple[int]]: + if axis == 0: + return (X.chunks[axis], (1,)) + return ((1,), X.chunks[axis]) + + +@axis_mul_or_truediv.register(DaskArray) +def _( + X: DaskArray, + scaling_array: Scaling_T, + axis: Literal[0, 1], + op: Callable[[Any, Any], Any], + *, + allow_divide_by_zero: bool = True, + out: None = None, +) -> DaskArray: + check_op(op) + if out is not None: + raise TypeError( + "`out` is not `None`. Do not do in-place modifications on dask arrays." + ) + + import dask.array as da + + scaling_array = broadcast_axis(scaling_array, axis) + row_scale = axis == 0 + column_scale = axis == 1 + + if isinstance(scaling_array, DaskArray): + if (row_scale and not X.chunksize[0] == scaling_array.chunksize[0]) or ( + column_scale + and ( + ( + len(scaling_array.chunksize) == 1 + and X.chunksize[1] != scaling_array.chunksize[0] + ) + or ( + len(scaling_array.chunksize) == 2 + and X.chunksize[1] != scaling_array.chunksize[1] + ) + ) + ): + warnings.warn("Rechunking scaling_array in user operation", UserWarning) + scaling_array = scaling_array.rechunk(make_axis_chunks(X, axis)) + else: + scaling_array = da.from_array( + scaling_array, + chunks=make_axis_chunks(X, axis), + ) + return da.map_blocks( + axis_mul_or_truediv, + X, + scaling_array, + axis, + op, + meta=X._meta, + out=out, + allow_divide_by_zero=allow_divide_by_zero, + ) + + +@overload +def axis_sum( + X: sparse.spmatrix, + *, + axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, + dtype: np.typing.DTypeLike | None = None, +) -> np.matrix: ... + + +@singledispatch +def axis_sum( + X: np.ndarray, + *, + axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, + dtype: np.typing.DTypeLike | None = None, +) -> np.ndarray: + return np.sum(X, axis=axis, dtype=dtype) + + +@axis_sum.register(DaskArray) +def _( + X: DaskArray, + *, + axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, + dtype: np.typing.DTypeLike | None = None, +) -> DaskArray: + import dask.array as da + + if dtype is None: + dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object) + + if isinstance(X._meta, np.ndarray) and not isinstance(X._meta, np.matrix): + return X.sum(axis=axis, dtype=dtype) + + def sum_drop_keepdims(*args, **kwargs): + kwargs.pop("computing_meta", None) + # masked operations on sparse produce which numpy matrices gives the same API issues handled here + if isinstance(X._meta, (sparse.spmatrix, np.matrix)) or isinstance( + args[0], (sparse.spmatrix, np.matrix) + ): + kwargs.pop("keepdims", None) + axis = kwargs["axis"] + if isinstance(axis, tuple): + if len(axis) != 1: + raise ValueError( + f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead" + ) + kwargs["axis"] = axis[0] + # returns a np.matrix normally, which is undesireable + return np.array(np.sum(*args, dtype=dtype, **kwargs)) + + def aggregate_sum(*args, **kwargs): + return np.sum(args[0], dtype=dtype, **kwargs) + + return da.reduction( + X, + sum_drop_keepdims, + aggregate_sum, + axis=axis, + dtype=dtype, + meta=np.array([], dtype=dtype), + ) + + @singledispatch def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray: """Checks values of X to ensure it is count data""" diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index 195550de17..c653ffadc2 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -31,9 +31,11 @@ def materialize_as_ndarray( def materialize_as_ndarray( - a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...], + a: DaskArray | ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...], ) -> tuple[np.ndarray] | np.ndarray: """Compute distributed arrays and convert them to numpy ndarrays.""" + if isinstance(a, DaskArray): + return a.compute() if not isinstance(a, tuple): return np.asarray(a) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index dad4d06160..08470e3c76 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -267,7 +267,7 @@ def _highly_variable_genes_single_batch( else: X = np.expm1(X) - mean, var = _get_mean_var(X) + mean, var = materialize_as_ndarray(_get_mean_var(X)) # now actually compute the dispersion mean[mean == 0] = 1e-12 # set entries equal to zero to small value dispersion = var / mean @@ -277,9 +277,7 @@ def _highly_variable_genes_single_batch( mean = np.log1p(mean) # all of the following quantities are "per-gene" here - df = pd.DataFrame( - dict(zip(["means", "dispersions"], materialize_as_ndarray((mean, dispersion)))) - ) + df = pd.DataFrame(dict(zip(["means", "dispersions"], (mean, dispersion)))) df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins) disp_stats = _get_disp_stats(df, flavor) diff --git a/scanpy/preprocessing/_normalization.py b/scanpy/preprocessing/_normalization.py index ea50686e60..3dd4e0b2a6 100644 --- a/scanpy/preprocessing/_normalization.py +++ b/scanpy/preprocessing/_normalization.py @@ -1,17 +1,24 @@ from __future__ import annotations +from operator import truediv from typing import TYPE_CHECKING, Literal from warnings import warn import numpy as np from scipy.sparse import issparse -from sklearn.utils import sparsefuncs from .. import logging as logg from .._compat import DaskArray, old_positionals -from .._utils import view_to_actual +from .._utils import axis_mul_or_truediv, axis_sum, view_to_actual from ..get import _get_obs_rep, _set_obs_rep +try: + import dask + import dask.array as da +except ImportError: + da = None + dask = None + if TYPE_CHECKING: from collections.abc import Iterable @@ -22,21 +29,30 @@ def _normalize_data(X, counts, after=None, copy: bool = False): X = X.copy() if copy else X if issubclass(X.dtype.type, (int, np.integer)): X = X.astype(np.float32) # TODO: Check if float64 should be used - if isinstance(counts, DaskArray): - counts_greater_than_zero = counts[counts > 0].compute_chunk_sizes() - else: - counts_greater_than_zero = counts[counts > 0] + if after is None: + if isinstance(counts, DaskArray): + + def nonzero_median(x): + return np.ma.median(np.ma.masked_array(x, x == 0)).item() - after = np.median(counts_greater_than_zero, axis=0) if after is None else after - counts += counts == 0 + after = da.from_delayed( + dask.delayed(nonzero_median)(counts), + shape=(), + meta=counts._meta, + dtype=counts.dtype, + ) + else: + counts_greater_than_zero = counts[counts > 0] + after = np.median(counts_greater_than_zero, axis=0) counts = counts / after - if issparse(X): - sparsefuncs.inplace_row_scale(X, 1 / counts) - elif isinstance(counts, np.ndarray): - np.divide(X, counts[:, None], out=X) - else: - X = np.divide(X, counts[:, None]) # dask does not support kwarg "out" - return X + return axis_mul_or_truediv( + X, + counts, + op=truediv, + out=X if isinstance(X, np.ndarray) or issparse(X) else None, + allow_divide_by_zero=False, + axis=0, + ) @old_positionals( @@ -78,6 +94,11 @@ def normalize_total( Similar functions are used, for example, by Seurat [Satija15]_, Cell Ranger [Zheng17]_ or SPRING [Weinreb17]_. + .. note:: + When used with a :class:`~dask.array.Array` in `adata.X`, this function will have to + call functions that trigger `.compute()` on the :class:`~dask.array.Array` if `exclude_highly_expressed` + is `True`, `layer_norm` is not `None`, or if `key_added` is not `None`. + Params ------ adata @@ -92,7 +113,8 @@ def normalize_total( normalization factor (size factor) for each cell. A gene is considered highly expressed, if it has more than `max_fraction` of the total counts in at least one cell. The not-excluded genes will sum up to - `target_sum`. + `target_sum`. Providing this argument when `adata.X` is a :class:`~dask.array.Array` + will incur blocking `.compute()` calls on the array. max_fraction If `exclude_highly_expressed=True`, consider cells as highly expressed that have more counts than `max_fraction` of the original total counts @@ -187,27 +209,27 @@ def normalize_total( gene_subset = None msg = "normalizing counts per cell" + + counts_per_cell = axis_sum(X, axis=1) if exclude_highly_expressed: - counts_per_cell = X.sum(1) # original counts per cell counts_per_cell = np.ravel(counts_per_cell) # at least one cell as more than max_fraction of counts per cell - gene_subset = (X > counts_per_cell[:, None] * max_fraction).sum(0) + gene_subset = axis_sum((X > counts_per_cell[:, None] * max_fraction), axis=0) gene_subset = np.asarray(np.ravel(gene_subset) == 0) msg += ( ". The following highly-expressed genes are not considered during " f"normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}" ) - counts_per_cell = X[:, gene_subset].sum(1) - else: - counts_per_cell = X.sum(1) + counts_per_cell = axis_sum(X[:, gene_subset], axis=1) + start = logg.info(msg) counts_per_cell = np.ravel(counts_per_cell) cell_subset = counts_per_cell > 0 - if not np.all(cell_subset): + if not isinstance(cell_subset, DaskArray) and not np.all(cell_subset): warn(UserWarning("Some cells have zero counts")) if inplace: diff --git a/scanpy/preprocessing/_simple.py b/scanpy/preprocessing/_simple.py index 2ab6f31370..a8099f52e7 100644 --- a/scanpy/preprocessing/_simple.py +++ b/scanpy/preprocessing/_simple.py @@ -7,6 +7,7 @@ import warnings from functools import singledispatch +from operator import truediv from typing import TYPE_CHECKING, Literal import numba @@ -18,11 +19,13 @@ from sklearn.utils import check_array, sparsefuncs from .. import logging as logg -from .._compat import old_positionals +from .._compat import DaskArray, old_positionals from .._settings import settings as sett from .._utils import ( AnyRandom, _check_array_function_arguments, + axis_mul_or_truediv, + axis_sum, renamed_arg, sanitize_anndata, view_to_actual, @@ -51,7 +54,7 @@ "min_counts", "min_genes", "max_counts", "max_genes", "inplace", "copy" ) def filter_cells( - data: AnnData | spmatrix | np.ndarray, + data: AnnData | spmatrix | np.ndarray | DaskArray, *, min_counts: int | None = None, min_genes: int | None = None, @@ -163,7 +166,7 @@ def filter_cells( X = data # proceed with processing the data matrix min_number = min_counts if min_genes is None else min_genes max_number = max_counts if max_genes is None else max_genes - number_per_cell = np.sum( + number_per_cell = axis_sum( X if min_genes is None and max_genes is None else X > 0, axis=1 ) if issparse(X): @@ -173,7 +176,7 @@ def filter_cells( if max_number is not None: cell_subset = number_per_cell <= max_number - s = materialize_as_ndarray(np.sum(~cell_subset)) + s = axis_sum(~cell_subset) if s > 0: msg = f"filtered out {s} cells that have " if min_genes is not None or min_counts is not None: @@ -198,7 +201,7 @@ def filter_cells( "min_counts", "min_cells", "max_counts", "max_cells", "inplace", "copy" ) def filter_genes( - data: AnnData | spmatrix | np.ndarray, + data: AnnData | spmatrix | np.ndarray | DaskArray, *, min_counts: int | None = None, min_cells: int | None = None, @@ -279,7 +282,7 @@ def filter_genes( X = data # proceed with processing the data matrix min_number = min_counts if min_cells is None else min_cells max_number = max_counts if max_cells is None else max_cells - number_per_gene = np.sum( + number_per_gene = axis_sum( X if min_cells is None and max_cells is None else X > 0, axis=0 ) if issparse(X): @@ -289,7 +292,7 @@ def filter_genes( if max_number is not None: gene_subset = number_per_gene <= max_number - s = np.sum(~gene_subset) + s = axis_sum(~gene_subset) if s > 0: msg = f"filtered out {s} genes that are detected " if min_cells is not None or min_counts is not None: @@ -750,7 +753,7 @@ def _regress_out_chunk(data): @old_positionals("zero_center", "max_value", "copy", "layer", "obsm") @singledispatch def scale( - data: AnnData | spmatrix | np.ndarray, + data: AnnData | spmatrix | np.ndarray | DaskArray, *, zero_center: bool = True, max_value: float | None = None, @@ -758,7 +761,7 @@ def scale( layer: str | None = None, obsm: str | None = None, mask_obs: NDArray[np.bool_] | str | None = None, -) -> AnnData | spmatrix | np.ndarray | None: +) -> AnnData | spmatrix | np.ndarray | DaskArray | None: """\ Scale data to unit variance and zero mean. @@ -817,15 +820,23 @@ def scale( @scale.register(np.ndarray) +@scale.register(DaskArray) def scale_array( - X: np.ndarray, + X: np.ndarray | DaskArray, *, zero_center: bool = True, max_value: float | None = None, copy: bool = False, return_mean_std: bool = False, mask_obs: NDArray[np.bool_] | None = None, -) -> np.ndarray | tuple[np.ndarray, NDArray[np.float64], NDArray[np.float64]]: +) -> ( + np.ndarray + | DaskArray + | tuple[ + np.ndarray | DaskArray, NDArray[np.float64] | DaskArray, NDArray[np.float64] + ] + | DaskArray +): if copy: X = X.copy() if mask_obs is not None: @@ -860,22 +871,40 @@ def scale_array( mean, var = _get_mean_var(X) std = np.sqrt(var) std[std == 0] = 1 - if issparse(X): - if zero_center: - raise ValueError("Cannot zero-center sparse matrix.") - sparsefuncs.inplace_column_scale(X, 1 / std) - else: - if zero_center: - X -= mean - X /= std + if zero_center: + if isinstance(X, DaskArray) and issparse(X._meta): + warnings.warn( + "zero-center being used with `DaskArray` sparse chunks. This can be bad if you have large chunks or intend to eventually read the whole data into memory.", + UserWarning, + ) + X -= mean + X = axis_mul_or_truediv( + X, + std, + op=truediv, + out=X if isinstance(X, np.ndarray) or issparse(X) else None, + axis=1, + ) # do the clipping if max_value is not None: logg.debug(f"... clipping at max_value {max_value}") - if zero_center: - X = np.clip(X, a_min=-max_value, a_max=max_value) + if isinstance(X, DaskArray) and issparse(X._meta): + + def clip_set(x): + x = x.copy() + x[x > max_value] = max_value + if zero_center: + x[x < -max_value] = -max_value + return x + + X = da.map_blocks(clip_set, X) else: - X[X > max_value] = max_value + if zero_center: + a_min, a_max = -max_value, max_value + X = np.clip(X, a_min, a_max) # dask does not accept these as kwargs + else: + X[X > max_value] = max_value if return_mean_std: return X, mean, std else: @@ -1084,7 +1113,7 @@ def _downsample_per_cell(X, counts_per_cell, random_state, replace): original_type = type(X) if not isspmatrix_csr(X): X = csr_matrix(X) - totals = np.ravel(X.sum(axis=1)) # Faster for csr matrix + totals = np.ravel(axis_sum(X, axis=1)) # Faster for csr matrix under_target = np.nonzero(totals > counts_per_cell)[0] rows = np.split(X.data, X.indptr[1:-1]) for rowidx in under_target: @@ -1100,7 +1129,7 @@ def _downsample_per_cell(X, counts_per_cell, random_state, replace): if original_type is not csr_matrix: # Put it back X = original_type(X) else: - totals = np.ravel(X.sum(axis=1)) + totals = np.ravel(axis_sum(X, axis=1)) under_target = np.nonzero(totals > counts_per_cell)[0] for rowidx in under_target: row = X[rowidx, :] diff --git a/scanpy/preprocessing/_utils.py b/scanpy/preprocessing/_utils.py index b9f6eb131c..4220503e6e 100644 --- a/scanpy/preprocessing/_utils.py +++ b/scanpy/preprocessing/_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import singledispatch from typing import TYPE_CHECKING, Literal import numba @@ -7,11 +8,26 @@ from scipy import sparse from sklearn.random_projection import sample_without_replacement -from .._utils import AnyRandom, _SupportedArray, elem_mul +from .._utils import AnyRandom, _SupportedArray, axis_sum, elem_mul if TYPE_CHECKING: from numpy.typing import NDArray + from .._compat import DaskArray + + +@singledispatch +def axis_mean( + X: DaskArray, *, axis: Literal[0, 1], dtype: np.typing.DTypeLike +) -> DaskArray: + total = axis_sum(X, axis=axis, dtype=dtype) + return total / X.shape[axis] + + +@axis_mean.register(np.ndarray) +def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: np.typing.DTypeLike) -> np.ndarray: + return X.mean(axis=axis, dtype=dtype) + def _get_mean_var( X: _SupportedArray, *, axis: Literal[0, 1] = 0 @@ -19,8 +35,8 @@ def _get_mean_var( if isinstance(X, sparse.spmatrix): mean, var = sparse_mean_variance_axis(X, axis=axis) else: - mean = X.mean(axis=axis, dtype=np.float64) - mean_sq = elem_mul(X, X).mean(axis=axis, dtype=np.float64) + mean = axis_mean(X, axis=axis, dtype=np.float64) + mean_sq = axis_mean(elem_mul(X, X), axis=axis, dtype=np.float64) var = mean_sq - mean**2 # enforce R convention (unbiased estimator) for variance var *= X.shape[axis] / (X.shape[axis] - 1) diff --git a/scanpy/testing/_pytest/params.py b/scanpy/testing/_pytest/params.py index 655daf6c7a..6acd44c8a0 100644 --- a/scanpy/testing/_pytest/params.py +++ b/scanpy/testing/_pytest/params.py @@ -8,7 +8,10 @@ from anndata.tests.helpers import asarray from scipy import sparse -from .._helpers import as_dense_dask_array, as_sparse_dask_array +from .._helpers import ( + as_dense_dask_array, + as_sparse_dask_array, +) from .._pytest.marks import needs if TYPE_CHECKING: @@ -63,21 +66,16 @@ def param_with( at for (_, spsty), ats in MAP_ARRAY_TYPES.items() if spsty == "dense" for at in ats ) ARRAY_TYPES_SPARSE = tuple( - at for (_, spsty), ats in MAP_ARRAY_TYPES.items() if spsty == "dense" for at in ats + at for (_, spsty), ats in MAP_ARRAY_TYPES.items() if "sparse" in spsty for at in ats ) - -ARRAY_TYPES_SUPPORTED = tuple( +ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED = tuple( ( param_with(at, marks=[pytest.mark.xfail(reason="sparse-in-dask not supported")]) - if attrs == ("dask", "sparse") + if attrs[0] == "dask" and "sparse" in attrs[1] else at ) for attrs, ats in MAP_ARRAY_TYPES.items() for at in ats ) -""" -Sparse matrices in dask arrays aren’t officially supported upstream, -so add xfail to them. -""" ARRAY_TYPES = tuple(at for ats in MAP_ARRAY_TYPES.values() for at in ats) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 27b237ad5b..9ef3ab5945 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -2,7 +2,7 @@ from pathlib import Path from string import ascii_letters -from typing import Literal +from typing import Callable, Literal import numpy as np import pandas as pd @@ -15,7 +15,7 @@ from scanpy.testing._helpers import _check_check_values_warnings from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.marks import needs -from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED +from scanpy.testing._pytest.params import ARRAY_TYPES FILE = Path(__file__).parent / Path("_scripts/seurat_hvg.csv") FILE_V3 = Path(__file__).parent / Path("_scripts/seurat_hvg_v3.csv.gz") @@ -85,7 +85,7 @@ def test_no_batch_matches_batch(adata): @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_no_inplace(adata, array_type, batch_key): """Tests that, with `n_top_genes=None` the returned dataframe has the expected columns.""" adata.X = array_type(adata.X) @@ -338,12 +338,14 @@ def test_pearson_residuals_batch(pbmc3k_parametrized_small, subset, n_top_genes) ), ], ) -def test_compare_to_upstream( +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_compare_to_upstream( # noqa: PLR0917 request: pytest.FixtureRequest, func: Literal["hvg", "fgd"], flavor: Literal["seurat", "cell_ranger"], params: dict[str, float | int], ref_path: Path, + array_type: Callable, ): if func == "fgd" and flavor == "cell_ranger": msg = "The deprecated filter_genes_dispersion behaves differently with cell_ranger" @@ -352,9 +354,11 @@ def test_compare_to_upstream( pbmc = pbmc68k_reduced() pbmc.X = pbmc.raw.X + pbmc.X = array_type(pbmc.X) pbmc.var_names_make_unique() + sc.pp.filter_cells(pbmc, min_counts=1) + sc.pp.normalize_total(pbmc, target_sum=1e4) - sc.pp.normalize_per_cell(pbmc, counts_per_cell_after=1e4) if func == "hvg": sc.pp.log1p(pbmc) sc.pp.highly_variable_genes(pbmc, flavor=flavor, **params, inplace=True) @@ -386,8 +390,8 @@ def test_compare_to_upstream( np.testing.assert_allclose( hvg_info["dispersions_norm"], pbmc.var["dispersions_norm"], - rtol=2e-05, - atol=2e-05, + rtol=2e-05 if "dask" not in array_type.__name__ else 1e-4, + atol=2e-05 if "dask" not in array_type.__name__ else 1e-4, ) @@ -568,7 +572,7 @@ def test_cutoff_info(): @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"]) def test_subset_inplace_consistency(flavor, array_type, subset, inplace): @@ -609,7 +613,7 @@ def test_subset_inplace_consistency(flavor, array_type, subset, inplace): @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize( - "to_dask", [p for p in ARRAY_TYPES_SUPPORTED if "dask" in p.values[0].__name__] + "to_dask", [p for p in ARRAY_TYPES if "dask" in p.values[0].__name__] ) def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): adata.X = np.abs(adata.X).astype(int) @@ -632,4 +636,4 @@ def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): assert_index_equal(adata.var_names, output_mem.index, check_names=False) assert_index_equal(adata.var_names, output_dask.index, check_names=False) - assert_frame_equal(output_mem, output_dask) + assert_frame_equal(output_mem, output_dask, atol=1e-4) diff --git a/scanpy/tests/test_normalization.py b/scanpy/tests/test_normalization.py index ea7556db74..3582b51fed 100644 --- a/scanpy/tests/test_normalization.py +++ b/scanpy/tests/test_normalization.py @@ -8,8 +8,10 @@ from anndata.tests.helpers import assert_equal from scipy import sparse from scipy.sparse import csr_matrix +from sklearn.utils import issparse import scanpy as sc +from scanpy._utils import axis_sum from scanpy.testing._helpers import ( _check_check_values_warnings, check_rep_mutation, @@ -17,7 +19,7 @@ ) # TODO: Add support for sparse-in-dask -from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED +from scanpy.testing._pytest.params import ARRAY_TYPES if TYPE_CHECKING: from collections.abc import Callable @@ -26,21 +28,50 @@ X_frac = np.array([[1, 0, 1], [3, 0, 1], [5, 6, 1]]) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("dtype", ["float32", "int64"]) +@pytest.mark.parametrize("target_sum", [None, 1.0]) +@pytest.mark.parametrize("exclude_highly_expressed", [True, False]) +def test_normalize_matrix_types( + array_type, dtype, target_sum, exclude_highly_expressed +): + adata = sc.datasets.pbmc68k_reduced() + adata.X = (adata.raw.X).astype(dtype) + adata_casted = adata.copy() + adata_casted.X = array_type(adata_casted.raw.X).astype(dtype) + sc.pp.normalize_total( + adata, target_sum=target_sum, exclude_highly_expressed=exclude_highly_expressed + ) + sc.pp.normalize_total( + adata_casted, + target_sum=target_sum, + exclude_highly_expressed=exclude_highly_expressed, + ) + X = adata_casted.X + if "dask" in array_type.__name__: + X = X.compute() + if issparse(X): + X = X.todense() + if issparse(adata.X): + adata.X = adata.X.todense() + np.testing.assert_allclose(X, adata.X, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("dtype", ["float32", "int64"]) def test_normalize_total(array_type, dtype): adata = AnnData(array_type(X_total).astype(dtype)) sc.pp.normalize_total(adata, key_added="n_counts") - assert np.allclose(np.ravel(adata.X.sum(axis=1)), [3.0, 3.0, 3.0]) + assert np.allclose(np.ravel(axis_sum(adata.X, axis=1)), [3.0, 3.0, 3.0]) sc.pp.normalize_total(adata, target_sum=1, key_added="n_counts2") - assert np.allclose(np.ravel(adata.X.sum(axis=1)), [1.0, 1.0, 1.0]) + assert np.allclose(np.ravel(axis_sum(adata.X, axis=1)), [1.0, 1.0, 1.0]) adata = AnnData(array_type(X_frac).astype(dtype)) sc.pp.normalize_total(adata, exclude_highly_expressed=True, max_fraction=0.7) - assert np.allclose(np.ravel(adata.X[:, 1:3].sum(axis=1)), [1.0, 1.0, 1.0]) + assert np.allclose(np.ravel(axis_sum(adata.X[:, 1:3], axis=1)), [1.0, 1.0, 1.0]) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("dtype", ["float32", "int64"]) def test_normalize_total_rep(array_type, dtype): # Test that layer kwarg works @@ -49,17 +80,17 @@ def test_normalize_total_rep(array_type, dtype): check_rep_results(sc.pp.normalize_total, X, fields=["layer"]) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("dtype", ["float32", "int64"]) def test_normalize_total_layers(array_type, dtype): adata = AnnData(array_type(X_total).astype(dtype)) adata.layers["layer"] = adata.X.copy() with pytest.warns(FutureWarning, match=r".*layers.*deprecated"): sc.pp.normalize_total(adata, layers=["layer"]) - assert np.allclose(adata.layers["layer"].sum(axis=1), [3.0, 3.0, 3.0]) + assert np.allclose(axis_sum(adata.layers["layer"], axis=1), [3.0, 3.0, 3.0]) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("dtype", ["float32", "int64"]) def test_normalize_total_view(array_type, dtype): adata = AnnData(array_type(X_total).astype(dtype)) diff --git a/scanpy/tests/test_pca.py b/scanpy/tests/test_pca.py index c2de64c0b9..2ffc490db2 100644 --- a/scanpy/tests/test_pca.py +++ b/scanpy/tests/test_pca.py @@ -19,7 +19,11 @@ from scanpy.testing._helpers import as_dense_dask_array, as_sparse_dask_array from scanpy.testing._helpers.data import pbmc3k_normalized from scanpy.testing._pytest.marks import needs -from scanpy.testing._pytest.params import ARRAY_TYPES, ARRAY_TYPES_SUPPORTED, param_with +from scanpy.testing._pytest.params import ( + ARRAY_TYPES, + ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED, + param_with, +) A_list = np.array( [ @@ -59,7 +63,7 @@ @pytest.fixture( params=[ param_with(at, marks=[needs.dask_ml]) if "dask" in at.id else at - for at in ARRAY_TYPES_SUPPORTED + for at in ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED ] ) def array_type(request: pytest.FixtureRequest): diff --git a/scanpy/tests/test_preprocessing.py b/scanpy/tests/test_preprocessing.py index 5ae3b3e08f..d043ac490c 100644 --- a/scanpy/tests/test_preprocessing.py +++ b/scanpy/tests/test_preprocessing.py @@ -9,6 +9,7 @@ from anndata.tests.helpers import asarray, assert_equal from numpy.testing import assert_allclose from scipy import sparse as sp +from sklearn.utils import issparse import scanpy as sc from scanpy.testing._helpers import ( @@ -17,7 +18,7 @@ check_rep_results, ) from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced -from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED +from scanpy.testing._pytest.params import ARRAY_TYPES def test_log1p(tmp_path): @@ -59,8 +60,7 @@ def test_log1p_rep(count_matrix_format, base, dtype): check_rep_results(sc.pp.log1p, X, base=base) -# TODO: Add support for sparse-in-dask -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_mean_var(array_type): pbmc = pbmc3k() pbmc.X = array_type(pbmc.X) @@ -159,6 +159,43 @@ def test_subsample_copy_backed(tmp_path): sc.pp.subsample(adata_d, n_obs=40, copy=False) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("zero_center", [True, False]) +@pytest.mark.parametrize("max_value", [None, 1.0]) +def test_scale_matrix_types(array_type, zero_center, max_value): + adata = pbmc68k_reduced() + adata.X = adata.raw.X + adata_casted = adata.copy() + adata_casted.X = array_type(adata_casted.raw.X) + sc.pp.scale(adata, zero_center=zero_center, max_value=max_value) + sc.pp.scale(adata_casted, zero_center=zero_center, max_value=max_value) + X = adata_casted.X + if "dask" in array_type.__name__: + X = X.compute() + if issparse(X): + X = X.todense() + if issparse(adata.X): + adata.X = adata.X.todense() + assert_allclose(X, adata.X, rtol=1e-5, atol=1e-5) + + +ARRAY_TYPES_DASK_SPARSE = [ + a for a in ARRAY_TYPES if "sparse" in a.id and "dask" in a.id +] + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES_DASK_SPARSE) +def test_scale_zero_center_warns_dask_sparse(array_type): + adata = pbmc68k_reduced() + adata.X = adata.raw.X + adata_casted = adata.copy() + adata_casted.X = array_type(adata_casted.raw.X) + with pytest.warns(UserWarning, match="zero-center being used with `DaskArray`*"): + sc.pp.scale(adata_casted) + sc.pp.scale(adata) + assert_allclose(adata_casted.X, adata.X, rtol=1e-5, atol=1e-5) + + def test_scale(): adata = pbmc68k_reduced() adata.X = adata.raw.X @@ -407,3 +444,81 @@ def test_recipe_weinreb(): orig = adata.copy() sc.pp.recipe_weinreb17(adata, log=False, copy=True) assert_equal(orig, adata) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize( + "max_cells,max_counts,min_cells,min_counts", + [ + [100, None, None, None], + [None, 100, None, None], + [None, None, 20, None], + [None, None, None, 20], + ], +) +def test_filter_genes(array_type, max_cells, max_counts, min_cells, min_counts): + adata = pbmc68k_reduced() + adata.X = adata.raw.X + adata_casted = adata.copy() + adata_casted.X = array_type(adata_casted.raw.X) + sc.pp.filter_genes( + adata, + max_cells=max_cells, + max_counts=max_counts, + min_cells=min_cells, + min_counts=min_counts, + ) + sc.pp.filter_genes( + adata_casted, + max_cells=max_cells, + max_counts=max_counts, + min_cells=min_cells, + min_counts=min_counts, + ) + X = adata_casted.X + if "dask" in array_type.__name__: + X = X.compute() + if issparse(X): + X = X.todense() + if issparse(adata.X): + adata.X = adata.X.todense() + assert_allclose(X, adata.X, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize( + "max_genes,max_counts,min_genes,min_counts", + [ + [100, None, None, None], + [None, 100, None, None], + [None, None, 20, None], + [None, None, None, 20], + ], +) +def test_filter_cells(array_type, max_genes, max_counts, min_genes, min_counts): + adata = pbmc68k_reduced() + adata.X = adata.raw.X + adata_casted = adata.copy() + adata_casted.X = array_type(adata_casted.raw.X) + sc.pp.filter_cells( + adata, + max_genes=max_genes, + max_counts=max_counts, + min_genes=min_genes, + min_counts=min_counts, + ) + sc.pp.filter_cells( + adata_casted, + max_genes=max_genes, + max_counts=max_counts, + min_genes=min_genes, + min_counts=min_counts, + ) + X = adata_casted.X + if "dask" in array_type.__name__: + X = X.compute() + if issparse(X): + X = X.todense() + if issparse(adata.X): + adata.X = adata.X.todense() + assert_allclose(X, adata.X, rtol=1e-5, atol=1e-5) diff --git a/scanpy/tests/test_utils.py b/scanpy/tests/test_utils.py index e7bdb24f7d..533485107a 100644 --- a/scanpy/tests/test_utils.py +++ b/scanpy/tests/test_utils.py @@ -1,21 +1,29 @@ from __future__ import annotations +from operator import mul, truediv from types import ModuleType import numpy as np import pytest from anndata.tests.helpers import asarray -from scipy.sparse import csr_matrix +from scipy.sparse import csr_matrix, issparse from scanpy._compat import DaskArray from scanpy._utils import ( + axis_mul_or_truediv, + axis_sum, check_nonnegative_integers, descend_classes_and_funcs, elem_mul, is_constant, ) from scanpy.testing._pytest.marks import needs -from scanpy.testing._pytest.params import ARRAY_TYPES, ARRAY_TYPES_SUPPORTED +from scanpy.testing._pytest.params import ( + ARRAY_TYPES, + ARRAY_TYPES_DASK, + ARRAY_TYPES_SPARSE, + ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED, +) def test_descend_classes_and_funcs(): @@ -35,16 +43,123 @@ def test_descend_classes_and_funcs(): assert {a.A, a.b.B} == set(descend_classes_and_funcs(a, "a")) -# TODO: add support for dask-in-sparse -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +def test_axis_mul_or_truediv_badop(): + dividend = np.array([[0, 1.0, 1.0], [1.0, 0, 1.0]]) + divisor = np.array([0.1, 0.2]) + with pytest.raises(ValueError, match=".*not one of truediv or mul"): + axis_mul_or_truediv(dividend, divisor, op=np.add, axis=0) + + +def test_axis_mul_or_truediv_bad_out(): + dividend = csr_matrix(np.array([[0, 1.0, 1.0], [1.0, 0, 1.0]])) + divisor = np.array([0.1, 0.2]) + with pytest.raises(ValueError, match="`out` argument provided but not equal to X"): + axis_mul_or_truediv(dividend, divisor, op=truediv, out=dividend.copy(), axis=0) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("op", [truediv, mul]) +def test_scale_row(array_type, op): + dividend = array_type(asarray([[0, 1.0, 1.0], [1.0, 0, 1.0]])) + divisor = np.array([0.1, 0.2]) + if op is mul: + divisor = 1 / divisor + expd = np.array([[0, 10.0, 10.0], [5.0, 0, 5.0]]) + out = dividend if issparse(dividend) or isinstance(dividend, np.ndarray) else None + res = asarray(axis_mul_or_truediv(dividend, divisor, op=op, axis=0, out=out)) + np.testing.assert_array_equal(res, expd) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("op", [truediv, mul]) +def test_scale_column(array_type, op): + dividend = array_type(asarray([[0, 1.0, 2.0], [3.0, 0, 4.0]])) + divisor = np.array([0.1, 0.2, 0.5]) + if op is mul: + divisor = 1 / divisor + expd = np.array([[0, 5.0, 4.0], [30.0, 0, 8.0]]) + out = dividend if issparse(dividend) or isinstance(dividend, np.ndarray) else None + res = asarray(axis_mul_or_truediv(dividend, divisor, op=op, axis=1, out=out)) + np.testing.assert_array_equal(res, expd) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_divide_by_zero(array_type): + dividend = array_type(asarray([[0, 1.0, 2.0], [3.0, 0, 4.0]])) + divisor = np.array([0.1, 0.2, 0.0]) + expd = np.array([[0, 5.0, 2.0], [30.0, 0, 4.0]]) + res = asarray( + axis_mul_or_truediv( + dividend, divisor, op=truediv, axis=1, allow_divide_by_zero=False + ) + ) + np.testing.assert_array_equal(res, expd) + res = asarray( + axis_mul_or_truediv( + dividend, divisor, op=truediv, axis=1, allow_divide_by_zero=True + ) + ) + expd = np.array([[0, 5.0, np.inf], [30.0, 0, np.inf]]) + np.testing.assert_array_equal(res, expd) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SPARSE) +def test_scale_out_with_dask_or_sparse_raises(array_type): + dividend = array_type(asarray([[0, 1.0, 2.0], [3.0, 0, 4.0]])) + divisor = np.array([0.1, 0.2, 0.5]) + if isinstance(dividend, DaskArray): + with pytest.raises( + TypeError if "dask" in array_type.__name__ else ValueError, + match="`out`*", + ): + axis_mul_or_truediv(dividend, divisor, op=truediv, axis=1, out=dividend) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES_DASK) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("op", [truediv, mul]) +def test_scale_rechunk(array_type, axis, op): + import dask.array as da + + dividend = array_type( + asarray([[0, 1.0, 2.0], [3.0, 0, 4.0], [3.0, 0, 4.0]]) + ).rechunk(((3,), (3,))) + divisor = da.from_array(np.array([0.1, 0.2, 0.5]), chunks=(1,)) + if op is mul: + divisor = 1 / divisor + if axis == 1: + expd = np.array([[0, 5.0, 4.0], [30.0, 0, 8.0], [30.0, 0, 8.0]]) + else: + expd = np.array([[0, 10.0, 20.0], [15.0, 0, 20.0], [6.0, 0, 8.0]]) + out = dividend if issparse(dividend) or isinstance(dividend, np.ndarray) else None + with pytest.warns(UserWarning, match="Rechunking scaling_array*"): + res = asarray(axis_mul_or_truediv(dividend, divisor, op=op, axis=axis, out=out)) + np.testing.assert_array_equal(res, expd) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_elem_mul(array_type): - m1 = array_type([[0, 1, 1], [1, 0, 1]]) - m2 = array_type([[2, 2, 1], [3, 2, 0]]) + m1 = array_type(asarray([[0, 1, 1], [1, 0, 1]])) + m2 = array_type(asarray([[2, 2, 1], [3, 2, 0]])) expd = np.array([[0, 2, 1], [3, 0, 0]]) res = asarray(elem_mul(m1, m2)) np.testing.assert_array_equal(res, expd) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_axis_sum(array_type): + m1 = array_type(asarray([[0, 1, 1], [1, 0, 1]])) + expd_0 = np.array([1, 1, 2]) + expd_1 = np.array([2, 2]) + res_0 = asarray(axis_sum(m1, axis=0)) + res_1 = asarray(axis_sum(m1, axis=1)) + if "matrix" in array_type.__name__: # for sparse since dimension is kept + res_0 = res_0.ravel() + res_1 = res_1.ravel() + np.testing.assert_array_equal(res_0, expd_0) + np.testing.assert_array_equal(res_1, expd_1) + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize( ("array_value", "expected"), @@ -79,7 +194,7 @@ def test_check_nonnegative_integers(array_type, array_value, expected): # TODO: Make it work for sparse-in-dask -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED) def test_is_constant(array_type): constant_inds = [1, 3] A = np.arange(20).reshape(5, 4)