Skip to content

Commit

Permalink
generate_axis function added
Browse files Browse the repository at this point in the history
  • Loading branch information
jwrth committed Mar 1, 2023
1 parent d0bb6ec commit eb41957
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xdbit_funcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
from . import images as im
from . import utils as utils
from . import datasets as data


1 change: 1 addition & 0 deletions xdbit_funcs/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .adata import *
from .preprocessing import *
from .external import *
from .generate_axis import *
152 changes: 152 additions & 0 deletions xdbit_funcs/tools/generate_axis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from scipy.interpolate import splprep, splev, UnivariateSpline
from sklearn.decomposition import PCA
from sklearn.preprocessing import minmax_scale
from anndata import AnnData
import matplotlib.pyplot as plt
from ..tools import get_nrows_maxcols, standard_preprocessing, extract_groups
from ..readwrite import save_and_show_figure
import gc
from tqdm import tqdm
import numpy as np
import pandas as pd

gene_to_sort = "Cyp2e1"
groupby = 'id'
umap_key = 'X_umap'
max_cols = 6
k_spline = 3
s = 5000
#s = None

def generate_axis(
adata: AnnData,
gene_to_sort: str,
groupby: str,
umap_key: str = 'X_umap',
obs_key: str = 'umap_spline',
max_cols: int = 6,
k_spline: int = 3,
s: int = 5000,
plot: bool = True,
savepath: bool = None,
save_only: bool = False,
**kwargs
):

groups = adata.obs[groupby].unique()
# check whether to calculate the UMAP per group or use existing UMAP calculation saved in adata.obsm[umap_key]
if umap_key is None:
# calculate UMAP separately per group
umap_per_group = True

obsm_umap = {}
for group in tqdm(groups):
#print("Processing group {}...".format(group))
# do preprocessing for one group
ad = extract_groups(adata, groupby=groupby, groups=group)
ad.uns['log1p']['base'] = None
adpp = standard_preprocessing(ad,
hvg_n_top_genes=2000,
do_lognorm=False,
dim_reduction=True,
umap=True, tsne=False,
verbose=False)
# collect results
obsm_umap[group] = adpp.obsm['X_umap']

# free RAM
del adpp
gc.collect()
else:
assert umap_key in adata.obsm, "`umap_key` not in `adata.obsm`."
umap_per_group = False
obsm_umap = None

# decide whether to generate a plot or not
if plot:
n_plots, nrows, ncols = get_nrows_maxcols(groups, max_cols=max_cols)
fig, axs = plt.subplots(nrows, ncols, figsize=(8*ncols, 6*nrows))

if len(axs.shape) > 1:
axs = axs.ravel()

# start processing
intpol = {}
for i, group in enumerate(groups):
# generate selection mask
mask = adata.obs[groupby] == group

# extract points, indices and expression of gene
if umap_per_group:
pts = obsm_umap[group]
else:
pts = adata.obsm[umap_key][mask]

idxs = adata.obs_names[mask]
expr = adata.X[:, adata.var_names.get_loc(gene_to_sort)][mask]

# perform PCA to rotate datapoints and always have the longest data axis as x
pca = PCA(n_components=2)
pcs = pca.fit_transform(pts)

# sort points and indices by x values
sortmask = pcs[:, 0].argsort()
pcs = pcs[sortmask]
idxs = idxs[sortmask]
expr = expr[sortmask]

# extact coordinates
x = pcs[:, 0]
y = pcs[:, 1]

# adjust x values according to expression of selected gene
slope = np.polyfit(x, expr, 1)[0]
slope = -1 if slope < 0 else 1
x *= slope

# sort again by x
sortmask = np.argsort(x)
y = y[sortmask]
expr = expr[sortmask]
idxs = idxs[sortmask]
x.sort()

# calculate spline
spl = UnivariateSpline(x, y, k=k_spline, s=s)
ys = spl(x)

# create array of spline points
Xs = np.array([x, ys]).T

# collect results
intpol[group] = pd.DataFrame(Xs, index=idxs, columns=['x', 'y'])

if plot:
axs[i].scatter(x, y, c=expr)
axs[i].plot(x, ys, 'r', lw=3)
axs[i].set_title(group)

if plot:
save_and_show_figure(savepath=savepath, fig=fig, save_only=save_only, **kwargs)

x_new = {}
for mid, Xs in intpol.items():
# calculate distance of consecutive points
d = np.diff(Xs.values, axis=0)
segdists = np.hypot(d[:,0], d[:,1])

# calculate cumulative sum of distances and add 0 at beginning
cumsum = np.insert(np.cumsum(segdists), 0, 0)

# min max scaling
cumsum = minmax_scale(cumsum)

x_new[mid] = pd.Series(cumsum, index=Xs.index)

# concatenate results and reshape index
x_new = pd.concat(x_new)

x_new.index = x_new.index.droplevel(0)

# add data to obs
adata.obs[obs_key] = x_new

0 comments on commit eb41957

Please sign in to comment.