Skip to content

Commit

Permalink
Use int32 counters in optix. (jax-ml#2239)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhennigan authored Feb 23, 2020
1 parent 89514f9 commit cd5bcb3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions jax/experimental/optix.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8):
def init_fn(params):
mu = tree_multimap(jnp.zeros_like, params) # First moment
nu = tree_multimap(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([]), mu=mu, nu=nu)
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def update_fn(updates, state):
mu = _update_moment(updates, state.mu, b1, 1)
Expand Down Expand Up @@ -278,7 +278,7 @@ def scale_by_schedule(step_size_fn):
"""

def init_fn(_):
return ScaleByScheduleState(count=jnp.zeros([]))
return ScaleByScheduleState(count=jnp.zeros([], jnp.int32))

def update_fn(updates, state):
updates = tree_multimap(lambda g: step_size_fn(state.count) * g, updates)
Expand Down Expand Up @@ -306,7 +306,8 @@ def add_noise(eta, gamma, seed):
"""

def init_fn(_):
return AddNoiseState(count=jnp.zeros([]), rng_key=jrandom.PRNGKey(seed))
return AddNoiseState(count=jnp.zeros([], jnp.int32),
rng_key=jrandom.PRNGKey(seed))

def update_fn(updates, state): # pylint: disable=missing-docstring
num_vars = len(tree_leaves(updates))
Expand Down

0 comments on commit cd5bcb3

Please sign in to comment.