Skip to content

Commit

Permalink
api_util: make shaped_abstractify respect raise_to_shaped
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 6, 2022
1 parent 212edd6 commit 5d45458
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ def _dtype(x):

def shaped_abstractify(x):
try:
return core.raise_to_shaped(core.get_aval(x))
return core.raise_to_shaped(
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
except TypeError:
pass

Expand Down
3 changes: 2 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax import dtypes
from jax import stages
from jax.errors import JAXTypeError
from jax import lax
Expand Down Expand Up @@ -827,7 +828,7 @@ def f(x, y):
return x @ y

shape = (8, 8)
aval = jax.ShapedArray(shape, jnp.int64)
aval = jax.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
x = jnp.arange(np.prod(shape)).reshape(shape)
exe = f.lower(aval, x, _global_avals=True).compile()
self.assertIsInstance(exe, stages.Compiled)
Expand Down

0 comments on commit 5d45458

Please sign in to comment.