Skip to content

Commit

Permalink
Merge branch 'combat' of https://github.com/Marius1311/scanpy into Ma…
Browse files Browse the repository at this point in the history
…rius1311-combat
  • Loading branch information
falexwolf committed Jan 6, 2019
2 parents a3958f8 + 8668f66 commit f71b73e
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ scripts/paths.py
scripts/test_private.sh
scripts/tgdyn.py
scripts/tgdyn_simple.py
.spyproject/
scanpy/.spyproject/



# always-ignore extensions
Expand Down
14 changes: 13 additions & 1 deletion docs/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ References
*ForceAtlas2, a Continuous Graph Layout Algorithm for Handy Network Visualization Designed for the Gephi Software*
`PLOS One <https://doi.org/10.1371/journal.pone.0098679>`__.
.. [Johnson07] Johnson, Li & Rabinovic (2007),
*Adjusting batch effects in microarray expression data using empirical Bayes methods*,
`Biostatistics <https://doi.org/10.1093/biostatistics/kxj037>`__.
.. [Kang18] Kang *et al.* (2018),
*Python Implementation of MNN correct*,
`GitHub <https://github.com/chriscainx/mnnpy>`__.
Expand All @@ -94,6 +98,10 @@ References
.. [Lambiotte09] Lambiotte *et al.* (2009)
*Laplacian Dynamics and Multiscale Modular Structure in Networks*
`arXiv <https://arxiv.org/abs/0812.1770>`__.
.. [Leek12] Leek *et al.* (2012),
*sva: Surrogate Variable Analysis. R package*
`Bioconductor <https://doi.org/10.18129/B9.bioc.sva>`__.
.. [Levine15] Levine *et al.* (2015),
*Data-Driven Phenotypic Dissection of AML Reveals Progenitor--like Cells that Correlate with Prognosis*,
Expand Down Expand Up @@ -139,7 +147,11 @@ References
.. [Park18] Park *et al.* (2018),
*Fast Batch Alignment of Single Cell Transcriptomes Unifies Multiple Mouse Cell Atlases into an Integrated Landscape*
`bioRxiv <https://doi.org/10.1101/397042>`__.
.. [Pedersen12] Pedersen (2012),
*Python implementation of ComBat*
`GitHub <https://github.com/brentp/combat.py>`__.
.. [Pedregosa11] Pedregosa *et al.* (2011),
*Scikit-learn: Machine Learning in Python*,
`JMLR <http://www.jmlr.org/papers/v12/pedregosa11a.html>`__.
Expand Down
2 changes: 2 additions & 0 deletions scanpy/api/pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
from ..preprocessing._dca import dca
from ..preprocessing._magic import magic
from ..neighbors import neighbors
from ..preprocessing.combat import combat

273 changes: 273 additions & 0 deletions scanpy/preprocessing/combat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import numpy as np
from scipy.sparse import issparse
import pandas as pd
import sys
from numpy import linalg as la
import patsy
import numba
import pdb

def design_mat(model, batch_levels):
"""
Computes a simple design matrix.
At the moment, only includes the categorical annotations passed with the 'key' argument
to the combat function
Parameters
--------
model : pd.DataFrame
Contains the batch annotation
batch_levels : list
Levels of the batch annotation
Returns
--------
design : pd.DataFrame
The design matrix for the regression problem
"""

design = patsy.dmatrix("~ 0 + C(batch, levels={})".format(batch_levels),
model, return_type="dataframe")
model = model.drop(["batch"], axis=1)
sys.stderr.write("found %i batches\n" % design.shape[1])
other_cols = [c for i, c in enumerate(model.columns)]
factor_matrix = model[other_cols]
design = pd.concat((design, factor_matrix), axis=1)
sys.stderr.write("found %i categorical variables:" % len(other_cols))
sys.stderr.write("\t" + ", ".join(other_cols) + '\n')

return design


def stand_data(model, data):
"""
Standardizes the data per gene.
The aim here is to make mean and variance be comparable across batches.
Parameters
--------
model : pd.DataFrame
Contains the batch annotation
data : pd.DataFrame
Contains the Data
Returns
--------
s_data : pd.DataFrame
Standardized Data
design : pd.DataFrame
Batch assignment as one-hot encodings
var_pooled : np.array
Pooled variance per gene
stand_mean : np.array
Gene-wise mean
"""

