forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_lite_script_module.py
570 lines (454 loc) · 21.2 KB
/
test_lite_script_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
# Owner(s): ["oncall: mobile"]
import torch
import torch.utils.bundled_inputs
import io
from typing import Dict, List
import inspect
from torch.testing import FileCheck
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_quantization import (
AnnotatedSingleLayerLinearModel,
TwoLayerLinearModel,
AnnotatedNestedModel
)
from torch.testing._internal.common_quantization import QuantizationLiteTestCase
class TestLiteScriptModule(TestCase):
def getScriptExportImportCopy(self, m, save_mobile_debug_info=True, also_test_file=False):
m_scripted = torch.jit.script(m)
if not also_test_file:
buffer = io.BytesIO(m_scripted._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=save_mobile_debug_info))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
return mobile_module
with TemporaryFileName() as fname:
m_scripted._save_for_lite_interpreter(fname, _save_mobile_debug_info=save_mobile_debug_info)
mobile_module = _load_for_lite_interpreter(fname)
return mobile_module
def test_load_mobile_module(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
def forward(self, x):
return x + 10
input = torch.tensor([1])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_close(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
def test_save_mobile_module_with_debug_info_with_trace(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self, x, y):
return x * y
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
self.A1 = A()
def forward(self, x, y, z):
return self.A0(x, y) + self.A1(y, z)
for export_method in ['trace', 'script']:
x = torch.rand((2, 3))
y = torch.rand((2, 3))
z = torch.rand((2, 3))
if export_method == 'trace':
trace_module = torch.jit.trace(B(), [x, y, z])
else:
trace_module = torch.jit.script(B())
exported_module = trace_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
buffer = io.BytesIO(exported_module)
buffer.seek(0)
assert(b"callstack_debug_map.pkl" in exported_module)
mobile_module = _load_for_lite_interpreter(buffer)
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul"):
x = torch.rand((2, 3))
y = torch.rand((8, 10))
z = torch.rand((8, 10))
mobile_module(x, y, z)
with self.assertRaisesRegex(RuntimeError, r"Module hierarchy:top\(B\)::<unknown>.A1\(A\)::forward.aten::mul"):
x = torch.rand((2, 3))
y = torch.rand((2, 3))
z = torch.rand((8, 10))
mobile_module(x, y, z)
def test_load_mobile_module_with_debug_info(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
def forward(self, x):
return x + 5
input = torch.tensor([3])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_close(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
def test_find_and_run_method(self):
class MyTestModule(torch.nn.Module):
def forward(self, arg):
return arg
input = (torch.tensor([1]), )
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(*input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertFalse(has_bundled_inputs)
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
script_module, [input], [])
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertTrue(has_bundled_inputs)
bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
torch.testing.assert_close(script_module_result, mobile_module_result)
def test_method_calls_with_optional_arg(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
# opt arg in script-to-script invocation
def forward(self, x, two: int = 2):
return x + two
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
# opt arg in Python-to-script invocation
def forward(self, x, one: int = 1):
return self.A0(x) + one
script_module = torch.jit.script(B())
buffer = io.BytesIO(
script_module._save_to_buffer_for_lite_interpreter()
)
mobile_module = _load_for_lite_interpreter(buffer)
input = torch.tensor([5])
script_module_forward_result = script_module.forward(input)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(
script_module_forward_result,
mobile_module_forward_result
)
# change ref only
script_module_forward_result = script_module.forward(input, 2)
self.assertFalse(
(script_module_forward_result == mobile_module_forward_result)
.all()
.item()
)
# now both match again
mobile_module_forward_result = mobile_module.forward(input, 2)
torch.testing.assert_close(
script_module_forward_result,
mobile_module_forward_result
)
def test_unsupported_classtype(self):
class Foo():
def __init__(self):
return
def func(self, x: int, y: int):
return x + y
class MyTestModule(torch.nn.Module):
def forward(self, arg):
f = Foo()
return f.func(1, 2)
script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(RuntimeError,
r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\. "
r"The problematic type is: "):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_list_with_module_class(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
class MyTestModuleForListWithModuleClass(torch.nn.Module):
def __init__(self):
super(MyTestModuleForListWithModuleClass, self).__init__()
self.foo = Foo()
def forward(self):
my_list: List[Foo] = [self.foo]
return my_list
script_module = torch.jit.script(MyTestModuleForListWithModuleClass())
with self.assertRaisesRegex(RuntimeError,
r"^Returining a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$"):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_dict_with_module_class(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
class MyTestModuleForDictWithModuleClass(torch.nn.Module):
def __init__(self):
super(MyTestModuleForDictWithModuleClass, self).__init__()
self.foo = Foo()
def forward(self):
my_dict: Dict[int, Foo] = {1: self.foo}
return my_dict
script_module = torch.jit.script(MyTestModuleForDictWithModuleClass())
with self.assertRaisesRegex(RuntimeError,
r"^Returining a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$"):
script_module._save_to_buffer_for_lite_interpreter()
def test_module_export_operator_list(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.weight = torch.ones((20, 1, 5, 5))
self.bias = torch.ones(20)
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(
input,
self.weight,
self.bias,
[1, 1],
[0, 0],
[1, 1],
False,
[0, 0],
1,
False,
False,
True,
True,
)
return (x1, x2, x3)
m = torch.jit.script(Foo())
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
expected_ops = {
"aten::_convolution",
"aten::empty.memory_format",
"aten::empty_like",
"aten::zeros",
}
actual_ops = _export_operator_list(mobile_module)
self.assertEqual(actual_ops, expected_ops)
def test_source_range_simple(self):
class FooTest(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest)
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
loaded(torch.rand(3, 4), torch.rand(30, 40))
def test_source_range_raise_exception(self):
class FooTest2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self):
raise RuntimeError('foo')
_, lineno = inspect.getsourcelines(FooTest2)
# In C++ code, the type of exception thrown is torch::jit::JITException
# which does not extend c10::Error, and hence it isn't possible to add
# additional context to the exception message and preserve the correct
# C++ stack trace for symbolication. i.e. it isn't possible to add
# the debug handle string to show where in the Python code the exception
# occured w/o first changing
# torch::jit::JITException to extend c10::Error.
with self.assertRaisesRegex(torch.jit.Error, 'foo'):
ft = FooTest2()
loaded = self.getScriptExportImportCopy(ft)
loaded()
def test_source_range_function_call(self):
class FooTest3(torch.jit.ScriptModule):
@torch.jit.script_method
def add_method(self, x, w):
return x + w
@torch.jit.script_method
def forward(self, x, y, w):
x = x * y
x = x + 2
return self.add_method(x, w)
ft = FooTest3()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest3)
try:
loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
error_message = f"{e}"
self.assertTrue('test_lite_script_module.py\", line {}'.format(lineno + 3) in error_message)
self.assertTrue('test_lite_script_module.py\", line {}'.format(lineno + 9) in error_message)
self.assertTrue('top(FooTest3)' in error_message)
def test_source_range_no_debug_info(self):
class FooTest4(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest4()
loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False)
try:
loaded(torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
error_message = f"{e}"
self.assertTrue("test_lite_script_module.py" not in error_message)
def test_source_range_raise_exc(self):
class FooTest5(torch.jit.ScriptModule):
def __init__(self, val: int):
super(FooTest5, self).__init__()
self.val = val
@torch.jit.script_method
def add_method(self, val: int, x, w):
if (val == self.val):
raise RuntimeError('self.val and val are same')
return x + w
@torch.jit.script_method
def forward(self, val: int, x, y, w):
x = x * y
x = x + 2
return self.add_method(val, x, w)
ft = FooTest5(42)
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest5)
try:
loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
except torch.jit.Error as e:
error_message = f"{e}"
# In C++ code, the type of exception thrown is torch::jit::JITException
# which does not extend c10::Error, and hence it isn't possible to add
# additional context to the exception message and preserve the correct
# C++ stack trace for symbolication. i.e. it isn't possible to add
# the debug handle string to show where in the Python code the exception
# occured w/o first changing
# torch::jit::JITException to extend c10::Error.
self.assertTrue('self.val and val are same' in error_message)
def test_stacktrace_interface_call(self):
@torch.jit.interface
class Forward(torch.nn.Module):
def forward(self, x) -> torch.Tensor:
pass
def forwardError(self, x) -> torch.Tensor:
pass
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def forwardError(self, x):
return self.call() + x
def call(self):
return torch.ones(-1)
class A(torch.nn.Module):
b : Forward
def __init__(self):
super().__init__()
self.b = B()
def forward(self):
self.b.forward(torch.ones(1))
self.b.forwardError(torch.ones(1))
a = torch.jit.script(A())
torch._C._enable_mobile_interface_call_export()
buffer = io.BytesIO(a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
try:
mobile_module()
self.assertTrue(False)
except RuntimeError as exp:
FileCheck().check("Trying to create tensor with negative dimension") \
.check("Traceback of TorchScript") \
.check("self.b.forwardError").check_next("~~~~~~~~~~~~~~~~~~~ <--- HERE") \
.check("return self.call").check_next("~~~~~~~~~ <--- HERE") \
.check("return torch.ones").check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
def test_single_layer(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(model_class=AnnotatedSingleLayerLinearModel, qengine="qnnpack")
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_two_layer(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(model_class=TwoLayerLinearModel)
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_annotated_nested(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(model_class=AnnotatedNestedModel, qengine="qnnpack")
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_quantization_example(self):
# From the example in Static Quantization section of https://pytorch.org/docs/stable/quantization.html
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
x = self.dequant(x)
return x
model_fp32 = M()
model_fp32.eval()
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
input = torch.randn(4, 1, 4, 4)
self._compare_script_and_mobile(model=model_int8, input=input)
def test_bundled_input_with_dynamic_type(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(
self,
x: Dict[int, torch.Tensor],
y: Dict[int, torch.Tensor],
z: Dict[int, torch.Tensor],
):
return x
model = Model()
script_module = torch.jit.script(model)
sample_input = {
script_module.forward: [
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
)
]
}
bundled_model = torch.utils.bundled_inputs.bundle_inputs(
script_module, sample_input
)
buf = bundled_model._save_to_buffer_for_lite_interpreter()
mobile_module = _load_for_lite_interpreter(io.BytesIO(buf))
i = mobile_module.run_method("get_all_bundled_inputs")
self.assertEqual(
i[0],
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
),
)
if __name__ == '__main__':
run_tests()