Skip to content

Latest commit

 

History

History
46 lines (37 loc) · 778 Bytes

jax.tree_util.rst

File metadata and controls

46 lines (37 loc) · 778 Bytes

jax.tree_util module

.. currentmodule:: jax.tree_util

.. automodule:: jax.tree_util

List of Functions

.. autosummary::
   :toctree: _autosummary

   Partial
   all_leaves
   build_tree
   register_dataclass
   register_pytree_node
   register_pytree_node_class
   register_pytree_with_keys
   register_pytree_with_keys_class
   register_static
   tree_flatten_with_path
   tree_leaves_with_path
   tree_map_with_path
   treedef_children
   treedef_is_leaf
   treedef_tuple
   keystr

Legacy APIs

These APIs are now accessed via :mod:`jax.tree`.

.. autosummary::
   :toctree: _autosummary

   tree_all
   tree_flatten
   tree_leaves
   tree_map
   tree_reduce
   tree_structure
   tree_transpose
   tree_unflatten