Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/theislab/ehrapy
Browse files Browse the repository at this point in the history
  • Loading branch information
Zethson committed Apr 11, 2024
2 parents d8780d7 + 9ccc36f commit 26ea3ed
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 28 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ repos:
hooks:
- id: prettier
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.4
rev: v0.3.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: detect-private-key
- id: check-ast
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret
tools.paga
```

### Group comparison
### Feature Ranking

```{eval-rst}
.. autosummary::
Expand All @@ -205,6 +205,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret
tools.rank_features_groups
tools.filter_rank_features_groups
tools.rank_features_supervised
```

### Dataset integration
Expand Down Expand Up @@ -358,7 +359,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'.
plot.paga_compare
```

### Group comparison
### Feature Ranking

```{eval-rst}
.. autosummary::
Expand All @@ -372,6 +373,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'.
plot.rank_features_groups_dotplot
plot.rank_features_groups_matrixplot
plot.rank_features_groups_tracksplot
plot.rank_features_supervised
```

### Survival Analysis
Expand Down
1 change: 1 addition & 0 deletions ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from ehrapy.plot._survival_analysis import kmf, ols
from ehrapy.plot._util import * # noqa: F403
from ehrapy.plot.causal_inference._dowhy import causal_effect
from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised
Empty file.
68 changes: 68 additions & 0 deletions ehrapy/plot/feature_ranking/_feature_importances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from anndata import AnnData
from matplotlib.axes import Axes


def rank_features_supervised(
adata: AnnData,
key: str = "feature_importances",
n_features: int = 10,
ax: Axes | None = None,
show: bool = True,
save: str | None = None,
**kwargs,
) -> Axes | None:
"""Plot features with greates absolute importances as a barplot.
Args:
adata: :class:`~anndata.AnnData` object storing the data. A key in adata.var should contain the feature
importances, calculated beforehand.
key: The key in adata.var to use for feature importances. Defaults to 'feature_importances'.
n_features: The number of features to plot. Defaults to 10.
ax: A matplotlib axes object to plot on. If `None`, a new figure will be created. Defaults to `None`.
show: If `True`, show the figure. If `False`, return the axes object. Defaults to `True`.
save: Path to save the figure. If `None`, the figure will not be saved. Defaults to `None`.
**kwargs: Additional arguments passed to `seaborn.barplot`.
Returns:
If `show == False` a `matplotlib.axes.Axes` object, else `None`.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> ep.pp.knn_impute(adata, n_neighbours=5)
>>> input_features = [
... feat for feat in adata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"}
... ]
>>> ep.tl.rank_features_supervised(adata, "tco2_first", "continuous", "rf", input_features=input_features)
>>> ep.pl.rank_features_supervised(adata)
.. image:: /_static/docstring_previews/feature_importances.png
"""
if key not in adata.var.keys():
raise ValueError(
f"Key {key} not found in adata.var. Make sure to calculate feature importances first with ep.tl.feature_importances."
)

df = pd.DataFrame({"importance": adata.var[key]}, index=adata.var_names)
df["absolute_importance"] = df["importance"].abs()
df = df.sort_values("absolute_importance", ascending=False)

