forked from NVIDIA/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNvUffParser.h
242 lines (210 loc) · 6.94 KB
/
NvUffParser.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef NV_UFF_PARSER_H
#define NV_UFF_PARSER_H
#include "NvInfer.h"
//Current supported Universal Framework Format (UFF) version for the parser.
#define UFF_REQUIRED_VERSION_MAJOR 0
#define UFF_REQUIRED_VERSION_MINOR 6
#define UFF_REQUIRED_VERSION_PATCH 3
namespace nvuffparser
{
//!
//! \enum UffInputOrder
//! \brief The different possible supported input order.
//!
enum class UffInputOrder : int
{
kNCHW = 0, //!< NCHW order.
kNHWC = 1, //!< NHWC order.
kNC = 2 //!< NC order.
};
//!
//! \enum FieldType
//! \brief The possible field types for custom layer.
//!
enum class FieldType : int
{
kFLOAT = 0, //!< FP32 field type.
kINT32 = 1, //!< INT32 field type.
kCHAR = 2, //!< char field type. String for length>1.
kDIMS = 4, //!< nvinfer1::Dims field type.
kDATATYPE = 5, //!< nvinfer1::DataType field type.
kUNKNOWN = 6
};
//!
//! \class FieldMap
//!
//! \brief An array of field params used as a layer parameter for plugin layers.
//!
//! The node fields are passed by the parser to the API through the plugin
//! constructor. The implementation of the plugin should parse the contents of
//! the fieldMap as part of the plugin constructor
//!
class TENSORRTAPI FieldMap
{
public:
const char* name;
const void* data;
FieldType type = FieldType::kUNKNOWN;
int length = 1;
FieldMap(const char* name, const void* data, const FieldType type, int length = 1);
};
struct FieldCollection
{
int nbFields;
const FieldMap* fields;
};
//!
//! \class IPluginFactory
//!
//! \brief Plugin factory used to configure plugins.
//!
class IPluginFactory
{
public:
//!
//! \brief A user implemented function that determines if a layer configuration is provided by an IPlugin.
//!
//! \param layerName Name of the layer which the user wishes to validate.
//!
virtual bool isPlugin(const char* layerName) = 0;
//!
//! \brief Creates a plugin.
//!
//! \param layerName Name of layer associated with the plugin.
//! \param weights Weights used for the layer.
//! \param nbWeights Number of weights.
//! \param fc A collection of FieldMaps used as layer parameters for different plugin layers.
//!
//! \see FieldCollection
//!
virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const nvinfer1::Weights* weights, int nbWeights,
const FieldCollection fc) = 0;
};
//!
//! \class IPluginFactoryExt
//!
//! \brief Plugin factory used to configure plugins with added support for TRT versioning.
//!
class IPluginFactoryExt : public IPluginFactory
{
public:
virtual int getVersion() const
{
return NV_TENSORRT_VERSION;
}
//!
//! \brief A user implemented function that determines if a layer configuration is provided by an IPluginExt.
//!
//! \param layerName Name of the layer which the user wishes to validate.
//!
virtual bool isPluginExt(const char* layerName) = 0;
};
//!
//! \class IUffParser
//!
//! \brief Class used for parsing models described using the UFF format.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
class IUffParser
{
public:
//!
//! \brief Register an input name of a UFF network with the associated Dimensions.
//!
//! \param inputName Input name.
//! \param inputDims Input dimensions.
//! \param inputOrder Input order on which the framework input was originally.
//!
virtual bool registerInput(const char* inputName, nvinfer1::Dims inputDims, UffInputOrder inputOrder) = 0;
//!
//! \brief Register an output name of a UFF network.
//!
//! \param outputName Output name.
//!
virtual bool registerOutput(const char* outputName) = 0;
//!
//! \brief Parse a UFF file.
//!
//! \param file File name of the UFF file.
//! \param network Network in which the UFFParser will fill the layers.
//! \param weightsType The type on which the weights will transformed in.
//!
virtual bool parse(const char* file,
nvinfer1::INetworkDefinition& network,
nvinfer1::DataType weightsType=nvinfer1::DataType::kFLOAT) = 0;
//!
//! \brief Parse a UFF buffer, useful if the file already live in memory.
//!
//! \param buffer Buffer of the UFF file.
//! \param size Size of buffer of the UFF file.
//! \param network Network in which the UFFParser will fill the layers.
//! \param weightsType The type on which the weights will transformed in.
//!
virtual bool parseBuffer(const char* buffer, std::size_t size,
nvinfer1::INetworkDefinition& network,
nvinfer1::DataType weightsType=nvinfer1::DataType::kFLOAT) = 0;
virtual void destroy() = 0;
//!
//! \brief Return Version Major of the UFF.
//!
virtual int getUffRequiredVersionMajor() = 0;
//!
//! \brief Return Version Minor of the UFF.
//!
virtual int getUffRequiredVersionMinor() = 0;
//!
//! \brief Return Patch Version of the UFF.
//!
virtual int getUffRequiredVersionPatch() = 0;
//!
//! \brief Set the IPluginFactory used to create the user defined plugins.
//!
//! \param factory Pointer to an instance of the user implmentation of IPluginFactory.
//!
virtual void setPluginFactory(IPluginFactory* factory) = 0;
//!
//! \brief Set the IPluginFactoryExt used to create the user defined pluginExts.
//!
//! \param factory Pointer to an instance of the user implmentation of IPluginFactoryExt.
//!
virtual void setPluginFactoryExt(IPluginFactoryExt* factory) = 0;
//!
//! \brief Set the namespace used to lookup and create plugins in the network.
//!
virtual void setPluginNamespace(const char* libNamespace) = 0;
protected:
virtual ~IUffParser() {}
};
//!
//! \brief Creates a IUffParser object.
//!
//! \return A pointer to the IUffParser object is returned.
//!
//! \see nvuffparser::IUffParser
//!
TENSORRTAPI IUffParser* createUffParser();
//!
//! \brief Shuts down protocol buffers library.
//!
//! \note No part of the protocol buffers library can be used after this function is called.
//!
TENSORRTAPI void shutdownProtobufLibrary(void);
}
extern "C" TENSORRTAPI void* createNvUffParser_INTERNAL();
#endif /* !NV_UFF_PARSER_H */