This document focuses on performance tips for neural network workloads
On recent GPU generations, such as the Nvidia A100 generation or later, it can
be a good idea to perform most computations in bfloat16
precision. For
example, if using Flax, instantiate Dense
layers using flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
. Here are some
code examples:
- In the Flax LM1B
example,
Dense
modules are instantiated with a configurable dtype which defaults to bfloat16. - In MaxText,
DenseGeneral
modules are also instantiated with a configurable dtype that defaults to bfloat16.
JAX-Toolbox also has a page on [NVIDIA XLA performance FLAGS](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/GPU_performance.md).
The existence and exact behavior of XLA flags may be jaxlib
-version dependent.
As of jaxlib==0.4.18
(released Oct 6
2023), setting these XLA flags can
improve performance. Some are related to communication between GPUs, and so are
only relevant when running computations on multiple devices, while others are
related to code generation on each device.
Some of these may be set by default in future releases.
These flags can be set via the XLA_FLAGS
shell environment variable. For
example, we can add this to the top of a Python file:
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
For more examples, see also XLA Flags recommended for Pax training on Nvidia GPUs.
- --xla_gpu_enable_triton_softmax_fusion This flag enables an automatic softmax fusion, based on pattern-matching backed by Triton code generation. The default value is False.
- --xla_gpu_triton_gemm_any Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False.
- --xla_gpu_enable_async_collectives This flag enables the collective ops
such as
AllReduce
,AllGather
,ReduceScatter
andCollectivePermute
to be asynchronous. Asynchronous communication can overlap cross-core communication with computation. The default value is False. - --xla_gpu_enable_latency_hiding_scheduler This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.
- --xla_gpu_enable_pipelined_collectives When using pipeline parallelism,
this flag enables overlapping the (i+1)-th layer weight
AllGather
with the i-th layer computation. It also enables overlapping (i+1)-th layer weightReduce
/ReduceScatter
with i-th layer's computation. The default value is False. There are some bugs when this flag is turned on. - --xla_gpu_collective_permute_decomposer_threshold This flag is useful when
performing GSPMD pipelining. Setting a
nonzero threshold decomposes
CollectivePermute
s intoCollectivePermuteReceiveDone
andCollectivePermuteSendDone
pairs, so that computation can be performed between each correspondingReceiveDone
/SendDone
pair and hence achieve more overlap. By default the threshold is 0 and there is no decomposition. Setting it to threshold > 0 such as--xla_gpu_collective_permute_decomposer_threshold=1024
can enable this feature. - --xla_gpu_all_gather_combine_threshold_bytes
--xla_gpu_reduce_scatter_combine_threshold_bytes
--xla_gpu_all_reduce_combine_threshold_bytes
These flags tune when to combine multiple small
AllGather
/ReduceScatter
/AllReduce
into one bigAllGather
/ReduceScatter
/AllReduce
to reduce time spent on cross-device communication. For example, for theAllGather
/ReduceScatter
thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer's weightAllGather
/ReduceScatter
. By default, thecombine_threshold_bytes
is set to 256.
These Nvidia NCCL flag values may be useful for single-host multi-device computations on Nvidia GPUs:
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
These NCCL flags could improve single-host communication speed. These flags don't seem useful for multi-host communication yet.
We recommand using one process per GPU and not one per node. In some
cases, this can speed up jitted computation. The
{func}jax.distributed.initialize
API will automatically understand
that configuration when run under SLURM. However, this only a rule of
thumb and it may be useful to test both one process per GPU and one
process per node on your use case.