Skip to content

Commit

Permalink
[typing] annotate jax.numpy ufuncs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 20, 2022
1 parent 2d563bf commit 6d30865
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 68 deletions.
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,7 +2151,7 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize] = None,
if (not dtypes.issubdtype(start_dtype, np.integer) and
not core.is_opaque_dtype(start_dtype)):
ceil_ = ceil if isinstance(start, core.Tracer) else np.ceil
start = ceil_(start).astype(int)
start = ceil_(start).astype(int) # type: ignore
return lax.iota(dtype, start)
else:
start = require(start, msg("start"))
Expand Down
Loading

0 comments on commit 6d30865

Please sign in to comment.