Skip to content

Commit

Permalink
Fix graph metrics when some variables are constant (scverse#1891)
Browse files Browse the repository at this point in the history
* Initial test

* Move _utils to directory, add is_constant

* Fix gearys C and morans I for constant values in X

* Some docs on why this fix exists
  • Loading branch information
ivirshup authored Jun 22, 2021
1 parent 0ffa787 commit b70dee1
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 21 deletions.
17 changes: 11 additions & 6 deletions scanpy/_utils.py → scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Utility functions and classes
This file largely consists of the old _utils.py file. Over time, these functions
should be moved of this file.
"""
import sys
import inspect
Expand All @@ -19,9 +22,11 @@
from textwrap import dedent
from packaging import version

from ._settings import settings
from ._compat import Literal
from . import logging as logg
from .._settings import settings
from .._compat import Literal
from .. import logging as logg

from .compute.is_constant import is_constant


class Empty(Enum):
Expand All @@ -37,12 +42,12 @@ class Empty(Enum):


def check_versions():
from ._compat import pkg_version
from .._compat import pkg_version

umap_version = pkg_version("umap-learn")

if version.parse(anndata_version) < version.parse('0.6.10'):
from . import __version__
from .. import __version__

raise ImportError(
f'Scanpy {__version__} needs anndata version >=0.6.10, '
Expand Down Expand Up @@ -635,7 +640,7 @@ def subsample_n(
def check_presence_download(filename: Path, backup_url):
"""Check if file is present otherwise download."""
if not filename.is_file():
from .readwrite import _download
from ..readwrite import _download

_download(backup_url, filename)

Expand Down
111 changes: 111 additions & 0 deletions scanpy/_utils/compute/is_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from functools import singledispatch
from numbers import Integral

import numpy as np
from numba import njit
from scipy import sparse


@singledispatch
def is_constant(a, axis=None) -> np.ndarray:
"""
Check whether values in array are constant.
Params
------
a
Array to check
axis
Axis to reduce over.
Returns
-------
Boolean array, True values were constant.
Example
-------
>>> a = np.array([[0, 1], [0, 0]])
>>> a
array([[0, 1],
[0, 0]])
>>> is_constant(a)
False
>>> is_constant(a, axis=0)
array([ False, True])
>>> is_constant(a, axis=1)
array([ True, False])
"""
raise NotImplementedError()


@is_constant.register(np.ndarray)
def _(a, axis=None):
# Should eventually support nd, not now.
if axis is None:
return np.array_equal(a, a.flat[0])
if not isinstance(axis, Integral):
raise TypeError("axis must be integer or None.")
assert axis in (0, 1)
if axis == 0:
return _is_constant_rows(a.T)
elif axis == 1:
return _is_constant_rows(a)


def _is_constant_rows(a):
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
return (a == b).all(axis=1)


@is_constant.register(sparse.csr_matrix)
def _(a, axis=None):
if axis is None:
if len(a.data) == np.multiply(*a.shape):
return is_constant(a.data)
else:
return (a.data == 0).all()
if not isinstance(axis, Integral):
raise TypeError("axis must be integer or None.")
assert axis in (0, 1)
if axis == 1:
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape)
elif axis == 0:
a = a.T.tocsr()
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape)


@njit
def _is_constant_csr_rows(data, indices, indptr, shape):
N = len(indptr) - 1
result = np.ones(N, dtype=np.bool_)
for i in range(N):
start = indptr[i]
stop = indptr[i + 1]
if stop - start == shape[1]:
val = data[start]
else:
val = 0
for j in range(start, stop):
if data[j] != val:
result[i] = False
break
return result


@is_constant.register(sparse.csc_matrix)
def _(a, axis=None):
if axis is None:
if len(a.data) == np.multiply(*a.shape):
return is_constant(a.data)
else:
return (a.data == 0).all()
if not isinstance(axis, Integral):
raise TypeError("axis must be integer or None.")
assert axis in (0, 1)
if axis == 0:
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape[::-1])
elif axis == 1:
a = a.T.tocsc()
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape[::-1])
44 changes: 38 additions & 6 deletions scanpy/metrics/_gearys_c.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import singledispatch
from typing import Optional, Union
import warnings


from anndata import AnnData
Expand Down Expand Up @@ -272,27 +273,58 @@ def _(val):
return val.to_numpy()


def _check_vals(vals):
"""\
Checks that values wont cause issues in computation.
Returns new set of vals, and indexer to put values back into result.
For details on why this is neccesary, see:
https://github.com/theislab/scanpy/issues/1806
"""
from scanpy._utils import is_constant

full_result = np.empty(vals.shape[0], dtype=np.float64)
full_result.fill(np.nan)
idxer = ~is_constant(vals, axis=1)
if idxer.all():
idxer = slice(None)
else:
warnings.warn(
UserWarning(
f"{len(idxer) - idxer.sum()} variables were constant, will return nan for these.",
)
)
return vals[idxer], idxer, full_result


@gearys_c.register(sparse.csr_matrix)
def _gearys_c(g, vals) -> np.ndarray:
assert g.shape[0] == g.shape[1], "`g` should be a square adjacency matrix"
vals = _resolve_vals(vals)
g_data = g.data.astype(np.float_, copy=False)
if isinstance(vals, sparse.csr_matrix):
assert g.shape[0] == vals.shape[1]
return _gearys_c_mtx_csr(
new_vals, idxer, full_result = _check_vals(vals)
result = _gearys_c_mtx_csr(
g_data,
g.indices,
g.indptr,
vals.data.astype(np.float_, copy=False),
vals.indices,
vals.indptr,
vals.shape,
new_vals.data.astype(np.float_, copy=False),
new_vals.indices,
new_vals.indptr,
new_vals.shape,
)
full_result[idxer] = result
return full_result
elif isinstance(vals, np.ndarray) and vals.ndim == 1:
assert g.shape[0] == vals.shape[0]
return _gearys_c_vec(g_data, g.indices, g.indptr, vals)
elif isinstance(vals, np.ndarray) and vals.ndim == 2:
assert g.shape[0] == vals.shape[1]
return _gearys_c_mtx(g_data, g.indices, g.indptr, vals)
new_vals, idxer, full_result = _check_vals(vals)
result = _gearys_c_mtx(g_data, g.indices, g.indptr, new_vals)
full_result[idxer] = result
return full_result
else:
raise NotImplementedError()
22 changes: 14 additions & 8 deletions scanpy/metrics/_morans_i.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numba import njit, prange

from scanpy.get import _get_obs_rep
from scanpy.metrics._gearys_c import _resolve_vals
from scanpy.metrics._gearys_c import _resolve_vals, _check_vals


@singledispatch
Expand Down Expand Up @@ -228,25 +228,31 @@ def _morans_i(g, vals) -> np.ndarray:
g_data = g.data.astype(np.float_, copy=False)
if isinstance(vals, sparse.csr_matrix):
assert g.shape[0] == vals.shape[1]
return _morans_i_mtx_csr(
new_vals, idxer, full_result = _check_vals(vals)
result = _morans_i_mtx_csr(
g_data,
g.indices,
g.indptr,
vals.data.astype(np.float_, copy=False),
vals.indices,
vals.indptr,
vals.shape,
new_vals.data.astype(np.float_, copy=False),
new_vals.indices,
new_vals.indptr,
new_vals.shape,
)
full_result[idxer] = result
return full_result
elif isinstance(vals, np.ndarray) and vals.ndim == 1:
assert g.shape[0] == vals.shape[0]
return _morans_i_vec(g_data, g.indices, g.indptr, vals)
elif isinstance(vals, np.ndarray) and vals.ndim == 2:
assert g.shape[0] == vals.shape[1]
return _morans_i_mtx(
new_vals, idxer, full_result = _check_vals(vals)
result = _morans_i_mtx(
g_data,
g.indices,
g.indptr,
vals.astype(np.float_, copy=False),
new_vals.astype(np.float_, copy=False),
)
full_result[idxer] = result
return full_result
else:
raise NotImplementedError()
46 changes: 46 additions & 0 deletions scanpy/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from operator import eq
from string import ascii_letters
import warnings

import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse

from anndata.tests.helpers import asarray
import pytest


def test_gearys_c_consistency():
pbmc = sc.datasets.pbmc68k_reduced()
Expand Down Expand Up @@ -121,6 +125,48 @@ def test_morans_i_correctness():
assert sc.metrics.morans_i(adata, vals=connected) == 1.0


@pytest.mark.parametrize("metric", [sc.metrics.gearys_c, sc.metrics.morans_i])
@pytest.mark.parametrize(
'array_type',
[asarray, sparse.csr_matrix, sparse.csc_matrix],
ids=lambda x: x.__name__,
)
def test_graph_metrics_w_constant_values(metric, array_type):
# https://github.com/theislab/scanpy/issues/1806
pbmc = sc.datasets.pbmc68k_reduced()
XT = array_type(pbmc.raw.X.T.copy())
g = pbmc.obsp["connectivities"].copy()

const_inds = np.random.choice(XT.shape[0], 10, replace=False)
with warnings.catch_warnings():
warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning)
XT_zero_vals = XT.copy()
XT_zero_vals[const_inds, :] = 0
XT_const_vals = XT.copy()
XT_const_vals[const_inds, :] = 42

results_full = metric(g, XT)
# TODO: Check for warnings
with pytest.warns(
UserWarning, match=r"10 variables were constant, will return nan for these"
):
results_const_zeros = metric(g, XT_zero_vals)
with pytest.warns(
UserWarning, match=r"10 variables were constant, will return nan for these"
):
results_const_vals = metric(g, XT_const_vals)

assert not np.isnan(results_full).any()
np.testing.assert_array_equal(results_const_zeros, results_const_vals)
np.testing.assert_array_equal(np.nan, results_const_zeros[const_inds])
np.testing.assert_array_equal(np.nan, results_const_vals[const_inds])

non_const_mask = ~np.isin(np.arange(XT.shape[0]), const_inds)
np.testing.assert_array_equal(
results_full[non_const_mask], results_const_zeros[non_const_mask]
)


def test_confusion_matrix():
mtx = sc.metrics.confusion_matrix(["a", "b"], ["c", "d"], normalize=False)
assert mtx.loc["a", "c"] == 1
Expand Down
30 changes: 29 additions & 1 deletion scanpy/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from types import ModuleType
from scipy.sparse import csr_matrix
from scipy.sparse import csr_matrix, csc_matrix
import numpy as np

from scanpy._utils import descend_classes_and_funcs, check_nonnegative_integers

from anndata.tests.helpers import assert_equal, asarray
import pytest


def test_descend_classes_and_funcs():
# create module hierarchy
Expand Down Expand Up @@ -37,3 +40,28 @@ def test_check_nonnegative_integers():

X_ = csr_matrix(X_)
assert check_nonnegative_integers(X_) is False


@pytest.mark.parametrize(
'array_type', [asarray, csr_matrix, csc_matrix], ids=lambda x: x.__name__
)
def test_is_constant(array_type):
from scanpy._utils import is_constant

constant_inds = [1, 3]
A = np.arange(20).reshape(5, 4)
A[constant_inds, :] = 10
A = array_type(A)
AT = array_type(A.T)

assert not is_constant(A)
assert not np.any(is_constant(A, axis=0))
np.testing.assert_array_equal(
[False, True, False, True, False], is_constant(A, axis=1)
)

assert not is_constant(AT)
assert not np.any(is_constant(AT, axis=1))
np.testing.assert_array_equal(
[False, True, False, True, False], is_constant(AT, axis=0)
)

0 comments on commit b70dee1

Please sign in to comment.