Skip to content

Commit

Permalink
update susiepca simulation example
Browse files Browse the repository at this point in the history
  • Loading branch information
Dong555 committed Nov 11, 2022
1 parent c074469 commit 7f5c1d7
Showing 1 changed file with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main(args):
will expect to take ~10 seconds to converge on CPU, and ~1.5 seconds on GPU.
"""
argp = ap.ArgumentParser(description="")
argp.add_argument("--n-sim", default=100, type=int, help="number of simulations")
argp.add_argument("--n-sim", default=1, type=int, help="number of simulations")
argp.add_argument("--n-dim", default=1000, type=int, help="Number of samples")
argp.add_argument("--p-dim", default=6000, type=int, help="Number of features")
argp.add_argument("--z-dim", default=4, type=int, help="Number of latent factors")
Expand All @@ -42,11 +42,17 @@ def main(args):
type=str,
default="/home1/dongyuan/SuSiEPCA/simulation_results/test.csv",
)
argp.add_argument(
"--platform",
type=str,
default="cpu",
help="Choose the platform to run inference",
)

args = argp.parse_args(args)

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "gpu")
config.update("jax_platform_name", args.platform)

for sim in range(args.n_sim):
# simulate data
Expand Down Expand Up @@ -75,7 +81,7 @@ def main(args):
end_susie = time()
run_susie = end_susie - start_susie

# calculate procruste error
# calculate procruste error for W
W_hat = results.W
proc_trans_susie = procrustes.orthogonal(
np.asarray(W_hat.T), np.asarray(W.T), scale=True
Expand All @@ -86,6 +92,13 @@ def main(args):
X_hat = results.params.mu_z @ W_hat
rrmse_susie = sp.metrics.mse(X, X_hat)

# calculate procruste error for Z
Z_hat = results.params.mu_z
proc_trans_susie_z = procrustes.orthogonal(
np.asarray(Z_hat), np.asarray(Z), scale=True
)
error_z_susie = proc_trans_susie_z.error

# summarize results
summary = [
sim,
Expand All @@ -95,6 +108,7 @@ def main(args):
args.l_dim,
error_susie,
rrmse_susie,
error_z_susie,
run_susie,
]

Expand Down

0 comments on commit 7f5c1d7

Please sign in to comment.