Skip to content

Commit

Permalink
add fp16 capability to batchedNMSPlugins
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Bridger <[email protected]>
  • Loading branch information
pbridger authored and rajeevsrao committed Jan 14, 2021
1 parent af8f24c commit 42dbbb0
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 44 deletions.
22 changes: 12 additions & 10 deletions plugin/batchedNMSPlugin/batchedNMSInference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
*/
const int numLocClasses = shareLocation ? 1 : numClasses;

size_t bboxDataSize = detectionForwardBBoxDataSize(N, perBatchBoxesSize, DataType::kFLOAT);
size_t bboxDataSize = detectionForwardBBoxDataSize(N, perBatchBoxesSize, DT_BBOX);
void* bboxDataRaw = workspace;
cudaMemcpyAsync(bboxDataRaw, locData, bboxDataSize, cudaMemcpyDeviceToDevice, stream);
pluginStatus_t status;
Expand All @@ -47,7 +47,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
*/
// float for now
void* bboxData;
size_t bboxPermuteSize = detectionForwardBBoxPermuteSize(shareLocation, N, perBatchBoxesSize, DataType::kFLOAT);
size_t bboxPermuteSize = detectionForwardBBoxPermuteSize(shareLocation, N, perBatchBoxesSize, DT_BBOX);
void* bboxPermute = nextWorkspacePtr((int8_t*) bboxDataRaw, bboxDataSize);

/*
Expand All @@ -58,7 +58,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
if (!shareLocation)
{
status = permuteData(
stream, locCount, numLocClasses, numPredsPerClass, 4, DataType::kFLOAT, false, bboxDataRaw, bboxPermute);
stream, locCount, numLocClasses, numPredsPerClass, 4, DT_BBOX, false, bboxDataRaw, bboxPermute);
ASSERT_FAILURE(status == STATUS_SUCCESS);
bboxData = bboxPermute;
}
Expand All @@ -77,6 +77,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
*/
const int numScores = N * perBatchScoresSize;
size_t totalScoresSize = detectionForwardPreNMSSize(N, perBatchScoresSize);
if(DT_BBOX == DataType::kHALF) totalScoresSize /= 2; // detectionForwardPreNMSSize is implemented in terms of kFLOAT
void* scores = nextWorkspacePtr((int8_t*) bboxPermute, bboxPermuteSize);

// need a conf_scores
Expand All @@ -85,21 +86,22 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
* [batch_size, numClasses, numPredsPerClass, 1]
*/
status = permuteData(
stream, numScores, numClasses, numPredsPerClass, 1, DataType::kFLOAT, confSigmoid, confData, scores);
stream, numScores, numClasses, numPredsPerClass, 1, DT_BBOX, confSigmoid, confData, scores);
ASSERT_FAILURE(status == STATUS_SUCCESS);

size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize);
void* indices = nextWorkspacePtr((int8_t*) scores, totalScoresSize);

size_t postNMSScoresSize = detectionForwardPostNMSSize(N, numClasses, topK);
size_t postNMSIndicesSize = detectionForwardPostNMSSize(N, numClasses, topK);
if(DT_BBOX == DataType::kHALF) postNMSScoresSize /= 2; // detectionForwardPostNMSSize is implemented in terms of kFLOAT
size_t postNMSIndicesSize = detectionForwardPostNMSSize(N, numClasses, topK); // indices are full int32
void* postNMSScores = nextWorkspacePtr((int8_t*) indices, indicesSize);
void* postNMSIndices = nextWorkspacePtr((int8_t*) postNMSScores, postNMSScoresSize);

void* sortingWorkspace = nextWorkspacePtr((int8_t*) postNMSIndices, postNMSIndicesSize);
// Sort the scores so that the following NMS could be applied.
status = sortScoresPerClass(stream, N, numClasses, numPredsPerClass, backgroundLabelId, scoreThreshold,
DataType::kFLOAT, scores, indices, sortingWorkspace);
DT_SCORE, scores, indices, sortingWorkspace);

ASSERT_FAILURE(status == STATUS_SUCCESS);

