From b70dee15c5b31ed9f56d3f5b74fd96e73e4443d4 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Tue, 22 Jun 2021 20:27:00 +1000 Subject: [PATCH] Fix graph metrics when some variables are constant (#1891) * Initial test * Move _utils to directory, add is_constant * Fix gearys C and morans I for constant values in X * Some docs on why this fix exists --- scanpy/{_utils.py => _utils/__init__.py} | 17 ++-- scanpy/_utils/compute/is_constant.py | 111 +++++++++++++++++++++++ scanpy/metrics/_gearys_c.py | 44 +++++++-- scanpy/metrics/_morans_i.py | 22 +++-- scanpy/tests/test_metrics.py | 46 ++++++++++ scanpy/tests/test_utils.py | 30 +++++- 6 files changed, 249 insertions(+), 21 deletions(-) rename scanpy/{_utils.py => _utils/__init__.py} (98%) create mode 100644 scanpy/_utils/compute/is_constant.py diff --git a/scanpy/_utils.py b/scanpy/_utils/__init__.py similarity index 98% rename from scanpy/_utils.py rename to scanpy/_utils/__init__.py index 101bab7712..105ca8802a 100644 --- a/scanpy/_utils.py +++ b/scanpy/_utils/__init__.py @@ -1,4 +1,7 @@ """Utility functions and classes + +This file largely consists of the old _utils.py file. Over time, these functions +should be moved of this file. """ import sys import inspect @@ -19,9 +22,11 @@ from textwrap import dedent from packaging import version -from ._settings import settings -from ._compat import Literal -from . import logging as logg +from .._settings import settings +from .._compat import Literal +from .. import logging as logg + +from .compute.is_constant import is_constant class Empty(Enum): @@ -37,12 +42,12 @@ class Empty(Enum): def check_versions(): - from ._compat import pkg_version + from .._compat import pkg_version umap_version = pkg_version("umap-learn") if version.parse(anndata_version) < version.parse('0.6.10'): - from . import __version__ + from .. import __version__ raise ImportError( f'Scanpy {__version__} needs anndata version >=0.6.10, ' @@ -635,7 +640,7 @@ def subsample_n( def check_presence_download(filename: Path, backup_url): """Check if file is present otherwise download.""" if not filename.is_file(): - from .readwrite import _download + from ..readwrite import _download _download(backup_url, filename) diff --git a/scanpy/_utils/compute/is_constant.py b/scanpy/_utils/compute/is_constant.py new file mode 100644 index 0000000000..eff57c74c8 --- /dev/null +++ b/scanpy/_utils/compute/is_constant.py @@ -0,0 +1,111 @@ +from functools import singledispatch +from numbers import Integral + +import numpy as np +from numba import njit +from scipy import sparse + + +@singledispatch +def is_constant(a, axis=None) -> np.ndarray: + """ + Check whether values in array are constant. + + Params + ------ + a + Array to check + axis + Axis to reduce over. + + + Returns + ------- + Boolean array, True values were constant. + + Example + ------- + + >>> a = np.array([[0, 1], [0, 0]]) + >>> a + array([[0, 1], + [0, 0]]) + >>> is_constant(a) + False + >>> is_constant(a, axis=0) + array([ False, True]) + >>> is_constant(a, axis=1) + array([ True, False]) + """ + raise NotImplementedError() + + +@is_constant.register(np.ndarray) +def _(a, axis=None): + # Should eventually support nd, not now. + if axis is None: + return np.array_equal(a, a.flat[0]) + if not isinstance(axis, Integral): + raise TypeError("axis must be integer or None.") + assert axis in (0, 1) + if axis == 0: + return _is_constant_rows(a.T) + elif axis == 1: + return _is_constant_rows(a) + + +def _is_constant_rows(a): + b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape) + return (a == b).all(axis=1) + + +@is_constant.register(sparse.csr_matrix) +def _(a, axis=None): + if axis is None: + if len(a.data) == np.multiply(*a.shape): + return is_constant(a.data) + else: + return (a.data == 0).all() + if not isinstance(axis, Integral): + raise TypeError("axis must be integer or None.") + assert axis in (0, 1) + if axis == 1: + return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape) + elif axis == 0: + a = a.T.tocsr() + return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape) + + +@njit +def _is_constant_csr_rows(data, indices, indptr, shape): + N = len(indptr) - 1 + result = np.ones(N, dtype=np.bool_) + for i in range(N): + start = indptr[i] + stop = indptr[i + 1] + if stop - start == shape[1]: + val = data[start] + else: + val = 0 + for j in range(start, stop): + if data[j] != val: + result[i] = False + break + return result + + +@is_constant.register(sparse.csc_matrix) +def _(a, axis=None): + if axis is None: + if len(a.data) == np.multiply(*a.shape): + return is_constant(a.data) + else: + return (a.data == 0).all() + if not isinstance(axis, Integral): + raise TypeError("axis must be integer or None.") + assert axis in (0, 1) + if axis == 0: + return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape[::-1]) + elif axis == 1: + a = a.T.tocsc() + return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape[::-1]) diff --git a/scanpy/metrics/_gearys_c.py b/scanpy/metrics/_gearys_c.py index 5dfb035944..f274e531eb 100644 --- a/scanpy/metrics/_gearys_c.py +++ b/scanpy/metrics/_gearys_c.py @@ -1,5 +1,6 @@ from functools import singledispatch from typing import Optional, Union +import warnings from anndata import AnnData @@ -272,6 +273,31 @@ def _(val): return val.to_numpy() +def _check_vals(vals): + """\ + Checks that values wont cause issues in computation. + + Returns new set of vals, and indexer to put values back into result. + + For details on why this is neccesary, see: + https://github.com/theislab/scanpy/issues/1806 + """ + from scanpy._utils import is_constant + + full_result = np.empty(vals.shape[0], dtype=np.float64) + full_result.fill(np.nan) + idxer = ~is_constant(vals, axis=1) + if idxer.all(): + idxer = slice(None) + else: + warnings.warn( + UserWarning( + f"{len(idxer) - idxer.sum()} variables were constant, will return nan for these.", + ) + ) + return vals[idxer], idxer, full_result + + @gearys_c.register(sparse.csr_matrix) def _gearys_c(g, vals) -> np.ndarray: assert g.shape[0] == g.shape[1], "`g` should be a square adjacency matrix" @@ -279,20 +305,26 @@ def _gearys_c(g, vals) -> np.ndarray: g_data = g.data.astype(np.float_, copy=False) if isinstance(vals, sparse.csr_matrix): assert g.shape[0] == vals.shape[1] - return _gearys_c_mtx_csr( + new_vals, idxer, full_result = _check_vals(vals) + result = _gearys_c_mtx_csr( g_data, g.indices, g.indptr, - vals.data.astype(np.float_, copy=False), - vals.indices, - vals.indptr, - vals.shape, + new_vals.data.astype(np.float_, copy=False), + new_vals.indices, + new_vals.indptr, + new_vals.shape, ) + full_result[idxer] = result + return full_result elif isinstance(vals, np.ndarray) and vals.ndim == 1: assert g.shape[0] == vals.shape[0] return _gearys_c_vec(g_data, g.indices, g.indptr, vals) elif isinstance(vals, np.ndarray) and vals.ndim == 2: assert g.shape[0] == vals.shape[1] - return _gearys_c_mtx(g_data, g.indices, g.indptr, vals) + new_vals, idxer, full_result = _check_vals(vals) + result = _gearys_c_mtx(g_data, g.indices, g.indptr, new_vals) + full_result[idxer] = result + return full_result else: raise NotImplementedError() diff --git a/scanpy/metrics/_morans_i.py b/scanpy/metrics/_morans_i.py index 5e712aa840..316118a2af 100644 --- a/scanpy/metrics/_morans_i.py +++ b/scanpy/metrics/_morans_i.py @@ -8,7 +8,7 @@ from numba import njit, prange from scanpy.get import _get_obs_rep -from scanpy.metrics._gearys_c import _resolve_vals +from scanpy.metrics._gearys_c import _resolve_vals, _check_vals @singledispatch @@ -228,25 +228,31 @@ def _morans_i(g, vals) -> np.ndarray: g_data = g.data.astype(np.float_, copy=False) if isinstance(vals, sparse.csr_matrix): assert g.shape[0] == vals.shape[1] - return _morans_i_mtx_csr( + new_vals, idxer, full_result = _check_vals(vals) + result = _morans_i_mtx_csr( g_data, g.indices, g.indptr, - vals.data.astype(np.float_, copy=False), - vals.indices, - vals.indptr, - vals.shape, + new_vals.data.astype(np.float_, copy=False), + new_vals.indices, + new_vals.indptr, + new_vals.shape, ) + full_result[idxer] = result + return full_result elif isinstance(vals, np.ndarray) and vals.ndim == 1: assert g.shape[0] == vals.shape[0] return _morans_i_vec(g_data, g.indices, g.indptr, vals) elif isinstance(vals, np.ndarray) and vals.ndim == 2: assert g.shape[0] == vals.shape[1] - return _morans_i_mtx( + new_vals, idxer, full_result = _check_vals(vals) + result = _morans_i_mtx( g_data, g.indices, g.indptr, - vals.astype(np.float_, copy=False), + new_vals.astype(np.float_, copy=False), ) + full_result[idxer] = result + return full_result else: raise NotImplementedError() diff --git a/scanpy/tests/test_metrics.py b/scanpy/tests/test_metrics.py index ba2ce86e28..025d788abb 100644 --- a/scanpy/tests/test_metrics.py +++ b/scanpy/tests/test_metrics.py @@ -1,11 +1,15 @@ from operator import eq from string import ascii_letters +import warnings import numpy as np import pandas as pd import scanpy as sc from scipy import sparse +from anndata.tests.helpers import asarray +import pytest + def test_gearys_c_consistency(): pbmc = sc.datasets.pbmc68k_reduced() @@ -121,6 +125,48 @@ def test_morans_i_correctness(): assert sc.metrics.morans_i(adata, vals=connected) == 1.0 +@pytest.mark.parametrize("metric", [sc.metrics.gearys_c, sc.metrics.morans_i]) +@pytest.mark.parametrize( + 'array_type', + [asarray, sparse.csr_matrix, sparse.csc_matrix], + ids=lambda x: x.__name__, +) +def test_graph_metrics_w_constant_values(metric, array_type): + # https://github.com/theislab/scanpy/issues/1806 + pbmc = sc.datasets.pbmc68k_reduced() + XT = array_type(pbmc.raw.X.T.copy()) + g = pbmc.obsp["connectivities"].copy() + + const_inds = np.random.choice(XT.shape[0], 10, replace=False) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning) + XT_zero_vals = XT.copy() + XT_zero_vals[const_inds, :] = 0 + XT_const_vals = XT.copy() + XT_const_vals[const_inds, :] = 42 + + results_full = metric(g, XT) + # TODO: Check for warnings + with pytest.warns( + UserWarning, match=r"10 variables were constant, will return nan for these" + ): + results_const_zeros = metric(g, XT_zero_vals) + with pytest.warns( + UserWarning, match=r"10 variables were constant, will return nan for these" + ): + results_const_vals = metric(g, XT_const_vals) + + assert not np.isnan(results_full).any() + np.testing.assert_array_equal(results_const_zeros, results_const_vals) + np.testing.assert_array_equal(np.nan, results_const_zeros[const_inds]) + np.testing.assert_array_equal(np.nan, results_const_vals[const_inds]) + + non_const_mask = ~np.isin(np.arange(XT.shape[0]), const_inds) + np.testing.assert_array_equal( + results_full[non_const_mask], results_const_zeros[non_const_mask] + ) + + def test_confusion_matrix(): mtx = sc.metrics.confusion_matrix(["a", "b"], ["c", "d"], normalize=False) assert mtx.loc["a", "c"] == 1 diff --git a/scanpy/tests/test_utils.py b/scanpy/tests/test_utils.py index 9fb28e8332..8940c3a372 100644 --- a/scanpy/tests/test_utils.py +++ b/scanpy/tests/test_utils.py @@ -1,9 +1,12 @@ from types import ModuleType -from scipy.sparse import csr_matrix +from scipy.sparse import csr_matrix, csc_matrix import numpy as np from scanpy._utils import descend_classes_and_funcs, check_nonnegative_integers +from anndata.tests.helpers import assert_equal, asarray +import pytest + def test_descend_classes_and_funcs(): # create module hierarchy @@ -37,3 +40,28 @@ def test_check_nonnegative_integers(): X_ = csr_matrix(X_) assert check_nonnegative_integers(X_) is False + + +@pytest.mark.parametrize( + 'array_type', [asarray, csr_matrix, csc_matrix], ids=lambda x: x.__name__ +) +def test_is_constant(array_type): + from scanpy._utils import is_constant + + constant_inds = [1, 3] + A = np.arange(20).reshape(5, 4) + A[constant_inds, :] = 10 + A = array_type(A) + AT = array_type(A.T) + + assert not is_constant(A) + assert not np.any(is_constant(A, axis=0)) + np.testing.assert_array_equal( + [False, True, False, True, False], is_constant(A, axis=1) + ) + + assert not is_constant(AT) + assert not np.any(is_constant(AT, axis=1)) + np.testing.assert_array_equal( + [False, True, False, True, False], is_constant(AT, axis=0) + )