Skip to content

Commit

Permalink
ONNX-TensorRT 10.1 GA release (#975)
Browse files Browse the repository at this point in the history
Signed-off-by: Akhil Goel <[email protected]>
  • Loading branch information
akhilg-nv authored Jun 17, 2024
1 parent 06adf44 commit 96e7811
Show file tree
Hide file tree
Showing 21 changed files with 933 additions and 398 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "third_party/onnx"]
path = third_party/onnx
url = https://github.com/onnx/onnx.git
branch = rel-1.16.0
branch = v1.16.0
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ add_definitions("-DSOURCE_LENGTH=${SOURCE_LENGTH}")
# Version information
#--------------------------------------------------
set(ONNX2TRT_MAJOR 10)
set(ONNX2TRT_MINOR 0)
set(ONNX2TRT_PATCH 1)
set(ONNX2TRT_MINOR 1)
set(ONNX2TRT_PATCH 0)
set(ONNX2TRT_VERSION "${ONNX2TRT_MAJOR}.${ONNX2TRT_MINOR}.${ONNX2TRT_PATCH}" CACHE STRING "ONNX2TRT version")

#--------------------------------------------------
Expand Down
476 changes: 297 additions & 179 deletions ModelImporter.cpp

Large diffs are not rendered by default.

85 changes: 63 additions & 22 deletions ModelImporter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include "ImporterContext.hpp"
#include "NvInferPlugin.h"
#include "NvOnnxParser.h"
#include "errorHelpers.hpp"
#include "onnxOpCheckers.hpp"
#include "onnxOpImporters.hpp"
#include <stdexcept>

namespace onnx2trt
{
Expand All @@ -24,32 +26,49 @@ Status parseGraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& grap

class ModelImporter : public nvonnxparser::IParser
{
using SubGraphSupport_t = std::pair<std::vector<int64_t>, bool>;
using SubGraphSupportVector_t = std::vector<SubGraphSupport_t>;

protected:
StringMap<NodeImporter> _op_importers;
virtual Status importModel(::ONNX_NAMESPACE::ModelProto const& model);
virtual Status importModel(::ONNX_NAMESPACE::ModelProto const& model) noexcept;

private:
ImporterContext mImporterCtx;
std::vector<std::string> mPluginLibraryList; // Array of strings containing plugin libs
std::vector<char const*>
mPluginLibraryListCStr; // Array of C-strings corresponding to the strings in mPluginLibraryList
std::list<::ONNX_NAMESPACE::ModelProto> mONNXModels; // Needed for ownership of weights
SubGraphSupportVector_t mSubGraphSupportVector;
int mCurrentNode;
std::vector<Status> mErrors;
nvonnxparser::OnnxParserFlags mOnnxParserFlags{1U << static_cast<uint32_t>(nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM)}; // kNATIVE_INSTANCENORM is ON by default.
mutable std::vector<Status> mErrors; // Marked as mutable so that errors could be reported from const functions
nvonnxparser::OnnxParserFlags mOnnxParserFlags{
1U << static_cast<uint32_t>(
nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM)}; // kNATIVE_INSTANCENORM is ON by default.
std::pair<bool, SubGraphSupportVector_t> doSupportsModel(
void const* serialized_onnx_model, size_t serialized_onnx_model_size, char const* model_path = nullptr);

public:
ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger) noexcept
: _op_importers(getBuiltinOpImporterMap())
, mImporterCtx(network, logger)
{
}
bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override;
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) override;
bool parseWithWeightDescriptors(
void const* serialized_onnx_model, size_t serialized_onnx_model_size) noexcept override;
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
const char* model_path = nullptr) noexcept override;

bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) override;
SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) noexcept override;
bool supportsModelV2(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
char const* model_path = nullptr) noexcept override;

int64_t getNbSubgraphs() noexcept override;
bool isSubgraphSupported(int64_t const index) noexcept override;
int64_t* getSubgraphNodes(int64_t const index, int64_t& subgraphLength) noexcept override;

