forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhlo_helpers.py
248 lines (221 loc) · 10.1 KB
/
hlo_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A small library of helpers for use in jaxlib to build MLIR operations."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
from typing import Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
_dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
def dtype_to_ir_type(dtype) -> ir.Type:
return _dtype_to_ir_type_factory[np.dtype(dtype)]()
def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))
# When we generate custom calls with dynamic shapes we have to pass
# both the result_types, with ir.ShapedType.get_dynamic_size in place of
# the dynamic dimensions, and also result_shapes, which are ir.Value
# representing 1D int32 tensors. If all the shapes are static we can use
# result_shapes=None. We first construct for each result a pair with the shape
# and element type, the shape containing either integer or ir.Value.
DimensionSize = Union[int, ir.Value] # an ir.Value if not static dimension
ShapeTypePair = tuple[Sequence[DimensionSize], ir.Type]
def mk_result_types_and_shapes(
shape_type_pairs: Sequence[ShapeTypePair]
) -> tuple[list[ir.Type], list[ir.Value] | None]:
result_types: list[ir.Type] = []
result_shapes: list[ir.Value] = []
has_dynamic_shapes = any(
any(not isinstance(d, int) for d in rshape)
for rshape, _ in shape_type_pairs)
for (rshape, rtype) in shape_type_pairs:
if has_dynamic_shapes:
result_shapes.append(shape_tensor(rshape))
result_types.append(
ir.RankedTensorType.get(
[d if isinstance(d, int) else ir.ShapedType.get_dynamic_size()
for d in rshape],
rtype))
return (result_types,
result_shapes if has_dynamic_shapes else None)
# TODO(necula): share this with mlir.shape_tensor
def shape_tensor(sizes: Sequence[int | ir.Value]) -> ir.Value:
int1d = shape_dtype_to_ir_type((1,), np.int32)
i32_type = shape_dtype_to_ir_type((), np.int32)
def dim_to_i32x1(d):
if type(d) is int:
return hlo_const(np.array([d], dtype=np.int32))
else:
if d.type != i32_type:
d = hlo.convert(i32_type, d)
return hlo.reshape(int1d, d)
ds = [dim_to_i32x1(sz) for sz in sizes]
if not ds:
return hlo_const(np.array([], np.int32))
elif len(ds) == 1:
return ds[0]
else:
return hlo.concatenate(
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0))
def hlo_const(x: np.ndarray) -> ir.Value:
assert isinstance(x, np.ndarray)
return hlo.constant(
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)))
def hlo_u8(x: int):
return hlo_const(np.array(x, dtype=np.uint8))
def hlo_s32(x: int):
return hlo_const(np.array(x, dtype=np.int32))
def ensure_hlo_s32(x: DimensionSize):
return hlo_s32(x) if isinstance(x, int) else x
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
if type(x) is int:
if type(y) is int:
return min(x, y)
x = hlo_s32(x)
if type(y) is int:
y = hlo_s32(y)
return hlo.minimum(x, y)
def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
if type(x) is int:
if type(y) is int:
return x + y
x = hlo_s32(x)
if type(y) is int:
y = hlo_s32(y)
return hlo.add(x, y)
# TODO(necula): this is identical with mlir.custom_call, but meant for use
# in jaxlib. Find a way to share these implementations.
def custom_call(
call_target_name: str,
*,
result_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
backend_config: str | bytes | dict[str, ir.Attribute] = "",
has_side_effect: bool = False,
result_shapes: Sequence[ir.Value] | None = None,
called_computations: Sequence[str] = (),
api_version: int = 2,
operand_output_aliases: dict[int, int] | None = None,
operand_layouts: Sequence[Sequence[int]] | None = None,
result_layouts: Sequence[Sequence[int]] | None = None,
extra_attributes: dict[str, ir.Attribute] | None = None,
) -> ir.Operation:
"""Helper function for building an hlo.CustomCall.
Args:
call_target_name: the name of the custom call target
result_types: the MLIR types of the results of the custom call
operands: the MLIR IR values that are arguments to the custom call
backend_config: an opaque string passed to the custom call kernel
has_side_effect: if True, marks the custom call as effectful
result_shapes: tensors that represent the result shapes, to be used when
the results have dynamic shapes. If not-None, its length must match the
number of the results.
called_computations: the list of function names called by the custom call.
api_version: the ABI contract version of the custom call
operand_output_aliases: a dict mapping operand numbers to outputs they alias
operand_layouts: a sequence of layouts (dimension orders) for each operand
result_layouts: a sequence of layouts (dimension orders) for each result
extra_attributes: additional IR attributes to apply to the custom_call.
"""
operands = list(operands)
if backend_config is None:
backend_config_attr = ir.StringAttr.get("")
elif isinstance(backend_config, (str, bytes)):
backend_config_attr = ir.StringAttr.get(backend_config)
elif isinstance(backend_config, dict):
# TODO(necula): it seems that the CustomCallOp constructor requires that
# backend_config_attr be a string attribute, even though in some cases we
# need it to be a DictAttr, e.g., for ApproxTopK on TPU.
# "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
# To workaround this limitation we first set it to the empty string and we
# use an unregistered attribute mhlo.backend_config to hold the DictAttr.
# We must also use api_version=1 to ensure that mhlo.backend_config is
# handled properly.
backend_config_attr = ir.StringAttr.get("")
api_version = 1
else:
raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
attributes = dict(
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=backend_config_attr,
api_version=ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), api_version),
called_computations=ir.ArrayAttr.get(
[ir.FlatSymbolRefAttr.get(name) for name in called_computations]
),
)
if operand_output_aliases is not None:
attributes["output_operand_aliases"] = ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(
# if len(result_types) == 1 then the aliasing refers implicitly to
# the only output.
output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
operand_index=input_idx,
operand_tuple_indices=[],
)
for input_idx, output_idx in (operand_output_aliases.items() or ())
])
if extra_attributes is not None:
attributes.update(extra_attributes)
if result_shapes is not None:
# We add the result_shapes at the end of the operands, and must pass
# the indices_of_output_operands attribute. This attribute is not yet
# accepted by the CustomCall constructor, so we use build_generic
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
dtype=np.int64))
if operand_layouts is not None:
assert len(operand_layouts) == len(operands), (operand_layouts, operands)
operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
operands = list(operands) + list(result_shapes)
if operand_layouts is not None:
attributes["operand_layouts"] = ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in operand_layouts
])
if result_layouts is not None:
assert result_layouts is not None
assert len(result_layouts) == len(result_types), (
result_layouts, result_types)
attributes["result_layouts"] = ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in result_layouts
])
op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
attributes=attributes)
if isinstance(backend_config, dict):
backend_config_attr = ir.DictAttr.get(backend_config)
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
return op