Skip to content

Commit

Permalink
Add transformations to optimize SR model (openvinotoolkit#5854)
Browse files Browse the repository at this point in the history
* Add transformations to optimize SR model

* Add test for SplitSqueezeConcatFusion

* Add TransposeFuse tests

* Return TransposeOptimization renamed to TransposeToReshape

* Fix docstring

* Fix codestyle

* Fix build

* Fix GNA build

* Fix TransposeToReshape tests

* Fix test

* Temporarily disable cpu test

* Fix codestyle

* Fix test

* Fix test

* Enable SplitSqueezeConcatFusion

* Apply suggestions from code review

Co-authored-by: Gleb Kazantaev <[email protected]>

* Apply review feedback

* Apply review feedback

* Update split_squeeze_concat_fusion.hpp

Co-authored-by: Gleb Kazantaev <[email protected]>
  • Loading branch information
mvafin and GlebKazantaev authored Jun 4, 2021
1 parent 5e8d1cc commit 90a93be
Show file tree
Hide file tree
Showing 12 changed files with 626 additions and 116 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API SplitSqueezeConcatFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief SplitSqueezeConcatFusion transformation replaces group of
* operations: Split -> Squeeze (multiple) -> Concat to Transpose -> Reshape ops.
*/
class ngraph::pass::SplitSqueezeConcatFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SplitSqueezeConcatFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,42 @@ namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API TransposeSinking;
class TRANSFORMATIONS_API TransposeOptimization;
class TRANSFORMATIONS_API TransposeReduction;
class TRANSFORMATIONS_API TransposeFQReduction;
class TRANSFORMATIONS_API TransposeFuse;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief TransposeOptimization transformation replaces suitable Transposes with Reshape operation or optimises them out
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
*/
class ngraph::pass::TransposeOptimization : public ngraph::pass::MatcherPass {
class ngraph::pass::TransposeReduction : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeOptimization();
TransposeReduction();
};

/**
* @ingroup ie_transformation_common_api
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
* @brief TransposeFQReduction transformation sinks Transpose through FakeQuantize in case it is followed by reduction or squeeze
*/
class ngraph::pass::TransposeReduction : public ngraph::pass::MatcherPass {
class ngraph::pass::TransposeFQReduction : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeReduction();
TransposeFQReduction();
};

/**
* @ingroup ie_transformation_common_api
* @brief TransposeFQReduction transformation sinks Transpose through FakeQuantize in case it is followed by reduction or squeeze
* @brief TransposeFuse transformation eliminates 2 consequtive Transposes if they result in no changes to input or fuses them
* to single Transpose if input gets changed
*/
class ngraph::pass::TransposeFQReduction : public ngraph::pass::MatcherPass {
class ngraph::pass::TransposeFuse : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeFQReduction();
TransposeFuse();
};

/**
Expand All @@ -64,6 +65,6 @@ class ngraph::pass::TransposeSinking: public ngraph::pass::GraphRewrite {
TransposeSinking() {
add_matcher<ngraph::pass::TransposeFQReduction>();
add_matcher<ngraph::pass::TransposeReduction>();
add_matcher<ngraph::pass::TransposeOptimization>();
add_matcher<ngraph::pass::TransposeFuse>();
}
};
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API TransposeToReshape;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief TransposeToReshape transformation replaces suitable Transposes with Reshape operation or optimizes them out
*/
class ngraph::pass::TransposeToReshape : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeToReshape();
};
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
#include "transformations/common_optimizations/batch_to_space_fusion.hpp"
#include "transformations/common_optimizations/dilated_convolution_converter.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/split_squeeze_concat_fusion.hpp"
#include "transformations/common_optimizations/transpose_to_reshape.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
Expand Down Expand Up @@ -91,7 +93,13 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
manager.register_pass<ngraph::pass::TransposeSinking>();

auto transpose_sinking = manager.register_pass<ngraph::pass::GraphRewrite>();
transpose_sinking->add_matcher<ngraph::pass::TransposeSinking>();
// SplitSqueezeConcatFusion should work in same GraphRewrite as TransposesSinking,
// because it replaces pattern that may contain Transposes which must be optimized before
// the transformation and it also inserts Transpose that can be optimized by TransposeSinking
transpose_sinking->add_matcher<ngraph::pass::SplitSqueezeConcatFusion>();

