-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrt_dep.cpp
127 lines (103 loc) · 3.17 KB
/
trt_dep.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <iostream>
#include <string>
#include <fstream>
#include <vector>
#include <array>
#include <sstream>
#include <chrono>
#include "trt_dep.hpp"
using nvinfer1::IHostMemory;
using nvinfer1::IBuilder;
using nvinfer1::INetworkDefinition;
using nvinfer1::ICudaEngine;
using nvinfer1::IInt8Calibrator;
using nvinfer1::IBuilderConfig;
using nvinfer1::IRuntime;
using nvinfer1::IExecutionContext;
using nvinfer1::ILogger;
using nvinfer1::Dims3;
using nvinfer1::Dims2;
using Severity = nvinfer1::ILogger::Severity;
using std::string;
using std::ios;
using std::ofstream;
using std::ifstream;
using std::vector;
using std::cout;
using std::endl;
using std::array;
Logger gLogger;
TrtSharedEnginePtr shared_engine_ptr(ICudaEngine* ptr) {
return TrtSharedEnginePtr(ptr, TrtDeleter());
}
TrtSharedEnginePtr deserialize(string serpth) {
ifstream ifile(serpth, ios::in | ios::binary);
if (!ifile) {
cout << "read serialized file failed\n";
std::abort();
}
ifile.seekg(0, ios::end);
const int mdsize = ifile.tellg();
ifile.clear();
ifile.seekg(0, ios::beg);
vector<char> buf(mdsize);
ifile.read(&buf[0], mdsize);
ifile.close();
cout << "model size: " << mdsize << endl;
auto runtime = TrtUniquePtr<IRuntime>(nvinfer1::createInferRuntime(gLogger));
TrtSharedEnginePtr engine = shared_engine_ptr(
runtime->deserializeCudaEngine((void*)&buf[0], mdsize, nullptr));
return engine;
}
vector<int> infer_with_engine(TrtSharedEnginePtr engine, vector<float>& data) {
Dims3 out_dims = static_cast<Dims3&&>(
engine->getBindingDimensions(engine->getBindingIndex("preds")));
const int batchsize{1}, H{out_dims.d[1]}, W{out_dims.d[2]};
const int in_size{static_cast<int>(data.size())};
const int out_size{batchsize * H * W};
vector<void*> buffs(2);
vector<int> res(out_size);
auto context = TrtUniquePtr<IExecutionContext>(engine->createExecutionContext());
if (!context) {
cout << "create execution context failed\n";
std::abort();
}
cudaError_t state;
state = cudaMalloc(&buffs[0], in_size * sizeof(float));
if (state) {
cout << "allocate memory failed\n";
std::abort();
}
state = cudaMalloc(&buffs[1], out_size * sizeof(int));
if (state) {
cout << "allocate memory failed\n";
std::abort();
}
cudaStream_t stream;
state = cudaStreamCreate(&stream);
if (state) {
cout << "create stream failed\n";
std::abort();
}
state = cudaMemcpyAsync(
buffs[0], &data[0], in_size * sizeof(float),
cudaMemcpyHostToDevice, stream);
if (state) {
cout << "transmit to device failed\n";
std::abort();
}
context->enqueueV2(&buffs[0], stream, nullptr);
// context->enqueue(1, &buffs[0], stream, nullptr);
state = cudaMemcpyAsync(
&res[0], buffs[1], out_size * sizeof(int),
cudaMemcpyDeviceToHost, stream);
if (state) {
cout << "transmit to host failed \n";
std::abort();
}
cudaStreamSynchronize(stream);
cudaFree(buffs[0]);
cudaFree(buffs[1]);
cudaStreamDestroy(stream);
return res;
}