# compute the design matrix
batch_items = model.groupby("batch").groups.items()
batch_levels = [k for k, v in batch_items]
batch_info = [v for k, v in batch_items]
n_batch = len(batch_info)
n_batches = np.array([len(v) for v in batch_info])
n_array = float(sum(n_batches))
drop_cols = [cname for cname, inter in ((model == 1).all()).iteritems() if inter == True]
model = model[[c for c in model.columns if not c in drop_cols]]
design = design_mat(model, batch_levels)

# compute pooled variance estimator
B_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
grand_mean = np.dot((n_batches / n_array).T, B_hat[:n_batch,:])
var_pooled = (data - np.dot(design, B_hat).T)**2
var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array))

# Compute the means
if np.sum(var_pooled == 0) > 0:
print('Found {} genes with zero variance.'.\
format(np.sum(var_pooled == 0)))
stand_mean = np.dot(grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array))))
tmp = np.array(design.copy())
tmp[:, :n_batch] = 0
stand_mean += np.dot(tmp, B_hat).T

# need to be a bit careful with the zero variance genes
# just set the zero variance genes to zero in the standardized data
s_data = np.where(var_pooled == 0, 0,
((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))))
s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns)

return s_data, design, var_pooled, stand_mean


def combat(adata, key = 'batch', inplace = True):
"""
ComBat function for batch effect correction [Johnson07]_ [Leek12]_.
Corrects for batch effects by fitting linear models, gains statistical power
via an EB framework where information is borrowed across genes. This uses the
implementation of `ComBat <https://github.com/brentp/combat.py>`__ [Pedersen12]_.
Parameters
----------
adata : :class:`~anndata.AnnData`
Annotated data matrix
key: `str`, optional (default: `"batch"`)
Key to a categorical annotation from adata.obs that will be used for batch effect removal
inplace: bool, optional (default: `True`)
Wether to replace adata.X or to return the corrected data
Returns
-------
Depending on the value of inplace, either returns an updated AnnData object
or modifies the passed one.
"""

# check the input
if key not in adata.obs.keys():
raise ValueError('Could not find the key \'{}\' in adata.obs'.format(key))

# only works on dense matrices so far
if issparse(adata.X):
X = adata.X.A.T
else:
X = adata.X.T
data = pd.DataFrame(data=X, index=adata.var_names,
columns=adata.obs_names)

# construct a pandas series of the batch annotation
batch = pd.Series(adata.obs[key])
model = pd.DataFrame({'batch': batch})
batch_items = model.groupby("batch").groups.items()
batch_info = [v for k, v in batch_items]
n_batch = len(batch_info)
n_batches = np.array([len(v) for v in batch_info])
n_array = float(sum(n_batches))

# standardize across genes using a pooled variance estimator
sys.stderr.write("Standardizing Data across genes.\n")
s_data, design, var_pooled, stand_mean = stand_data(model, data)

# fitting the parameters on the standardized data
sys.stderr.write("Fitting L/S model and finding priors\n")
batch_design = design[design.columns[:n_batch]]
# first estimate of the additive batch effect
gamma_hat = np.dot(np.dot(la.inv(np.dot(batch_design.T, batch_design)), batch_design.T), s_data.T)
delta_hat = []

# first estimate for the multiplicative batch effect
for i, batch_idxs in enumerate(batch_info):
delta_hat.append(s_data[batch_idxs].var(axis=1))

# empirically fix the prior hyperparameters
gamma_bar = gamma_hat.mean(axis=1)
t2 = gamma_hat.var(axis=1)
# a_prior and b_prior are the priors on lambda and theta from Johnson and Li (2006)
a_prior = list(map(aprior, delta_hat))
b_prior = list(map(bprior, delta_hat))

sys.stderr.write("Finding parametric adjustments\n")
# gamma star and delta star will be our empirical bayes (EB) estimators
# for the additive and multiplicative batch effect per batch and cell
gamma_star, delta_star = [], []
for i, batch_idxs in enumerate(batch_info):
# temp stores our estimates for the batch effect parameters.
# temp[0] is the additive batch effect
# temp[1] is the multiplicative batch effect
temp = _it_sol(s_data[batch_idxs].values, gamma_hat[i],
delta_hat[i].values, gamma_bar[i], t2[i], a_prior[i], b_prior[i])

