Skip to content

Commit

Permalink
update mypy & related package versions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 15, 2022
1 parent 375777f commit 7972b98
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ repos:
- id: flake8

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.931'
rev: 'v0.942'
hooks:
- id: mypy
files: jax/
additional_dependencies: [types-requests==0.1.11, jaxlib==0.1.74]
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]

- repo: https://github.com/mwouts/jupytext
rev: v1.13.6
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _one_hot(x: Array, num_classes: int, *,
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x == axis_idx, dtype=dtype)
axis = operator.index(axis)
axis = operator.index(axis) # type: ignore[arg-type]
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
Expand Down
2 changes: 1 addition & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def sharding_to_proto(sharding: SpatialSharding):
proto.type = xc.OpSharding.Type.REPLICATED
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_dimensions = list(sharding) # type: ignore
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
return proto

Expand Down

0 comments on commit 7972b98

Please sign in to comment.