Skip to content

Commit

Permalink
enable running on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
Dong555 committed Nov 8, 2022
1 parent bdc257a commit c9dd6e3
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions examples/susie_package_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from csv import writer
from time import time

import numpy as np
#import procrustes
# import numpy as np
# import procrustes
from jax.config import config

import susiepca as sp
Expand All @@ -16,7 +16,7 @@ def main(args):
"""
run the simulation with user-specific setting and
produce summary results in args.output path. Currently each simulation
will expect to take ~10 seconds to converge on CPU, and ~5 seconds on GPU.
will expect to take ~10 seconds to converge on CPU, and ~1.5 seconds on GPU.
"""
argp = ap.ArgumentParser(description="")
argp.add_argument("--n-sim", default=100, type=int, help="number of simulations")
Expand All @@ -40,7 +40,7 @@ def main(args):
args = argp.parse_args(args)

config.update("jax_enable_x64", True)
config.update("jax_platform_name","gpu")
config.update("jax_platform_name", "gpu")

# loading_df = pd.DataFrame()

Expand All @@ -65,10 +65,10 @@ def main(args):

# calculate error
W_hat = results.W
#proc_trans_susie = procrustes.orthogonal(
# proc_trans_susie = procrustes.orthogonal(
# np.asarray(W_hat.T), np.asarray(W.T), scale=True
#)
#error_susie = proc_trans_susie.error
# )
# error_susie = proc_trans_susie.error
# compute the predicted data
X_hat = results.params.mu_z @ W_hat
rrmse_susie = sp.metrics.mse(X, X_hat)
Expand All @@ -79,7 +79,7 @@ def main(args):
args.p_dim,
args.z_dim,
args.l_dim,
#error_susie,
# error_susie,
rrmse_susie,
run_susie,
]
Expand Down

0 comments on commit c9dd6e3

Please sign in to comment.