Skip to content

Commit

Permalink
add losses to the docs, fix black failur (ml-explore#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Dec 9, 2023
1 parent 430bfb4 commit 2520dbc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
10 changes: 10 additions & 0 deletions docs/src/python/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,13 @@ simple functions.
gelu_fast_approx
relu
silu

Loss Functions
--------------

.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst

losses.cross_entropy
losses.l1_loss
9 changes: 3 additions & 6 deletions python/mlx/nn/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@ def cross_entropy(
Args:
logits (mx.array): The predicted logits.
targets (mx.array): The target values.
axis (int, optional): The axis over which to compute softmax. Defaults to -1.
reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
'none': no reduction will be applied.
'mean': the sum of the output will be divided by the number of elements in the output.
'sum': the output will be summed.
Defaults to 'none'.
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed cross entropy loss.
Expand Down

0 comments on commit 2520dbc

Please sign in to comment.