Skip to content

Commit

Permalink
Merge pull request scverse#628 from fidelram/multi_panel_density
Browse files Browse the repository at this point in the history
Multi-panel embedded density
  • Loading branch information
fidelram authored May 6, 2019
2 parents 73f1e6a + beb4c87 commit 1a14cf1
Showing 1 changed file with 107 additions and 35 deletions.
142 changes: 107 additions & 35 deletions scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from collections import abc
import numpy as np
import pandas as pd
from scipy.sparse import issparse
from matplotlib import pyplot as pl
from matplotlib import rcParams, cm, colors
from anndata import AnnData
from typing import Union, Optional
from typing import Union, Optional, List

from .. import _utils as utils
from ...utils import doc_params, sanitize_anndata
from ... import logging as logg
from .._anndata import scatter, ranking
from .._utils import timeseries, timeseries_subplot, timeseries_as_heatmap
from .._docs import doc_scatter_bulk, doc_show_save_ax
from .._docs import doc_scatter_bulk, doc_show_save_ax, _doc_scatter_panels
from .scatterplots import pca, plot_scatter
from matplotlib.colors import Colormap

Expand Down Expand Up @@ -642,19 +643,23 @@ def sim(adata, tmax_realization=None, as_heatmap=False, shuffle=False,
utils.savefig_or_show('sim_shuffled', save=save, show=show)


@doc_params(show_save_ax=doc_show_save_ax)
@doc_params(show_save_ax=doc_show_save_ax, panels=_doc_scatter_panels)
def embedding_density(
adata: AnnData,
basis: str,
key: str,
*,
group: Optional[str] = None,
group: Optional[Union[str, List[str], None]] = 'all',
color_map: Union[Colormap, str] = 'YlOrRd',
bg_dotsize: Optional[int] = 80,
fg_dotsize: Optional[int] = 180,
vmax: Optional[int] = 1,
vmin: Optional[int] = 0,
ncols: Optional[int] = 4,
hspace: Optional[float] = 0.25,
wspace: Optional[None] = None,
save: Union[bool, str, None] = None,
show: Optional[bool] = None,
**kwargs
):
"""\
Expand All @@ -677,7 +682,10 @@ def embedding_density(
Name of the `.obs` covariate that contains the density estimates
group
The category in the categorical observation annotation to be plotted.
For example, 'G1' in the cell cycle 'phase' covariate.
For example, 'G1' in the cell cycle 'phase' covariate. If all categories
are to be plotted use group='all' (default), If multiple categories
want to be plotted use a list (e.g.: ['G1', 'S']. If the overall density
wants to be ploted set group to 'None'.
color_map
Matplolib color map to use for density plotting.
bg_dotsize
Expand All @@ -688,17 +696,21 @@ def embedding_density(
Density that corresponds to color bar maximum.
vmin
Density that corresponds to color bar minimum.
{panels}
{show_save_ax}
Examples
--------
>>> adata = sc.datasets.pbmc68k_reduced()
>>> sc.tl.umap(adata)
>>> sc.tl.embedding_density(adata, basis='umap', groupby='phase')
Plot all categories be default
>>> sc.pl.embedding_density(adata, basis='umap', key='umap_density_phase')
Plot selected categories
>>> sc.pl.embedding_density(adata, basis='umap', key='umap_density_phase',
... group='G1')
>>> sc.pl.embedding_density(adata, basis='umap', key='umap_density_phase',
... group='S')
... group=['G1', 'S'])
"""
sanitize_anndata(adata)

Expand Down Expand Up @@ -726,51 +738,111 @@ def embedding_density(
components = adata.uns[key+'_params']['components']
groupby = adata.uns[key+'_params']['covariate']

# turn group into a list if needed
if group == 'all':
if groupby is None:
group = None
else:
group = list(adata.obs[groupby].cat.categories)
elif isinstance(group, str):
group = [group]

if (group is None) and (groupby is not None):
raise ValueError('Densities were calculated over an `.obs` covariate. '
'Please specify a group from this covariate to plot.')

if (group is not None) and (group not in adata.obs[groupby].cat.categories):
raise ValueError('Please specify a group from the `.obs` category over which the density '
'was calculated.')
if (group is not None) and (groupby is None):
logg.warn('value of \'group\' is ignored because densities were not calculated for an `.obs` covariate.')
group = None

if (np.min(adata.obs[key]) < 0) or (np.max(adata.obs[key]) > 1):
raise ValueError('Densities should be scaled between 0 and 1.')

# Define plotting data
dens_values = -np.ones(adata.n_obs)
dot_sizes = np.ones(adata.n_obs)*bg_dotsize

if group is not None:
group_mask = (adata.obs[groupby] == group)
dens_values[group_mask] = adata.obs[key][group_mask]
dot_sizes[group_mask] = np.ones(sum(group_mask))*fg_dotsize

else:
dens_values = adata.obs[key]
dot_sizes = np.ones(adata.n_obs)*fg_dotsize
if wspace is None:
# try to set a wspace that is not too large or too small given the
# current figure size
wspace = 0.75 / rcParams['figure.figsize'][0] + 0.02

# Make the color map
if isinstance(color_map, str):
cmap = cm.get_cmap(color_map)
else:
cmap = color_map

#norm = colors.Normalize(vmin=-1, vmax=1)
adata_vis = adata.copy()
adata_vis.obs['Density'] = dens_values

norm = colors.Normalize(vmin=vmin, vmax=vmax)
cmap.set_over('black')
cmap.set_under('lightgray')
# a name to store the density values is needed. To avoid
# overwriting a user name a new random name is created
while True:
density_col_name = '_tmp_embedding_density_column_{}_'.format(np.random.randint(1000, 10000))
if density_col_name not in adata.obs.columns:
break

# if group is set, then plot it using multiple panels (even if only one group is set)
if group is not None and isinstance(group, abc.Sequence):
from matplotlib import gridspec
# set up the figure
num_panels = len(group)
n_panels_x = min(ncols, num_panels)
n_panels_y = np.ceil(num_panels / n_panels_x).astype(int)
# each panel will have the size of rcParams['figure.figsize']
fig = pl.figure(figsize=(n_panels_x * rcParams['figure.figsize'][0] * (1 + wspace),
n_panels_y * rcParams['figure.figsize'][1]))
left = 0.2 / n_panels_x
bottom = 0.13 / n_panels_y
gs = gridspec.GridSpec(
nrows=n_panels_y, ncols=n_panels_x,
left=left, right=1 - (n_panels_x - 1) * left - 0.01 / n_panels_x,
bottom=bottom, top=1 - (n_panels_y - 1) * bottom - 0.1 / n_panels_y,
hspace=hspace, wspace=wspace,
)

axs = []
for count, group_name in enumerate(group):
if group_name not in adata.obs[groupby].cat.categories:
raise ValueError('Please specify a group from the `.obs` category over which the density '
'was calculated. Invalid group name: {}'.format(group_name))

ax = pl.subplot(gs[count])
# Define plotting data
dot_sizes = np.ones(adata.n_obs) * bg_dotsize
group_mask = (adata.obs[groupby] == group_name)
dens_values = -np.ones(adata.n_obs)
dens_values[group_mask] = adata.obs[key][group_mask]
adata.obs[density_col_name] = dens_values
dot_sizes[group_mask] = np.ones(sum(group_mask)) * fg_dotsize

if 'title' not in kwargs:
title = group_name
else:
title = kwargs.pop('title')
ax=plot_scatter(adata, basis, components=components, color=density_col_name,
color_map=cmap, norm=norm, size=dot_sizes, vmax=vmax,
vmin=vmin, save=False, title=title, ax=ax, show=False, **kwargs)
axs.append(ax)

# Ensure title is blank as default
if 'title' not in kwargs:
title=""
ax = axs
else:
title = kwargs.pop('title')
dens_values = adata.obs[key]
dot_sizes = np.ones(adata.n_obs)*fg_dotsize

adata.obs[density_col_name] = dens_values

# Ensure title is blank as default
if 'title' not in kwargs:
title = group if group is not None else ""
else:
title = kwargs.pop('title')

# Plot the graph
ax = plot_scatter(adata, basis, components=components, color=density_col_name,
color_map=cmap, norm=norm, size=dot_sizes, vmax=vmax,
vmin=vmin, save=False, show=False, title=title, **kwargs)

# remove temporary column name
adata.obs = adata.obs.drop(columns=[density_col_name])

# Plot the graph
return plot_scatter(adata_vis, basis, components=components, color='Density',
color_map=cmap, norm=norm, size=dot_sizes, vmax=vmax,
vmin=vmin, save=save, title=title, **kwargs)
utils.savefig_or_show(key + "_", show=show, save=save)
if show is False:
return ax

0 comments on commit 1a14cf1

Please sign in to comment.