-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/theislab/ehrapy
- Loading branch information
Showing
11 changed files
with
337 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
import seaborn as sns | ||
from anndata import AnnData | ||
from matplotlib.axes import Axes | ||
|
||
|
||
def rank_features_supervised( | ||
adata: AnnData, | ||
key: str = "feature_importances", | ||
n_features: int = 10, | ||
ax: Axes | None = None, | ||
show: bool = True, | ||
save: str | None = None, | ||
**kwargs, | ||
) -> Axes | None: | ||
"""Plot features with greates absolute importances as a barplot. | ||
Args: | ||
adata: :class:`~anndata.AnnData` object storing the data. A key in adata.var should contain the feature | ||
importances, calculated beforehand. | ||
key: The key in adata.var to use for feature importances. Defaults to 'feature_importances'. | ||
n_features: The number of features to plot. Defaults to 10. | ||
ax: A matplotlib axes object to plot on. If `None`, a new figure will be created. Defaults to `None`. | ||
show: If `True`, show the figure. If `False`, return the axes object. Defaults to `True`. | ||
save: Path to save the figure. If `None`, the figure will not be saved. Defaults to `None`. | ||
**kwargs: Additional arguments passed to `seaborn.barplot`. | ||
Returns: | ||
If `show == False` a `matplotlib.axes.Axes` object, else `None`. | ||
Examples: | ||
>>> import ehrapy as ep | ||
>>> adata = ep.dt.mimic_2(encoded=False) | ||
>>> ep.pp.knn_impute(adata, n_neighbours=5) | ||
>>> input_features = [ | ||
... feat for feat in adata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"} | ||
... ] | ||
>>> ep.tl.rank_features_supervised(adata, "tco2_first", "continuous", "rf", input_features=input_features) | ||
>>> ep.pl.rank_features_supervised(adata) | ||
.. image:: /_static/docstring_previews/feature_importances.png | ||
""" | ||
if key not in adata.var.keys(): | ||
raise ValueError( | ||
f"Key {key} not found in adata.var. Make sure to calculate feature importances first with ep.tl.feature_importances." | ||
) | ||
|
||
df = pd.DataFrame({"importance": adata.var[key]}, index=adata.var_names) | ||
df["absolute_importance"] = df["importance"].abs() | ||
df = df.sort_values("absolute_importance", ascending=False) | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots() | ||
ax = sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs) | ||
plt.ylabel("Feature") | ||
plt.xlabel("Importance") | ||
plt.tight_layout() | ||
|
||
if save: | ||
plt.savefig(save, bbox_inches="tight") | ||
if show: | ||
plt.show() | ||
return None | ||
else: | ||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.