Skip to content

Commit

Permalink
updated types and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Oct 21, 2022
1 parent d45174d commit 1be8614
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 64 deletions.
24 changes: 13 additions & 11 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
# serve to show the default.

import os
import sys
import shutil
import sys

# -- Path setup --------------------------------------------------------------
__location__ = os.path.dirname(__file__)
Expand All @@ -20,7 +20,7 @@
sys.path.insert(0, os.path.join(__location__, "../"))

# If your documentation needs a minimal Sphinx version, state it here.
needs_sphinx = '5.0'
needs_sphinx = "5.0"

# Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
Expand All @@ -42,7 +42,7 @@
]

python_apigen_modules = {
"susiepca": "api",
"susiepca": "api",
}

python_apigen_rst_prolog = """
Expand Down Expand Up @@ -110,8 +110,10 @@
]

# type options
python_transform_type_annotations_pep585 = True # simplify typing names (e.g., typing.List -> list)
python_transform_type_annotations_pep604 = True # simplify Union and Optional types
python_transform_type_annotations_pep585 = (
True # simplify typing names (e.g., typing.List -> list)
)
python_transform_type_annotations_pep604 = True # simplify Union and Optional types


# If this is True, todo emits a warning for each TODO entries. The default is False.
Expand All @@ -126,7 +128,7 @@

# The name for this set of Sphinx documents. If None, it defaults to
html_static_path = ["_static"]
#html_css_files = ["extra_css.css"]
# html_css_files = ["extra_css.css"]
html_last_updated_fmt = ""
html_title = "SuSiE-PCA"
html_favicon = "_static/images/favicon.ico"
Expand All @@ -136,12 +138,12 @@
"icon": {
"repo": "fontawesome/brands/github",
},
"site_url": "https://mancusolab.github.io/susiepca/",
"site_url": "https://mancusolab.github.io/susiepca/",
"repo_url": "https://github.com/mancusolab/susiepca/",
"repo_name": "susiepca",
"repo_type": "github",
"edit_uri": "blob/main/docs",
"globaltoc_collapse": True,
"globaltoc_collapse": True,
"features": [
"navigation.expand",
# "navigation.tabs",
Expand All @@ -156,7 +158,7 @@
"toc.follow",
"toc.sticky",
],
"palette": [
"palette": [
{
"media": "(prefers-color-scheme: light)",
"scheme": "default",
Expand Down Expand Up @@ -195,10 +197,10 @@
"icon": "fontawesome/brands/github",
"link": "https://github.com/mancusolab/susiepca",
},
#{
# {
# "icon": "fontawesome/brands/python",
# "link": "https://pypi.org/project/sphinx-immaterial/",
#},
# },
],
# END: social icons
}
Expand Down
113 changes: 60 additions & 53 deletions susiepca/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from jax import jit, lax, nn, random
from sklearn.decomposition import PCA

# TODO: append internal functions to have '_'


__all__ = [
"ModelParams",
"ELBOResults",
Expand Down Expand Up @@ -56,21 +59,6 @@ def __str__(self):
)


class _FactorLoopResults(NamedTuple):
X: jnp.ndarray
W: jnp.ndarray
EZZ: jnp.ndarray
params: ModelParams


class _EffectLoopResults(NamedTuple):
E_zzk: jnp.ndarray
RtZk: jnp.ndarray
Wk: jnp.ndarray
k: int
params: ModelParams


class SuSiEPCAResults(NamedTuple):
params: ModelParams
elbo: ELBOResults
Expand All @@ -79,28 +67,30 @@ class SuSiEPCAResults(NamedTuple):
W: jnp.ndarray


# Create a function to initiate parameters in SuSiE PCA
def init_params(
rng_key, X, z_dim: int, l_dim: int, init: _init_type = "pca", tau: int = 10
rng_key: random.PRNGKey,
X: jnp.ndarray,
z_dim: int,
l_dim: int,
init: _init_type = "pca",
tau: float = 10.0,
) -> ModelParams:

"""
Initialize parameters for SuSiE PCA
Args:
rng_key: A random key return by function jax.random.PRNGkey()
rng_key: Random number generator seed
X: Input data. Should be a array-like
z_dim: Latent factor dimension (int; K)
l_dim: Number of single-effects comprising each factor (int; L)
z_dim: Latent factor dimension (K)
l_dim: Number of single-effects comprising each factor ( L)
init: How to initialize the variational mean parameters for latent factors.
Either "pca" or "random" (default = "pca")
tau: the initial value of tau
tau: initial value of residual precision
Returns:
ModelParams: Initial value of model parameters
ModelParams: initialized set of model parameters
"""

tau = tau

tau_0 = jnp.ones((l_dim, z_dim))

n_dim, p_dim = X.shape
Expand Down Expand Up @@ -160,7 +150,9 @@ def compute_W_moment(params):


# Update posterior mean and variance W
def update_w(RtZk, E_zzk, params, kdx, ldx):
def update_w(
RtZk: jnp.ndarray, E_zzk: jnp.ndarray, params: ModelParams, kdx: int, ldx: int
) -> ModelParams:
# n_dim, z_dim = params.mu_z.shape

# calculate update_var_w as the new V[w | gamma]
Expand Down Expand Up @@ -337,33 +329,19 @@ def compute_pve(params):
return pve


@jit
def _inner_loop(X, params):
n_dim, z_dim = params.mu_z.shape
l_dim, _, _ = params.mu_w.shape

# compute expected residuals
# use posterior mean of Z, W, and Alpha to calculate residuals
W = jnp.sum(params.mu_w * params.alpha, axis=0)
E_ZZ = params.mu_z.T @ params.mu_z + n_dim * params.var_z

# update effect precision via MLE
params = update_tau0_mle(params)

# update locals (W, alpha)
init_loop_param = _FactorLoopResults(X, W, E_ZZ, params)
_, W, _, params = lax.fori_loop(0, z_dim, _factor_loop, init_loop_param)

# update factor parameters
params = update_z(X, params)

# update precision parameters via MLE
params = update_tau(X, params)
class _FactorLoopResults(NamedTuple):
X: jnp.ndarray
W: jnp.ndarray
EZZ: jnp.ndarray
params: ModelParams

# compute elbo
elbo_res = compute_elbo(X, params)

return W, elbo_res, params
class _EffectLoopResults(NamedTuple):
E_zzk: jnp.ndarray
RtZk: jnp.ndarray
Wk: jnp.ndarray
k: int
params: ModelParams


def _factor_loop(kdx: int, loop_params: _FactorLoopResults) -> _FactorLoopResults:
Expand Down Expand Up @@ -408,7 +386,35 @@ def _effect_loop(ldx: int, effect_params: _EffectLoopResults) -> _EffectLoopResu
return effect_params._replace(Wk=Wk, params=params)


# The main inference function for SuSiE PCA
@jit
def _inner_loop(X: jnp.ndarray, params: ModelParams):
n_dim, z_dim = params.mu_z.shape
l_dim, _, _ = params.mu_w.shape

# compute expected residuals
# use posterior mean of Z, W, and Alpha to calculate residuals
W = jnp.sum(params.mu_w * params.alpha, axis=0)
E_ZZ = params.mu_z.T @ params.mu_z + n_dim * params.var_z

# update effect precision via MLE
params = update_tau0_mle(params)

# update locals (W, alpha)
init_loop_param = _FactorLoopResults(X, W, E_ZZ, params)
_, W, _, params = lax.fori_loop(0, z_dim, _factor_loop, init_loop_param)

# update factor parameters
params = update_z(X, params)

# update precision parameters via MLE
params = update_tau(X, params)

# compute elbo
elbo_res = compute_elbo(X, params)

return W, elbo_res, params


def susie_pca(
X: jnp.ndarray,
z_dim: int,
Expand All @@ -420,6 +426,7 @@ def susie_pca(
verbose: bool = True,
) -> SuSiEPCAResults:
"""
The main inference function for SuSiE PCA.
Args:
X: Input data. Should be a array-like
Expand Down Expand Up @@ -484,7 +491,7 @@ def susie_pca(
# type check for init
if init not in type_options:
raise ValueError(
f"Unknown initialization provided {init}; Choice: {type_options}"
f'Unknown initialization provided "{init}"; Choice: {type_options}'
)

# initialize PRNGkey and params
Expand Down

0 comments on commit 1be8614

Please sign in to comment.