Skip to content

Commit

Permalink
improve djax dim indexing type checking
Browse files Browse the repository at this point in the history
Co-authored-by: Edward Loper <[email protected]>
  • Loading branch information
mattjj and edloper committed Apr 2, 2021
1 parent d2c53e0 commit d4af8cd
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions jax/experimental/djax.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,16 @@ def typecheck_type(env, aval):
if not (isinstance(d.name.aval, AbsArray) and
isinstance(d.name.aval._eltTy, BoundedIntTy)):
raise TypeError(f'dim var of unexpected type: {d.name.aval}')
if not all(j < i for j in d.indices):
raise TypeError(f'out-of-bounds dim indexing: {aval}')
expected_shape = tuple(aval.shape[j] for j in d.indices)
if d.name.aval.shape != expected_shape:
d_indices_set = set(d.indices)
if i in d_indices_set:
raise TypeError(f"circular dim indexing expression: {d}")
for j in d.indices:
d_j = aval.shape[j]
if (isinstance(d_j, DimIndexingExpr) and
not d_indices_set.issuperset(d_j.indices)):
raise TypeError(f"dim indexing not transitively closed: {d}")
expected_idx_array_shape = tuple(aval.shape[j] for j in d.indices)
if d.name.aval.shape != expected_idx_array_shape:
raise TypeError(f'incompatible shapes in dim indexing: {aval}')
else:
raise TypeError(f'unexpected type in shape: {type(d)}')
Expand Down

0 comments on commit d4af8cd

Please sign in to comment.