Skip to content

Commit

Permalink
temporarily work around a bug that jax-ml#4008 will fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Oct 16, 2020
1 parent d0ab44d commit f3b4f43
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions jax/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ def batched_fwd_jaxpr_thunk():
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
out_trees=out_trees)
out_dims = out_dims2[0] if out_dims2 else out_dims1
out_dims = out_dims[:len(batched_outs)] # TODO(mattjj): remove after #4008
return batched_outs, out_dims
batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap

Expand Down

0 comments on commit f3b4f43

Please sign in to comment.