-
Notifications
You must be signed in to change notification settings - Fork 30
/
seralizeEngineFromPythonAPI.py
138 lines (120 loc) · 6.38 KB
/
seralizeEngineFromPythonAPI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python2
from __future__ import print_function
from __future__ import print_function
import numpy as np
import ctypes
ctypes.cdll.LoadLibrary('/opt/TensorRT-6.0.1.5/lib/libnvinfer_plugin.so')
import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
from PIL import ImageDraw
from yolov3_to_onnx import download_file
from data_processing import PreprocessYOLO, PostprocessYOLO, ALL_CATEGORIES
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import common
TRT_LOGGER = trt.Logger()
logger = trt.Logger(trt.Logger.INFO)
ctypes.cdll.LoadLibrary('./libyoloPlugin.so')
def get_plugin_creator(plugin_name):
trt.init_libnvinfer_plugins(logger, '')
plugin_creator_list = trt.get_plugin_registry().plugin_creator_list
plugin_creator = None
for c in plugin_creator_list:
if c.name == plugin_name:
plugin_creator = c
return plugin_creator
def get_engine(onnx_file_path, engine_file_path=""):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
def build_engine():
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 28 # 256MiB
builder.max_batch_size = 1
# Parse model file
if not os.path.exists(onnx_file_path):
print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
exit(0)
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(engine.serialize())
return engine
def build_engiinewithyoloplugin():
plugin_creator = get_plugin_creator('Yolo_TRT')
if plugin_creator == None:
print('Plugin not found. Exiting')
exit()
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(1) as network, trt.OnnxParser(network,
TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 28 # 256MiB
builder.max_batch_size = 1
# Parse model file
if not os.path.exists(onnx_file_path):
print(
'ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
exit(0)
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('add the yolo plugin to original network')
tensor1 = network.get_output(0)
tensor2 = network.get_output(1)
tensor3 = network.get_output(2)
ytensor1 = network.add_plugin_v2(
[tensor1],
plugin_creator.create_plugin('Yolo_TRT', trt.PluginFieldCollection([
trt.PluginField("numclass", np.array(80, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("stride", np.array(32, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("gridesize", np.array(13, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("numanchors", np.array(3, dtype=np.int32), trt.PluginFieldType.INT32)
]))
).get_output(0)
ytensor2 = network.add_plugin_v2(
[tensor2],
plugin_creator.create_plugin('Yolo_TRT', trt.PluginFieldCollection([
trt.PluginField("numclass", np.array(80, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("stride", np.array(16, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("gridesize", np.array(26, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("numanchors", np.array(3, dtype=np.int32), trt.PluginFieldType.INT32)
]))
).get_output(0)
ytensor3 = network.add_plugin_v2(
[tensor3],
plugin_creator.create_plugin('Yolo_TRT', trt.PluginFieldCollection([
trt.PluginField("numclass", np.array(80, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("stride", np.array(8, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("gridesize", np.array(52, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("numanchors", np.array(3, dtype=np.int32), trt.PluginFieldType.INT32)
]))
).get_output(0)
network.mark_output(ytensor1)
network.mark_output(ytensor2)
network.mark_output(ytensor3)
network.unmark_output(tensor1)
network.unmark_output(tensor2)
network.unmark_output(tensor3)
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(engine.serialize())
return engine
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_file_path))
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
else:
return build_engiinewithyoloplugin()
get_engine("yolov3.onnx", "yolov3.trt")