forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into paralle…
…l_bcast
- Loading branch information
Showing
25 changed files
with
398 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" | ||
#include <string> | ||
#include <vector> | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( | ||
std::unique_ptr<ir::Graph> graph) const { | ||
PADDLE_ENFORCE(graph.get()); | ||
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get()); | ||
|
||
std::unordered_set<Node*> nodes2delete; | ||
|
||
GraphPatternDetector gpd; | ||
auto* conv_input = gpd.mutable_pattern() | ||
->NewNode("conv_relu_mkldnn_fuse/conv_input") | ||
->AsInput() | ||
->assert_is_op_input("conv2d", "Input"); | ||
patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(), | ||
"conv_relu_mkldnn_fuse"); | ||
conv_relu_pattern(conv_input); | ||
|
||
int found_conv_relu_count = 0; | ||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, | ||
Graph* g) { | ||
VLOG(4) << "handle ConvReLU fuse"; | ||
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, | ||
conv_relu_pattern); // Filter | ||
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias | ||
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp | ||
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op | ||
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out | ||
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op | ||
|
||
// Create an ConvReLU Node. | ||
OpDesc desc; | ||
std::string conv_relu_i_in = subgraph.at(conv_input)->Name(); | ||
std::string conv_relu_w_in = conv_weight->Name(); | ||
std::string conv_relu_b_in = conv_bias->Name(); | ||
std::string conv_relu_out = relu_out->Name(); | ||
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in})); | ||
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in})); | ||
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in})); | ||
desc.SetOutput("Out", std::vector<std::string>({conv_relu_out})); | ||
desc.SetType("conv2d"); | ||
for (auto& attr : conv->Op()->GetAttrMap()) { | ||
desc.SetAttr(attr.first, attr.second); | ||
} | ||
desc.SetAttr("fuse_relu", true); | ||
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied. | ||
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out}); | ||
|
||
PADDLE_ENFORCE(subgraph.count(conv_input)); | ||
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node); | ||
IR_NODE_LINK_TO(conv_weight, conv_relu_node); | ||
IR_NODE_LINK_TO(conv_bias, conv_relu_node); | ||
IR_NODE_LINK_TO(conv_relu_node, relu_out); | ||
|
||
found_conv_relu_count++; | ||
}; | ||
|
||
gpd(graph.get(), handler); | ||
|
||
AddStatis(found_conv_relu_count); | ||
return graph; | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
REGISTER_PASS(conv_relu_mkldnn_fuse_pass, | ||
paddle::framework::ir::ConvReLUFusePass); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/ir/fuse_pass_base.h" | ||
#include "paddle/fluid/framework/ir/graph.h" | ||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" | ||
#include "paddle/fluid/framework/ir/pass.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
/* | ||
* Fuse the CONV and ReLU to a ConvReLUOp. | ||
*/ | ||
class ConvReLUFusePass : public FusePassBase { | ||
public: | ||
virtual ~ConvReLUFusePass() {} | ||
|
||
protected: | ||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle |
108 changes: 108 additions & 0 deletions
108
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
void SetOp(ProgramDesc* prog, const std::string& type, | ||
const std::vector<std::string>& inputs, | ||
const std::vector<std::string>& outputs) { | ||
auto* op = prog->MutableBlock(0)->AppendOp(); | ||
op->SetType(type); | ||
if (type == "conv2d") { | ||
op->SetAttr("use_mkldnn", true); | ||
op->SetInput("Input", {inputs[0]}); | ||
op->SetInput("Filter", {inputs[1]}); | ||
op->SetInput("Bias", {inputs[2]}); | ||
} else if (type == "relu") { | ||
op->SetInput("X", inputs); | ||
} | ||
op->SetOutput("Out", outputs); | ||
} | ||
|
||
// a->OP0->b | ||
// b->OP1->c | ||
// (c, weights, bias)->conv->f | ||
// (f)->relu->g | ||
ProgramDesc BuildProgramDesc() { | ||
ProgramDesc prog; | ||
for (auto& v : | ||
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) { | ||
auto* var = prog.MutableBlock(0)->Var(v); | ||
var->SetType(proto::VarType::SELECTED_ROWS); | ||
if (v == "weights" || v == "bias") { | ||
var->SetPersistable(true); | ||
} | ||
} | ||
|
||
SetOp(&prog, "OP0", std::vector<std::string>({"a"}), | ||
std::vector<std::string>({"b"})); | ||
SetOp(&prog, "OP1", std::vector<std::string>({"b"}), | ||
std::vector<std::string>({"c"})); | ||
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}), | ||
std::vector<std::string>({"f"})); | ||
SetOp(&prog, "relu", std::vector<std::string>({"f"}), | ||
std::vector<std::string>({"g"})); | ||
|
||
return prog; | ||
} | ||
|
||
TEST(ConvReLUFusePass, basic) { | ||
auto prog = BuildProgramDesc(); | ||
|
||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); | ||
|
||
auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass"); | ||
|
||
int original_nodes_num = graph->Nodes().size(); | ||
|
||
graph = pass->Apply(std::move(graph)); | ||
|
||
int current_nodes_num = graph->Nodes().size(); | ||
|
||
// Remove 3 Nodes: CONV, RELU, conv_out | ||
// Add 1 Node: ConvReLU | ||
EXPECT_EQ(original_nodes_num - 2, current_nodes_num); | ||
|
||
// Assert conv_relu op in newly generated graph | ||
int conv_relu_count = 0; | ||
|
||
for (auto* node : graph->Nodes()) { | ||
if (node->IsOp() && node->Op()->Type() == "conv2d") { | ||
if (node->Op()->HasAttr("use_mkldnn")) { | ||
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn")); | ||
if (use_mkldnn) { | ||
if (node->Op()->HasAttr("fuse_relu")) { | ||
bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu")); | ||
if (fuse_relu) { | ||
++conv_relu_count; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
EXPECT_EQ(conv_relu_count, 1); | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
USE_PASS(conv_relu_mkldnn_fuse_pass); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.