Skip to content

Commit

Permalink
initial refactor to include annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Jun 21, 2023
1 parent a2c9497 commit 63ffb24
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ formats = bdist_wheel

[flake8]
# Some sane defaults for the code style checker flake8
max_line_length = 88
max_line_length = 120
extend_ignore = E203, W503
# ^ Black-compatible
# E203 and W503 have edge cases handled by black
Expand Down
65 changes: 36 additions & 29 deletions susiepca/infer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import get_args, Literal, NamedTuple, Optional, Tuple
from typing import Literal, NamedTuple, Optional, Tuple, get_args

import optax
from jax import Array, jit, lax, nn, random, numpy as jnp, grad

from jax import Array, grad, jit, lax, nn, numpy as jnp, random
from jax.scipy import special as spec
from jax.typing import ArrayLike


__all__ = [
"compute_elbo",
"susie_pca",
Expand Down Expand Up @@ -353,6 +353,38 @@ def _annotation_inner_loop(X: jnp.ndarray, A: jnp.ndarray, params: ModelParams):
return W, elbo_res, params


def _reorder_factors_by_pve(
A: ArrayLike, params: ModelParams, pve: ArrayLike
) -> Tuple[ModelParams, Array]:
sorted_indices = jnp.argsort(pve)[::-1]
pve = pve[sorted_indices]
sorted_mu_z = params.mu_z[:, sorted_indices]
sorted_var_z = params.var_z[sorted_indices, sorted_indices]
sorted_mu_w = params.mu_w[:, sorted_indices, :]
sorted_var_w = params.var_w[:, sorted_indices]
sorted_alpha = params.alpha[:, sorted_indices, :]
sorted_tau_0 = params.tau_0[:, sorted_indices]
if A is not None:
sorted_theta = params.theta[:, sorted_indices]
sorted_pi = _compute_pi(A, sorted_theta)
else:
sorted_pi = params.pi

params = ModelParams(
sorted_mu_z,
sorted_var_z,
sorted_mu_w,
sorted_var_w,
sorted_alpha,
params.tau,
sorted_tau_0,
sorted_theta,
sorted_pi,
)

return params, pve


def _init_params(
rng_key: random.PRNGKey,
X: ArrayLike,
Expand Down Expand Up @@ -575,32 +607,7 @@ def susie_pca(

# compute PVE and reorder in descending value
pve = compute_pve(params)
sorted_indices = jnp.argsort(pve)[::-1]

pve = pve[sorted_indices]
sorted_mu_z = params.mu_z[:, sorted_indices]
sorted_var_z = params.var_z[sorted_indices, sorted_indices]
sorted_mu_w = params.mu_w[:, sorted_indices, :]
sorted_var_w = params.var_w[:, sorted_indices]
sorted_alpha = params.alpha[:, sorted_indices, :]
sorted_tau_0 = params.tau_0[:, sorted_indices]
if A is not None:
sorted_theta = params.theta[:, sorted_indices]
sorted_pi = _compute_pi(A, sorted_theta)
else:
sorted_pi = params.pi

params = ModelParams(
sorted_mu_z,
sorted_var_z,
sorted_mu_w,
sorted_var_w,
sorted_alpha,
tau,
sorted_tau_0,
sorted_theta,
sorted_pi,
)
params, pve = _reorder_factors_by_pve(A, params, pve)

# compute PIPs
pip = compute_pip(params)
Expand Down

0 comments on commit 63ffb24

Please sign in to comment.