From f3b4f43c2036df1fe885127154eb75573d95bc2c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 15 Oct 2020 21:58:27 -0700 Subject: [PATCH] temporarily work around a bug that #4008 will fix --- jax/custom_derivatives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 809a5541850c..11cdf1bcee6b 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -608,6 +608,7 @@ def batched_fwd_jaxpr_thunk(): fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, out_trees=out_trees) out_dims = out_dims2[0] if out_dims2 else out_dims1 + out_dims = out_dims[:len(batched_outs)] # TODO(mattjj): remove after #4008 return batched_outs, out_dims batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap