forked from WongKinYiu/yolov7
-
Notifications
You must be signed in to change notification settings - Fork 0
/
add_nms.py
155 lines (139 loc) · 5.48 KB
/
add_nms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy as np
import onnx
from onnx import shape_inference
try:
import onnx_graphsurgeon as gs
except Exception as e:
print('Import onnx_graphsurgeon failure: %s' % e)
import logging
LOGGER = logging.getLogger(__name__)
class RegisterNMS(object):
def __init__(
self,
onnx_model_path: str,
precision: str = "fp32",
):
self.graph = gs.import_onnx(onnx.load(onnx_model_path))
assert self.graph
LOGGER.info("ONNX graph created successfully")
# Fold constants via ONNX-GS that PyTorch2ONNX may have missed
self.graph.fold_constants()
self.precision = precision
self.batch_size = 1
def infer(self):
"""
Sanitize the graph by cleaning any unconnected nodes, do a topological resort,
and fold constant inputs values. When possible, run shape inference on the
ONNX graph to determine tensor shapes.
"""
for _ in range(3):
count_before = len(self.graph.nodes)
self.graph.cleanup().toposort()
try:
for node in self.graph.nodes:
for o in node.outputs:
o.shape = None
model = gs.export_onnx(self.graph)
model = shape_inference.infer_shapes(model)
self.graph = gs.import_onnx(model)
except Exception as e:
LOGGER.info(f"Shape inference could not be performed at this time:\n{e}")
try:
self.graph.fold_constants(fold_shapes=True)
except TypeError as e:
LOGGER.error(
"This version of ONNX GraphSurgeon does not support folding shapes, "
f"please upgrade your onnx_graphsurgeon module. Error:\n{e}"
)
raise
count_after = len(self.graph.nodes)
if count_before == count_after:
# No new folding occurred in this iteration, so we can stop for now.
break
def save(self, output_path):
"""
Save the ONNX model to the given location.
Args:
output_path: Path pointing to the location where to write
out the updated ONNX model.
"""
self.graph.cleanup().toposort()
model = gs.export_onnx(self.graph)
onnx.save(model, output_path)
LOGGER.info(f"Saved ONNX model to {output_path}")
def register_nms(
self,
*,
score_thresh: float = 0.25,
nms_thresh: float = 0.45,
detections_per_img: int = 100,
):
"""
Register the ``EfficientNMS_TRT`` plugin node.
NMS expects these shapes for its input tensors:
- box_net: [batch_size, number_boxes, 4]
- class_net: [batch_size, number_boxes, number_labels]
Args:
score_thresh (float): The scalar threshold for score (low scoring boxes are removed).
nms_thresh (float): The scalar threshold for IOU (new boxes that have high IOU
overlap with previously selected boxes are removed).
detections_per_img (int): Number of best detections to keep after NMS.
"""
self.infer()
# Find the concat node at the end of the network
op_inputs = self.graph.outputs
op = "EfficientNMS_TRT"
attrs = {
"plugin_version": "1",
"background_class": -1, # no background class
"max_output_boxes": detections_per_img,
"score_threshold": score_thresh,
"iou_threshold": nms_thresh,
"score_activation": False,
"box_coding": 0,
}
if self.precision == "fp32":
dtype_output = np.float32
elif self.precision == "fp16":
dtype_output = np.float16
else:
raise NotImplementedError(f"Currently not supports precision: {self.precision}")
# NMS Outputs
output_num_detections = gs.Variable(
name="num_dets",
dtype=np.int32,
shape=[self.batch_size, 1],
) # A scalar indicating the number of valid detections per batch image.
output_boxes = gs.Variable(
name="det_boxes",
dtype=dtype_output,
shape=[self.batch_size, detections_per_img, 4],
)
output_scores = gs.Variable(
name="det_scores",
dtype=dtype_output,
shape=[self.batch_size, detections_per_img],
)
output_labels = gs.Variable(
name="det_classes",
dtype=np.int32,
shape=[self.batch_size, detections_per_img],
)
op_outputs = [output_num_detections, output_boxes, output_scores, output_labels]
# Create the NMS Plugin node with the selected inputs. The outputs of the node will also
# become the final outputs of the graph.
self.graph.layer(op=op, name="batched_nms", inputs=op_inputs, outputs=op_outputs, attrs=attrs)
LOGGER.info(f"Created NMS plugin '{op}' with attributes: {attrs}")
self.graph.outputs = op_outputs
self.infer()
def save(self, output_path):
"""
Save the ONNX model to the given location.
Args:
output_path: Path pointing to the location where to write
out the updated ONNX model.
"""
self.graph.cleanup().toposort()
model = gs.export_onnx(self.graph)
onnx.save(model, output_path)
LOGGER.info(f"Saved ONNX model to {output_path}")