Skip to content

Commit

Permalink
Merge pull request scverse#1828 from maximz/rapids-neighbors-metrics
Browse files Browse the repository at this point in the history
Bugfix in RAPIDS usage for neighbors(), and support for additional distance metrics
  • Loading branch information
Zethson authored Jan 6, 2022
2 parents 60353cc + b6bd872 commit c2ae5b9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.8.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Fixed finding variables with ``use_raw=True`` and ``basis=None`` in :func:`scanpy.pl.scatter` :pr:`2027` :small:`E Rice`
- Fixed :func:`scanpy.external.pp.scrublet` to address :issue:`1957` :smaller:`FlMai` and ensure raw counts are used for simulation
- Functions in :mod:`scanpy.datasets` no longer throw `OldFormatWarnings` when using `anndata` `0.8` :pr:`2096` :small:`I Virshup`
- Fixed use of :func:`scanpy.pp.neighbors` with ``method='rapids'``: RAPIDS cuML no longer returns a squared Euclidean distance matrix, so we should not square-root the kNN distance matrix. :pr:`1828` :small:`M Zaslavsky`

.. rubric:: Performance

Expand Down
21 changes: 12 additions & 9 deletions scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def compute_neighbors_umap(
return knn_indices, knn_dists, forest


def compute_neighbors_rapids(X: np.ndarray, n_neighbors: int):
def compute_neighbors_rapids(
X: np.ndarray, n_neighbors: int, metric: _Metric = 'euclidean'
):
"""Compute nearest neighbors using RAPIDS cuml.
Parameters
Expand All @@ -324,18 +326,21 @@ def compute_neighbors_rapids(X: np.ndarray, n_neighbors: int):
The data to compute nearest neighbors for.
n_neighbors
The number of neighbors to use.
metric
The metric to use to compute distances in high dimensional space.
This string must match a valid predefined metric in RAPIDS cuml.
Returns
-------
**knn_indices**, **knn_dists** : np.arrays of shape (n_observations, n_neighbors)
"""
from cuml.neighbors import NearestNeighbors

nn = NearestNeighbors(n_neighbors=n_neighbors)
nn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric)
X_contiguous = np.ascontiguousarray(X, dtype=np.float32)
nn.fit(X_contiguous)
knn_distsq, knn_indices = nn.kneighbors(X_contiguous)
return knn_indices, np.sqrt(knn_distsq) # cuml uses sqeuclidean metric so take sqrt
knn_dist, knn_indices = nn.kneighbors(X_contiguous)
return knn_indices, knn_dist


def _get_sparse_matrix_from_indices_distances_umap(
Expand Down Expand Up @@ -755,10 +760,6 @@ def compute_neighbors(
logg.warning(f'n_obs too small: adjusting to `n_neighbors = {n_neighbors}`')
if method == 'umap' and not knn:
raise ValueError('`method = \'umap\' only with `knn = True`.')
if method == 'rapids' and metric != 'euclidean':
raise ValueError(
"`method` 'rapids' only supports the 'euclidean' `metric`."
)
if method not in {'umap', 'gauss', 'rapids'}:
raise ValueError("`method` needs to be 'umap', 'gauss', or 'rapids'.")
if self._adata.shape[0] >= 10000 and not knn:
Expand All @@ -782,7 +783,9 @@ def compute_neighbors(
else:
self._distances = _distances
elif method == 'rapids':
knn_indices, knn_distances = compute_neighbors_rapids(X, n_neighbors)
knn_indices, knn_distances = compute_neighbors_rapids(
X, n_neighbors, metric=metric
)
else:
# non-euclidean case and approx nearest neighbors
if X.shape[0] < 4096:
Expand Down

0 comments on commit c2ae5b9

Please sign in to comment.