Skip to content

Commit

Permalink
added tests and restructured
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius1311 committed Dec 18, 2018
1 parent 5126de5 commit 9ca748b
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 61 deletions.
14 changes: 13 additions & 1 deletion docs/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ References
*ForceAtlas2, a Continuous Graph Layout Algorithm for Handy Network Visualization Designed for the Gephi Software*
`PLOS One <https://doi.org/10.1371/journal.pone.0098679>`__.
.. [Johnson07] Johnson, Li & Rabinovic (2007),
*Adjusting batch effects in microarray expression data using empirical Bayes methods*,
`Biostatistics <https://doi.org/10.1093/biostatistics/kxj037>`__.
.. [Kang18] Kang *et al.* (2018),
*Python Implementation of MNN correct*,
`GitHub <https://github.com/chriscainx/mnnpy>`__.
Expand All @@ -94,6 +98,10 @@ References
.. [Lambiotte09] Lambiotte *et al.* (2009)
*Laplacian Dynamics and Multiscale Modular Structure in Networks*
`arXiv <https://arxiv.org/abs/0812.1770>`__.
.. [Leek12] Leek *et al.* (2012),
*sva: Surrogate Variable Analysis. R package*
`Bioconductor <https://doi.org/10.18129/B9.bioc.sva>`__.
.. [Levine15] Levine *et al.* (2015),
*Data-Driven Phenotypic Dissection of AML Reveals Progenitor--like Cells that Correlate with Prognosis*,
Expand Down Expand Up @@ -139,7 +147,11 @@ References
.. [Park18] Park *et al.* (2018),
*Fast Batch Alignment of Single Cell Transcriptomes Unifies Multiple Mouse Cell Atlases into an Integrated Landscape*
`bioRxiv <https://doi.org/10.1101/397042>`__.
.. [Pedersen12] Pedersen (2012),
*Python implementation of ComBat*
`GitHub <https://github.com/brentp/combat.py>`__.
.. [Pedregosa11] Pedregosa *et al.* (2011),
*Scikit-learn: Machine Learning in Python*,
`JMLR <http://www.jmlr.org/papers/v12/pedregosa11a.html>`__.
Expand Down
240 changes: 180 additions & 60 deletions scanpy/preprocessing/combat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,99 +4,179 @@
import sys
from numpy import linalg as la
import patsy
import pdb

def design_mat(mod, batch_levels):
# require levels to make sure they are in the same order as we use in the
# rest of the script.
design = patsy.dmatrix("~ 0 + C(batch, levels=%s)" % str(batch_levels),
mod, return_type="dataframe")
def design_mat(model, batch_levels):
"""
Computes a simple design matrix.
At the moment, only includes the categorical annotations passed with the 'key' argument
to the combat function
Parameters
--------
model : pd.DataFrame
Contains the batch annotation
batch_levels : list
Levels of the batch annotation
Returns
--------
design : pd.DataFrame
The design matrix for the regression problem
"""

mod = mod.drop(["batch"], axis=1)
design = patsy.dmatrix("~ 0 + C(batch, levels={})".format(batch_levels),
model, return_type="dataframe")
model = model.drop(["batch"], axis=1)
sys.stderr.write("found %i batches\n" % design.shape[1])
other_cols = [c for i, c in enumerate(mod.columns)]
factor_matrix = mod[other_cols]
other_cols = [c for i, c in enumerate(model.columns)]
factor_matrix = model[other_cols]
design = pd.concat((design, factor_matrix), axis=1)
sys.stderr.write("found %i categorical variables:" % len(other_cols))
sys.stderr.write("\t" + ", ".join(other_cols) + '\n')

return design

