Skip to content

Commit

Permalink
Delete repeat ops add gather squeeze unsqueeze (PaddlePaddle#55371)
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 authored Jul 19, 2023
1 parent bc15370 commit 552ed8d
Show file tree
Hide file tree
Showing 6 changed files with 788 additions and 19 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pass_library(dense_multihead_matmul_to_sparse_pass inference)
pass_library(delete_cast_op_pass inference)
pass_library(delete_elementwise_mul_op_pass inference)
pass_library(delete_repeated_ops_pass inference)
pass_library(fused_continuous_same_ops_pass inference)
pass_library(sigmoid_elementmul_fuse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
Expand Down
77 changes: 59 additions & 18 deletions paddle/fluid/framework/ir/delete_repeated_ops_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ class DeleteRepeatedOpsPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;

private:
void DeleteRepeatedOps(
ir::Graph* graph,
const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const;
void DeleteRepeatedOps(ir::Graph* graph,
const std::string& op_type,
std::function<std::string(Node*)> gen_op_key_fn) const;

const std::string name_scope_{"delete_repeated_ops_pass"};
mutable int delete_op_count{0};
};

void DeleteRepeatedOpsPass::DeleteRepeatedOps(
ir::Graph* graph,
const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const {
std::function<std::string(Node*)> gen_op_key_fn) const {
GraphPatternDetector gpd;
patterns::VarWithRepeatedOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, op_type);
Expand Down Expand Up @@ -140,7 +140,7 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps(
}
}
if (out_op_is_invalid) continue;
auto attr_key = gen_op_key_fn(op->Op());
auto attr_key = gen_op_key_fn(op);
ops_map[attr_key].push_back(op);
}
for (auto iter = ops_map.begin(); iter != ops_map.end();) {
Expand Down Expand Up @@ -173,16 +173,18 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps(
};

gpd(graph, handler);
delete_op_count += delete_counts;
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type
<< " ops";
}
}

std::string GenShapeAttrKey(OpDesc* slice_op_desc) { return ""; }
std::string GenShapeAttrKey(Node* shape_op_node) { return ""; }

std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
std::string GenSliceAttrKey(Node* slice_op_node) {
std::string attr_key;
auto slice_op_desc = slice_op_node->Op();
auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts");
auto ends = slice_op_desc->GetAttrIfExists<std::vector<int>>("ends");
auto axes = slice_op_desc->GetAttrIfExists<std::vector<int>>("axes");
Expand All @@ -207,21 +209,24 @@ std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
return attr_key;
}

