Skip to content

Commit b8ae8e3

Browse files
Eugene Burmakojax authors
Eugene Burmako
authored and
jax authors
committedDec 16, 2022
(NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order to prepare for the upcoming migration. Unchanged occurrences: 1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo" argument value in Lowering.as_text and Lowering.compiler_ir. 2) Documentation (changelog, JEPs, IR examples, etc). 3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence, so both are necessary to disambiguate. PiperOrigin-RevId: 495771153
1 parent 523c6f7 commit b8ae8e3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+991
-882
lines changed
 

‎build/build_wheel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def prepare_wheel(sources_path):
169169
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
170170
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
171171
copy_to_jaxlib("__main__/jaxlib/lapack.py")
172-
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
172+
copy_to_jaxlib("__main__/jaxlib/hlo_helpers.py")
173173
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
174174
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
175175
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")

‎jax/_src/ad_checkpoint.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from jax._src.lax import lax as lax_internal
3535
from jax._src.lax import convolution as lax_convolution
3636
from jax._src.lib import xla_client as xc
37-
from jax._src.lib.mlir.dialects import mhlo
37+
from jax._src.lib.mlir.dialects import hlo
3838
from jax._src.traceback_util import api_boundary
3939
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
4040
safe_zip, merge_lists, weakref_lru_cache)
@@ -623,9 +623,9 @@ def _optimization_barrier_lowering_rule(ctx, *args):
623623
flat_barrier_types = util.flatten(barrier_types)
624624
flat_args = mlir.flatten_lowering_ir_args(args)
625625
if xc.mlir_api_version < 40:
626-
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
626+
barrier_op = hlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
627627
else:
628-
barrier_op = mhlo.OptimizationBarrierOp(flat_args)
628+
barrier_op = hlo.OptimizationBarrierOp(flat_args)
629629
return util.unflatten(barrier_op.results, map(len, barrier_types))
630630

631631
def _optimization_barrier(arg):

0 commit comments

Comments
 (0)
Please sign in to comment.