Expand All @@ -108,18 +110,18 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
bool flipXY = true;
// NMS
status = allClassNMS(stream, N, numClasses, numPredsPerClass, topK, iouThreshold, shareLocation, isNormalized,
DataType::kFLOAT, DataType::kFLOAT, bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY);
DT_SCORE, DT_BBOX, bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY);
ASSERT_FAILURE(status == STATUS_SUCCESS);

// Sort the bounding boxes after NMS using scores
status = sortScoresPerImage(stream, N, numClasses * topK, DataType::kFLOAT, postNMSScores, postNMSIndices, scores,
status = sortScoresPerImage(stream, N, numClasses * topK, DT_SCORE, postNMSScores, postNMSIndices, scores,
indices, sortingWorkspace);

ASSERT_FAILURE(status == STATUS_SUCCESS);

// Gather data from the sorted bounding boxes after NMS
status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topK, keepTopK, DataType::kFLOAT,
DataType::kFLOAT, indices, scores, bboxData, keepCount, nmsedBoxes, nmsedScores, nmsedClasses, clipBoxes);
status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topK, keepTopK, DT_BBOX,
DT_SCORE, indices, scores, bboxData, keepCount, nmsedBoxes, nmsedScores, nmsedClasses, clipBoxes);
ASSERT_FAILURE(status == STATUS_SUCCESS);

return STATUS_SUCCESS;
Expand Down
29 changes: 19 additions & 10 deletions plugin/batchedNMSPlugin/batchedNMSPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ BatchedNMSPlugin::BatchedNMSPlugin(const void* data, size_t length)
scoresSize = read<int>(d);
numPriors = read<int>(d);
mClipBoxes = read<bool>(d);
mPrecision = read<DataType>(d);
ASSERT(d == a + length);
}

Expand All @@ -67,6 +68,7 @@ BatchedNMSDynamicPlugin::BatchedNMSDynamicPlugin(const void* data, size_t length
scoresSize = read<int>(d);
numPriors = read<int>(d);
mClipBoxes = read<bool>(d);
mPrecision = read<DataType>(d);
ASSERT(d == a + length);
}

Expand Down Expand Up @@ -195,14 +197,14 @@ DimsExprs BatchedNMSDynamicPlugin::getOutputDimensions(
size_t BatchedNMSPlugin::getWorkspaceSize(int maxBatchSize) const
{
return detectionInferenceWorkspaceSize(param.shareLocation, maxBatchSize, boxesSize, scoresSize, param.numClasses,
numPriors, param.topK, DataType::kFLOAT, DataType::kFLOAT);
numPriors, param.topK, mPrecision, mPrecision);
}

size_t BatchedNMSDynamicPlugin::getWorkspaceSize(
const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const
{
return detectionInferenceWorkspaceSize(param.shareLocation, inputs[0].dims.d[0], boxesSize, scoresSize,
param.numClasses, numPriors, param.topK, DataType::kFLOAT, DataType::kFLOAT);
param.numClasses, numPriors, param.topK, mPrecision, mPrecision);
}

