Skip to content

Commit

Permalink
Give jax.numpy.array the type Callable.
Browse files Browse the repository at this point in the history
This is to prevent users from using as the type of arrays in type annotations.

PiperOrigin-RevId: 560754568
  • Loading branch information
hawkinsp authored and jax authors committed Aug 28, 2023
1 parent 3ea141d commit 70206ee
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import numpy as _np
from jax.numpy import fft, linalg
from typing import Any, Callable, Dict, Tuple, Type, Union
from jax._src.typing import Array, ArrayLike
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.numpy.reductions import CumulativeReduction as _CumulativeReduction
from jax._src.numpy.ufunc_api import ufunc as ufunc
Expand Down Expand Up @@ -33,7 +35,9 @@ argpartition: Any
argsort: Any
argwhere: Any
around: Any
array: Any
array: Callable
# def array(object: Any, dtype: DTypeLike | None = ..., copy: bool = True,
# order: str | None = ..., ndmin: int = ...) -> Array: ...
array_equal: Any
array_equiv: Any
array_repr: Any
Expand Down

0 comments on commit 70206ee

Please sign in to comment.