Skip to content

Latest commit

 

History

History
45 lines (32 loc) · 1.04 KB

index.rst

File metadata and controls

45 lines (32 loc) · 1.04 KB

Pallas: a JAX kernel language

Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. It aims to provide fine-grained control over the generated code, combined with the high-level ergonomics of JAX tracing and the jax.numpy API.

This section contains tutorials, guides and examples for using Pallas. See also the :class:`jax.experimental.pallas` module API documentation.

Warning

Pallas is experimental and is changing frequently. See the :ref:`pallas-changelog` for the recent changes.

You can expect to encounter errors and unimplemented cases, e.g., when lowering of high-level JAX concepts that would require emulation, or simply because Pallas is still under development.

.. toctree::
   :caption: Guides
   :maxdepth: 2

   quickstart
   design
   grid_blockspec


.. toctree::
   :caption: Platform Features
   :maxdepth: 2

   tpu/index

.. toctree::
   :caption: Design Notes
   :maxdepth: 1

   async_note

.. toctree::
   :caption: Other
   :maxdepth: 1

   CHANGELOG