diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 63f238d93b92..24fd2a481b59 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -398,7 +398,7 @@ def xmap(fun: Callable, Note that the contraction in the program is performed over the positional axes, while named axes are just a convenient way to achieve batching. While this might seem like a silly example at first, it might turn out to be useful in - practice, since with conjuction with ``axis_resources`` this makes it possible + practice, since with conjunction with ``axis_resources`` this makes it possible to implement a distributed matrix-multiplication in just a few lines of code:: devices = np.array(jax.devices())[:4].reshape((2, 2))