Skip to content

Commit

Permalink
Improve error message for reverse-mode of while loop
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 28, 2023
1 parent ae9160a commit 7a2fc0e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,8 +1471,8 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn):

def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")
"lax.while_loop or lax.fori_loop with dynamic start/stop values. "
"Try using lax.scan, or using fori_loop with static start/stop.")

# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
Expand Down

0 comments on commit 7a2fc0e

Please sign in to comment.