Skip to content

Commit

Permalink
Add switch for batch agnostic mode in NMS plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
kevinch-nv authored and rajeevsrao committed Jun 23, 2021
1 parent 0953f2f commit a2b3d3d
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Algorithms optimization for NMS kernels and ROIAlign kernel
- Fix invalid cuda config issue when bs is larger than 32
- Fix issues found on Jetson NANO
- Add switch for batch-agnostic mode in NMS plugin

### Removed
- Removed fcplugin from demoBERT to improve latency
Expand Down
2 changes: 2 additions & 0 deletions include/NvInferPluginUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ enum class CodeTypeSSD : int32_t
//! \param inputOrder Specifies the order of inputs {loc_data, conf_data, priorbox_data}.
//! \param confSigmoid Set to true to calculate sigmoid of confidence scores.
//! \param isNormalized Set to true if bounding box data is normalized by the network.
//! \param isBatchAgnostic Defaults to true. Set to false if prior boxes are unique per batch
//!
struct DetectionOutputParameters
{
Expand All @@ -187,6 +188,7 @@ struct DetectionOutputParameters
int32_t inputOrder[3];
bool confSigmoid;
bool isNormalized;
bool isBatchAgnostic{true};
};

//!
Expand Down
19 changes: 12 additions & 7 deletions plugin/common/kernels/decodeBBoxes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ __launch_bounds__(nthds_per_cta)
const bool clip_bbox,
const T_BBOX* loc_data,
const T_BBOX* prior_data,
T_BBOX* bbox_data)
T_BBOX* bbox_data,
const bool batch_agnostic)
{
for (int index = blockIdx.x * nthds_per_cta + threadIdx.x;
index < nthreads;
Expand All @@ -113,7 +114,7 @@ __launch_bounds__(nthds_per_cta)
// do not assume each images' anchor boxes are identical
// e.g., in FasterRCNN, priors are ROIs from proposal layer and are different
// for each image.
const int pi = (batch * 2 * num_priors + d) * 4;
const int pi = batch_agnostic ? d * 4 : (batch * 2 * num_priors + d) * 4;
// Index to the right variances corresponding to the current bounding box
const int vi = pi + num_priors * 4;
// Encoding method: CodeTypeSSD::CORNER
Expand Down Expand Up @@ -296,15 +297,16 @@ pluginStatus_t decodeBBoxes_gpu(
const bool clip_bbox,
const void* loc_data,
const void* prior_data,
void* bbox_data)
void* bbox_data,
const bool batch_agnostic)
{
const int BS = 512;
const int GS = (nthreads + BS - 1) / BS;
decodeBBoxes_kernel<T_BBOX, BS><<<GS, BS, 0, stream>>>(nthreads, code_type, variance_encoded_in_target,
num_priors, share_location, num_loc_classes,
background_label_id, clip_bbox,
(const T_BBOX*) loc_data, (const T_BBOX*) prior_data,
(T_BBOX*) bbox_data);
(T_BBOX*) bbox_data, batch_agnostic);
CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
}
Expand All @@ -321,7 +323,8 @@ typedef pluginStatus_t (*dbbFunc)(cudaStream_t,
const bool,
const void*,
const void*,
void*);
void*,
const bool);

struct dbbLaunchConfig
{
Expand Down Expand Up @@ -361,7 +364,8 @@ pluginStatus_t decodeBBoxes(
const DataType DT_BBOX,
const void* loc_data,
const void* prior_data,
void* bbox_data)
void* bbox_data,
const bool batch_agnostic)
{
dbbLaunchConfig lc = dbbLaunchConfig(DT_BBOX);
for (unsigned i = 0; i < dbbLCOptions.size(); ++i)
Expand All @@ -380,7 +384,8 @@ pluginStatus_t decodeBBoxes(
clip_bbox,
loc_data,
prior_data,
bbox_data);
bbox_data,
batch_agnostic);
}
}
return STATUS_BAD_PARAM;
Expand Down
12 changes: 8 additions & 4 deletions plugin/common/kernels/detectionForward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ pluginStatus_t detectionInference(
void* workspace,
bool isNormalized,
bool confSigmoid,
int scoreBits)
int scoreBits,
const bool isBatchAgnostic)
{
// Batch size * number bbox per sample * 4 = total number of bounding boxes * 4
const int locCount = N * C1;
Expand Down Expand Up @@ -70,7 +71,8 @@ pluginStatus_t detectionInference(
DT_BBOX,
locData,
priorData,
bboxDataRaw);
bboxDataRaw,
isBatchAgnostic);

