Skip to content

Commit

Permalink
[host_callback] Update the documentation
Browse files Browse the repository at this point in the history
The module-level documentation was out of date.
  • Loading branch information
gnecula committed Sep 23, 2020
1 parent 80fa22c commit 625be69
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,28 @@
This module introduces the host callback functions :func:`id_tap` and
:func:`id_print`, which behave like the identity function but have the
side-effect of sending the arguments from the accelerator to the host and
side-effect of sending the arguments from the device to the host and
invoking a user-specified Python function (for :func:`id_tap`) or printing the
arguments on the host (for :func:`id_print`). A few examples::
arguments on the host (for :func:`id_print`). The Python function passed
to :func:`id_tap` takes two positional arguments (the value tapped from the
device computation along with ``transforms`` sequence, described below).
A few examples::
# call func(2x) on host and return 2x
# calls func(2x, []) on host and returns 2x
y = id_tap(func, 2 * x)
# call func((2x, 3x)) and return (2x, 3x)
# calls func((2x, 3x), []) and returns (2x, 3x)
y, z = id_tap(func, (2 * x, 3 * x)) # The argument can be a pytree
# call func(2x) and return y
y = id_tap(func, 2 * x, result=y)
# call func(2x, what='activation') and return 2x
y = id_tap(func, 2 * x, what='activation')
# call func(dict(x=x, y=y), what='data') and return dict(x=x, y=y)
x, y = id_tap(func, dict(x=x, y=y), what='data')
# calls func(2x, []) and returns y
y = id_tap(func, 2 * x, result=y) # override the result of id_tap
# calls func(2x, [], what='activation') and returns 2x
y = id_tap(functools.partial(func, what='activation'), 2 * x)
# calls func(dict(x=x, y=y), what='data') and returns dict(x=x, y=y)
x, y = id_tap(lambda tap, transforms: func(tap, what='data'), dict(x=x, y=y))
The above examples can all be adapted to use :func:`id_print` instead, with
the difference that :func:`id_print` takes one positional argument (to print
on the host), the optional kwarg ``result``, and possibly additional kwargs
that are also printed along with the automatic kwarg ``transforms``.
The order of execution of the tap functions is constrained by data dependency:
the arguments are sent after all the arguments are computed and before the
Expand Down Expand Up @@ -78,45 +86,43 @@
def power3(x):
y = x * x
_, y = id_print(x, y, what="x,x^2")
_, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments
return y * x
power3(3.)
# what: x,x^2 : [3., 9.]
During JAX transformations the special parameter ``transforms`` is added to
contain a list of transformation descriptors. Each descriptor is a dictionary
containing the key ``name`` holding the name of the transformation and
additional keys holding transformation parameters, if applicable. This
parameter is passed to the tap function (or printed), in addition to
user-defined parameters.
contain a list of transformation descriptors in the form
``(transform_name, transform_params)``.
For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended
with transformation name ``batch`` and ``batch_dims`` set to the the tuple of
batched dimensions (one entry per argument, ``None`` denotes an argument that
was broadcast)::
jax.vmap(power3)(np.arange(3.))
# what=x,x^2 transforms=({name=batch, batch_dims=(0, 0)}): ([0, 1, 2], [0, 1,
4])
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1,
4]]
For :func:`jax.jvp` there will be two callbacks, one with the values of
the primals and one with the tangents::
jax.jvp(power3, (3.,), (0.1,))
# what=x,x^2: (3., 9.)
# what=x,x^2 transforms={name=jvp}: (0.1, 0.6)
# what: x,x^2: [3., 9.]
# transforms: ['jvp'] what: x,x^2 : [0.1, 0.6]
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the
values of the adjoints for the arguments. You may also see a callback with
the values of the primals from the forward pass, if those values are needed for
the backward pass::
jax.grad(power3)(3.)
# what=x,x^2: (3., 9.) # from forward pass, since y is needed in backward
pass
# what=x,x^2 transforms=({name=jvp}, {name=transpose}): (0., 3.) # from
backward pass, adjoints of _, y
# what=x,x^2: [3., 9.] # from forward pass, since y is used in backward pass
# transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.] # from backward pass, adjoints of _, y
See documentation for :func:`id_tap` and :func:`id_print`.
For usage example, see tests/host_callback_test.py.
For more usage example, see tests/host_callback_test.py.
Still to do:
* Performance tests.
Expand Down Expand Up @@ -187,20 +193,18 @@ def id_tap(tap_func, arg, *, result=None, **kwargs):
``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime
values of the argument.
value of the argument.
Args:
tap_func: tap function to call like ``tap_func(arg, transforms)``, with
``arg`` as described below and where ``transforms`` is sequence of applied
JAX transformations in the form ``(name, params)``.
``arg`` as described below and where ``transforms`` is the sequence of
applied JAX transformations in the form ``(name, params)``.
arg: the argument passed to the tap function, can be a pytree of JAX
types.
result: if given, specifies the return value of ``id_tap``. This value is
not passed to the tap function, and in fact is not sent from the device to
the host. If the ``result`` parameter is not specified then the return
value of ``id_tap`` is ``arg``.
**kwargs: Deprecated option for passing additional keyword arguments to
``tap_func``.
Returns:
``arg``, or ``result`` if given.
Expand Down

0 comments on commit 625be69

Please sign in to comment.