Skip to content

Commit

Permalink
remove _reduce_sum from public jax.lax module
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Mar 9, 2022
1 parent 7890fb7 commit 6f51957
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
6 changes: 4 additions & 2 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,11 @@ def chooser_taylor_rule(primals_in, series_in, **params):
location_indicators = lax.convert_element_type(
lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
primal_dtype)
counts = lax._reduce_sum(location_indicators, axes)
counts = lax_internal._reduce_sum(location_indicators, axes)
def _reduce_chooser_taylor_rule(g):
return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
return lax.div(
lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
return chooser_taylor_rule
Expand Down
2 changes: 1 addition & 1 deletion jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
zeros_like_array as zeros_like_array,
)
from jax._src.lax.lax import (
_reduce_sum, _reduce_max, _reduce_min, _reduce_or, _reduce_and,
_reduce_max, _reduce_min, _reduce_or, _reduce_and,
_check_user_dtype_supported,
_upcast_fp16_for_computation, _broadcasting_shape_rule,
_eye, _tri, _delta, _ones, _zeros, _dilate_shape)
Expand Down
7 changes: 4 additions & 3 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from jax import lax
from jax import numpy as jnp
from jax import linear_util as lu
from jax._src import test_util as jtu
from jax._src.abstract_arrays import make_shaped_array
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
from jax.interpreters import partial_eval as pe

from jax._src import test_util as jtu
from jax._src.abstract_arrays import make_shaped_array
from jax._src.lax import lax as lax_internal

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -610,7 +611,7 @@ def test_staging_primitive_applications(self):
def f(x, y):
z = lax.mul(x, y)
w = lax.sin(z)
u = lax._reduce_sum(w, [0])
u = lax_internal._reduce_sum(w, [0])
return (u,)

jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
Expand Down
4 changes: 3 additions & 1 deletion tests/filecheck/array.filecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from jax import lax
import numpy as np

from jax._src.lax import lax as lax_internal

from jax.tests.filecheck.jax_filecheck_helpers import print_ir

jax.config.update("jax_enable_mlir", True)
Expand Down Expand Up @@ -60,7 +62,7 @@ def main(_):
# CHECK: mhlo.add
# CHECK: tensor<3xi32>
print_ir(np.empty([2, 3, 7], np.int32))(
partial(lax._reduce_sum, axes=(0, 2)))
partial(lax_internal._reduce_sum, axes=(0, 2)))

# CHECK-LABEL: TEST: reshape int32[2,3,7]
# CHECK: mhlo.reshape
Expand Down

0 comments on commit 6f51957

Please sign in to comment.