Skip to content

Commit

Permalink
optimized FP16 NMS/BatchedNMS plugins.
Browse files Browse the repository at this point in the history
- enable n-bit radix sort for NMS_TRT plugin
- implement NMSDynamic_TRT plugin and enable n-bit radix sort for it
- enable n-bit radix sort for BatchedNMS_TRT and BatchedNMSDynamic_TRT plugins
- fixed a bug in configurePlugin() method of BatchedNMS_TRT
- other minor fixes for BatchedNMS_TRT/BatchedNMSDynamic_TRT plugins

Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
zhimengf authored and rajeevsrao committed Mar 18, 2021
1 parent 10a51be commit d5b878a
Show file tree
Hide file tree
Showing 17 changed files with 1,027 additions and 342 deletions.
1 change: 1 addition & 0 deletions plugin/InferPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ extern "C"
initializePlugin<nvinfer1::plugin::MultilevelCropAndResizePluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::MultilevelProposeROIPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::NMSPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::NMSDynamicPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::NormalizePluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::PriorBoxPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::ProposalLayerPluginCreator>(logger, libNamespace);
Expand Down
2 changes: 1 addition & 1 deletion plugin/batchedNMSPlugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ The `batchedNMSPlugin` is created using `BatchedNMSPluginCreator` with `NMSParam
|`float` |`iouThreshold` |The scalar threshold for IOU (new boxes that have high IOU overlap with previously selected boxes are removed).
|`bool` |`isNormalized` |Set to `false` if the box coordinates are not normalized, meaning they are not in the range `[0,1]`. Defaults to `true`.
|`bool` |`clipBoxes` |Forcibly restrict bounding boxes to the normalized range `[0,1]`. Only applicable if `isNormalized` is also `true`. Defaults to `true`.

|`int` |`scoreBits` |The number of bits to represent the score values during radix sort. The number of bits to represent score values(confidences) during radix sort. This valid range is 0 < scoreBits <= 10. The default value is 16(which means to use full bits in radix sort). Setting this parameter to any invalid value will result in the same effect as setting it to 16. This parameter can be tuned to strike for a best trade-off between performance and accuracy. Lowering scoreBits will improve performance but with some minor degradation to the accuracy. This parameter is only valid for FP16 data type for now.

## Algorithms

Expand Down
15 changes: 9 additions & 6 deletions plugin/batchedNMSPlugin/batchedNMSInference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
const bool shareLocation, const int backgroundLabelId, const int numPredsPerClass, const int numClasses,
const int topK, const int keepTopK, const float scoreThreshold, const float iouThreshold, const DataType DT_BBOX,
const void* locData, const DataType DT_SCORE, const void* confData, void* keepCount, void* nmsedBoxes,
void* nmsedScores, void* nmsedClasses, void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes)
void* nmsedScores, void* nmsedClasses, void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes, int scoreBits)
{
// locCount = batch_size * number_boxes_per_sample * 4
const int locCount = N * perBatchBoxesSize;
Expand Down Expand Up @@ -86,7 +86,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
* [batch_size, numClasses, numPredsPerClass, 1]
*/
status = permuteData(
stream, numScores, numClasses, numPredsPerClass, 1, DT_BBOX, confSigmoid, confData, scores);
stream, numScores, numClasses, numPredsPerClass, 1, DT_SCORE, confSigmoid, confData, scores);
ASSERT_FAILURE(status == STATUS_SUCCESS);

size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize);
Expand All @@ -100,8 +100,11 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch

void* sortingWorkspace = nextWorkspacePtr((int8_t*) postNMSIndices, postNMSIndicesSize);
// Sort the scores so that the following NMS could be applied.
float scoreShift = 0.f;
if(DT_SCORE == DataType::kHALF && scoreBits > 0 && scoreBits <= 10)
scoreShift = 1.f;
status = sortScoresPerClass(stream, N, numClasses, numPredsPerClass, backgroundLabelId, scoreThreshold,
DT_SCORE, scores, indices, sortingWorkspace);
DT_SCORE, scores, indices, sortingWorkspace, scoreBits, scoreShift);

ASSERT_FAILURE(status == STATUS_SUCCESS);

Expand All @@ -110,18 +113,18 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
bool flipXY = true;
// NMS
status = allClassNMS(stream, N, numClasses, numPredsPerClass, topK, iouThreshold, shareLocation, isNormalized,
DT_SCORE, DT_BBOX, bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY);
DT_SCORE, DT_BBOX, bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY, scoreShift);
ASSERT_FAILURE(status == STATUS_SUCCESS);

// Sort the bounding boxes after NMS using scores
status = sortScoresPerImage(stream, N, numClasses * topK, DT_SCORE, postNMSScores, postNMSIndices, scores,
indices, sortingWorkspace);
indices, sortingWorkspace, scoreBits);

ASSERT_FAILURE(status == STATUS_SUCCESS);

// Gather data from the sorted bounding boxes after NMS
status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topK, keepTopK, DT_BBOX,
DT_SCORE, indices, scores, bboxData, keepCount, nmsedBoxes, nmsedScores, nmsedClasses, clipBoxes);
DT_SCORE, indices, scores, bboxData, keepCount, nmsedBoxes, nmsedScores, nmsedClasses, clipBoxes, scoreShift);
ASSERT_FAILURE(status == STATUS_SUCCESS);

return STATUS_SUCCESS;
Expand Down
Loading

0 comments on commit d5b878a

Please sign in to comment.