bool supportsOperator(const char* op_name) const override;
bool supportsOperator(const char* op_name) const noexcept override;

void setFlags(nvonnxparser::OnnxParserFlags onnxParserFlags) noexcept override
{
Expand All @@ -62,44 +81,66 @@ class ModelImporter : public nvonnxparser::IParser

void clearFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) noexcept override
{
mOnnxParserFlags &= ~(1U << static_cast<uint32_t>(onnxParserFlag));
ONNXTRT_TRY
{
mOnnxParserFlags &= ~(1U << static_cast<uint32_t>(onnxParserFlag));
}
ONNXTRT_CATCH_RECORD
}

void setFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) noexcept override
{
mOnnxParserFlags |= 1U << static_cast<uint32_t>(onnxParserFlag);
ONNXTRT_TRY
{
mOnnxParserFlags |= 1U << static_cast<uint32_t>(onnxParserFlag);
}
ONNXTRT_CATCH_RECORD
}

bool getFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) const noexcept override
{
auto flag = 1U << static_cast<uint32_t>(onnxParserFlag);
return static_cast<bool>(mOnnxParserFlags & flag);
ONNXTRT_TRY
{
auto flag = 1U << static_cast<uint32_t>(onnxParserFlag);
return static_cast<bool>(mOnnxParserFlags & flag);
}
ONNXTRT_CATCH_RECORD
return false;
}

int32_t getNbErrors() const override
int32_t getNbErrors() const noexcept override
{
return mErrors.size();
}
nvonnxparser::IParserError const* getError(int32_t index) const override
nvonnxparser::IParserError const* getError(int32_t index) const noexcept override
{
assert(0 <= index && index < (int32_t) mErrors.size());
return &mErrors[index];
ONNXTRT_TRY
{
return &mErrors.at(index);
}
ONNXTRT_CATCH_RECORD
return nullptr;
}
void clearErrors() override
void clearErrors() noexcept override
{
mErrors.clear();
}

nvinfer1::ITensor const* getLayerOutputTensor(char const* name, int64_t i)
nvinfer1::ITensor const* getLayerOutputTensor(char const* name, int64_t i) noexcept override
{
if (!name)
ONNXTRT_TRY
{
return nullptr;
if (!name)
{
throw std::invalid_argument("name is a nullptr");
}
return mImporterCtx.findLayerOutputTensor(name, i);
}
return mImporterCtx.findLayerOutputTensor(name, i);
ONNXTRT_CATCH_RECORD
return nullptr;
}

bool parseFromFile(char const* onnxModelFile, int32_t verbosity) override;
bool parseFromFile(char const* onnxModelFile, int32_t verbosity) noexcept override;

