forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_partitioning_sharding_rule_test.py
468 lines (406 loc) · 18.9 KB
/
custom_partitioning_sharding_rule_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
from jax._src.custom_partitioning_sharding_rule import ArrayMapping, BATCHING, CompoundFactor, sdy_sharding_rule_to_mlir, str_to_sdy_sharding_rule, SdyShardingRule
from jax._src.lib.mlir.dialects import hlo as stablehlo
class SdyShardingRuleTest(jtu.JaxTestCase):
def test_compound_factor_not_enough_factors(self):
with self.assertRaisesRegex(ValueError, "A compound factor should contain at least two factors"):
CompoundFactor("i")
def test_compound_factor_batching_now_allowed(self):
with self.assertRaisesRegex(ValueError, "Ellipsis can't be used in a compound factor"):
CompoundFactor(BATCHING, "i")
def test_compound_factor_element_not_a_str(self):
with self.assertRaisesRegex(ValueError, "Each element of CompoundFactor must be a str"):
CompoundFactor("i", 2)
def test_compound_factor_str(self):
c = CompoundFactor("i", "j", "k")
self.assertEqual(str(c), "('i', 'j', 'k')")
def test_value_mapping_element_not_a_str_or_compound_factor(self):
with self.assertRaisesRegex(ValueError, "Each element of ArrayMapping must be a str or CompoundFactor"):
ArrayMapping(CompoundFactor("i", "j"), 3)
def test_value_mapping_factor_name_not_start_with_letter(self):
with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"):
ArrayMapping("3i", "j")
def test_value_mapping_ellipsis_not_first(self):
with self.assertRaisesRegex(ValueError, "Ellipsis can only be used at the beginning of a dimension"):
ArrayMapping("i_j", BATCHING)
def test_value_mapping_str(self):
v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k")
self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')")
def test_sdy_sharding_rule_factor_size_not_used(self):
with self.assertRaisesRegex(ValueError, "Factor k is not used"):
SdyShardingRule(("i",), ("j",), k=10)
def test_sdy_sharding_rule_factor_sizes_missing(self):
with self.assertRaisesRegex(
ValueError,
"Factor k is only used in compound factors; must specify its size"):
SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")),
(ArrayMapping(CompoundFactor("j", "k")),))
def test_sdy_sharding_rule_factor_size_not_necessary(self):
with self.assertRaisesRegex(
ValueError,
"Factor i represents a whole dimension; do not specify its size"):
SdyShardingRule((ArrayMapping("i"),), (ArrayMapping("j"),), i=10)
def test_sdy_sharding_rule_compound_factor_size_not_necessary(self):
with self.assertRaisesRegex(
ValueError,
"Factor i represents a whole dimension; do not specify its size"):
SdyShardingRule((ArrayMapping(CompoundFactor("i", "j")),),
(ArrayMapping("i"),), i=10, j=20)
def test_sdy_sharding_rule_str(self):
r = SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")),
(ArrayMapping(CompoundFactor("j", "k")),), k=10)
self.assertEqual(str(r), "SdyShardingRule((('i',), ('j',)), ((('j', 'k'),),), {'k': 10})")
class StrToSdyShardingRuleTest(jtu.JaxTestCase):
def test_rule_is_not_a_str(self):
with self.assertRaisesRegex(TypeError, "rule must be a str"):
str_to_sdy_sharding_rule(1)
def test_factor_sizes_is_not_a_proper_dict(self):
with self.assertRaisesRegex(
TypeError, "factor_sizes must be a dict of str to int"):
str_to_sdy_sharding_rule("i->j", i="j")
def test_sharding_rule_ellipsis_not_complete(self):
with self.assertRaisesRegex(
ValueError, "Character '.' must be used inside ellipsis '...'"):
str_to_sdy_sharding_rule(".i -> j")
def test_sharding_rule_invalid_factor_name(self):
with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"):
str_to_sdy_sharding_rule("2i -> j")
def test_sharding_rule_missing_results(self):
with self.assertRaisesRegex(ValueError, "There is no -> in rule"):
str_to_sdy_sharding_rule("i")
def test_sharding_rule_inbalenced_brackets(self):
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
str_to_sdy_sharding_rule("i j, k)->j")
def test_sharding_rule_inbalenced_brackets2(self):
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
str_to_sdy_sharding_rule("i (j k->j")
def test_sharding_rule_empty_compound_dim(self):
with self.assertRaisesRegex(
ValueError, "Brackets should contain at least two factors"):
str_to_sdy_sharding_rule("i ( ) j k->j")
def test_sharding_rule_one_factorcompound_dim(self):
with self.assertRaisesRegex(
ValueError, "Brackets should contain at least two factors"):
str_to_sdy_sharding_rule("i (j ) k->j")
def test_sharding_rule_nested_brackets(self):
with self.assertRaisesRegex(
ValueError, "Compound factors should be one level"):
str_to_sdy_sharding_rule("i (j (k))->j")
def test_sharding_rule_unknown_char(self):
with self.assertRaisesRegex(ValueError, "Unknown character"):
str_to_sdy_sharding_rule("i; j->j")
def test_sharding_rule_unknown_single_char_ellipse(self):
with self.assertRaisesRegex(ValueError, "Unknown character"):
str_to_sdy_sharding_rule("…j->…j")
def test_sharding_rule_ellipsis_not_leading_dim(self):
with self.assertRaisesRegex(
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
str_to_sdy_sharding_rule("i ... -> j")
def test_sharding_rule_ellipsis_inside_compound_dim(self):
with self.assertRaisesRegex(
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
str_to_sdy_sharding_rule("i, (..., j) -> j")
def test_sharding_rule_scalar_operand_scalar_result(self):
rule = str_to_sdy_sharding_rule("->")
self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})")
def test_sharding_rule_one_scalar_operand(self):
rule = str_to_sdy_sharding_rule("i j, , k->j")
self.assertEqual(
str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})")
def test_sharding_rule_factor_elementwise_add(self):
rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j")
self.assertEqual(
str(rule),
"SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i',"
" 'j'),), {})")
def test_sharding_rule_factor_vector_scalar_add(self):
rule = str_to_sdy_sharding_rule("...i, -> ...i")
self.assertEqual(
str(rule),
"SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})")
def test_sharding_rule_factor_reshape_combining(self):
rule = str_to_sdy_sharding_rule("i j -> (i j)")
self.assertEqual(
str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})")
def test_sharding_rule_factor_reshape_reordering(self):
rule = str_to_sdy_sharding_rule("(j i) -> (i j)", i=10, j=20)
self.assertEqual(
str(rule),
"SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':"
" 20})")
def test_sharding_rule_factor_compound_then_individual(self):
rule = str_to_sdy_sharding_rule("(i j) (j k) i -> j k")
self.assertEqual(
str(rule),
"SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})")
def test_sharding_rule_factor_individual_then_compound(self):
rule = str_to_sdy_sharding_rule("i j k -> (i j) (j k)")
self.assertEqual(
str(rule),
"SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})")
def test_sharding_rule_factor_infer_k(self):
rule = str_to_sdy_sharding_rule("i_ (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20)
self.assertEqual(
str(rule),
"SdyShardingRule((('i_', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))"
",), {'k': 10, 'm': 10, 'bar_24': 20})")
class SdyShardingRuleConversionTest(jtu.JaxTestCase):
def run(self, result=None):
with ir.Context() as ctx, ir.Location.unknown(ctx):
sdy.register_dialect(ctx)
stablehlo.register_dialect(ctx)
module = ir.Module.create()
with ir.InsertionPoint(module.body):
super().run(result)
def get_tensor_type(self, shape):
return ir.RankedTensorType.get(shape, ir.F32Type.get())
def create_tensor_value(self, shape):
return ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type(shape)],
attributes=dict(call_target_name=ir.StringAttr.get("dummy_target"))
).result
def test_conversion_rule_op_mismatch_in_operands_num(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = str_to_sdy_sharding_rule("i j-> i j")
with self.assertRaisesRegex(
ValueError,
"Sharding rule has 1 operands, but the operation has 2 operands"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_rule_op_mismatch_in_operands_rank(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = str_to_sdy_sharding_rule("i j, i j k-> i j")
with self.assertRaisesRegex(
ValueError,
"Sharding rule 1th operand has rank 3, but the operation 1th "
"operand has rank 2"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_rule_op_mismatch_in_results_num(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0,
opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = str_to_sdy_sharding_rule("i j, i j -> i j, i j")
with self.assertRaisesRegex(
ValueError,
"Sharding rule has 2 results, but the operation has 1 results"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_rule_op_mismatch_in_results_dim(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("i j, i j -> i j k")
with self.assertRaisesRegex(
ValueError,
"Sharding rule 0th result has rank 3, but the operation 0th "
"result has rank 2"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_factor_has_two_sizes(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 64))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("i j, i j -> i j")
with self.assertRaisesRegex(
ValueError,
"Factor j corresponds to two sizes: 32 and 64"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_batching_dim_has_two_sizes(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 64))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("..., ... -> ...")
with self.assertRaisesRegex(
ValueError,
"Batching dimension 1 corresponds to two sizes: 32 and 64"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,],)
def test_conversion_invalid_batching_dim(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = str_to_sdy_sharding_rule("... i j k, ... i j k -> ... i j k")
with self.assertRaisesRegex(
ValueError,
"Sharding rule 0th operand has rank 3, but the operation 0th operand has rank 2"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_compound_dimension_size_mismatch(self):
opnd = self.create_tensor_value((2, 4))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((9,))],
operands=[opnd,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("i j -> (i j)")
with self.assertRaisesRegex(
ValueError,
"0th result actual size 9 doesn't match the size 8 derived from the"
" compound factors"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type],
[result.result.type,])
def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16,))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("..., ... -> ...")
with self.assertRaisesRegex(
ValueError,
"Ellipsis represents different number of leading dimensions 2 and 1"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_compound_then_individual(self):
opnd = self.create_tensor_value((8,))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((2,4))],
operands=[opnd,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("(i j) -> i j")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>")
def test_conversion_elementwise_rule_scalar_instance(self):
opnd0 = self.create_tensor_value(())
opnd1 = self.create_tensor_value(())
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type(())],
operands=[opnd0, opnd1],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = str_to_sdy_sharding_rule("..., ... -> ...")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([], [])->([])>")
def test_conversion_elementwise_rule_2D_instance(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = str_to_sdy_sharding_rule("..., ... -> ...")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>")
def test_conversion_vector_scalar_add_2D_instance(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value(())
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = str_to_sdy_sharding_rule("..., -> ...")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>")
def test_conversion_reshape_rule(self):
opnd0 = self.create_tensor_value((2, 4))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((8,))],
operands=[opnd0,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("i j -> (i j)")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>")
def test_conversion_contracting_dim_matmul(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((32, 8))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 8))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("... contracting_dim, contracting_dim k -> ... k")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>")
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())