forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmhlo_helpers.py
60 lines (56 loc) · 2.4 KB
/
mhlo_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Helpers for building MHLO operators
from typing import Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import numpy as np
def custom_call(call_target_name: str, out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
operand_layouts: Sequence[Sequence[int]],
result_layouts: Sequence[Sequence[int]],
backend_config: Optional[str] = None,
has_side_effect: bool = False,
api_version: int = 2,
) -> Union[ir.Value, Sequence[ir.Value]]:
"""Less-verbose helper for building an MHLO custom call op.
Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
may be able to go away.
"""
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
(out_types if len(out_types) == 1 else
[ir.TupleType.get_tuple(out_types)]),
operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=ir.StringAttr.get(
"" if backend_config is None else backend_config),
api_version=ir.IntegerAttr.get(i32_type, api_version),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get(
[ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get())
for l in operand_layouts]),
result_layouts=ir.ArrayAttr.get(
[ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get())
for l in result_layouts]))
if len(out_types) == 1:
return out.result
else:
return [mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(len(out_types))]