diff --git a/susiepca/infer.py b/susiepca/infer.py index 4bffce6..332d18c 100755 --- a/susiepca/infer.py +++ b/susiepca/infer.py @@ -25,8 +25,24 @@ def logdet(A): return ldet -# Define the class for model parameters class ModelParams(NamedTuple): + """Define the class for variational parameters. This + class include variational parameters of all the variable + we need to infer from the SuSiE PCA + + Args: + mu_z: mean parameter for factor Z + var_z: variance parameter for factor Z + mu_w: conditional mean parameter for loadings W + var_w: conditional variance parameter for loading W + alpha: parameter for the gamma that follows multinomial + distribution + tau: inverse variance parameter of observed data X + tau_0: inverse variance parameter of single effect w_kl + pi: prior probability for gamma + + """ + # variational params for Z mu_z: jnp.ndarray var_z: jnp.ndarray @@ -46,10 +62,11 @@ class ModelParams(NamedTuple): pi: jnp.ndarray -# Define the class of all components in ELBO. +# class ELBOResults(NamedTuple): - """ + """Define the class of all components in ELBO, + which is returned by function ``compute_elbo`` Args: elbo: the value of ELBO @@ -75,6 +92,17 @@ def __str__(self): class SuSiEPCAResults(NamedTuple): + """Define the object returned by function ``susie_pca`` + + Args: + params: the dictionary contain all the infered parameters + elbo: the value of ELBO + pve: the ndarray of percent of variance explained + pip: the ndarray of posterior inclusion probabilities + W: the posterior mean parameter for loadings + + """ + params: ModelParams elbo: ELBOResults pve: jnp.ndarray @@ -301,13 +329,13 @@ def compute_elbo(X, params) -> ELBOResults: return result -# Create a function to compute the posterior inclusion probabilities (PIPs). def compute_pip(params): - """ + """Create a function to compute the posterior inclusion probabilities (PIPs). Args: - params: the dictionary return from the function ``susie_pca``. + params: the dictionary contains all the infered parameters, + returned from the function ``susie_pca``. Returns: pip: the K by P array of posterior inclusion probabilities (PIPs) @@ -323,7 +351,8 @@ def compute_pve(params): """Create a function to compute the percent of variance explained (PVE). Args: - params: the dictionary return from the function susie_pca + params: the dictionary contains all the infered parameters, + returned from the function ``susie_pca`` Returns: pve: the length K array of percent of variance explained by each factor (PVE) diff --git a/susiepca/metrics.py b/susiepca/metrics.py index 9ed8680..2a73d74 100644 --- a/susiepca/metrics.py +++ b/susiepca/metrics.py @@ -7,8 +7,8 @@ def mse(X: jnp.ndarray, Xhat: jnp.ndarray): - """ - + """Create a function to compute relative + roote mean square error. Args: X : Input data. Should be a array-like @@ -23,7 +23,8 @@ def mse(X: jnp.ndarray, Xhat: jnp.ndarray): def get_credset(params, rho=0.9): - """ + """Creat a function to compute the rho-level + credible set Args: params: the dictionary return from the function susie_pca