Skip to content

Commit

Permalink
DOC: add examples to lax function docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 29, 2021
1 parent 23cbcbe commit 71a25cd
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,35 @@ def reshape(operand: Array, new_sizes: Shape,
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
``lax.expand_dims``. These preserve information about axis identity that may
be useful for advanced transformation rules.
Args:
operand: array to be reshaped.
new_sizes: sequence of integers specifying the resulting shape. The size
of the final array must match the size of the input.
dimensions: optional sequence of integers specifying the permutation order of
the input shape. If specified, the length must match ``operand.shape``.
Returns:
out: reshaped array.
Examples:
Simple reshaping from one to two dimensions:
>>> x = jnp.arange(6)
>>> y = reshape(x, (2, 3))
>>> y
DeviceArray([[0, 1, 2],
[3, 4, 5]], dtype=int32)
Reshaping back to one dimension:
>>> reshape(y, (6,))
DeviceArray([0, 1, 2, 3, 4, 5], dtype=int32)
Reshaping to one dimension with permutation of dimensions:
>>> reshape(y, (6,), (1, 0))
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
"""
new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
Expand Down Expand Up @@ -823,6 +852,27 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
Returns:
An array containing the slice.
Examples:
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
DeviceArray([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3))
DeviceArray([[ 5, 6, 7],
[ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice
overruns the bounds of the array; in this case the start index is adjusted to
return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4))
DeviceArray([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_slice_p.bind(operand, *start_indices,
Expand Down

0 comments on commit 71a25cd

Please sign in to comment.