Skip to content

Commit

Permalink
Use lower-case PEP 585 names for types.
Browse files Browse the repository at this point in the history
Issue jax-ml#16537

PiperOrigin-RevId: 542969282
  • Loading branch information
hawkinsp authored and jax authors committed Jun 23, 2023
1 parent f67acee commit 816ba91
Show file tree
Hide file tree
Showing 148 changed files with 1,492 additions and 1,526 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/cat_slurm_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import argparse
import os
from typing import List

ISSUE_FORMAT = """\
<details><summary>Failure summary {name}</summary>
Expand All @@ -27,7 +26,7 @@
</details>
"""

def main(logfiles: List[str], outfile: str):
def main(logfiles: list[str], outfile: str):
print(f"extracting content of {logfiles}")
print(f"and writing to {outfile}")
with open(outfile, 'w') as f:
Expand Down
6 changes: 3 additions & 3 deletions docs/jax-101/05.1-pytrees.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@
}
],
"source": [
"from typing import Tuple, Iterable\n",
"from typing import Iterable\n",
"\n",
"def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:\n",
"def flatten_MyContainer(container) -> tuple[Iterable[int], str]:\n",
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
" flat_contents = [container.a, container.b, container.c]\n",
"\n",
Expand Down Expand Up @@ -593,7 +593,7 @@
"class MyKeyPathContainer(MyContainer):\n",
" pass\n",
"\n",
"def flatten_with_keys_MyKeyPathContainer(container) -> Tuple[Iterable[int], str]:\n",
"def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:\n",
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
" \n",
" # GetAttrKey is a common way to express an attribute key. Users are free\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/jax-101/05.1-pytrees.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ except TypeError as e:
To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:

```{code-cell} ipython3
from typing import Tuple, Iterable
from typing import Iterable
def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
def flatten_MyContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
flat_contents = [container.a, container.b, container.c]
Expand Down Expand Up @@ -312,7 +312,7 @@ Alternatively, using the key path API mentioned above, you can register this con
class MyKeyPathContainer(MyContainer):
pass
def flatten_with_keys_MyKeyPathContainer(container) -> Tuple[Iterable[int], str]:
def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
# GetAttrKey is a common way to express an attribute key. Users are free
Expand Down
4 changes: 2 additions & 2 deletions docs/jax-101/06-parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@
},
"outputs": [],
"source": [
"from typing import NamedTuple, Tuple\n",
"from typing import NamedTuple\n",
"import functools\n",
"\n",
"class Params(NamedTuple):\n",
Expand Down Expand Up @@ -571,7 +571,7 @@
"# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it\n",
"# 'num_devices', but could have used anything, so long as `pmean` used the same.\n",
"@functools.partial(jax.pmap, axis_name='num_devices')\n",
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:\n",
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:\n",
" \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n",
"\n",
" # Compute the gradients on the given minibatch (individually on each device).\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/jax-101/06-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ If this example is too confusing, you can find the same example, but without par
```{code-cell} ipython3
:id: cI8xQqzRrc-4
from typing import NamedTuple, Tuple
from typing import NamedTuple
import functools
class Params(NamedTuple):
Expand Down Expand Up @@ -249,7 +249,7 @@ LEARNING_RATE = 0.005
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:
"""Performs one SGD update step on params using the given data."""
# Compute the gradients on the given minibatch (individually on each device).
Expand Down
4 changes: 1 addition & 3 deletions docs/jax-101/07-state.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,11 @@
}
],
"source": [
"from typing import Tuple\n",
"\n",
"CounterState = int\n",
"\n",
"class CounterV2:\n",
"\n",
" def count(self, n: CounterState) -> Tuple[int, CounterState]:\n",
" def count(self, n: CounterState) -> tuple[int, CounterState]:\n",
" # You could just return n+1, but here we separate its role as \n",
" # the output and as the counter state for didactic purposes.\n",
" return n+1, n+1\n",
Expand Down
4 changes: 1 addition & 3 deletions docs/jax-101/07-state.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,11 @@ Part of the problem with our counter was that the returned value didn't depend o
:id: 53pSdK4KoOEZ
:outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79
from typing import Tuple
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> Tuple[int, CounterState]:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
Expand Down
6 changes: 2 additions & 4 deletions docs/notebooks/autodiff_remat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1141,10 +1141,8 @@
},
"source": [
"```python\n",
"from typing import Tuple, List\n",
"\n",
"LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer\n",
"ParamsList = List[LayerParam]\n",
"LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer\n",
"ParamsList = list[LayerParam]\n",
"\n",
"def net(params: ParamsList, x: jnp.ndarray):\n",
" for W, b in params:\n",
Expand Down
6 changes: 2 additions & 4 deletions docs/notebooks/autodiff_remat.md
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,8 @@ For example, one common pattern in large [Transformer models](https://en.wikiped
+++ {"id": "BUeqKFRS5yPU"}

```python
from typing import Tuple, List

LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = List[LayerParam]
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

from functools import partial
from typing import Set

import numpy as np

Expand All @@ -40,7 +39,7 @@ def zeros_like_array(x):
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
return ad_util.zeros_like_aval(aval)

numpy_scalar_types: Set[type] = { # pylint: disable=g-bare-generic
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.complex64, np.complex128,
Expand All @@ -52,7 +51,7 @@ def zeros_like_array(x):
if dtypes.uint4 is not None:
numpy_scalar_types.add(dtypes.uint4)

array_types: Set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic

def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
Expand Down
25 changes: 12 additions & 13 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import functools
from functools import partial
import logging
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
Union)
from typing import Any, Callable, Optional, Sequence, Union
import types

import numpy as np
Expand Down Expand Up @@ -134,7 +133,7 @@ def policy(prim, *args, **params):
@api_boundary
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
policy: Optional[Callable[..., bool]] = None,
static_argnums: Union[int, Tuple[int, ...]] = (),
static_argnums: Union[int, tuple[int, ...]] = (),
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
Expand Down Expand Up @@ -348,14 +347,14 @@ def __eq__(self, other):
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
# On that benchmark, including this caching makes a ~10x difference (which can
# be made arbitrary large by involving larger functions to be traced).
def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
static_args: Tuple[WrapHashably, ...], nargs: int):
def _dyn_args_fun(fun: Callable, static_argnums: frozenset[int],
static_args: tuple[WrapHashably, ...], nargs: int):
if any(isinstance(x.val, core.Tracer) for x in static_args):
return _dyn_args_fun_uncached(fun, static_argnums, static_args, nargs)
return _dyn_args_fun_cached(fun, static_argnums, static_args, nargs)

def _dyn_args_fun_uncached(fun: Callable, static_argnums: FrozenSet[int],
static_args: Tuple[WrapHashably, ...], nargs: int):
def _dyn_args_fun_uncached(fun: Callable, static_argnums: frozenset[int],
static_args: tuple[WrapHashably, ...], nargs: int):
def new_fun(*dyn_args, **kwargs):
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
full_args = [next(static_args_).val if i in static_argnums
Expand Down Expand Up @@ -391,7 +390,7 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):

### Utilities

def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
in_leaves, in_tree = tree_flatten((args, kwargs))

def f_(*args):
Expand All @@ -409,7 +408,7 @@ def f_(*args):
arg_info = pe.arg_info_all(dbg)
return _saved_residuals(jaxpr, arg_info)

def _saved_residuals(jaxpr, arg_info) -> List[Tuple[core.AbstractValue, str]]:
def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

Expand Down Expand Up @@ -579,7 +578,7 @@ def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]],
out_zeros: Union[bool, Sequence[bool]],
reduce_axes: Sequence[core.AxisName],
) -> Tuple[core.ClosedJaxpr, List[bool]]:
) -> tuple[core.ClosedJaxpr, list[bool]]:
if type(in_linear) is bool:
in_linear = (in_linear,) * len(jaxpr.in_avals)
if type(out_zeros) is bool:
Expand Down Expand Up @@ -640,8 +639,8 @@ def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap

# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], Optional[core.JaxprEqn]]:
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
new_params = dict(eqn.params, jaxpr=new_jaxpr)
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
Expand Down Expand Up @@ -781,7 +780,7 @@ def checkpoint_wrapper(
*,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
static_argnums: Union[int, tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable:
if concrete:
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import types
from typing import Any, Callable, Dict, TypeVar, Union, cast
from typing import Any, Callable, TypeVar, Union, cast

from jax._src import core
from jax._src import traceback_util
Expand All @@ -30,7 +30,7 @@

map = safe_map

jaxval_adders: Dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}
jaxval_adders: dict[type, Callable[[ArrayLike, ArrayLike], Array]] = {}

def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
return add_jaxvals_p.bind(x, y)
Expand All @@ -46,7 +46,7 @@ def add_impl(xs, ys):
def add_abstract(xs, ys):
return lattice_join(xs, ys)

jaxval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
jaxval_zeros_likers: dict[type, Callable[[Any], Array]] = {}

def instantiate(z: Union[Zero, Array]) -> Array:
if type(z) is Zero:
Expand All @@ -56,7 +56,7 @@ def instantiate(z: Union[Zero, Array]) -> Array:
def zeros_like_aval(aval: core.AbstractValue) -> Array:
return aval_zeros_likers[type(aval)](aval)

aval_zeros_likers: Dict[type, Callable[[Any], Array]] = {}
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}

def zeros_like_jaxval(val: ArrayLike) -> Array:
return zeros_like_p.bind(val)
Expand Down
Loading

0 comments on commit 816ba91

Please sign in to comment.