forked from snuspl/nimble
-
Notifications
You must be signed in to change notification settings - Fork 0
/
onnx_exporter.h
114 lines (91 loc) · 3.63 KB
/
onnx_exporter.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/onnx/helper.h"
#include "caffe2/proto/caffe2.pb.h"
#include "onnx/onnx_pb.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace caffe2 {
namespace onnx {
namespace {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::GraphProto;
using ::ONNX_NAMESPACE::ModelProto;
using ::ONNX_NAMESPACE::NodeProto;
using ::ONNX_NAMESPACE::TensorProto;
} // namespace
using ConvertedResult =
std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
// output names for predict net.
std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
caffe2::NetDef* pred_net);
class OnnxExporter {
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
const caffe2::OperatorDef&,
const std::unordered_map<std::string, caffe2::TensorShape>&);
public:
OnnxExporter(DummyName* dummy = nullptr, bool legacy_mode = false)
: legacy_mode_(legacy_mode) {
if (dummy) {
dummy_ = std::shared_ptr<DummyName>(dummy, [](DummyName*) {});
} else {
dummy_ = std::make_shared<DummyName>();
}
}
ConvertedResult Caffe2OpToOnnxNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
void InitOpToTensorProto(const caffe2::OperatorDef& def, TensorProto* tensor);
private:
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
ConvertedResult CreateCastNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConvPoolNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateGemmNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateReshapeNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateSliceNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateChannelShuffleNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConcatNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateLrnNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
// \brief Check black listed arguemnts where we won't pass down when
// converting to ONNX node
bool IsBlackListed(const caffe2::Argument& arg);
// \brief Convert Caffe2 argument to Onnx attribute
void CopyCaffe2ArgToOnnxAttr(
AttributeProto* attr,
const std::string& op_type,
const caffe2::Argument& arg);
// LUT getters
const std::unordered_map<std::string, std::string>& get_renamed_operators()
const;
const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
const std::
unordered_map<std::string, std::unordered_map<std::string, std::string>>&
get_per_op_renamed_attrs() const;
const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
get_special_operators() const;
// To generate onnx models with opset < 6
bool legacy_mode_{false};
// Dummy name generator
std::shared_ptr<DummyName> dummy_;
};
} // namespace onnx
} // namespace caffe2