int BatchedNMSPlugin::enqueue(
Expand All @@ -218,7 +220,7 @@ int BatchedNMSPlugin::enqueue(

pluginStatus_t status = nmsInference(stream, batchSize, boxesSize, scoresSize, param.shareLocation,
param.backgroundLabelId, numPriors, param.numClasses, param.topK, param.keepTopK, param.scoreThreshold,
param.iouThreshold, DataType::kFLOAT, locData, DataType::kFLOAT, confData, keepCount, nmsedBoxes, nmsedScores,
param.iouThreshold, mPrecision, locData, mPrecision, confData, keepCount, nmsedBoxes, nmsedScores,
nmsedClasses, workspace, param.isNormalized, false, mClipBoxes);
ASSERT(status == STATUS_SUCCESS);
return 0;
Expand All @@ -237,7 +239,7 @@ int BatchedNMSDynamicPlugin::enqueue(const PluginTensorDesc* inputDesc, const Pl

pluginStatus_t status = nmsInference(stream, inputDesc[0].dims.d[0], boxesSize, scoresSize, param.shareLocation,
param.backgroundLabelId, numPriors, param.numClasses, param.topK, param.keepTopK, param.scoreThreshold,
param.iouThreshold, DataType::kFLOAT, locData, DataType::kFLOAT, confData, keepCount, nmsedBoxes, nmsedScores,
param.iouThreshold, mPrecision, locData, mPrecision, confData, keepCount, nmsedBoxes, nmsedScores,
nmsedClasses, workspace, param.isNormalized, false, mClipBoxes);
ASSERT(status == STATUS_SUCCESS);
return 0;
Expand All @@ -257,6 +259,7 @@ void BatchedNMSPlugin::serialize(void* buffer) const
write(d, scoresSize);
write(d, numPriors);
write(d, mClipBoxes);
write(d, mPrecision);
ASSERT(d == a + getSerializationSize());
}

Expand All @@ -274,6 +277,7 @@ void BatchedNMSDynamicPlugin::serialize(void* buffer) const
write(d, scoresSize);
write(d, numPriors);
write(d, mClipBoxes);
write(d, mPrecision);
ASSERT(d == a + getSerializationSize());
}

Expand Down Expand Up @@ -320,11 +324,13 @@ void BatchedNMSDynamicPlugin::configurePlugin(
scoresSize = in[1].desc.dims.d[1] * in[1].desc.dims.d[2];
// num_boxes
numPriors = in[0].desc.dims.d[1];

mPrecision = in[0].desc.type;
}

bool BatchedNMSPlugin::supportsFormat(DataType type, PluginFormat format) const
{
return ((type == DataType::kFLOAT || type == DataType::kINT32) && format == PluginFormat::kNCHW);
return ((type == DataType::kHALF || type == DataType::kFLOAT || type == DataType::kINT32) && format == PluginFormat::kNCHW);
}

bool BatchedNMSDynamicPlugin::supportsFormatCombination(
Expand All @@ -333,14 +339,15 @@ bool BatchedNMSDynamicPlugin::supportsFormatCombination(
ASSERT(0 <= pos && pos < 6);
const auto* in = inOut;
const auto* out = inOut + nbInputs;
const bool consistentFloatPrecision = in[0].type == in[pos].type;
switch (pos)
{
case 0: return in[0].type == DataType::kFLOAT && in[0].format == PluginFormat::kLINEAR;
case 1: return in[1].type == DataType::kFLOAT && in[1].format == PluginFormat::kLINEAR;
case 0: return (in[0].type == DataType::kHALF || in[0].type == DataType::kFLOAT) && in[0].format == PluginFormat::kLINEAR && consistentFloatPrecision;
case 1: return (in[1].type == DataType::kHALF || in[1].type == DataType::kFLOAT) && in[1].format == PluginFormat::kLINEAR && consistentFloatPrecision;
case 2: return out[0].type == DataType::kINT32 && out[0].format == PluginFormat::kLINEAR;
case 3: return out[1].type == DataType::kFLOAT && out[1].format == PluginFormat::kLINEAR;
case 4: return out[2].type == DataType::kFLOAT && out[2].format == PluginFormat::kLINEAR;
case 5: return out[3].type == DataType::kFLOAT && out[3].format == PluginFormat::kLINEAR;
case 3: return (out[1].type == DataType::kHALF || out[1].type == DataType::kFLOAT) && out[1].format == PluginFormat::kLINEAR && consistentFloatPrecision;
case 4: return (out[2].type == DataType::kHALF || out[2].type == DataType::kFLOAT) && out[2].format == PluginFormat::kLINEAR && consistentFloatPrecision;
case 5: return (out[3].type == DataType::kHALF || out[3].type == DataType::kFLOAT) && out[3].format == PluginFormat::kLINEAR && consistentFloatPrecision;
}
return false;
}
Expand Down Expand Up @@ -383,6 +390,7 @@ IPluginV2Ext* BatchedNMSPlugin::clone() const
plugin->numPriors = numPriors;
plugin->setPluginNamespace(mNamespace.c_str());
plugin->setClipParam(mClipBoxes);
plugin->mPrecision = mPrecision;
return plugin;
}

Expand All @@ -394,6 +402,7 @@ IPluginV2DynamicExt* BatchedNMSDynamicPlugin::clone() const
plugin->numPriors = numPriors;
plugin->setPluginNamespace(mNamespace.c_str());
plugin->setClipParam(mClipBoxes);
plugin->mPrecision = mPrecision;
return plugin;
}

