Skip to content

Commit

Permalink
Embeddings show nulls in continuous variables
Browse files Browse the repository at this point in the history
In embeddings this comes up as light gray points, in spatial plots these are transparent.
This required adding explicit handling of the color map, and making sure we accepted `nan`s in a few more places.
  • Loading branch information
ivirshup committed Aug 25, 2020
1 parent 785c59a commit 498b757
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
34 changes: 20 additions & 14 deletions scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections.abc as cabc
from copy import copy
from typing import Union, Optional, Sequence, Any, Mapping, List, Tuple, Callable

import numpy as np
Expand All @@ -9,6 +10,7 @@
from matplotlib.figure import Figure
from pandas.api.types import is_categorical_dtype
from matplotlib import pyplot as pl, colors
from matplotlib.cm import get_cmap
from matplotlib import rcParams
from matplotlib import patheffects
from matplotlib.colors import Colormap
Expand Down Expand Up @@ -68,6 +70,7 @@ def embedding(
library_id: str = None,
#
color_map: Union[Colormap, str, None] = None,
cmap: Union[Colormap, str, None] = None,
palette: Union[str, Sequence[str], Cycler, None] = None,
size: Union[float, Sequence[float], None] = None,
frameon: Optional[bool] = None,
Expand All @@ -88,6 +91,7 @@ def embedding(
save: Union[bool, str, None] = None,
ax: Optional[Axes] = None,
return_fig: Optional[bool] = None,
_missing_color="lightgray", # Keeping private for now
**kwargs,
) -> Union[Figure, Axes, None]:
"""\
Expand All @@ -108,8 +112,17 @@ def embedding(
"""

sanitize_anndata(adata)

# Setting up color map for continuous values
if color_map is not None:
kwargs['cmap'] = color_map
if cmap is not None:
raise ValueError("Cannot specify both `color_map` and `cmap`.")
else:
cmap = color_map
cmap = copy(get_cmap(cmap))
cmap.set_bad(_missing_color)
kwargs["cmap"] = cmap

if size is not None:
kwargs['s'] = size
if 'edgecolor' not in kwargs:
Expand All @@ -122,14 +135,6 @@ def embedding(
if isinstance(groups, str):
groups = [groups]

# Setting color for missing values
if library_id is None:
# Light gray for most cases
missing_color = colors.to_hex("lightgray", keep_alpha=True)
else:
# Clear for spatial
missing_color = colors.to_hex((0, 0, 0, 0), keep_alpha=True)

make_projection_available(projection)
args_3d = dict(projection='3d') if projection == '3d' else {}

Expand Down Expand Up @@ -247,14 +252,14 @@ def embedding(
color_source_vector,
groups=groups,
palette=palette,
missing_color=missing_color,
missing_color=_missing_color,
)

### Order points
order = slice(None)
if sort_order is True and value_to_plot is not None and categorical is False:
# Higher values plotted on top
order = np.argsort(color_vector, kind="stable")
# Higher values plotted on top, null values on bottom
order = np.argsort(-color_vector, kind="stable")[::-1]
elif sort_order and categorical and groups is not None:
# Left out groups go on bottom
order = np.argsort(color_source_vector.isin(groups), kind="stable")
Expand Down Expand Up @@ -323,7 +328,7 @@ def embedding(
size_spot = 70 * size

scatter = (
partial(ax.scatter, s=size)
partial(ax.scatter, s=size, plotnonfinite=True)
if library_id is None
else partial(circles, s=size_spot, ax=ax)
)
Expand Down Expand Up @@ -534,7 +539,7 @@ def my_vmax(color_vector): np.percentile(color_vector, p=80)
f"Please check the correct format for percentiles."
)
# interpret value of vmin/vmax as quantile with the following syntax 'p99.9'
v_value = np.percentile(color_vector, q=float(v_value[1:]))
v_value = np.nanpercentile(color_vector, q=float(v_value[1:]))
elif callable(v_value):
# interpret vmin/vmax as function
v_value = v_value(color_vector)
Expand Down Expand Up @@ -773,6 +778,7 @@ def spatial(
bw=bw,
library_id=library_id,
size=size,
_missing_color=(0, 0, 0, 0),
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ def circles(x, y, s, ax, marker=None, c='b', vmin=None, vmax=None, **kwargs):
patches = [Circle((x_, y_), s_) for x_, y_, s_ in zipped]
collection = PatchCollection(patches, **kwargs)
if isinstance(c, np.ndarray) and np.issubdtype(c.dtype, np.number):
collection.set_array(c)
collection.set_array(np.ma.masked_invalid(c))
collection.set_clim(vmin, vmax)
else:
collection.set_facecolor(c)
Expand Down

0 comments on commit 498b757

Please sign in to comment.