Skip to content

Commit 896a37b

Browse files
authored
fea/link ir to inference analysis and fc fuse support (PaddlePaddle#12789)
* link IR graph to analysis graph * add clean code and update * add infer_clean_pass * add ir_pass_manager * support fc fuse executation * fix ir circle
1 parent e23ddf6 commit 896a37b

39 files changed

+1114
-206
lines changed

doc/fluid/design/others/graph_survey.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def get_symbol(num_classes=10, **kwargs):
2828

2929

3030

31-
Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own NodeAttr. There is a op field in NodeAttr class, when a Symbol represents Variable(often input data), the op field is null.
31+
Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own AnyAttr. There is a op field in AnyAttr class, when a Symbol represents Variable(often input data), the op field is null.
3232

3333
Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph.
3434

paddle/fluid/framework/ir/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
55
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
66
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
77
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
8+
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
9+
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
10+
811

912
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
1013
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
1114
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
1215
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
16+
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
16+
#include <string>
17+
#include <vector>
18+
#include "paddle/fluid/platform/enforce.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
bool VarOutLinksToOp(Node* node, const std::string& op_type) {
25+
for (auto* out : node->outputs) {
26+
if (out->IsOp() && out->Op()->Type() == op_type) {
27+
return true;
28+
}
29+
}
30+
return false;
31+
}
32+
33+
void BuildFCPattern(PDPattern* pattern) {
34+
// make sure the selected MUL op has one input argument is a parameter.
35+
auto* mul_parameter_var = pattern->NewNode(
36+
[](Node* node) {
37+
return node->IsVar() && node->outputs.size() == 1UL &&
38+
node->outputs.front()->Op()->Type() == "mul" && node->Var() &&
39+
node->Var()->Persistable(); // check is a parameter
40+
},
41+
"mul_weight" /*name*/);
42+
43+
auto* mul_tmp_input_var = pattern->NewNode(
44+
[](Node* node) {
45+
bool result =
46+
node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
47+
!node->Var()->Persistable(); // this input is not an parameter.
48+
if (!result) return false;
49+
// check whether one output is MUL op.
50+
for (auto* op : node->outputs) {
51+
if (op->IsOp() && op->Op()->Type() == "mul") return true;
52+
}
53+
return false;
54+
},
55+
"mul_tmp_var" /*name*/);
56+
57+
// select a MUL op
58+
auto* mul_op = pattern->NewNode(
59+
[](Node* node) {
60+
return node->IsOp() && // start from an Op
61+
node->Op()->Type() == "mul"; // type is mul
62+
// the output should be consumed only by one element_add, that check
63+
// leaves in a Var PDNode.
64+
},
65+
"mul" /*name*/);
66+
67+
// make sure the MUL op's output has only one consumer and links to an
68+
// ELEMENTWISE_ADD op.
69+
auto* mul_out_var = pattern->NewNode(
70+
[](Node* node) {
71+
return node->IsVar() && // starts from a Var
72+
node->outputs.size() == 1UL && // only has one consumer
73+
node->outputs.front()->IsOp() && // check basic logic
74+
node->Var() && // not a ControlDepVar
75+
node->outputs.front()->Op()->Type() ==
76+
"elementwise_add"; // a very strong validation
77+
},
78+
"mul_out");
79+
// this check is not essential, just to make the corresponding variable Node
80+
// retrival easier.
81+
auto* elementwise_add_tmp_var = pattern->NewNode(
82+
[](Node* node) {
83+
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
84+
VarOutLinksToOp(node, "elementwise_add");
85+
},
86+
"elementwise_add_tmpvar");
87+
88+
// select an ELEMENTWISE_ADD op
89+
auto* elementwise_add_op = pattern->NewNode(
90+
[](Node* node) {
91+
return node->IsOp() && node->Op()->Type() == "elementwise_add";
92+
},
93+
"elementwise_add" /*name*/);
94+
95+
// get the ELEMENTWISE_ADD op's output
96+
auto* elementwise_add_out_var = pattern->NewNode(
97+
[](Node* node) {
98+
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
99+
node->inputs.front()->Op()->Type() == "elementwise_add";
100+
},
101+
"elementwise_add_out");
102+
103+
pattern->AddEdge(mul_parameter_var, mul_op);
104+
pattern->AddEdge(mul_tmp_input_var, mul_op);
105+
pattern->AddEdge(mul_op, mul_out_var);
106+
pattern->AddEdge(mul_out_var, elementwise_add_op);
107+
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
108+
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
109+
}
110+
111+
// Replace the node `from` in the links to `to`
112+
bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
113+
for (auto*& n : *links) {
114+
if (n == from) {
115+
n = to;
116+
return true;
117+
}
118+
}
119+
return false;
120+
}
121+
122+
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
123+
std::unique_ptr<ir::Graph> graph) const {
124+
PADDLE_ENFORCE(graph.get());
125+
126+
std::unordered_set<Node*> nodes2delete;
127+
128+
GraphPatternDetecter gpd;
129+
BuildFCPattern(gpd.mutable_pattern());
130+
131+
#define GET_NODE(id) \
132+
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
133+
"pattern has no Node called %s", #id); \
134+
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
135+
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
136+
137+
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph,
138+
Graph* g) {
139+
VLOG(4) << "handle FC fuse";
140+
// Currently, there is no FC op available, so I will just simulate the
141+
// scenerio.
142+
// FC's fusion is simple, just op fuse, no need to process the
143+
// parameters.
144+
GET_NODE(mul_tmp_var); // x
145+
GET_NODE(mul_weight); // Y
146+
GET_NODE(elementwise_add_tmpvar); // bias
147+
GET_NODE(elementwise_add_out); // Out
148+
GET_NODE(mul); // MUL op
149+
GET_NODE(elementwise_add); // ELEMENT_ADD op
150+
GET_NODE(mul_out); // tmp
151+
#undef GET_NODE
152+
153+
// Create an FC Node.
154+
OpDesc desc;
155+
std::string fc_x_in = mul_tmp_var->Name();
156+
std::string fc_Y_in = mul_weight->Name();
157+
std::string fc_bias_in = elementwise_add_tmpvar->Name();
158+
std::string fc_out = elementwise_add_out->Name();
159+
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
160+
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
161+
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
162+
desc.SetOutput("Out", std::vector<std::string>({fc_out}));
163+
desc.SetType("fc");
164+
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
165+
fc_node->inputs =
166+
std::vector<Node*>({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
167+
fc_node->outputs.push_back(elementwise_add_out);
168+
169+
// Update link relatons
170+
PADDLE_ENFORCE(LinksReplace(&mul_tmp_var->outputs, mul, fc_node));
171+
PADDLE_ENFORCE(LinksReplace(&mul_weight->outputs, mul, fc_node));
172+
PADDLE_ENFORCE(LinksReplace(&elementwise_add_tmpvar->outputs,
173+
elementwise_add, fc_node));
174+
PADDLE_ENFORCE(
175+
LinksReplace(&elementwise_add_out->inputs, elementwise_add, fc_node));
176+
177+
// Drop old nodes
178+
graph->RemoveNode(mul);
179+
graph->RemoveNode(elementwise_add);
180+
graph->RemoveNode(mul_out); // tmp variable
181+
};
182+
183+
gpd(graph.get(), handler);
184+
185+
return graph;
186+
}
187+
188+
} // namespace ir
189+
} // namespace framework
190+
} // namespace paddle
191+
192+
REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass);
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/graph.h"
16+
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
17+
#include "paddle/fluid/framework/ir/pass.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
/*
24+
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
25+
*/
26+
class FCFusePass : public Pass {
27+
public:
28+
virtual ~FCFusePass() {}
29+
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
};
33+
34+
} // namespace ir
35+
} // namespace framework
36+
} // namespace paddle
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
void SetOp(ProgramDesc* prog, const std::string& type,
24+
const std::vector<std::string>& inputs,
25+
const std::vector<std::string>& outputs) {
26+
auto* op = prog->MutableBlock(0)->AppendOp();
27+
op->SetType(type);
28+
op->SetInput("Xs", inputs);
29+
op->SetOutput("Ys", outputs);
30+
}
31+
32+
// a->OP0->b
33+
// a->OP1->c
34+
// (b, c)->mul->d
35+
// (d, e)->elementwise_add->f
36+
ProgramDesc BuildProgramDesc() {
37+
ProgramDesc prog;
38+
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
39+
auto* var = prog.MutableBlock(0)->Var(v);
40+
var->SetType(proto::VarType::SELECTED_ROWS);
41+
if (v == "c") {
42+
var->SetPersistable(true);
43+
}
44+
}
45+
46+
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
47+
std::vector<std::string>({"b"}));
48+
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
49+
std::vector<std::string>({"c"}));
50+
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
51+
std::vector<std::string>({"d"}));
52+
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
53+
std::vector<std::string>({"f"}));
54+
55+
return prog;
56+
}
57+
58+
TEST(FCFusePass, basic) {
59+
auto prog = BuildProgramDesc();
60+
61+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
62+
63+
auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
64+
65+
int pre_nodes = graph->Nodes().size();
66+
67+
graph = pass->Apply(std::move(graph));
68+
69+
int after_nodes = graph->Nodes().size();
70+
71+
// Remove 3 Nodes: MUL,ELEMENTWISE_ADD, mul_out
72+
// Add 1 Node: FC
73+
EXPECT_EQ(pre_nodes - 2, after_nodes);
74+
75+
// Assert fc op in newly generated graph
76+
int fc_count = 0;
77+
78+
for (auto* node : graph->Nodes()) {
79+
if (node->IsOp() && node->Op()->Type() == "fc") {
80+
++fc_count;
81+
}
82+
}
83+
EXPECT_EQ(fc_count, 1);
84+
}
85+
86+
} // namespace ir
87+
} // namespace framework
88+
} // namespace paddle
89+
90+
USE_PASS(fc_fuse_pass);

