Skip to content

Commit

Permalink
Standardize adam signatures with multidiff
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanPearl committed Jul 10, 2024
1 parent 0558e61 commit c45ca59
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
7 changes: 4 additions & 3 deletions kdescent/descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None,
nsteps : int, optional
Number of gradient descent iterations to perform, by default 100
param_bounds : Sequence, optional
Lower and upper bounds of each parameter, by default None
Lower and upper bounds of each parameter of "shape" (ndim, 2). Pass
`None` as the bound for each unbounded parameter, by default None
learning_rate : float, optional
Initial Adam learning rate, by default 0.05
randkey : int, optional
Expand Down Expand Up @@ -68,10 +69,10 @@ def adam_unbounded(lossfunc, guess, nsteps=100,
randkey, key_i = jax.random.split(randkey)
kwargs["randkey"] = key_i
opt = optax.adam(learning_rate)
solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, nsteps=nsteps)
solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, maxiter=nsteps)
state = solver.init_state(guess, **kwargs)
params = [guess]
for _ in tqdm.trange(nsteps):
for _ in tqdm.trange(nsteps, desc="Adam Gradient Descent Progress"):
if randkey is not None:
randkey, key_i = jax.random.split(randkey)
kwargs["randkey"] = key_i
Expand Down
5 changes: 4 additions & 1 deletion kdescent/kstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def __init__(self, training_x, training_weights=None, num_kernels=20,
Increase or decrease the Fourier search space, by default 10.0
comm : MPI Communicator, optional
For parallel computing, this guarantees consistent kernel
placements within a shared comm, by default None
placements by all MPI ranks within the comm, by default None.
WARNING: Do not pass in an MPI communicator if you plan on wrapping
kernel drawing with a JIT-compiled function. In this case, be very
careful to pass identical randkeys for each MPI rank
"""
self.training_x = jnp.atleast_2d(jnp.asarray(training_x).T).T
assert self.training_x.ndim == 2, "x must have shape (ndata, ndim)"
Expand Down

0 comments on commit c45ca59

Please sign in to comment.