diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 114fd8a904..a0aa0bfad7 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -170,3 +170,13 @@ simple functions. gelu_fast_approx relu silu + +Loss Functions +-------------- + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + losses.cross_entropy + losses.l1_loss diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 067dcd6ddd..3445b686e0 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -12,12 +12,9 @@ def cross_entropy( Args: logits (mx.array): The predicted logits. targets (mx.array): The target values. - axis (int, optional): The axis over which to compute softmax. Defaults to -1. - reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. - 'none': no reduction will be applied. - 'mean': the sum of the output will be divided by the number of elements in the output. - 'sum': the output will be summed. - Defaults to 'none'. + axis (int, optional): The axis over which to compute softmax. Default: ``-1``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: mx.array: The computed cross entropy loss.