auto eliminations = manager.register_pass<ngraph::pass::GraphRewrite>();
eliminations->add_matcher<ngraph::pass::EliminateUnsqueezeGather>();
Expand Down Expand Up @@ -119,6 +127,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
common_fusions->add_matcher<ngraph::pass::BatchToSpaceFusion>();
common_fusions->add_matcher<ngraph::pass::DilatedConvolutionConverter>();
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
common_fusions->set_name("ngraph::pass::CommonFusions");

manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "itt.hpp"
#include "transformations/common_optimizations/split_squeeze_concat_fusion.hpp"

#include <memory>
#include <vector>
#include <numeric>

#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::SplitSqueezeConcatFusion, "SplitSqueezeConcatFusion", 0);

ngraph::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
MATCHER_SCOPE(SplitSqueezeConcatFusion);
// Detect only concat, because we don't know how many inputs will go into concat
auto concat_pattern = ngraph::pattern::wrap_type<ngraph::opset7::Concat>();

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto concat = std::dynamic_pointer_cast<ngraph::opset7::Concat>(pattern_to_output.at(concat_pattern).get_node_shared_ptr());
if (!concat) return false;

NodeVector nodes_to_delete{ concat };

int64_t axis_value = 0;
std::shared_ptr<ngraph::opset7::Split> split;

const auto& concat_inputs = concat->input_values();
if (concat_inputs.empty()) return false;
for (size_t i = 0; i < concat_inputs.size(); i++) {
auto squeeze = std::dynamic_pointer_cast<ngraph::opset7::Squeeze>(concat_inputs[i].get_node_shared_ptr());
if (!squeeze) return false;

nodes_to_delete.push_back(squeeze);

auto split_to_check = std::dynamic_pointer_cast<ngraph::opset7::Split>(squeeze->input_value(0).get_node_shared_ptr());
auto squeeze_axes = std::dynamic_pointer_cast<ngraph::opset7::Constant>(squeeze->input_value(1).get_node_shared_ptr());
if (!squeeze_axes || !split_to_check) return false;

auto squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>();
if (squeeze_axes_vec.size() != 1) return false;

if (i == 0) {
axis_value = squeeze_axes_vec[0];
nodes_to_delete.push_back(split_to_check);
split = split_to_check;
} else if (axis_value != squeeze_axes_vec[0] || split_to_check != split) {
return false;
}

auto split_output = squeeze->input_value(0);
if (split_output.get_target_inputs().size() != 1 ||
split_output.get_index() != i)
return false;
}

if (split->get_num_splits() != concat_inputs.size()) return false;

auto split_axis = std::dynamic_pointer_cast<ngraph::opset7::Constant>(split->input_value(1).get_node_shared_ptr());
if (!split_axis) return false;

auto axis_vec = split_axis->cast_vector<int64_t>();
if (axis_vec.size() != 1 || axis_value != axis_vec[0])
return false;

auto input = split->input_value(0);

auto concat_axis = concat->get_axis();
auto rank = input.get_partial_shape().rank();
if (!rank.is_static())
return false;
std::vector<int64_t> order(rank.get_length());
std::iota(order.begin(), order.end(), 0);
order.erase(order.begin() + axis_value);
order.insert(order.begin() + concat_axis, axis_value);

auto transpose_order = ngraph::opset7::Constant::create(element::i64, { (size_t)rank.get_length() }, order);
auto transpose = register_new_node<ngraph::opset7::Transpose>(input, transpose_order);
auto shape_after = ngraph::opset7::Constant::create(element::i64, { (size_t)rank.get_length() - 1 }, concat->get_output_shape(0));
auto reshape = std::make_shared<ngraph::opset7::Reshape>(transpose, shape_after, false);

reshape->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(nodes_to_delete, { transpose, reshape });
ngraph::replace_node(m.get_match_root(), reshape);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(concat_pattern, matcher_name);
register_matcher(m, callback);
}
Loading

0 comments on commit 90a93be

Please sign in to comment.