std::string GenCastAttrKey(OpDesc* cast_op_desc) {
std::string GenCastAttrKey(Node* cast_op_node) {
auto cast_op_desc = cast_op_node->Op();
auto in_dtype = cast_op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = cast_op_desc->GetAttrIfExists<int>("out_dtype");
return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" +
std::to_string(out_dtype);
}

std::string GenAddAttrKey(OpDesc* add_op_desc) {
std::string GenAddAttrKey(Node* add_op_node) {
auto add_op_desc = add_op_node->Op();
std::string x_name = add_op_desc->Input("X")[0];
std::string y_name = add_op_desc->Input("Y")[0];
auto axis = add_op_desc->GetAttrIfExists<int>("axis");
return x_name + "_" + y_name + "_axis_" + std::to_string(axis);
}

std::string GenScaleAttrKey(OpDesc* scale_op_desc) {
std::string GenScaleAttrKey(Node* scale_op_node) {
auto scale_op_desc = scale_op_node->Op();
auto scale = scale_op_desc->GetAttrIfExists<float>("scale");
auto bias = scale_op_desc->GetAttrIfExists<float>("bias");
auto bias_after_scale =
Expand All @@ -230,17 +235,53 @@ std::string GenScaleAttrKey(OpDesc* scale_op_desc) {
"_bias_after_scale_" + std::to_string(bias_after_scale);
}

std::string GenGatherAttrKey(Node* gather_op_node) {
std::string input_names{""};
for (auto input_var : gather_op_node->inputs) {
input_names += input_var->Var()->Name();
}
auto gather_op_desc = gather_op_node->Op();
auto axis = gather_op_desc->GetAttrIfExists<int>("axis");
return "axis_" + std::to_string(axis) + "_input_names_" + input_names;
}

std::string GenSqueeze2AttrKey(Node* squeeze2_op_node) {
auto squeeze2_op_desc = squeeze2_op_node->Op();
auto axes = squeeze2_op_desc->GetAttrIfExists<std::vector<int>>("axes");
std::string attr_key{""};
attr_key += "axes_";
for (auto axis : axes) {
attr_key += std::to_string(axis) + "_";
}
return attr_key;
}

void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);

DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
int repeat_time = 0;
int total_delete_op_count = 0;
// This pass needs to loop run until there are no nodes in the graph that need
// to be deleted.
while (true) {
delete_op_count = 0;
DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
DeleteRepeatedOps(graph, "gather", GenGatherAttrKey);
DeleteRepeatedOps(graph, "squeeze2", GenSqueeze2AttrKey);
DeleteRepeatedOps(graph, "unsqueeze2", GenSqueeze2AttrKey);
LOG(INFO) << "Round " << repeat_time++
<< ": delete op counts: " << delete_op_count;
total_delete_op_count += delete_op_count;
if (delete_op_count == 0) {
break; // No node need to delete.
}
}
LOG(INFO) << "Total delete op counts: " << total_delete_op_count;
}

} // namespace ir
Expand Down
237 changes: 237 additions & 0 deletions paddle/fluid/framework/ir/fused_continuous_same_ops_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {

namespace patterns {

struct ContinuousSameOpsPattern : public PatternBase {
ContinuousSameOpsPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type);
PATTERN_DECL_NODE(first_in_var_node);
PATTERN_DECL_NODE(first_out_var_node);
PATTERN_DECL_NODE(second_out_var_node);
// declare op node's name
PATTERN_DECL_NODE(first_op_node);
PATTERN_DECL_NODE(second_op_node);
std::string op_type_;
};

ContinuousSameOpsPattern::ContinuousSameOpsPattern(
PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type)
: PatternBase(pattern, name_scope, name_scope), op_type_(op_type) {
auto* first_in_var_node =
pattern->NewNode(first_in_var_node_repr())
->assert_var_not_persistable()
->assert_is_op_input(op_type_, "X")
->AsInput()
->assert_more([&](Node* node) {
// assert pre op type is not same.
auto input_nodes = node->inputs;
if (input_nodes.size() != 1) return false;
if (!input_nodes.empty() && input_nodes[0]->IsOp() &&
input_nodes[0]->Op()->Type() == op_type_) {
return false;
}
return true;
});
auto* first_op_node =
pattern->NewNode(first_op_node_repr())->assert_is_op(op_type_);
auto* first_out_var_node = pattern->NewNode(first_out_var_node_repr())
->assert_var_not_persistable()
->assert_is_op_output(op_type_, "Out")
->assert_has_n_outputs(1);
first_op_node->LinksFrom({first_in_var_node}).LinksTo({first_out_var_node});
auto* second_op_node =
pattern->NewNode(second_op_node_repr())->assert_is_op(op_type_);
auto* second_out_var_node = pattern->NewNode(second_out_var_node_repr())
->assert_var_not_persistable()
->assert_is_op_output(op_type_, "Out")
->AsOutput();
second_op_node->LinksFrom({first_out_var_node})
.LinksTo({second_out_var_node});
}

} // namespace patterns

/*
Fused continuous same ops into one.
Origin graph:
input
|
|
unsqueeze2
|
|
unsqueeze2
|
|
unsqueeze2
|
|
out
After:
input
|
|
unsqueeze2
|
|
out
*/

class FusedContinuousSameOpsPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
void FusedReshapeOps(ir::Graph* graph) const;
void FusedUnsqueezeOps(ir::Graph* graph) const;

const std::string name_scope_{"fused_continuous_same_ops_pass"};
mutable int delete_op_count{0};
};