Expand Down
2 changes: 2 additions & 0 deletions plugin/batchedNMSPlugin/batchedNMSPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class BatchedNMSPlugin : public IPluginV2Ext
int numPriors{};
std::string mNamespace;
bool mClipBoxes{};
DataType mPrecision;
};

class BatchedNMSDynamicPlugin : public IPluginV2DynamicExt
Expand Down Expand Up @@ -114,6 +115,7 @@ class BatchedNMSDynamicPlugin : public IPluginV2DynamicExt
int numPriors{};
std::string mNamespace;
bool mClipBoxes{};
DataType mPrecision;
};

class BatchedNMSBasePluginCreator : public BaseCreator
Expand Down
39 changes: 29 additions & 10 deletions plugin/batchedNMSPlugin/gatherNMSOutputs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#include "kernel.h"
#include "plugin.h"
#include "cuda_fp16.h"
#include "gatherNMSOutputs.h"
#include <array>

Expand Down Expand Up @@ -64,23 +65,39 @@ __launch_bounds__(nthds_per_cta)
: index % (numClasses * numPredsPerClass)) + bboxOffset) * 4;
nmsedClasses[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label
nmsedScores[i] = score; // confidence score
const T_BBOX xMin = bboxData[bboxId];
const T_BBOX yMin = bboxData[bboxId + 1];
const T_BBOX xMax = bboxData[bboxId + 2];
const T_BBOX yMax = bboxData[bboxId + 3];
// clipped bbox xmin
nmsedBoxes[i * 4] = clipBoxes ? max(min(bboxData[bboxId],
T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId];
nmsedBoxes[i * 4] = clipBoxes ? saturate(xMin) : xMin;
// clipped bbox ymin
nmsedBoxes[i * 4 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1],
T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 1];
nmsedBoxes[i * 4 + 1] = clipBoxes ? saturate(yMin) : yMin;
// clipped bbox xmax
nmsedBoxes[i * 4 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2],
T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 2];
nmsedBoxes[i * 4 + 2] = clipBoxes ? saturate(xMax) : xMax;
// clipped bbox ymax
nmsedBoxes[i * 4 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3],
T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 3];
nmsedBoxes[i * 4 + 3] = clipBoxes ? saturate(yMax) : yMax;
atomicAdd(&numDetections[i / keepTopK], 1);
}
}
}

template <typename T_BBOX>
__device__ T_BBOX saturate(T_BBOX v)
{
return max(min(v, T_BBOX(1)), T_BBOX(0));
}

template <>
__device__ __half saturate(__half v)
{
#if __CUDA_ARCH__ >= 800
return __hmax(__hmin(v, __half(1)), __half(0));
#else
return max(min(v, float(1)), float(0));
#endif
}

template <typename T_BBOX, typename T_SCORE>
pluginStatus_t gatherNMSOutputs_gpu(
cudaStream_t stream,
Expand Down Expand Up @@ -158,8 +175,10 @@ struct nmsOutLaunchConfig

using nvinfer1::DataType;

static std::array<nmsOutLaunchConfig, 1> nmsOutLCOptions = {
nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, gatherNMSOutputs_gpu<float, float>)};
static std::array<nmsOutLaunchConfig, 2> nmsOutLCOptions = {
nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, gatherNMSOutputs_gpu<float, float>),
nmsOutLaunchConfig(DataType::kHALF, DataType::kHALF, gatherNMSOutputs_gpu<__half, __half>)
};

pluginStatus_t gatherNMSOutputs(
cudaStream_t stream,
Expand Down
Loading

0 comments on commit 42dbbb0

Please sign in to comment.