Skip to content

Commit

Permalink
Mark user-supplised plugins as supported in ONNXRT-TRT
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv committed Jul 12, 2023
1 parent 6ba67d3 commit 0462dc3
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions ModelImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,17 @@ Status deserialize_onnx_model(int32_t fd, bool is_serialized_as_text, ::ONNX_NAM
return Status::success();
}

// Internal helper function used for ONNXRT-TRT EP to filter out DDS nodes
bool isDDSOp(char const* op_name)
{
auto is = [op_name](char const* name) { return std::strcmp(op_name, name) == 0; };
if (is("NonMaxSuppression") || is("NonZero") || is("RoiAlign"))
{
return true;
}
return false;
}

bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection, char const* model_path)
{
Expand Down Expand Up @@ -446,13 +457,13 @@ bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t seri
{
::ONNX_NAMESPACE::NodeProto const& node = model.graph().node(node_idx);
// Add the node to the subgraph if:
// 1. There is an importer function registered for the operator type
// 1. It is not a node that requires DDS
// 2. It is not directly connected to an unsupported input
// 3. The importer function did not throw an assertion
bool registered = supportsOperator(node.op_type().c_str());
bool unsupportedDDS = isDDSOp(node.op_type().c_str());
bool unsupportedInput = (input_node.empty()) ? false : checkForInput(node);
bool unsuccessfulParse = node_idx == error_node;
if (registered && !unsupportedInput && !unsuccessfulParse)
if (!unsupportedDDS && !unsupportedInput && !unsuccessfulParse)
{
if (newSubGraph)
{
Expand Down Expand Up @@ -481,22 +492,8 @@ bool ModelImporter::supportsModel(void const* serialized_onnx_model, size_t seri
return allSupported;
}

// This funciton is used by ONNXRT to partition out unsupported nodes
bool ModelImporter::supportsOperator(char const* op_name) const
{
auto is = [op_name](char const* name) { return std::strcmp(op_name, name) == 0; };

// Mark these following plugins as supported
if (is("EfficientNMS_TRT") || is("PyramidROIAlign_TRT") || is("MultilevelCropAndResize_TRT")
|| is("DisentangledAttention_TRT"))
{
return true;
}
// Disable nodes that rely on DDS as ONNXRuntime does not support it at the moment
if (is("NonMaxSuppression") || is("NonZero") || is("RoiAlign"))
{
return false;
}
return _op_importers.count(op_name);
}

Expand Down

0 comments on commit 0462dc3

Please sign in to comment.