void FusedContinuousSameOpsPass::FusedReshapeOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ContinuousSameOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "reshape2");
int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle fused continuous reshape ops.";
GET_IR_NODE_FROM_SUBGRAPH(first_in_var_node, first_in_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_out_var_node, first_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
second_out_var_node, second_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_op_node, first_op_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(second_op_node, second_op_node, pattern);
auto first_node_attr_shape =
first_op_node->Op()->GetAttrIfExists<std::vector<int>>("shape");
if (first_node_attr_shape.empty()) return;
auto second_node_attr_shape =
second_op_node->Op()->GetAttrIfExists<std::vector<int>>("shape");
if (second_node_attr_shape.empty()) return;
second_op_node->Op()->RenameInput(first_out_var_node->Name(),
first_in_var_node->Name());
IR_NODE_LINK_TO(first_in_var_node, second_op_node);
GraphSafeRemoveNodes(graph, {first_op_node, first_out_var_node});
delete_counts++;
};
gpd(graph, handler);
delete_op_count += delete_counts;
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated "
<< "reshape2"
<< " ops";
}
}
void FusedContinuousSameOpsPass::FusedUnsqueezeOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ContinuousSameOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "unsqueeze2");
int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle fused continuous unsqueeze ops.";
GET_IR_NODE_FROM_SUBGRAPH(first_in_var_node, first_in_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_out_var_node, first_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
second_out_var_node, second_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_op_node, first_op_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(second_op_node, second_op_node, pattern);
auto first_node_attr_axes =
first_op_node->Op()->GetAttrIfExists<std::vector<int>>("axes");
if (first_node_attr_axes.empty()) return;
auto second_node_attr_axes =
second_op_node->Op()->GetAttrIfExists<std::vector<int>>("axes");
if (second_node_attr_axes.empty()) return;
second_op_node->Op()->RenameInput(first_out_var_node->Name(),
first_in_var_node->Name());
second_node_attr_axes.insert(second_node_attr_axes.begin(),
first_node_attr_axes.begin(),
first_node_attr_axes.end());
second_op_node->Op()->SetAttr("axes", second_node_attr_axes);
IR_NODE_LINK_TO(first_in_var_node, second_op_node);
GraphSafeRemoveNodes(graph, {first_op_node, first_out_var_node});
delete_counts++;
};
gpd(graph, handler);
delete_op_count += delete_counts;
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated "
<< "unsqueeze2"
<< " ops";
}
}
void FusedContinuousSameOpsPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int repeat_time = 0;
int total_delete_op_count = 0;
// This pass needs to loop run until there are no nodes in the graph that need
// to be deleted.
while (true) {
delete_op_count = 0;
FusedReshapeOps(graph);
FusedUnsqueezeOps(graph);
LOG(INFO) << "Round " << repeat_time++
<< ": delete op counts: " << delete_op_count;
total_delete_op_count += delete_op_count;
if (delete_op_count == 0) {
LOG(INFO) << "--- no nodes need to delete --- break";
break; // No node need to delete.
}
}
LOG(INFO) << "Total delete op counts: " << total_delete_op_count;
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(fused_continuous_same_ops_pass,
paddle::framework::ir::FusedContinuousSameOpsPass);

REGISTER_PASS_CAPABILITY(fused_continuous_same_ops_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"reshape2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"unsqueeze2", 0));
3 changes: 2 additions & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,9 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"delete_assign_op_pass",
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_op_clean_pass",
"delete_repeated_ops_pass",
"identity_op_clean_pass",
"fused_continuous_same_ops_pass",
"reshape_unstack_concat_fuse_pass",
"delete_op_device_pass",
"constant_folding_pass",
Expand Down
Loading

0 comments on commit 552ed8d

Please sign in to comment.