Skip to content

Commit

Permalink
#5 rewrite "get_credset_v2" function to remove the if statement
Browse files Browse the repository at this point in the history
  • Loading branch information
Dong555 committed Nov 7, 2022
1 parent c389259 commit 33badcf
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions susiepca/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax.numpy as jnp

# from jax import jit

__all__ = [
"mse",
"get_credset",
Expand Down Expand Up @@ -54,3 +56,29 @@ def get_credset(params, rho=0.9):
cs["z" + str(zdx)] = cs_s

return cs


def get_credset_v2(params, rho=0.9):
l_dim, z_dim, p_dim = params.alpha.shape
idxs = jnp.argsort(-params.alpha, axis=-1)
cs = {}
for zdx in range(z_dim):
cs_s = []
for ldx in range(l_dim):
cs_s.append([])

# idxs for all feature at this zdx and ldx
p_idxs = idxs[ldx, zdx, :]
# compute the cumulative sum
p_sums = jnp.cumsum(params.alpha[ldx, zdx, p_idxs])
# find all the index where the cumsum>rho
p_gts = jnp.where(p_sums >= rho)[0]
# get the minimum value that satisfy the above criterion
min_p_gts = p_gts[0]
# form the cs. note it's possible that min_p_gets is 0
idx = p_idxs[0 : max(min_p_gts, 1)]
cs_s[ldx].append(idx)

cs["z" + str(zdx)] = cs_s

return cs

0 comments on commit 33badcf

Please sign in to comment.