forked from lutzroeder/netron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackend.py
executable file
·59 lines (51 loc) · 2.44 KB
/
backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python
''' Expermiental Python Server backend test '''
import os
import sys
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)
sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'test', 'backend')
netron = __import__('source')
third_party_dir = os.path.join(root_dir, 'third_party')
test_data_dir = os.path.join(third_party_dir, 'test')
def _test_onnx():
file = os.path.join(test_data_dir, 'onnx', 'candy.onnx')
onnx = __import__('onnx')
model = onnx.load(file)
netron.serve(None, model)
def _test_onnx_list():
folder = os.path.join(test_data_dir, 'onnx')
for item in os.listdir(folder):
file = os.path.join(folder, item)
if file.endswith('.onnx') and \
item != 'super_resolution.onnx' and \
item != 'arcface-resnet100.onnx':
print(item)
onnx = __import__('onnx')
model = onnx.load(file)
address = netron.serve(file, model, verbosity='quiet')
netron.stop(address)
def _test_torchscript():
torch = __import__('torch')
torchvision = __import__('torchvision')
# model = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
# model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT)
model = torchvision.models.resnet34()
state_dict = torch.load(os.path.join(test_data_dir, 'pytorch', 'resnet34-333f7ec4.pth'))
model.load_state_dict(state_dict)
args = torch.zeros([1, 3, 224, 224])
trace = torch.jit.trace(model, args, strict=True)
# graph, _ = torch.jit._get_trace_graph(model, args) # pylint: disable=protected-access
# torch.onnx._optimize_trace(graph, torch.onnx.OperatorExportTypes.ONNX)
# trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'fasterrcnn_resnet50_fpn.pt'))
# trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'mobilenetv2-quant_full-nnapi.pt'))
torch.backends.quantized.engine = 'qnnpack'
trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'd2go.pt'))
# trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'inception_v3_traced.pt'))
# trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'netron_issue_920.pt'))
torch._C._jit_pass_inline(trace.graph)
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h
netron.serve('resnet34', trace)
# _test_onnx()
_test_torchscript()
# _test_onnx_list()