Skip to content

Commit

Permalink
[DOC] Fix references to JAX and numpy functions (PythonOT#475)
Browse files Browse the repository at this point in the history
* Fix ref to JAX

* Fix references to numpy.random.*

* Typo in CONTRIBUTING

* Removed :any: reference from func parameters

* Make markup of params consistent with other docstrings

* Mentioned latest open PR in RELEASES

* Fix See Also references for ot.factored

---------

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
kachayev and rflamary authored May 11, 2023
1 parent 8cc8dd2 commit 5faa4fb
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ method does to the data and a figure (coming from an example)
illustrating it.


This Contribution guide is strongly inpired by the one of the [scikit-learn](https://github.com/scikit-learn/scikit-learn) team.
This Contribution guide is strongly inspired by the one of the [scikit-learn](https://github.com/scikit-learn/scikit-learn) team.
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

- Fix circleci-redirector action and codecov (PR #460)
- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457)
- Major documentation cleanup (PR #462, #467)
- Major documentation cleanup (PR #462, #467, #475)
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
Expand Down
5 changes: 3 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,11 @@ def __getattr__(cls, name):

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
'matplotlib': ('http://matplotlib.org/', None),
'torch': ('https://pytorch.org/docs/stable/', None)}
'torch': ('https://pytorch.org/docs/stable/', None),
'jax': ('https://jax.readthedocs.io/en/latest/', None)}

sphinx_gallery_conf = {
'examples_dirs': ['../../examples', '../../examples/da'],
Expand Down
6 changes: 3 additions & 3 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def seed(self, seed=None):
This function follows the api from :any:`numpy.random.seed`
See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html
See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html
"""
raise NotImplementedError()

Expand All @@ -690,7 +690,7 @@ def rand(self, *size, type_as=None):
This function follows the api from :any:`numpy.random.rand`
See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html
"""
raise NotImplementedError()

Expand All @@ -700,7 +700,7 @@ def randn(self, *size, type_as=None):
This function follows the api from :any:`numpy.random.rand`
See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html
"""
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions ot/factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None,
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
regularized OT
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT
"""

nx = get_backend(Xa, Xb)
Expand Down
10 changes: 5 additions & 5 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
The regularization is selected with :any:`reg` (:math:`\lambda_r`) and :any:`reg_type`. By
The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By
default ``reg=None`` and there is no regularization. The unbalanced marginal
penalization can be selected with :any:`unbalanced` (:math:`\lambda_u`) and
:any:`unbalanced_type`. By default ``unbalanced=None`` and the function
penalization can be selected with `unbalanced` (:math:`\lambda_u`) and
`unbalanced_type`. By default ``unbalanced=None`` and the function
solves the exact optimal transport problem (respecting the marginals).
Parameters
Expand All @@ -46,12 +46,12 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:`R` either "KL", "L2", 'entropy', by default "KL"
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
unbalanced : float, optional
Unbalanced penalization weight :math:`\lambda_u`, by default None
(balanced OT)
unbalanced_type : str, optional
Type of unbalanced penalization unction :math:`U` either "KL", "L2", 'TV', by default 'KL'
Type of unbalanced penalization unction :math:`U` either "KL", "L2", "TV", by default "KL"
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Expand Down

0 comments on commit 5faa4fb

Please sign in to comment.