forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transformations to optimize SR model (openvinotoolkit#5854)
* Add transformations to optimize SR model * Add test for SplitSqueezeConcatFusion * Add TransposeFuse tests * Return TransposeOptimization renamed to TransposeToReshape * Fix docstring * Fix codestyle * Fix build * Fix GNA build * Fix TransposeToReshape tests * Fix test * Temporarily disable cpu test * Fix codestyle * Fix test * Fix test * Enable SplitSqueezeConcatFusion * Apply suggestions from code review Co-authored-by: Gleb Kazantaev <[email protected]> * Apply review feedback * Apply review feedback * Update split_squeeze_concat_fusion.hpp Co-authored-by: Gleb Kazantaev <[email protected]>
- Loading branch information
1 parent
5e8d1cc
commit 90a93be
Showing
12 changed files
with
626 additions
and
116 deletions.
There are no files selected for viewing
33 changes: 33 additions & 0 deletions
33
...nsformations/include/transformations/common_optimizations/split_squeeze_concat_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
#include <memory> | ||
|
||
#include <transformations_visibility.hpp> | ||
|
||
#include <ngraph/ngraph.hpp> | ||
#include <ngraph/pass/graph_rewrite.hpp> | ||
#include "ngraph/pattern/matcher.hpp" | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API SplitSqueezeConcatFusion; | ||
|
||
} // namespace pass | ||
} // namespace ngraph | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief SplitSqueezeConcatFusion transformation replaces group of | ||
* operations: Split -> Squeeze (multiple) -> Concat to Transpose -> Reshape ops. | ||
*/ | ||
class ngraph::pass::SplitSqueezeConcatFusion : public ngraph::pass::MatcherPass { | ||
public: | ||
NGRAPH_RTTI_DECLARATION; | ||
SplitSqueezeConcatFusion(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
...src/transformations/include/transformations/common_optimizations/transpose_to_reshape.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
#include <memory> | ||
|
||
#include <transformations_visibility.hpp> | ||
|
||
#include <ngraph/ngraph.hpp> | ||
#include <ngraph/pass/graph_rewrite.hpp> | ||
#include "ngraph/pattern/matcher.hpp" | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API TransposeToReshape; | ||
|
||
} // namespace pass | ||
} // namespace ngraph | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief TransposeToReshape transformation replaces suitable Transposes with Reshape operation or optimizes them out | ||
*/ | ||
class ngraph::pass::TransposeToReshape : public ngraph::pass::MatcherPass { | ||
public: | ||
NGRAPH_RTTI_DECLARATION; | ||
TransposeToReshape(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
.../transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "itt.hpp" | ||
#include "transformations/common_optimizations/split_squeeze_concat_fusion.hpp" | ||
|
||
#include <memory> | ||
#include <vector> | ||
#include <numeric> | ||
|
||
#include <ngraph/opsets/opset7.hpp> | ||
#include <ngraph/rt_info.hpp> | ||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
|
||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SplitSqueezeConcatFusion, "SplitSqueezeConcatFusion", 0); | ||
|
||
ngraph::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() { | ||
MATCHER_SCOPE(SplitSqueezeConcatFusion); | ||
// Detect only concat, because we don't know how many inputs will go into concat | ||
auto concat_pattern = ngraph::pattern::wrap_type<ngraph::opset7::Concat>(); | ||
|
||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { | ||
const auto& pattern_to_output = m.get_pattern_value_map(); | ||
auto concat = std::dynamic_pointer_cast<ngraph::opset7::Concat>(pattern_to_output.at(concat_pattern).get_node_shared_ptr()); | ||
if (!concat) return false; | ||
|
||
NodeVector nodes_to_delete{ concat }; | ||
|
||
int64_t axis_value = 0; | ||
std::shared_ptr<ngraph::opset7::Split> split; | ||
|
||
const auto& concat_inputs = concat->input_values(); | ||
if (concat_inputs.empty()) return false; | ||
for (size_t i = 0; i < concat_inputs.size(); i++) { | ||
auto squeeze = std::dynamic_pointer_cast<ngraph::opset7::Squeeze>(concat_inputs[i].get_node_shared_ptr()); | ||
if (!squeeze) return false; | ||
|
||
nodes_to_delete.push_back(squeeze); | ||
|
||
auto split_to_check = std::dynamic_pointer_cast<ngraph::opset7::Split>(squeeze->input_value(0).get_node_shared_ptr()); | ||
auto squeeze_axes = std::dynamic_pointer_cast<ngraph::opset7::Constant>(squeeze->input_value(1).get_node_shared_ptr()); | ||
if (!squeeze_axes || !split_to_check) return false; | ||
|
||
auto squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>(); | ||
if (squeeze_axes_vec.size() != 1) return false; | ||
|
||
if (i == 0) { | ||
axis_value = squeeze_axes_vec[0]; | ||
nodes_to_delete.push_back(split_to_check); | ||
split = split_to_check; | ||
} else if (axis_value != squeeze_axes_vec[0] || split_to_check != split) { | ||
return false; | ||
} | ||
|
||
auto split_output = squeeze->input_value(0); | ||
if (split_output.get_target_inputs().size() != 1 || | ||
split_output.get_index() != i) | ||
return false; | ||
} | ||
|
||
if (split->get_num_splits() != concat_inputs.size()) return false; | ||
|
||
auto split_axis = std::dynamic_pointer_cast<ngraph::opset7::Constant>(split->input_value(1).get_node_shared_ptr()); | ||
if (!split_axis) return false; | ||
|
||
auto axis_vec = split_axis->cast_vector<int64_t>(); | ||
if (axis_vec.size() != 1 || axis_value != axis_vec[0]) | ||
return false; | ||
|
||
auto input = split->input_value(0); | ||
|
||
auto concat_axis = concat->get_axis(); | ||
auto rank = input.get_partial_shape().rank(); | ||
if (!rank.is_static()) | ||
return false; | ||
std::vector<int64_t> order(rank.get_length()); | ||
std::iota(order.begin(), order.end(), 0); | ||
order.erase(order.begin() + axis_value); | ||
order.insert(order.begin() + concat_axis, axis_value); | ||
|
||
auto transpose_order = ngraph::opset7::Constant::create(element::i64, { (size_t)rank.get_length() }, order); | ||
auto transpose = register_new_node<ngraph::opset7::Transpose>(input, transpose_order); | ||
auto shape_after = ngraph::opset7::Constant::create(element::i64, { (size_t)rank.get_length() - 1 }, concat->get_output_shape(0)); | ||
auto reshape = std::make_shared<ngraph::opset7::Reshape>(transpose, shape_after, false); | ||
|
||
reshape->set_friendly_name(m.get_match_root()->get_friendly_name()); | ||
ngraph::copy_runtime_info(nodes_to_delete, { transpose, reshape }); | ||
ngraph::replace_node(m.get_match_root(), reshape); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_pattern, matcher_name); | ||
register_matcher(m, callback); | ||
} |
Oops, something went wrong.