Skip to content

Commit

Permalink
Add JAX wrappers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 316634649
  • Loading branch information
mblondel committed Jun 30, 2020
1 parent c7353b3 commit 487fa7a
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 0 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ TensorFlow Example
<tf.Tensor: shape=(2, 3), dtype=float64, numpy= array([[3., 1., 2.], [2., 1., 3.]])>
```

JAX Example
-----------

```python
>>> import jax.numpy as jnp
>>> from jax_ops import soft_rank, soft_sort
>>> values = jnp.array([[5., 1., 2.], [2., 1., 5.]], dtype=jnp.float64)
>>> soft_sort(values, regularization_strength=1.0)
[[1.66666667 2.66666667 3.66666667]
[1.66666667 2.66666667 3.66666667]]
>>> soft_sort(values, regularization_strength=0.1)
[[1. 2. 5.]
[1. 2. 5.]]
>>> soft_rank(values, regularization_strength=2.0)
[[3. 1.25 1.75]
[1.75 1.25 3. ]]
>>> soft_rank(values, regularization_strength=1.0)
[[3. 1. 2.]
[2. 1. 3.]]
```

PyTorch Example
---------------

Expand Down Expand Up @@ -64,4 +85,5 @@ Reference

> Fast Differentiable Sorting and Ranking
> Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
> In proceedings of ICML 2020
> [arXiv:2002.08871](https://arxiv.org/abs/2002.08871)
109 changes: 109 additions & 0 deletions fast_soft_sort/jax_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX operators for soft sorting and ranking.
Fast Differentiable Sorting and Ranking
Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
https://arxiv.org/abs/2002.08871
"""

from . import numpy_ops
import jax
import numpy as np
import jax.numpy as jnp
from jax import tree_util


def _wrap_numpy_op(cls, **kwargs):
"""Converts NumPy operator to a JAX one."""

def _func_fwd(values):
"""Converts values to numpy array, applies function and returns array."""
dtype = values.dtype
values = np.array(values)
obj = cls(values, **kwargs)
result = obj.compute()
return jnp.array(result, dtype=dtype), tree_util.Partial(obj.vjp)

def _func_bwd(vjp, g):
g = np.array(g)
result = jnp.array(vjp(g), dtype=g.dtype)
return (result,)

@jax.custom_vjp
def _func(values):
return _func_fwd(values)[0]

_func.defvjp(_func_fwd, _func_bwd)

return _func


def soft_rank(values, direction="ASCENDING", regularization_strength=1.0,
regularization="l2"):
r"""Soft rank the given values (array) along the second axis.
The regularization strength determines how close are the returned values
to the actual ranks.
Args:
values: A 2d-array holding the numbers to be ranked.
direction: Either 'ASCENDING' or 'DESCENDING'.
regularization_strength: The regularization strength to be used. The smaller
this number, the closer the values to the true ranks.
regularization: Which regularization method to use. It
must be set to one of ("l2", "kl", "log_kl").
Returns:
A 2d-array, soft-ranked along the second axis.
"""
if len(values.shape) != 2:
raise ValueError("'values' should be a 2d-array "
"but got %r." % values.shape)

func = _wrap_numpy_op(numpy_ops.SoftRank,
regularization_strength=regularization_strength,
direction=direction,
regularization=regularization)

return jnp.vstack([func(val) for val in values])


def soft_sort(values, direction="ASCENDING",
regularization_strength=1.0, regularization="l2"):
r"""Soft sort the given values (array) along the second axis.
The regularization strength determines how close are the returned values
to the actual sorted values.
Args:
values: A 2d-array holding the numbers to be sorted.
direction: Either 'ASCENDING' or 'DESCENDING'.
regularization_strength: The regularization strength to be used. The smaller
this number, the closer the values to the true sorted values.
regularization: Which regularization method to use. It
must be set to one of ("l2", "log_kl").
Returns:
A 2d-array, soft-sorted along the second axis.
"""
if len(values.shape) != 2:
raise ValueError("'values' should be a 2d-array "
"but got %s." % str(values.shape))

func = _wrap_numpy_op(numpy_ops.SoftSort,
regularization_strength=regularization_strength,
direction=direction,
regularization=regularization)

return jnp.vstack([func(val) for val in values])
72 changes: 72 additions & 0 deletions tests/jax_ops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for jax_ops.py."""

import functools
import itertools
import unittest

from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
import jax.numpy as jnp
import jax

from jax.config import config
config.update("jax_enable_x64", True)

from fast_soft_sort import jax_ops

GAMMAS = (0.1, 1, 10.0)
DIRECTIONS = ("ASCENDING", "DESCENDING")
REGULARIZERS = ("l2", )


class JaxOpsTest(parameterized.TestCase):

def _test(self, func, regularization_strength, direction, regularization):

def loss_func(values):
soft_values = func(values,
regularization_strength=regularization_strength,
direction=direction,
regularization=regularization)
return jnp.sum(soft_values ** 2)

rng = np.random.RandomState(0)
values = jnp.array(rng.randn(5, 10))
mat = jnp.array(rng.randn(5, 10))
unitmat = mat / np.sqrt(np.vdot(mat, mat))
eps = 1e-5
numerical = (loss_func(values + 0.5 * eps * unitmat) -
loss_func(values - 0.5 * eps * unitmat)) / eps
autodiff = jnp.vdot(jax.grad(loss_func)(values), unitmat)
np.testing.assert_almost_equal(numerical, autodiff)


@parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
def test_soft_rank(self, regularization_strength, direction, regularization):
self._test(jax_ops.soft_rank,
regularization_strength, direction, regularization)

@parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
def test_soft_sort(self, regularization_strength, direction, regularization):
self._test(jax_ops.soft_sort,
regularization_strength, direction, regularization)


if __name__ == "__main__":
absltest.main()

0 comments on commit 487fa7a

Please sign in to comment.