Skip to content

Commit

Permalink
Add an experimental lax.top_k operator. (jax-ml#2280)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp authored Feb 21, 2020
1 parent 8372a70 commit af0967f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
19 changes: 18 additions & 1 deletion jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,11 @@ def sort_key_val(keys, values, dimension=-1):
sorted_keys, sorted_values = result
return sorted_keys, sorted_values

def top_k(operand, k):
k = int(k)
if k < 0:
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
return top_k_p.bind(operand, k=k)

def tie_in(x, y):
return tie_in_p.bind(x, y)
Expand Down Expand Up @@ -4034,7 +4039,6 @@ def _sort_batch_rule(batched_args, batch_dims, dimension):
ad.defjvp(sort_p, _sort_jvp_rule)
batching.primitive_batchers[sort_p] = _sort_batch_rule


def _sort_key_val_abstract_eval(keys, values, dimension):
return raise_to_shaped(keys), raise_to_shaped(values)

Expand Down Expand Up @@ -4106,6 +4110,19 @@ def _sort_key_val_batch_rule(batched_args, batch_dims, dimension):
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule


def _top_k_abstract_eval(operand, k):
if len(operand.shape) == 0:
raise TypeError("top_k operand must have >= 1 dimension, got {}"
.format(operand.shape))
return raise_to_shaped(operand), ShapedArray(operand.shape, onp.int32)

top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
xla.translations[top_k_p] = partial(standard_translate, 'top_k')


def _tie_in_transpose_rule(t):
return [ad_util.zero, t]

Expand Down
24 changes: 24 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import partial
import itertools
from typing import Optional, cast
import unittest
from unittest import skip, SkipTest

from absl.testing import absltest
Expand Down Expand Up @@ -1319,6 +1320,29 @@ def args_maker():
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
self._CheckAgainstNumpy(op, numpy_op, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
for dtype in [onp.float32, onp.int32, onp.uint32]
for shape in [(3,), (5, 3)]
for k in [1, 3]
for rng_factory in [jtu.rand_default]))
@unittest.skipIf(jax.lib.version <= (0, 1, 40), "Test requires jaxlib 0.1.40")
def testTopK(self, shape, dtype, k, rng_factory):
rng = rng_factory()
perm_rng = onp.random.RandomState(0)
def args_maker():
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
values = perm_rng.permutation(flat_values).reshape(shape)
return [values]
def reference_top_k(x):
bcast_idxs = onp.broadcast_to(onp.arange(shape[-1]), shape)
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
op = lambda vs: lax.top_k(vs, k=k)
self._CheckAgainstNumpy(op, reference_top_k, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
Expand Down

0 comments on commit af0967f

Please sign in to comment.