Skip to content

Commit

Permalink
switched to "reference" instead of "baseline"
Browse files Browse the repository at this point in the history
  • Loading branch information
johannesostner committed Dec 4, 2020
1 parent 7e32dfb commit d4981e4
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 632 deletions.
2 changes: 1 addition & 1 deletion docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Models
scCODA uses Bayesian modeling to detect changes in compositional data.
The model is implemented in ``sccoda.model.dirichlet_models``.
Te easiest way to call a compositional model is via calling an instance of ``sccoda.util.comp_ana.CompositionalAnalysis``, which returns a compositional model
scCODA automatically selects the correct model based on whether a baseline cell type was specified.
scCODA automatically selects the correct model based on whether a reference cell type was specified.

Model structure
~~~~~~~~~~~~~~~
Expand Down
104 changes: 56 additions & 48 deletions sccoda/model/dirichlet_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Dirichlet-multinomial models for statistical analysis of compositional changes
Dirichlet-multinomial models for statistical analysis of compositional changes in single-cell data.
For further reference, see:
Johannes Ostner: Development of a statistical framework for compositional analysis of single-cell data
Büttner et al.: scCODA: A Bayesian model for compositional single-cell data analysis
:authors: Johannes Ostner
"""
Expand All @@ -19,14 +19,14 @@
tfb = tfp.bijectors


class CompositionalModel:
class CompositionalModel():
"""
Implements class framework for compositional data models
Dynamical framework for formulation and inference of Bayesian models for compositional data analysis.
"""

def __init__(self, covariate_matrix, data_matrix, cell_types, covariate_names, formula, *args, **kwargs):
"""
Generalized Constructor of model class
Generalized Constructor of Bayesian compositional model class.
Parameters
----------
Expand All @@ -43,7 +43,7 @@ def __init__(self, covariate_matrix, data_matrix, cell_types, covariate_names, f
dtype = tf.float64
self.x = tf.convert_to_tensor(covariate_matrix, dtype)

# Add pseudocount if needed.
# Add pseudocount if zeroes are present.
if np.count_nonzero(data_matrix) != np.size(data_matrix):
print("Zero counts encountered in data! Added a pseudocount of 0.5.")
data_matrix += 0.5
Expand All @@ -67,7 +67,7 @@ def __init__(self, covariate_matrix, data_matrix, cell_types, covariate_names, f

def sampling(self, num_results, num_burnin, kernel, init_state, trace_fn):
"""
MCMC sampling process
MCMC sampling process (tensorflow 2)
Parameters
----------
Expand Down Expand Up @@ -112,7 +112,8 @@ def sample_mcmc(num_results_, num_burnin_, kernel_, current_state_, trace_fn):

def get_chains_after_burnin(self, samples, kernel_results, num_burnin, is_nuts=False):
"""
Application of burnin after sampling
Application of burn-in after MCMC sampling.
Cuts the first `num_burnin` samples from all inferred variables and diagnostic statistics.
Parameters
----------
Expand All @@ -121,15 +122,16 @@ def get_chains_after_burnin(self, samples, kernel_results, num_burnin, is_nuts=F
kernel_results -- list
Kernel meta-information
num_burnin -- int
number of burnin iterations
number of burn-in iterations
Returns
-------
states_burnin -- list
Kernel states without burnin samples
Kernel states without burn-in samples
p_accept -- float
acceptance rate of MCMC process
"""

# Samples after burn-in
states_burnin = []
stats = {}
Expand All @@ -151,9 +153,11 @@ def get_chains_after_burnin(self, samples, kernel_results, num_burnin, is_nuts=F

return states_burnin, stats, p_accept

def sample_hmc(self, num_results=int(20e3), num_burnin=int(5e3), num_leapfrog_steps=10, step_size=0.01, num_adapt_steps=None):
def sample_hmc(self, num_results=int(20e3), num_burnin=int(5e3), num_adapt_steps=None,
num_leapfrog_steps=10, step_size=0.01):

"""
HMC sampling
HMC sampling in tensorflow 2.
Parameters
----------
Expand All @@ -165,6 +169,8 @@ def sample_hmc(self, num_results=int(20e3), num_burnin=int(5e3), num_leapfrog_st
HMC leapfrog steps (default 10)
step_size -- float
Initial step size (default 0.01)
num_adapt_steps -- int
Length of step size adaptation procedure
Returns
-------
Expand All @@ -191,7 +197,7 @@ def sample_hmc(self, num_results=int(20e3), num_burnin=int(5e3), num_leapfrog_st
hmc_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=hmc_kernel, num_adaptation_steps=num_adapt_steps, target_accept_prob=0.8)

# tracing function
# diagnostics tracing function
def trace_fn(_, pkr):
return {
'target_log_prob': pkr.inner_results.inner_results.accepted_results.target_log_prob,
Expand All @@ -200,21 +206,22 @@ def trace_fn(_, pkr):
'step_size': pkr.inner_results.inner_results.accepted_results.step_size,
}

# HMC sampling
# The actual HMC sampling process
states, kernel_results, duration = self.sampling(num_results, num_burnin, hmc_kernel, self.params, trace_fn)

# apply burnin
states_burnin, sample_stats, acc_rate = self.get_chains_after_burnin(states, kernel_results, num_burnin)
# apply burn-in
states_burnin, sample_stats, acc_rate = self.get_chains_after_burnin(states, kernel_results, num_burnin,
is_nuts=False)

# Calculate posterior predictive
y_hat = self.get_y_hat(states_burnin, num_results, num_burnin)

params = dict(zip(self.param_names, states_burnin))

# Result object generation setup
# Get names of cell types that are not the baseline
if self.baseline_index is not None:
cell_types_nb = self.cell_types[:self.baseline_index] + self.cell_types[self.baseline_index+1:]
# Get names of cell types that are not the reference
if self.reference_cell_type is not None:
cell_types_nb = self.cell_types[:self.reference_cell_type] + self.cell_types[self.reference_cell_type+1:]
else:
cell_types_nb = self.cell_types

Expand Down Expand Up @@ -247,7 +254,7 @@ def trace_fn(_, pkr):
sampling_stats = {"chain_length": num_results, "num_burnin": num_burnin,
"acc_rate": acc_rate, "duration": duration, "y_hat": y_hat}

model_specs = {"baseline": self.baseline_index, "formula": self.formula}
model_specs = {"reference": self.reference_cell_type, "formula": self.formula}

return res.CAResultConverter(posterior=posterior,
posterior_predictive=posterior_predictive,
Expand Down Expand Up @@ -309,15 +316,16 @@ def trace_fn(_, pkr):

# HMC sampling
states, kernel_results, duration = self.sampling(num_results, num_burnin, hmc_kernel, self.params, trace_fn)
states_burnin, sample_stats, acc_rate = self.get_chains_after_burnin(states, kernel_results, num_burnin)
states_burnin, sample_stats, acc_rate = self.get_chains_after_burnin(states, kernel_results, num_burnin,
is_nuts=False)

y_hat = self.get_y_hat(states_burnin, num_results, num_burnin)

params = dict(zip(self.param_names, states_burnin))

# Result object generation setup
if self.baseline_index is not None:
cell_types_nb = self.cell_types[:self.baseline_index] + self.cell_types[self.baseline_index+1:]
if self.reference_cell_type is not None:
cell_types_nb = self.cell_types[:self.reference_cell_type] + self.cell_types[self.reference_cell_type+1:]
else:
cell_types_nb = self.cell_types

Expand Down Expand Up @@ -347,11 +355,11 @@ def trace_fn(_, pkr):
"sample": range(self.y.shape[0])
}

# build dicionary with sampling statistics
# build dictionary with sampling statistics
sampling_stats = {"chain_length": num_results, "num_burnin": num_burnin,
"acc_rate": acc_rate, "duration": duration, "y_hat": y_hat}

model_specs = {"baseline": self.baseline_index, "formula": self.formula}
model_specs = {"reference": self.reference_cell_type, "formula": self.formula}

return res.CAResultConverter(posterior=posterior,
posterior_predictive=posterior_predictive,
Expand Down Expand Up @@ -424,16 +432,17 @@ def trace_fn(_, pkr):

# HMC sampling
states, kernel_results, duration = self.sampling(num_results, num_burnin, nuts_kernel, self.params, trace_fn)
states_burnin, sample_stats, acc_rate = self.get_chains_after_burnin(states, kernel_results, num_burnin, is_nuts=True)
states_burnin, sample_stats, acc_rate = self.get_chains_after_burnin(states, kernel_results, num_burnin,
is_nuts=True)

y_hat = self.get_y_hat(states_burnin, num_results, num_burnin)

params = dict(zip(self.param_names, states_burnin))

# Result object generation setup
# Get names of cell types that are not the baseline
if self.baseline_index is not None:
cell_types_nb = self.cell_types[:self.baseline_index] + self.cell_types[self.baseline_index + 1:]
# Get names of cell types that are not the reference
if self.reference_cell_type is not None:
cell_types_nb = self.cell_types[:self.reference_cell_type] + self.cell_types[self.reference_cell_type + 1:]
else:
cell_types_nb = self.cell_types

Expand Down Expand Up @@ -466,7 +475,7 @@ def trace_fn(_, pkr):
sampling_stats = {"chain_length": num_results, "num_burnin": num_burnin,
"acc_rate": acc_rate, "duration": duration, "y_hat": y_hat}

model_specs = {"baseline": self.baseline_index, "formula": self.formula}
model_specs = {"reference": self.reference_cell_type, "formula": self.formula}

return res.CAResultConverter(posterior=posterior,
posterior_predictive=posterior_predictive,
Expand All @@ -477,10 +486,10 @@ def trace_fn(_, pkr):
model_specs=model_specs)


class NoBaselineModel(CompositionalModel):
class NoReferenceModel(CompositionalModel):

""""
implements statistical model for compositional differential change analysis without specification of a baseline cell type
statistical model for compositional differential change analysis without specification of a reference cell type
"""

def __init__(self, *args, **kwargs):
Expand All @@ -494,7 +503,7 @@ def __init__(self, *args, **kwargs):
"""
super(self.__class__, self).__init__(*args, **kwargs)

self.baseline_index = None
self.reference_cell_type = None
dtype = tf.float64

# All parameters that are returned for analysis
Expand Down Expand Up @@ -581,11 +590,11 @@ def get_y_hat(self, states_burnin, num_results, num_burnin):
Parameters
----------
states_burnin -- List
MCMC chain without burnin samples
MCMC chain without burn-in samples
num_results -- int
Chain length (with burnin)
Chain length (with burn-in)
num_burnin -- int
Number of burnin samples
Number of burn-in samples
Returns
-------
Expand Down Expand Up @@ -629,26 +638,26 @@ def get_y_hat(self, states_burnin, num_results, num_burnin):
return y_mean


class BaselineModel(CompositionalModel):
class ReferenceModel(CompositionalModel):
"""
implements statistical model for compositional differential change analysis with specification of a baseline cell type
implements statistical model for compositional differential change analysis with specification of a reference cell type
"""

def __init__(self, baseline_index, *args, **kwargs):
def __init__(self, reference_cell_type, *args, **kwargs):

"""
Constructor of model class
Parameters
----------
baseline_index -- string or int
reference_cell_type -- string or int
Index of reference cell type (column in count data matrix)
args -- arguments passed to top-level class
kwargs -- arguments passed to top-level class
"""
super(self.__class__, self).__init__(*args, **kwargs)

self.baseline_index = baseline_index
self.reference_cell_type = reference_cell_type
dtype = tf.float64

# All parameters that are returned for analysis
Expand All @@ -674,7 +683,7 @@ def define_model(x, n_total, K):
# normal prior on bias
alpha = ed.Normal(loc=tf.zeros([K], dtype=dtype), scale=tf.ones([K], dtype=dtype) * 5, name="alpha")

# Noncentered parametrization for raw slopes of all cell types except baseline type (before spike-and-slab)
# Noncentered parametrization for raw slopes of all cell types except reference type (before spike-and-slab)
mu_b = ed.Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype), name="mu_b")
sigma_b = ed.HalfCauchy(tf.zeros(1, dtype=dtype), tf.ones(1, dtype=dtype), name="sigma_b")
b_offset = ed.Normal(loc=tf.zeros([D, K-1], dtype=dtype), scale=tf.ones([D, K-1], dtype=dtype),
Expand All @@ -693,10 +702,10 @@ def define_model(x, n_total, K):
# Calculate betas
beta = ind * b_raw

# Include slope 0 for baseline cell type
beta = tf.concat(axis=1, values=[beta[:, :baseline_index],
# Include slope 0 for reference cell type
beta = tf.concat(axis=1, values=[beta[:, :reference_cell_type],
tf.zeros(shape=[D, 1], dtype=dtype),
beta[:, baseline_index:]])
beta[:, reference_cell_type:]])

# Concentration vector from intercepts, slopes
concentration_ = tf.exp(alpha + tf.matmul(x, beta))
Expand Down Expand Up @@ -772,9 +781,9 @@ def get_y_hat(self, states_burnin, num_results, num_burnin):

beta_ = np.zeros(chain_size_beta)
for i in range(num_results - num_burnin):
beta_[i] = np.concatenate([beta_temp[i, :, :self.baseline_index],
beta_[i] = np.concatenate([beta_temp[i, :, :self.reference_cell_type],
np.zeros(shape=[self.D, 1], dtype=np.float64),
beta_temp[i, :, self.baseline_index:]], axis=1)
beta_temp[i, :, self.reference_cell_type:]], axis=1)

conc_ = np.exp(np.einsum("jk, ...kl->...jl", self.x, beta_)
+ alphas.reshape((num_results - num_burnin, 1, self.K))).astype(np.float64)
Expand All @@ -794,4 +803,3 @@ def get_y_hat(self, states_burnin, num_results, num_burnin):
concentration = np.exp(np.matmul(self.x, betas_final) + alphas_final).astype(np.float64)
y_mean = concentration / np.sum(concentration, axis=1, keepdims=True) * self.n_total.numpy()[:, np.newaxis]
return y_mean

Loading

0 comments on commit d4981e4

Please sign in to comment.