virtual char const* const* getUsedVCPluginLibraries(int64_t& nbPluginLibs) const noexcept override;
};
Expand Down
74 changes: 42 additions & 32 deletions ModelRefitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Status deserializeOnnxModelFile(char const* onnxModelFile, ::ONNX_NAMESPACE::Mod
{
// Define S_ISREG macro for Windows
#if !defined(S_ISREG)
#define S_ISREG(mode) (((mode) &S_IFMT) == S_IFREG)
#define S_ISREG(mode) (((mode) & S_IFMT) == S_IFREG)
#endif

struct stat sb;
Expand Down Expand Up @@ -393,52 +393,62 @@ Status ModelRefitter::refitOnnxScanNode(::ONNX_NAMESPACE::NodeProto const& node)
bool ModelRefitter::refitFromBytes(
void const* serializedOnnxModel, size_t serializedOnnxModelSize, char const* modelPath) noexcept
{
if (modelPath)
ONNXTRT_TRY
{
// Keep track of the absolute path to the ONNX file.
mWeightsContext.setOnnxFileLocation(modelPath);
}
if (modelPath)
{
// Keep track of the absolute path to the ONNX file.
mWeightsContext.setOnnxFileLocation(modelPath);
}

Status status
= deserializeOnnxModel(serializedOnnxModel, serializedOnnxModelSize, &onnx_model);
if (status.is_error())
{
mErrors.push_back(status);
return false;
}
Status status = deserializeOnnxModel(serializedOnnxModel, serializedOnnxModelSize, &onnx_model);
if (status.is_error())
{
mErrors.push_back(status);
return false;
}

refittableWeights = getRefittableWeights();
status = refitOnnxWeights(onnx_model);
if (status.is_error())
{
mErrors.push_back(status);
return false;
refittableWeights = getRefittableWeights();
status = refitOnnxWeights(onnx_model);
if (status.is_error())
{
mErrors.push_back(status);
return false;
}
return true;
}
return true;
ONNXTRT_CATCH_LOG(mLogger)
return false;
}

bool ModelRefitter::refitFromFile(char const* onnxModelFile) noexcept
{
// Keep track of the absolute path to the ONNX file.
mWeightsContext.setOnnxFileLocation(onnxModelFile);

Status status = deserializeOnnxModelFile(onnxModelFile, onnx_model);
if (status.is_error())
ONNXTRT_TRY
{
mErrors.push_back(status);
return false;
}
// Keep track of the absolute path to the ONNX file.
mWeightsContext.setOnnxFileLocation(onnxModelFile);

refittableWeights = getRefittableWeights();
if (!refittableWeights.empty())
{
status = refitOnnxWeights(onnx_model);
Status status = deserializeOnnxModelFile(onnxModelFile, onnx_model);
if (status.is_error())
{
mErrors.push_back(status);
return false;
}

refittableWeights = getRefittableWeights();
if (!refittableWeights.empty())
{
status = refitOnnxWeights(onnx_model);
if (status.is_error())
{
mErrors.push_back(status);
return false;
}
}
return true;
}
return true;
ONNXTRT_CATCH_LOG(mLogger)

return false;
}
} // namespace onnx2trt
11 changes: 8 additions & 3 deletions ModelRefitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "NvInferRuntime.h"
#include "Status.hpp"
#include "WeightsContext.hpp"
#include "errorHelpers.hpp"
#include <onnx/onnx_pb.h>
#include <string>
#include <unordered_set>
Expand Down Expand Up @@ -51,7 +52,7 @@ class ModelRefitter : public nvonnxparser::IParserRefitter
std::unordered_set<std::string> refittableWeights;
std::unordered_set<std::string> refittedWeights;

std::vector<Status> mErrors;
mutable std::vector<Status> mErrors;

std::unordered_set<std::string> getRefittableWeights();

Expand Down Expand Up @@ -90,8 +91,12 @@ class ModelRefitter : public nvonnxparser::IParserRefitter

nvonnxparser::IParserError const* getError(int32_t index) const noexcept override
{
assert(0 <= index && index < (int32_t) mErrors.size());
return &mErrors[index];
ONNXTRT_TRY
{
return &mErrors.at(index);
}
ONNXTRT_CATCH_LOG(mLogger)
return nullptr;
}

void clearErrors() noexcept override
Expand Down
6 changes: 3 additions & 3 deletions NvOnnxParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
#include "ModelRefitter.hpp"
#include "NvInferRuntime.h"

extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version)
extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version) noexcept
{
auto network = static_cast<nvinfer1::INetworkDefinition*>(network_);
auto logger = static_cast<nvinfer1::ILogger*>(logger_);
return new onnx2trt::ModelImporter(network, logger);
}

extern "C" void* createNvOnnxParserRefitter_INTERNAL(void* refitter_, void* logger_, int32_t version)
extern "C" void* createNvOnnxParserRefitter_INTERNAL(void* refitter_, void* logger_, int32_t version) noexcept
{
auto refitter = static_cast<nvinfer1::IRefitter*>(refitter_);
auto logger = static_cast<nvinfer1::ILogger*>(logger_);
return new onnx2trt::ModelRefitter(refitter, logger);
}

extern "C" int getNvOnnxParserVersion()
extern "C" int getNvOnnxParserVersion() noexcept
{
return NV_ONNX_PARSER_VERSION;
}
Loading

0 comments on commit 96e7811

Please sign in to comment.