ASSERT_FAILURE(status == STATUS_SUCCESS);

Expand Down Expand Up @@ -246,7 +248,8 @@ namespace plugin
void* workspace,
bool isNormalized,
bool confSigmoid,
int scoreBits)
int scoreBits,
const bool isBatchAgnostic)
{
// Batch size * number bbox per sample * 4 = total number of bounding boxes * 4
const int locCount = N * C1;
Expand Down Expand Up @@ -275,7 +278,8 @@ namespace plugin
DT_BBOX,
locData,
priorData,
bboxDataRaw);
bboxDataRaw,
isBatchAgnostic);

ASSERT_FAILURE(status == STATUS_SUCCESS);

Expand Down
4 changes: 2 additions & 2 deletions plugin/common/kernels/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pluginStatus_t detectionInference(cudaStream_t stream, int N, int C1, int C2, bo
bool varianceEncodedInTarget, int backgroundLabelId, int numPredsPerClass, int numClasses, int topK, int keepTopK,
float confidenceThreshold, float nmsThreshold, CodeTypeSSD codeType, DataType DT_BBOX, const void* locData,
const void* priorData, DataType DT_SCORE, const void* confData, void* keepCount, void* topDetections,
void* workspace, bool isNormalized = true, bool confSigmoid = false, int scoreBits = 16);
void* workspace, bool isNormalized = true, bool confSigmoid = false, int scoreBits = 16, const bool isBatchAgnostic = true);

pluginStatus_t nmsInference(cudaStream_t stream, int N, int boxesSize, int scoresSize, bool shareLocation,
int backgroundLabelId, int numPredsPerClass, int numClasses, int topK, int keepTopK, float scoreThreshold,
Expand Down Expand Up @@ -84,7 +84,7 @@ size_t detectionForwardPostNMSSize(int N, int numClasses, int topK);

pluginStatus_t decodeBBoxes(cudaStream_t stream, int nthreads, CodeTypeSSD code_type, bool variance_encoded_in_target,
int num_priors, bool share_location, int num_loc_classes, int background_label_id, bool clip_bbox, DataType DT_BBOX,
const void* loc_data, const void* prior_data, void* bbox_data);
const void* loc_data, const void* prior_data, void* bbox_data, const bool batch_agnostic);

size_t normalizePluginWorkspaceSize(bool acrossSpatial, int C, int H, int W);

Expand Down
6 changes: 6 additions & 0 deletions plugin/nmsPlugin/nmsPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ NMSBasePluginCreator::NMSBasePluginCreator() noexcept
mPluginAttributes.emplace_back(PluginField("isNormalized", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("codeType", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("scoreBits", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("isBatchAgnostic", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
Expand Down Expand Up @@ -684,6 +685,11 @@ IPluginV2Ext* NMSPluginCreator::createPlugin(const char* name, const PluginField
ASSERT(fields[i].type == PluginFieldType::kINT32);
mScoreBits = *(static_cast<const int32_t*>(fields[i].data));
}
else if (!strcmp(attrName, "isBatchAgnostic"))
{
ASSERT(fields[i].type == PluginFieldType::kINT32);
params.isBatchAgnostic = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
}
}

DetectionOutput* obj = new DetectionOutput(params);
Expand Down

0 comments on commit a2b3d3d

Please sign in to comment.