if ax is None:
fig, ax = plt.subplots()
ax = sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs)
plt.ylabel("Feature")
plt.xlabel("Importance")
plt.tight_layout()

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
return None
else:
return ax
18 changes: 9 additions & 9 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def explicit_impute(
replacement: The value to replace missing values with. If a dictionary is provided, the keys represent column
names and the values represent replacement values for those columns.
impute_empty_strings: If True, empty strings are also replaced. Defaults to True.
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 30.
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 70.
copy: If True, returns a modified copy of the original AnnData object. If False, modifies the object in place.
Returns:
Expand Down Expand Up @@ -133,7 +133,7 @@ def simple_impute(
adata: The annotated data matrix to impute missing values on.
var_names: A list of column names to apply imputation on (if None, impute all columns).
strategy: Imputation strategy to use. One of {'mean', 'median', 'most_frequent'}.
warning_threshold: Display a warning message if percentage of missing values exceeds this threshold. Defaults to 30.
warning_threshold: Display a warning message if percentage of missing values exceeds this threshold. Defaults to 70.
copy:Whether to return a copy of `adata` or modify it inplace. Defaults to False.
Returns:
Expand Down Expand Up @@ -214,7 +214,7 @@ def knn_impute(
n_neighbours: Number of neighbors to use when performing the imputation. Defaults to 5.
copy: Whether to perform the imputation on a copy of the original `AnnData` object.
If `True`, the original object remains unmodified. Defaults to `False`.
warning_threshold: Percentage of missing values above which a warning is issued. Defaults to 30.
warning_threshold: Percentage of missing values above which a warning is issued. Defaults to 70.
Returns:
An updated AnnData object with imputed values.
Expand Down Expand Up @@ -328,7 +328,7 @@ def miss_forest_impute(
n_estimators: The number of trees to fit for every missing variable. Has a big effect on the run time.
Decrease for faster computations. Defaults to 100.
random_state: The random seed for the initialization. Defaults to 0.
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 30 .
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 70 .
copy: Whether to return a copy or act in place. Defaults to False.
Returns:
Expand Down Expand Up @@ -464,7 +464,7 @@ def soft_impute(
adata: The AnnData object to impute missing values for.
var_names: A list of var names indicating which columns to impute (if None -> all columns).
copy: Whether to return a copy or act in place.
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 30 .
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 70 .
shrinkage_value : Value by which we shrink singular values on each iteration.
If omitted then the default value will be the maximum singular value of the initialized matrix (zeros for missing values) divided by 50.
convergence_threshold : Minimum ration difference between iterations (as a fraction of the Frobenius norm of the current solution) before stopping.
Expand Down Expand Up @@ -609,7 +609,7 @@ def iterative_svd_impute(
var_names: A list of var names indicating which columns to impute. If `None`, all columns will be imputed.
Defaults to None.
copy: Whether to return a copy of the AnnData object or act in place. Defaults to False.
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 30.
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 70.
rank: Rank of the SVD decomposition. Defaults to 10.
convergence_threshold: Convergence threshold for the iterative algorithm.
The algorithm stops when the relative difference in
Expand Down Expand Up @@ -754,7 +754,7 @@ def matrix_factorization_impute(
Args:
adata: The AnnData object to use MatrixFactorization on.
var_names: A list of var names indicating which columns to impute (if None -> all columns).
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 30 .
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 70 .
rank: Number of latent factors to use in the matrix factorization model.
It determines the size of the latent feature space that will be used to estimate the missing values.
A higher rank will allow for more complex relationships between the features, but it can also lead to overfitting.
Expand Down Expand Up @@ -890,7 +890,7 @@ def nuclear_norm_minimization_impute(
Args:
adata: The AnnData object to apply NuclearNormMinimization on.
var_names: Var names indicating which columns to impute (if None -> all columns).
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 30.
warning_threshold: Threshold of percentage of missing values to display a warning for. Defaults to 70.
require_symmetric_solution: Whether to add a symmetry constraint to the convex problem. Defaults to False.
min_value: Smallest possible imputed value. Defaults to None (no minimum value constraint).
max_value: Largest possible imputed value. Defaults to None (no maximum value constraint).
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def mice_forest_impute(
adata: The AnnData object containing the data to impute.
var_names: A list of variable names to impute. If None, impute all variables.
warning_threshold: Threshold of percentage of missing values to display a warning for.
Defaults to 30.
Defaults to 70.
save_all_iterations: Whether to save all imputed values from all iterations or just the latest.
Saving all iterations allows for additional plotting, but may take more memory. Defaults to True.
random_state: The random state ensures script reproducibility.
Expand Down
27 changes: 12 additions & 15 deletions ehrapy/preprocessing/_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,24 @@ def winsorize(
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.pp.winsorize(adata, ["bmi"])
"""
_validate_outlier_input(adata, obs_cols, vars)

if copy: # pragma: no cover
adata = adata.copy()

if limits is None:
limits = [0.01, 0.99]
obs_cols_set, vars_set = _validate_outlier_input(adata, obs_cols, vars)

if vars:
for var in vars:
adata[:, var].X = scipy.stats.mstats.winsorize(
np.array(adata[:, var].X), limits=limits, nan_policy="omit", **kwargs
)
if vars_set:
for var in vars_set:
data_array = np.array(adata[:, var].X, dtype=float)
winsorized_data = scipy.stats.mstats.winsorize(data_array, limits=limits, nan_policy="omit", **kwargs)
adata[:, var].X = winsorized_data

if obs_cols:
for col in obs_cols:
winsorized_array = scipy.stats.mstats.winsorize(adata.obs[col], limits=limits, nan_policy="omit", **kwargs)
adata.obs[col] = pd.Series(winsorized_array).values
if obs_cols_set:
for col in obs_cols_set:
obs_array = adata.obs[col].to_numpy(dtype=float)
winsorized_obs = scipy.stats.mstats.winsorize(obs_array, limits=limits, nan_policy="omit", **kwargs)
adata.obs[col] = pd.Series(winsorized_obs).values

if copy:
return adata
return adata if copy else None


def clip_quantile(
Expand Down
1 change: 1 addition & 0 deletions ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ehrapy.tools._scanpy_tl_api import * # noqa: F403
from ehrapy.tools.causal._dowhy import causal_inference
from ehrapy.tools.cohort_tracking._cohort_tracker import CohortTracker
from ehrapy.tools.feature_ranking._feature_importances import rank_features_supervised
from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups

try: # pragma: no cover
Expand Down
Loading

0 comments on commit 26ea3ed

Please sign in to comment.