forked from leoxiaobin/deep-high-resolution-net.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.py
31 lines (26 loc) · 1.05 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt_runtime = trt.Runtime(TRT_LOGGER)
def build_engine(onnx_path, shape = [1,3,384,288]):
"""
This is the function to create the TensorRT engine
Args:
onnx_path : Path to onnx_file.
shape : Shape of the input of the ONNX file.
"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(1) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser:
config.max_workspace_size = (256 << 20)
with open(onnx_path, 'rb') as model:
parser.parse(model.read())
network.get_input(0).shape = shape
engine = builder.build_engine(network, config)
return engine
def save_engine(engine, file_name):
buf = engine.serialize()
with open(file_name, 'wb') as f:
f.write(buf)
def load_engine(trt_runtime, plan_path):
with open(plan_path, 'rb') as f:
engine_data = f.read()
engine = trt_runtime.deserialize_cuda_engine(engine_data)
return engine