gamma_star.append(temp[0])
delta_star.append(temp[1])

sys.stdout.write("Adjusting data\n")
bayesdata = s_data
gamma_star = np.array(gamma_star)
delta_star = np.array(delta_star)

# we now apply the parametric adjustment to the standardized data from above
# loop over all batches in the data
for j, batch_idxs in enumerate(batch_info):

# we basically substract the additive batch effect, rescale by the ratio
# of multiplicative batch effect to pooled variance and add the overall gene
# wise mean
dsq = np.sqrt(delta_star[j,:])
dsq = dsq.reshape((len(dsq), 1))
denom = np.dot(dsq, np.ones((1, n_batches[j])))
numer = np.array(bayesdata[batch_idxs] - np.dot(batch_design.loc[batch_idxs], gamma_star).T)
bayesdata[batch_idxs] = numer / denom

vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean

# put back into the adata object or return
if inplace:
adata.X = bayesdata.values.transpose()
else:
return bayesdata.values.transpose()

@numba.jit
def _it_sol(s_data, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001):
"""
Iteratively compute the conditional posterior means for gamma and delta.
gamma is an estimator for the additive batch effect, deltat is an estimator
for the multiplicative batch effect. We use an EB framework to estimate these
two. Analytical expressions exist for both parameters, which however depend on each other.
We therefore iteratively evalutate these two expressions until convergence is reached.
Parameters
--------
s_data : pd.DataFrame
Contains the standardized Data
g_hat : float
Initial guess for gamma
d_hat : float
Initial guess for delta
g_bar, t_2, a, b : float
Hyperparameters
conv: float, optional (default: `0.0001`)
convergence criterium
Returns:
--------
adjust: tuple
contains estimated values for gamma and delta
"""

n = (1 - np.isnan(s_data)).sum(axis=1)
g_old = g_hat.copy()
d_old = d_hat.copy()

change = 1
count = 0

# we place a normally distributed prior on gamma and and inverse gamma prior on delta
# in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence.
while change > conv:
g_new = (t2*n*g_hat + d_old*g_bar) / (t2*n + d_old)
sum2 = s_data - g_new.reshape((g_new.shape[0], 1)) @ np.ones((1, s_data.shape[1]))
sum2 = sum2 ** 2
sum2 = sum2.sum(axis = 1)
d_new = (0.5*sum2 + b) / (n/2.0 + a-1.0)

change = max((abs(g_new - g_old) / g_old).max(), (abs(d_new - d_old) / d_old).max())
g_old = g_new #.copy()
d_old = d_new #.copy()
count = count + 1

adjust = (g_new, d_new)
return adjust


def aprior(delta_hat):
m = delta_hat.mean()
s2 = delta_hat.var()
return (2 * s2 +m**2) / s2


def bprior(delta_hat):
m = delta_hat.mean()
s2 = delta_hat.var()
return (m*s2+m**3)/s2
45 changes: 45 additions & 0 deletions scanpy/tests/test_combat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import pandas as pd
import scanpy.api as sc
import scanpy
from sklearn.metrics import silhouette_score
from scanpy.preprocessing.combat import stand_data


def test_norm():
# this test trivially checks wether mean normalisation worked

# load in data
adata = sc.datasets.blobs()
key = 'blobs'
data = pd.DataFrame(data=adata.X.T, index=adata.var_names,
columns=adata.obs_names)

# construct a pandas series of the batch annotation
batch = pd.Series(adata.obs[key])
model = pd.DataFrame({'batch': batch})

# standardize the data
s_data, design, var_pooled, stand_mean = stand_data(model, data)

assert np.allclose(s_data.mean(axis = 1), np.zeros(s_data.shape[0]))


def test_shilhouette():
# this test checks wether combat can align data from several gaussians
# it checks this by computing the silhouette coefficient in a pca embedding

# load in data
adata = sc.datasets.blobs()

# apply combat
sc.pp.combat(adata, 'blobs')

# compute pca
sc.tl.pca(adata)
X_pca = adata.obsm['X_pca']

# compute silhouette coefficient in pca
sh = silhouette_score(X_pca[:, :2], adata.obs['blobs'].values )

assert sh < 0.1

0 comments on commit f71b73e

Please sign in to comment.