Skip to content

Commit

Permalink
Add JAX-defined augmentation examples (#5426)
Browse files Browse the repository at this point in the history
* Add JAX operator examples: introducion and multi-gpu
* Fix typo in error message
---------

Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed Apr 16, 2024
1 parent d771a4f commit 2bfb162
Show file tree
Hide file tree
Showing 8 changed files with 848 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def cpu_to_dlpack(tensor: jax.Array):
if devices[0].platform != "cpu":
raise ValueError(
f"The function returned array residing on the device of "
f"kind `{devices[0].platform}`, expected `gpu`."
f"kind `{devices[0].platform}`, expected `cpu`."
)
return jax.dlpack.to_dlpack(tensor)

Expand Down
12 changes: 8 additions & 4 deletions dali/test/python/jax_plugin/test_jax_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,10 @@ def pipeline():


@restrict_python_version(3, 9)
def test_wrong_device_output():
@params("cpu", "gpu")
def test_wrong_device_output(device):

other_device = "gpu" if device == "cpu" else "cpu"

@jax.vmap
def flip(image):
Expand All @@ -375,14 +378,15 @@ def flip(image):
@pipeline_def(batch_size=11, device_id=0, num_threads=4, seed=42, enable_conditionals=True)
def pipeline():
img, _ = fn.readers.file(name="Reader", file_root=images_dir, random_shuffle=True, seed=42)
img = fn.decoders.image(img, device="mixed")
img = fn.decoders.image(img, device="cpu" if device == "cpu" else "mixed")
img = fn.resize(img, size=(224, 224))
return dax.fn.jax_function(jax.jit(flip, backend="cpu"))(img)
return dax.fn.jax_function(jax.jit(flip, backend=other_device), device=device)(img)

p = pipeline()
p.build()
with assert_raises(
RuntimeError, glob="*array residing on the device of kind `cpu`, expected `gpu`*"
RuntimeError,
glob=f"*array residing on the device of kind `{other_device}`, expected `{device}`*",
):
p.run()

Expand Down
16 changes: 15 additions & 1 deletion docs/examples/custom_operations/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,5 +56,19 @@
"Running custom operations written as Numba JIT-compiled functions",
),
),
doc_entry(
"jax_operator_basic.ipynb",
op_reference(
"plugin.jax.fn.jax_function",
"Running custom JAX augmentations in DALI",
),
),
doc_entry(
"jax_operator_multi_gpu.ipynb",
op_reference(
"plugin.jax.fn.jax_function",
"Running JAX augmentations on multiple GPUs",
),
),
],
)
448 changes: 448 additions & 0 deletions docs/examples/custom_operations/jax_operator_basic.ipynb

Large diffs are not rendered by default.

287 changes: 287 additions & 0 deletions docs/examples/custom_operations/jax_operator_multi_gpu.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is counterpart of the jax_operator_multi_gpu notebook,
# the notebook and this script run as a group,
# the notebook runs as process 0, this code runs as process 1


from functools import partial
import os

import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental.shard_map import shard_map

import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator
from nvidia.dali.plugin.jax.fn import jax_function


os.environ["CUDA_VISIBLE_DEVICES"] = "1"

jax.distributed.initialize(
coordinator_address="localhost:12321",
num_processes=2,
process_id=1,
)

assert len(jax.devices()) == 2
assert len(jax.local_devices()) == 1

mesh = Mesh(jax.devices(), axis_names=("batch"))
sharding = NamedSharding(mesh, PartitionSpec("batch"))

dogs = [f"../data/images/dog/dog_{i}.jpg" for i in range(1, 9)]
kittens = [f"../data/images/kitten/cat_{i}.jpg" for i in range(1, 9)]


@data_iterator(
output_map=["images"],
sharding=sharding,
)
def iterator_function(shard_id, num_shards):
assert num_shards == 2
jpegs, _ = fn.readers.file(files=dogs if shard_id == 0 else kittens, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, size=(244, 244))

# mixup images between shards
images = global_mixup(images)
return images


@jax_function(sharding=sharding)
@jax.jit
@partial(
shard_map,
mesh=sharding.mesh,
in_specs=PartitionSpec("batch"),
out_specs=PartitionSpec("batch"),
)
@jax.vmap
def global_mixup(sample):
mixed_up = 0.5 * sample + 0.5 * jax.lax.pshuffle(sample, "batch", [1, 0])
mixed_up = jax.numpy.clip(mixed_up, 0, 255)
return jax.numpy.array(mixed_up, dtype=jax.numpy.uint8)


local_batch_size = 8
num_shards = 2

iterator = iterator_function(batch_size=num_shards * local_batch_size, num_threads=4)
batch = next(iterator)
5 changes: 3 additions & 2 deletions docs/plugins/jax_fn.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Running JAX in DALI pipeline
============================

See the :meth:`plugin.jax.fn.jax_function <nvidia.dali.plugin.jax.fn.jax_function>` to run
JAX functions inside DALI pipeline or JAX plugin iterator.
See the :meth:`plugin.jax.fn.jax_function <nvidia.dali.plugin.jax.fn.jax_function>` for
documentation and examples on running custom JAX augmentations as a part of DALI
iterator or pipeline.
3 changes: 2 additions & 1 deletion qa/TL0_jupyter/test_jax.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#!/bin/bash -e
# used pip packages
pip_packages='jupyter numpy jax flax'
pip_packages='jupyter numpy matplotlib<3.5.3 jax flax'
target_dir=./docs/examples

test_body() {
test_files=(
"frameworks/jax/jax-basic_example.ipynb"
"custom_operations/jax_operator_basic.ipynb"
)
for f in ${test_files[@]}; do
jupyter nbconvert --to notebook --inplace --execute \
Expand Down

0 comments on commit 2bfb162

Please sign in to comment.