Skip to content

Commit

Permalink
implement plot_trajectory
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Nov 7, 2023
1 parent bbe9fb7 commit 4a04527
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Release Notes
* require `python>=3.8`
* implement CI for testing
* fixes for edge cases discoverd through extended testing
* implement `plot_trajectory` function to show trajectory on the umap

### Version 1.3.1
* implemented `palantir.plot.plot_stats` to plot arbitray cell-wise statistics as x-/y-positions.
Expand Down
180 changes: 179 additions & 1 deletion src/palantir/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.preprocessing import StandardScaler
from scipy.stats import gaussian_kde
import scanpy as sc
import mellon

import matplotlib
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -290,7 +291,6 @@ def plot_tsne_by_cell_sizes(data, tsne, fig=None, ax=None, vmin=None, vmax=None)
return fig, ax



def plot_gene_expression(
data: pd.DataFrame,
tsne: pd.DataFrame,
Expand Down Expand Up @@ -1873,3 +1873,181 @@ def gene_score_histogram(
)

return fig


def plot_trajectory(
ad: sc.AnnData,
branch: str,
ax: Optional[plt.Axes] = None,
pseudo_time_key: str = "palantir_pseudotime",
masks_key: str = "branch_masks",
embedding_basis: str = "X_umap",
cell_color: str = "branch_selection",
smoothness: float = 1.0,
n_arrows: int = 5,
arrowprops: Optional[dict] = dict(),
scanpy_kwargs: Optional[dict] = dict(),
figsize: Tuple[float, float] = (5, 5),
**kwargs,
):
"""
Plot a trajectory on the UMAP embedding of single-cell data.
Parameters
----------
ad : sc.AnnData
Annotated data matrix. Pseudotime and fate probabilities should be stored under provided keys.
branch : str
Branch/fate to plot the trajectory for.
ax : matplotlib.axes.Axes, optional
Matplotlib axes object to plot on. If None, a new figure is created.
pseudo_time_key : str, optional
Key for pseudotime data in `ad.obs`. Defaults to 'palantir_pseudotime'.
masks_key : str, optional
Key for branch cell selection masks in `ad.obsm`. Defaults to 'branch_masks'.
embedding_basis : str, optional
Key for UMAP embedding in `ad.obsm`. Defaults to 'X_umap'.
cell_color : str or None, optional
Coloring strategy for UMAP plot. 'branch_selection' highlights cells in the branch.
If None, no coloring is applied. Defaults to 'branch_selection'.
smoothness : float, optional
Smoothness of fitted trajectory. Higher value means smoother. Defaults to 1.
n_arrows : int, optional
Number of arrows to plot. Defaults to 5.
arrowprops : dict, optional
Properties for the arrowstyle. If None, defaults to black arrow with lw=1.
scanpy_kwargs : dict, optional
Keyword arguments for the scanpy.pl.emebdding function to plot the cells.
figsize : Tuple[float, float], optional
Size of the plot in inches, as (width, height). Defaults to (5, 5).
**kwargs
Extra keyword arguments are passed to the plot function for trajectory lines.
Returns
-------
matplotlib.axes.Axes
The axes object with the trajectory plot.
"""
if pseudo_time_key not in ad.obs:
raise KeyError(f"{pseudo_time_key} not found in ad.obs")

fate_mask, fate_mask_names = _validate_obsm_key(ad, masks_key)

if embedding_basis not in ad.obsm:
raise KeyError(f"{embedding_basis} not found in ad.obsm")

if branch not in fate_mask_names:
raise (
f"Specified branch name '{branch}' is not in the available set of "
+ ", ".join(fate_mask_names)
)

default_kwargs = {"color": "black"}
default_kwargs.update(kwargs)

pt = ad.obs[pseudo_time_key]
umap = ad.obsm[embedding_basis]

if ax is None:
_, ax = plt.subplots(1, 1, figsize=figsize)

mask = fate_mask[branch].astype(bool)

pseudotime = pt[mask]
pseudotime_grid = np.linspace(np.min(pseudotime), np.max(pseudotime), 200)
ls = (
smoothness
* np.sqrt(np.sum((np.max(umap, axis=0) - np.min(umap, axis=0)) ** 2))
/ 20
)
umap_est = mellon.FunctionEstimator(ls=ls, sigma=ls, n_landmarks=50)
umap_trajectory = umap_est.fit_predict(pseudotime, umap[mask, :], pseudotime_grid)

# plot UMAP
if cell_color == "branch_selection":
ax.scatter(
umap[~mask, 0],
umap[~mask, 1],
c=config.DESELECTED_COLOR,
label="Other Cells",
)
ax.scatter(
umap[mask, 0],
umap[mask, 1],
c=config.SELECTED_COLOR,
label="Selected Cells",
)
elif cell_color is not None:
b = embedding_basis[2:] if embedding_basis.startswith("X_") else embedding_basis
sc.pl.embedding(ad, b, color=cell_color, ax=ax, show=False, **scanpy_kwargs)

# plot trajectory
_plot_arrows(
umap_trajectory[:, 0],
umap_trajectory[:, 1],
n=n_arrows,
ax=ax,
arrowprops=arrowprops,
**default_kwargs,
)

ax.set_xticks([])
ax.set_yticks([])
ax.set_title(f"Branch: {branch}")
ax.axis("off")

return ax


def _plot_arrows(x, y, n=5, ax=None, arrowprops=dict(), **kwargs):
"""
Helper function to plot arrows on a trajectory line.
Parameters
----------
x, y : array-like
Coordinates of the trajectory points.
n : int, optional
Number of arrows to plot. Defaults to 5.
ax : matplotlib.axes.Axes, optional
Matplotlib axes object to plot on. If None, a new figure is created.
arrowprops : dict, optional
Properties for the arrowstyle. If None, defaults to black arrow with lw=1.
**kwargs
Extra keyword arguments are passed to the plot function.
Returns
-------
None
"""
if ax is None:
fig, ax = plt.subplots()

default_kwargs = {"color": "black"}
default_kwargs.update(kwargs)

ax.plot(x, y, **default_kwargs)

if n <= 0:
return ax

default_arrowprops = dict(arrowstyle="->", lw=1)
default_arrowprops["color"] = default_kwargs.get("color", "black")
default_arrowprops.update(arrowprops)

# Calculate the length of each subsection
total_points = len(x)
section_length = total_points // n

for i in range(n):
idx = total_points - i * section_length
if idx < 2:
break
# Add arrowhead at the last point of the subsection on ax
ax.annotate(
"",
xy=(x[idx - 1], y[idx - 1]),
xytext=(x[idx - 2], y[idx - 2]),
arrowprops=default_arrowprops,
)
return ax

0 comments on commit 4a04527

Please sign in to comment.