paddle/fluid/framework/ir/graph.h

+10-6
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ class Graph {
9898

9999
// Create a normal variable with non-null VarDesc.
100100
ir::Node *CreateVarNode(VarDesc *var_desc) {
101+
PADDLE_ENFORCE(var_desc);
101102
return AddNode(new ir::Node(var_desc));
102103
}
103104

104105
// Create a normal runnable operator with OpDesc.
105106
ir::Node *CreateOpNode(OpDesc *op_desc) {
107+
PADDLE_ENFORCE(op_desc);
106108
return AddNode(new ir::Node(op_desc));
107109
}
108110

@@ -134,6 +136,14 @@ class Graph {
134136
return ret;
135137
}
136138

139+
void RemoveNode(ir::Node *node) {
140+
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
141+
node_set_.erase(node);
142+
nodes_.erase(node);
143+
}
144+
145+
const ProgramDesc &program() const { return program_; }
146+
137147
private:
138148
// This method takes ownership of `node`.
139149
ir::Node *AddNode(ir::Node *node) {
@@ -143,12 +153,6 @@ class Graph {
143153
return node;
144154
}
145155

146-
void RemoveNode(ir::Node *node) {
147-
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
148-
node_set_.erase(node);
149-
nodes_.erase(node);
150-
}
151-
152156
// NOTE: program_ shouldn't be exposed to user.
153157
const ProgramDesc &program_;
154158
std::map<std::string, boost::any> attrs_;

0 commit comments

Comments
 (0)