The FLAX documentation can be found here: https://flax.readthedocs.io/en/latest/
- Clone the
flax
repository withgit clone https://github.com/google/flax.git
. - In the main
flax
folder, install the required dependencies usingpip install -r docs/requirements.txt
. - Install
pandoc
:pip install pandoc
. - [Optional] If you need to make any local changes to the docs, create and switch to a branch. Make your changes to the docs in that branch.
- To build the docs, in the
flax/docs
folder run the make script:make html
. Alternatively, installentr
, which helps run arbitrary commands when files change. Then runfind ../ ! -regex '.*/[\.|\_].*' | entr -s 'make html'
. - If the build is successful, you should get the
The HTML pages are in _build/html.
message. You can preview the docs inflax/docs/_build/html
.
We use doctest
blocks for embedded code in documents, that are also
tested. Learn more at https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html
To run tests locally, run make doctest
Our documentation is written in reStructuredText for Sphinx. It is a meta-language that is compiled into online documentation. For more details, check out Sphinx's documentation. As a result, our docstrings adhere to a specific syntax that has to be kept in mind. Below we provide some guidelines.
To learn how to contribute to Jupyter Notebooks or other formats in Flax docs, refer to the dedicated Contributing page.
Docstring should be informative. We prefer to err on the side of too much
documentation than too little. For instance, providing a one-line explanation
to a new Module
which implements new functionality is not sufficient.
Furthermore, we highly encourage adding examples to your docstrings, so users can directly see how code can be used.
We use doctest syntax for
writing examples in documentation. These examples are ran as tests as part of
our CI process. In order to write doctest
code in your documentation, please
use the following notation:
# Example code::
#
# def sum(a, b):
# return a + b
#
# sum(0, 1)
The Example code
string at the beginning can be replaced by anything as long
as there are two semicolons and a newline following it, and the code is
indented.
When writing code font in a docstring, please use double backticks. Example:
# This returns a ``str`` object.
Note that argument names and objects like True, None or any strings should
usually be put in code
.
It is possible to create cross-references to other classes, functions, and
methods. In the following, obj_typ
is either class
, func
, or meth
.
# First method:
# <obj_type>:`path_to_obj`
# Second method:
# :<obj_type>:`description <path_to_obj>`
You can use the second method if the path_to_obj
is very long. Some examples:
# Create: a reference to class flax.linen.Module.
# :class:`flax.linen.Module`
# Create a reference to local function my_func.
# :func:`my_func`
# Create a reference "Module.apply()" to method flax.linen.Module.apply.
# :meth:`Module.apply() <flax.linen.Module.apply>` #
To creata a hyperlink, use the following syntax:
# Note the double underscore at the end:
# `Link to Google <http://www.google.com>`__
- Class attributes should be specified using the
Attributes:
tag. - Method argument should be specified using the
Args:
tags. - All attributes and arguments should have types.
Here is an example from our library:
class DenseGeneral(Module):
"""A linear transformation with flexible axes.
Attributes:
features: int or tuple with number of output features.
axis: int or tuple with axes to apply the transformation on. For instance,
(-2, -1) will apply the transformation to the last two axes.
batch_dims: tuple with batch axes.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
"""
features: Union[int, Iterable[int]]
axis: Union[int, Iterable[int]] = -1
batch_dims: Iterable[int] = ()
use_bias: bool = True
dtype: Dtype = jnp.float32
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
precision: Any = None
@compact
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
...