def combat(adata, key = 'batch'):
"""Correct for batch effects in a dataset. Expects normalised, log-transformed data.
This is

def stand_data(model, data):
"""
Standardizes the data per gene.
The aim here is to make mean and variance be comparable across batches.
Parameters
----------
adata : AnnData object
key: str
key to a categorical annotation from adata.obs that will be used for batch effect removal
copy: bool
wether to update the adata object or to copy it
--------
model : pd.DataFrame
Contains the batch annotation
data : pd.DataFrame
Contains the Data
Returns
-------
Depending on the value of copy, either returns an updated AnnData object or modifies the passed one
--------
s_data : pd.DataFrame
Standardized Data
design : pd.DataFrame
Batch assignment as one-hot encodings
var_pooled : np.array
Pooled variance per gene
stand_mean : np.array
Gene-wise mean
"""

# only works on dense matrices so far
if issparse(adata.X):
X = adata.X.A.T
else:
X = adata.X.T
data = pd.DataFrame(data = X, index = adata.var_names,
columns= adata.obs_names)

# construct a pandas series of the batch annotation
batch = pd.Series(adata.obs[key])
model = pd.DataFrame({'batch': batch})
# compute the design matrix
batch_items = model.groupby("batch").groups.items()
batch_levels = [k for k, v in batch_items]
batch_info = [v for k, v in batch_items]
n_batch = len(batch_info)
n_batches = np.array([len(v) for v in batch_info])
n_array = float(sum(n_batches))

# drop intercept
drop_cols = [cname for cname, inter in ((model == 1).all()).iteritems() if inter == True]
model = model[[c for c in model.columns if not c in drop_cols]]
design = design_mat(model, batch_levels)

# standardize across genes using a pooled variance estimator
sys.stderr.write("Standardizing Data across genes.\n")

# compute pooled variance estimator
B_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
grand_mean = np.dot((n_batches / n_array).T, B_hat[:n_batch,:])
var_pooled = np.dot(((data - np.dot(design, B_hat).T)**2), np.ones((int(n_array), 1)) / int(n_array))
var_pooled = (data - np.dot(design, B_hat).T)**2
var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array))

# Compute the means
if np.sum(var_pooled == 0) > 0:
print('Found {} genes with zero variance.'.\
format(np.sum(var_pooled == 0)))
stand_mean = np.dot(grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array))))
tmp = np.array(design.copy())
tmp[:,:n_batch] = 0
tmp[:, :n_batch] = 0
stand_mean += np.dot(tmp, B_hat).T

# need to be a bit careful with the zero variance genes
s_data = np.where(var_pooled == 0, 0, \
# just set the zero variance genes to zero in the standardized data
s_data = np.where(var_pooled == 0, 0,
((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))))
s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns)

pdb.set_trace()

return s_data, design, var_pooled, stand_mean


# fitting the parameters
def combat(adata, key = 'batch', inplace = True):
"""
ComBat function for batch effect correction [Johnson07]_ [Leek12]_.
Corrects for batch effects by fitting linear models, gains statistical power
via an EB framework where information is borrowed across genes. This uses the
implementation of `ComBat <https://github.com/brentp/combat.py>`__ [Pedersen12]_.
Parameters
----------
adata : :class:`~anndata.AnnData`
Annotated data matrix
key: `str`, optional (default: `"batch"`)
Key to a categorical annotation from adata.obs that will be used for batch effect removal
inplace: bool, optional (default: `True`)
Wether to replace adata.X or to return the corrected data
Returns
-------
Depending on the value of inplace, either returns an updated AnnData object
or modifies the passed one.
"""

# check the input
if key not in adata.obs.keys():
raise ValueError('Could not find the key \'{}\' in adata.obs'.format(key))

# only works on dense matrices so far
if issparse(adata.X):
X = adata.X.A.T
else:
X = adata.X.T
data = pd.DataFrame(data=X, index=adata.var_names,
columns=adata.obs_names)

# construct a pandas series of the batch annotation
batch = pd.Series(adata.obs[key])
model = pd.DataFrame({'batch': batch})
batch_items = model.groupby("batch").groups.items()
batch_info = [v for k, v in batch_items]
n_batch = len(batch_info)
n_batches = np.array([len(v) for v in batch_info])
n_array = float(sum(n_batches))

# standardize across genes using a pooled variance estimator
sys.stderr.write("Standardizing Data across genes.\n")
s_data, design, var_pooled, stand_mean = stand_data(model, data)

# fitting the parameters on the standardized data
sys.stderr.write("Fitting L/S model and finding priors\n")
batch_design = design[design.columns[:n_batch]]
# first estimate of the additive batch effect
gamma_hat = np.dot(np.dot(la.inv(np.dot(batch_design.T, batch_design)), batch_design.T), s_data.T)
delta_hat = []

# first estimate for the multiplicative batch effect
for i, batch_idxs in enumerate(batch_info):
delta_hat.append(s_data[batch_idxs].var(axis=1))

# empirically fix the prior hyperparameters
gamma_bar = gamma_hat.mean(axis=1)
t2 = gamma_hat.var(axis=1)
# a_prior and b_prior are the priors on lambda and theta from Johnson and Li (2006)
a_prior = list(map(aprior, delta_hat))
b_prior = list(map(bprior, delta_hat))

sys.stderr.write("Finding parametric adjustments\n")
# gamma star and delta star will be our empirical bayes (EB) estimators
# for the additive and multiplicative batch effect per batch and cell
gamma_star, delta_star = [], []
for i, batch_idxs in enumerate(batch_info):
temp = it_sol(s_data[batch_idxs], gamma_hat[i],
delta_hat[i], gamma_bar[i], t2[i], a_prior[i], b_prior[i])
# temp stores our estimates for the batch effect parameters.
# temp[0] is the additive batch effect
# temp[1] is the multiplicative batch effect
temp = _it_sol(s_data[batch_idxs], gamma_hat[i],
delta_hat[i], gamma_bar[i], t2[i], a_prior[i], b_prior[i])

gamma_star.append(temp[0])
delta_star.append(temp[1])
Expand All @@ -106,33 +186,69 @@ def combat(adata, key = 'batch'):
gamma_star = np.array(gamma_star)
delta_star = np.array(delta_star)


# we now apply the parametric adjustment to the standardized data from above
# loop over all batches in the data
for j, batch_idxs in enumerate(batch_info):


# we basically substract the additive batch effect, rescale by the ratio
# of multiplicative batch effect to pooled variance and add the overall gene
# wise mean
dsq = np.sqrt(delta_star[j,:])
dsq = dsq.reshape((len(dsq), 1))
denom = np.dot(dsq, np.ones((1, n_batches[j])))
numer = np.array(bayesdata[batch_idxs] - np.dot(batch_design.loc[batch_idxs], gamma_star).T)

bayesdata[batch_idxs] = numer / denom

vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean

# put back into the adata object
adata.X = bayesdata.values.transpose()
# put back into the adata object or return
if inplace:
adata.X = bayesdata.values.transpose()
else:
return bayesdata.values.transpose()


def it_sol(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001):
n = (1 - np.isnan(sdat)).sum(axis=1)

def _it_sol(s_data, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001):
"""
Iteratively compute the conditional posterior means for gamma and delta.
gamma is an estimator for the additive batch effect, deltat is an estimator
for the multiplicative batch effect. We use an EB framework to estimate these
two. Analytical expressions exist for both parameters, which however depend on each other.
We therefore iteratively evalutate these two expressions until convergence is reached.
Parameters
--------
s_data : pd.DataFrame
Contains the standardized Data
g_hat : float
Initial guess for gamma
d_hat : float
Initial guess for delta
g_bar, t_2, a, b : float
Hyperparameters
conv: float, optional (default: `0.0001`)
convergence criterium
Returns:
--------
adjust: tuple
contains estimated values for gamma and delta
"""

n = (1 - np.isnan(s_data)).sum(axis=1)
g_old = g_hat.copy()
d_old = d_hat.copy()

change = 1
count = 0

# we place a normally distributed prior on gamma and and inverse gamma prior on delta
# in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence.
while change > conv:
g_new = postmean(g_hat, g_bar, n, d_old, t2)
sum2 = ((sdat - np.dot(g_new.values.reshape((g_new.shape[0], 1)), np.ones((1, sdat.shape[1])))) ** 2).sum(axis=1)
sum2 = ((s_data - np.dot(g_new.values.reshape((g_new.shape[0], 1)), np.ones((1, s_data.shape[1])))) ** 2).sum(axis=1)
d_new = postvar(sum2, n, a, b)

change = max((abs(g_new - g_old) / g_old).max(), (abs(d_new - d_old) / d_old).max())
Expand All @@ -142,20 +258,24 @@ def it_sol(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001):
adjust = (g_new, d_new)
return adjust



def aprior(gamma_hat):
m = gamma_hat.mean()
s2 = gamma_hat.var()
def aprior(delta_hat):
m = delta_hat.mean()
s2 = delta_hat.var()
return (2 * s2 +m**2) / s2

def bprior(gamma_hat):
m = gamma_hat.mean()
s2 = gamma_hat.var()

def bprior(delta_hat):
m = delta_hat.mean()
s2 = delta_hat.var()
return (m*s2+m**3)/s2


def postmean(g_hat, g_bar, n, d_star, t2):
return (t2*n*g_hat+d_star * g_bar) / (t2*n+d_star)
return (t2*n*g_hat + d_star*g_bar) / (t2*n + d_star)


def postvar(sum2, n, a, b):
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
return (0.5*sum2 + b) / (n/2.0 + a-1.0)


Loading

0 comments on commit 9ca748b

Please sign in to comment.