forked from cyrusbehr/tensorrt-cpp-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.h
174 lines (155 loc) · 7.83 KB
/
engine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#pragma once
#include <fstream>
#include <chrono>
#include <opencv2/opencv.hpp>
#include <opencv2/core/cuda.hpp>
#include <opencv2/cudawarping.hpp>
#include <opencv2/cudaarithm.hpp>
#include "NvInfer.h"
#include <cuda_runtime.h>
// Utility methods
namespace Util {
inline bool doesFileExist(const std::string& filepath) {
std::ifstream f(filepath.c_str());
return f.good();
}
inline void checkCudaErrorCode(cudaError_t code) {
if (code != 0) {
std::string errMsg = "CUDA operation failed with code: " + std::to_string(code) + "(" + cudaGetErrorName(code) + "), with message: " + cudaGetErrorString(code);
std::cout << errMsg << std::endl;
throw std::runtime_error(errMsg);
}
}
std::vector<std::string> getFilesInDirectory(const std::string& dirPath);
}
// Utility Timer
template <typename Clock = std::chrono::high_resolution_clock>
class Stopwatch
{
typename Clock::time_point start_point;
public:
Stopwatch() :start_point(Clock::now()){}
// Returns elapsed time
template <typename Rep = typename Clock::duration::rep, typename Units = typename Clock::duration>
Rep elapsedTime() const {
std::atomic_thread_fence(std::memory_order_relaxed);
auto counted_time = std::chrono::duration_cast<Units>(Clock::now() - start_point).count();
std::atomic_thread_fence(std::memory_order_relaxed);
return static_cast<Rep>(counted_time);
}
};
using preciseStopwatch = Stopwatch<>;
// Precision used for GPU inference
enum class Precision {
// Full precision floating point value
FP32,
// Half prevision floating point value
FP16,
// Int8 quantization.
// Has reduced dynamic range, may result in slight loss in accuracy.
// If INT8 is selected, must provide path to calibration dataset directory.
INT8,
};
// Options for the network
struct Options {
// Precision to use for GPU inference.
Precision precision = Precision::FP16;
// If INT8 precision is selected, must provide path to calibration dataset directory.
std::string calibrationDataDirectoryPath;
// The batch size to be used when computing calibration data for INT8 inference.
// Should be set to as large a batch number as your GPU will support.
int32_t calibrationBatchSize = 128;
// The batch size which should be optimized for.
int32_t optBatchSize = 1;
// Maximum allowable batch size
int32_t maxBatchSize = 16;
// GPU device index
int deviceIndex = 0;
};
// Class used for int8 calibration
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
Int8EntropyCalibrator2(int32_t batchSize, int32_t inputW, int32_t inputH, const std::string& calibDataDirPath, const std::string& calibTableName, const std::string& inputBlobName,
const std::array<float, 3>& subVals = {0.f, 0.f, 0.f},const std::array<float, 3>& divVals = {1.f, 1.f, 1.f}, bool normalize = true, bool readCache = true);
virtual ~Int8EntropyCalibrator2();
// Abstract base class methods which must be implemented
int32_t getBatchSize () const noexcept override;
bool getBatch (void *bindings[], char const *names[], int32_t nbBindings) noexcept override;
void const * readCalibrationCache (std::size_t &length) noexcept override;
void writeCalibrationCache (void const *ptr, std::size_t length) noexcept override;
private:
const int32_t m_batchSize;
const int32_t m_inputW;
const int32_t m_inputH;
int32_t m_imgIdx;
std::vector<std::string> m_imgPaths;
size_t m_inputCount;
const std::string m_calibTableName;
const std::string m_inputBlobName;
const std::array<float, 3> m_subVals;
const std::array<float, 3> m_divVals;
const bool m_normalize;
const bool m_readCache;
void* m_deviceInput;
std::vector<char> m_calibCache;
};
// Class to extend TensorRT logger
class Logger : public nvinfer1::ILogger {
void log (Severity severity, const char* msg) noexcept override;
};
class Engine {
public:
Engine(const Options& options);
~Engine();
// Build the network
// The default implementation will normalize values between [0.f, 1.f]
// Setting the normalize flag to false will leave values between [0.f, 255.f] (some converted models may require this).
// If the model requires values to be normalized between [-1.f, 1.f], use the following params:
// subVals = {0.5f, 0.5f, 0.5f};
// divVals = {0.5f, 0.5f, 0.5f};
// normalize = true;
bool build(std::string onnxModelPath, const std::array<float, 3>& subVals = {0.f, 0.f, 0.f}, const std::array<float, 3>& divVals = {1.f, 1.f, 1.f},
bool normalize = true);
// Load and prepare the network for inference
bool loadNetwork();
// Run inference.
// Input format [input][batch][cv::cuda::GpuMat]
// Output format [batch][output][feature_vector]
bool runInference(const std::vector<std::vector<cv::cuda::GpuMat>>& inputs, std::vector<std::vector<std::vector<float>>>& featureVectors);
// Utility method for resizing an image while maintaining the aspect ratio by adding padding to smaller dimension after scaling
// While letterbox padding normally adds padding to top & bottom, or left & right sides, this implementation only adds padding to the right or bottom side
// This is done so that it's easier to convert detected coordinates (ex. YOLO model) back to the original reference frame.
static cv::cuda::GpuMat resizeKeepAspectRatioPadRightBottom(const cv::cuda::GpuMat& input, size_t height, size_t width, const cv::Scalar& bgcolor = cv::Scalar(0, 0, 0));
[[nodiscard]] const std::vector<nvinfer1::Dims3>& getInputDims() const { return m_inputDims; };
[[nodiscard]] const std::vector<nvinfer1::Dims>& getOutputDims() const { return m_outputDims ;};
// Utility method for transforming triple nested output array into 2D array
// Should be used when the output batch size is 1, but there are multiple output feature vectors
static void transformOutput(std::vector<std::vector<std::vector<float>>>& input, std::vector<std::vector<float>>& output);
// Utility method for transforming triple nested output array into single array
// Should be used when the output batch size is 1, and there is only a single output feature vector
static void transformOutput(std::vector<std::vector<std::vector<float>>>& input, std::vector<float>& output);
// Convert NHWC to NCHW and apply scaling and mean subtraction
static cv::cuda::GpuMat blobFromGpuMats(const std::vector<cv::cuda::GpuMat>& batchInput, const std::array<float, 3>& subVals, const std::array<float, 3>& divVals, bool normalize);
private:
// Converts the engine options into a string
std::string serializeEngineOptions(const Options& options, const std::string& onnxModelPath);
void getDeviceNames(std::vector<std::string>& deviceNames);
// Normalization, scaling, and mean subtraction of inputs
std::array<float, 3> m_subVals{};
std::array<float, 3> m_divVals{};
bool m_normalize;
// Holds pointers to the input and output GPU buffers
std::vector<void*> m_buffers;
std::vector<uint32_t> m_outputLengthsFloat{};
std::vector<nvinfer1::Dims3> m_inputDims;
std::vector<nvinfer1::Dims> m_outputDims;
std::vector<std::string> m_IOTensorNames;
// Must keep IRuntime around for inference, see: https://forums.developer.nvidia.com/t/is-it-safe-to-deallocate-nvinfer1-iruntime-after-creating-an-nvinfer1-icudaengine-but-before-running-inference-with-said-icudaengine/255381/2?u=cyruspk4w6
std::unique_ptr<nvinfer1::IRuntime> m_runtime = nullptr;
std::unique_ptr<Int8EntropyCalibrator2> m_calibrator = nullptr;
std::unique_ptr<nvinfer1::ICudaEngine> m_engine = nullptr;
std::unique_ptr<nvinfer1::IExecutionContext> m_context = nullptr;
const Options m_options;
Logger m_logger;
std::string m_engineName;
};