From f6a4dcb5ac890e8ec9e54b411073576816a30f17 Mon Sep 17 00:00:00 2001 From: Szymon Irzabek Date: Mon, 22 Nov 2021 12:39:25 +0100 Subject: [PATCH] [GNA] Add MVN decomposition (#8142) --- .../src/gna_plugin/gna_plugin.cpp | 4 + .../transformations/decompose_mvn.cpp | 265 ++++++++++++++++++ .../transformations/decompose_mvn.hpp | 24 ++ .../plugin/gna/pass_tests/decompose_mvn.cpp | 157 +++++++++++ .../transformations/gna_decompose_mvn.cpp | 253 +++++++++++++++++ 5 files changed, 703 insertions(+) create mode 100644 inference-engine/src/gna_plugin/transformations/decompose_mvn.cpp create mode 100644 inference-engine/src/gna_plugin/transformations/decompose_mvn.hpp create mode 100644 inference-engine/tests/functional/plugin/gna/pass_tests/decompose_mvn.cpp create mode 100644 inference-engine/tests/unit/gna/ngraph/transformations/gna_decompose_mvn.cpp diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 83c2c0dbaea57e..633490ce2a78bc 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -70,6 +70,8 @@ #include "transformations/op_conversions/lstm_cell_decomposition.hpp" #include "transformations/remove_single_input_concat.hpp" #include "transformations/broadcast_const.hpp" +#include "transformations/op_conversions/convert_mvn1_to_mvn6.hpp" +#include "transformations/decompose_mvn.hpp" #include "transformations/substitute_softsign.hpp" #include @@ -687,6 +689,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) { ngraph::pass::Manager manager; manager.register_pass(); fake_quantized = ngraph::op::util::has_op_with_type(graph); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/inference-engine/src/gna_plugin/transformations/decompose_mvn.cpp b/inference-engine/src/gna_plugin/transformations/decompose_mvn.cpp new file mode 100644 index 00000000000000..5a1f5ecccefe32 --- /dev/null +++ b/inference-engine/src/gna_plugin/transformations/decompose_mvn.cpp @@ -0,0 +1,265 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/decompose_mvn.hpp" + +#include +#include +#include +#include +#include "backend/gna_limitations.hpp" + + +using namespace GNAPluginNS; +using namespace ngraph; + +NGRAPH_RTTI_DEFINITION(DecomposeMVN, "DecomposeMVN", 0); + +struct MVNData { + size_t N; + size_t C; + size_t H; + size_t W; + size_t num_parts; + float eps; + op::MVNEpsMode eps_mode; + bool normalize_variance; + element::Type element_type; + std::string name; +}; + +template +static bool ValidateAxes(const std::shared_ptr axes_const, const size_t& mvn_shape_size) { + T axes_value; + size_t axes_vector_size; + + std::vector axes_const_vector = axes_const->cast_vector(); + IE_ASSERT(!axes_const_vector.empty()); + axes_value = axes_const_vector[0]; + axes_vector_size = axes_const_vector.size(); + + if (axes_vector_size != mvn_shape_size - 2) { + return false; + } + + // Verify supported first axes value + if (axes_value != 2 && axes_value != 2 - mvn_shape_size) + return false; + + return true; +} + +static bool GetVerifiedMVNData(const std::shared_ptr mvn, MVNData& mvn_data) { + const auto mvn_shape = mvn->get_output_shape(0); + auto mvn_shape_size = mvn_shape.size(); + + // Validate axes parameter + auto axes_const = std::dynamic_pointer_cast(mvn->input_value(1).get_node_shared_ptr()); + IE_ASSERT(axes_const); + auto element_type = axes_const->get_element_type(); + + if (!(element_type == element::Type_t::i64 ? ValidateAxes(axes_const, mvn_shape_size) : + ValidateAxes(axes_const, mvn_shape_size))) + return false; + + if (mvn_shape_size == 4) { + mvn_data.N = mvn_shape[0]; + mvn_data.C = mvn_shape[1]; + mvn_data.H = mvn_shape[2]; + mvn_data.W = mvn_shape[3]; + } else if (mvn_shape_size == 3) { + mvn_data.N = 1; + mvn_data.C = mvn_shape[0]; + mvn_data.H = mvn_shape[1]; + mvn_data.W = mvn_shape[2]; + } + + // Check if average must be split + mvn_data.num_parts = 1; + while (mvn_data.W / mvn_data.num_parts > GNALimitations::convFilterMaxSize) { + mvn_data.num_parts *= 2; + } + + // Abort if W is not divisible by power of 2 + if ((mvn_data.W / mvn_data.num_parts) * mvn_data.num_parts != mvn_data.W) { + return false; + } + + mvn_data.eps = mvn->get_eps(); + mvn_data.eps_mode = mvn->get_eps_mode(); + mvn_data.normalize_variance = mvn->get_normalize_variance(); + mvn_data.element_type = mvn->get_element_type(); + mvn_data.name = mvn->get_friendly_name(); + + return true; +} + +static std::shared_ptr NormalizeVariance(const std::shared_ptr mvn, const MVNData& mvn_data, + const std::shared_ptr& subtract_mean, const std::shared_ptr& avg_broadcast_const) { + // Prepare consts + auto combined_C_H = mvn_data.C * mvn_data.H; + + std::vector avg_weights(8 * mvn_data.W / mvn_data.num_parts, 1.0f / mvn_data.W); + auto avg_weights_const = opset8::Constant::create(mvn_data.element_type, Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, avg_weights); + std::vector eps_tensor(combined_C_H * mvn_data.W, mvn_data.eps); + auto eps_tensor_const = opset8::Constant::create(mvn_data.element_type, Shape{1, combined_C_H * mvn_data.W}, eps_tensor); + std::vector minus_half(combined_C_H * mvn_data.W, -0.5f); + auto minus_half_const = opset8::Constant::create(mvn_data.element_type, Shape{1, combined_C_H * mvn_data.W}, minus_half); + + // Calculate square of the difference between input and its mean + auto squared_diff = std::make_shared(subtract_mean, subtract_mean); + squared_diff->set_friendly_name(mvn_data.name + "_SqrDiff"); + + // Calculate sum of the squares + auto squared_diff_reshape = std::make_shared(squared_diff, + opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false); + auto transposed_input_3 = std::make_shared(squared_diff_reshape, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_3 = std::make_shared(transposed_input_3, avg_weights_const, + Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID); + transposed_avg_conv_3->set_friendly_name(mvn_data.name + "_Avg3"); + auto avg_conv_3 = std::make_shared(transposed_avg_conv_3, opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1})); + auto reshape_avg_conv_3 = std::make_shared(avg_conv_3, + opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false); + auto transposed_input_4 = std::make_shared(reshape_avg_conv_3, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_4 = std::make_shared(transposed_input_4, + avg_broadcast_const, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID); + transposed_avg_conv_4->set_friendly_name(mvn_data.name + "_Avg4"); + auto avg_conv_4 = std::make_shared(transposed_avg_conv_4, + opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1})); + auto reshape_avg_conv_4 = std::make_shared(avg_conv_4, + opset8::Constant::create(element::i32, Shape{2}, Shape{1ull, combined_C_H * mvn_data.W}), false); + std::shared_ptr inv_stdev; + + // Create normalization part of the graph + // We ignore inside/outside epsilon position here and always use inside, to get better accuracy + // even though the built-in MVN1 to MVN6 transformation enforces outside setting + + // Add epsilon inside the square root + auto add_epsilon = std::make_shared(eps_tensor_const, reshape_avg_conv_4); + + // Calculate square root and inversion + auto log_var_eps = std::make_shared(add_epsilon); + log_var_eps->set_friendly_name(mvn_data.name + "_LogVarEps"); + auto log_inv_stdev = std::make_shared(log_var_eps, minus_half_const); + log_inv_stdev->set_friendly_name(mvn_data.name + "_LogInvStdev"); + inv_stdev = std::make_shared(log_inv_stdev); + inv_stdev->set_friendly_name(mvn_data.name + "_InvStdev"); + copy_runtime_info(mvn, {add_epsilon, log_var_eps, log_inv_stdev, inv_stdev}); + + auto normalized_output = std::make_shared(subtract_mean, inv_stdev); + normalized_output->set_friendly_name(mvn_data.name + "_Output"); + + copy_runtime_info(mvn, {squared_diff, squared_diff_reshape, transposed_input_3, transposed_avg_conv_3, avg_conv_3, reshape_avg_conv_3, + transposed_input_4, transposed_avg_conv_4, avg_conv_4, reshape_avg_conv_4}); + + return normalized_output; +} + +static void Decompose(const std::shared_ptr mvn, const MVNData& mvn_data) { + // Prepare data + auto combined_C_H = mvn_data.C * mvn_data.H; + + std::vector neg_avg_weights(8 * mvn_data.W / mvn_data.num_parts, -1.0f / mvn_data.W); + auto neg_avg_weights_const = opset8::Constant::create(mvn_data.element_type, Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, neg_avg_weights); + + std::vector avg_broadcast(8 * mvn_data.W * mvn_data.num_parts, 0.0f); + for (size_t i = 0; i < mvn_data.W * mvn_data.num_parts; i++) { + avg_broadcast[i * 8] = 1.0f; + } + auto avg_broadcast_const = opset8::Constant::create(mvn_data.element_type, Shape{mvn_data.W, 8 * mvn_data.num_parts, 1, 1}, avg_broadcast); + + // Create average calculation part of the graph + // We assume C = 1 case (combined channels) + const auto input = mvn->input_value(0); + auto reshape = std::make_shared(input, + opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, 1ull, combined_C_H, mvn_data.W}), false); + auto input_4d = std::make_shared(reshape, + opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false); + auto input_2d = std::make_shared(reshape, + opset8::Constant::create(element::i32, Shape{2}, Shape{1ull, combined_C_H * mvn_data.W}), false); + auto transposed_input_1 = std::make_shared(input_4d, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_1 = std::make_shared(transposed_input_1, neg_avg_weights_const, + Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID); + transposed_avg_conv_1->set_friendly_name(mvn_data.name + "_Avg1"); + auto avg_conv_1 = std::make_shared(transposed_avg_conv_1, opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1})); + auto reshape_avg_conv_1 = std::make_shared(avg_conv_1, + opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false); + auto transposed_input_2 = std::make_shared(reshape_avg_conv_1, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_2 = std::make_shared(transposed_input_2, + avg_broadcast_const, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID); + transposed_avg_conv_2->set_friendly_name(mvn_data.name + "_Avg2"); + auto avg_conv_2 = std::make_shared(transposed_avg_conv_2, + opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1})); + auto avg_conv_2_2d = std::make_shared(avg_conv_2, + opset8::Constant::create(element::i32, Shape{2}, Shape{1ull, combined_C_H * mvn_data.W}), false); + auto subtract_mean = std::make_shared(input_2d, avg_conv_2_2d); + subtract_mean->set_friendly_name(mvn_data.name + "_SubMean"); + + std::shared_ptr mvn_output, pre_output = subtract_mean; + + // Normalize variance if required + if (mvn_data.normalize_variance) { + pre_output = NormalizeVariance(mvn, mvn_data, subtract_mean, avg_broadcast_const); + } + + // Reshape (combined channels) back to get the final output + if (mvn->get_output_shape(0).size() == 3) { + mvn_output = std::make_shared(pre_output, + opset8::Constant::create(element::i32, Shape{3}, {mvn_data.C, mvn_data.H, mvn_data.W}), false); + } else { + mvn_output = std::make_shared(pre_output, + opset8::Constant::create(element::i32, Shape{4}, {mvn_data.N, mvn_data.C, mvn_data.H, mvn_data.W}), false); + } + + copy_runtime_info(mvn, {reshape, input_4d, input_2d, transposed_input_1, transposed_avg_conv_1, avg_conv_1, reshape_avg_conv_1, + transposed_input_2, transposed_avg_conv_2, avg_conv_2, avg_conv_2_2d, subtract_mean, mvn_output}); + + // We need retain the MVN layer name, so its output can be used as a network result + replace_node(mvn, mvn_output); + mvn_output->set_friendly_name(mvn_data.name); +} + +static bool Convert(std::shared_ptr mvn_node) { + const auto mvn = std::dynamic_pointer_cast(mvn_node); + MVNData mvn_data; + + if (!GetVerifiedMVNData(mvn, mvn_data)) + return false; + + Decompose(mvn, mvn_data); + + return true; +} + +static std::function)> verify_rank_batch() { + return [=](Output output) -> bool { + // Only rank 3 and 4 and batch 1 are supported for now + auto rank = output.get_partial_shape().rank(); + if (rank != 3 && rank != 4) + return false; + + auto batch = (rank == 3 ? 1 : output.get_partial_shape()[0]); + if (batch != 1) + return false; + + return true; + }; +} + +DecomposeMVN::DecomposeMVN() { + MATCHER_SCOPE(DecomposeMVN); + + auto axes = pattern::wrap_type(); + auto mvn = pattern::wrap_type({pattern::any_input(), axes}, verify_rank_batch()); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + return Convert(pattern_map.at(mvn).get_node_shared_ptr()); + }; + + auto m = std::make_shared(mvn, matcher_name); + this->register_matcher(m, callback); +} diff --git a/inference-engine/src/gna_plugin/transformations/decompose_mvn.hpp b/inference-engine/src/gna_plugin/transformations/decompose_mvn.hpp new file mode 100644 index 00000000000000..455503a8822880 --- /dev/null +++ b/inference-engine/src/gna_plugin/transformations/decompose_mvn.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace GNAPluginNS { + +/** + * @brief Decompose MVN operation + * See official OpenVINO documentation for the MVN formula + * implemented partially by this decomposition: + * https://docs.openvino.ai/latest/openvino_docs_ops_normalization_MVN_6.html + * + */ +class DecomposeMVN : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + DecomposeMVN(); +}; + +} // namespace GNAPluginNS diff --git a/inference-engine/tests/functional/plugin/gna/pass_tests/decompose_mvn.cpp b/inference-engine/tests/functional/plugin/gna/pass_tests/decompose_mvn.cpp new file mode 100644 index 00000000000000..49e03dc74a9511 --- /dev/null +++ b/inference-engine/tests/functional/plugin/gna/pass_tests/decompose_mvn.cpp @@ -0,0 +1,157 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_test_utils/test_common.hpp" +#include +#include +#include +#include +#include +#include + +#include "transformations/init_node_info.hpp" +#include "ngraph_functions/builders.hpp" +#include "shared_test_classes/base/layer_test_utils.hpp" + + +using namespace ngraph; +using namespace opset8; + + +namespace LayerTestsDefinitions { + +typedef std::tuple< + bool, // Normalize variance + float, // Epsilon + op::MVNEpsMode, // Epsilon mode + bool, // Across channels + bool // MVN version, true = v6, false = v1 +> mvnSpecificParams; + +typedef std::tuple< + mvnSpecificParams, // MVN parameters + InferenceEngine::Precision, // Network Precision + std::string, // Target Device + std::map, // Configuration + InferenceEngine::SizeVector // Input shapes +> decomposeMVNParams; + +class DecomposeMVNTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + mvnSpecificParams mvnParams; + InferenceEngine::Precision netPrecision; + std::string targetDevice; + std::map configuration; + InferenceEngine::SizeVector inputShape; + std::tie(mvnParams, netPrecision, targetDevice, configuration, inputShape) = obj.param; + float eps; + op::MVNEpsMode epsMode; + bool normalizeVariance, acrossChannels, mvnVersion6; + std::tie(normalizeVariance, eps, epsMode, acrossChannels, mvnVersion6) = mvnParams; + + std::ostringstream result; + result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_"; + result << "NV=" << normalizeVariance << "_"; + result << "eps=" << eps << "_"; + result << "mode=" << static_cast(epsMode) << "_"; + result << "AC=" << acrossChannels << "_"; + result << "version=" << mvnVersion6 << "_"; + result << "netPRC=" << netPrecision.name() << "_"; + result << "targetDevice=" << targetDevice << "_"; + for (auto const& configItem : configuration) { + result << "_configItem=" << configItem.first << "_" << configItem.second; + } + return result.str(); + } + +protected: + void SetUp() override { + threshold = 0.2f; + mvnSpecificParams mvnParams; + InferenceEngine::Precision netPrecision; + InferenceEngine::SizeVector inputShape; + std::tie(mvnParams, netPrecision, targetDevice, configuration, inputShape) = this->GetParam(); + float eps; + op::MVNEpsMode epsMode; + bool normalizeVariance, acrossChannels, mvnVersion6; + std::tie(normalizeVariance, eps, epsMode, acrossChannels, mvnVersion6) = mvnParams; + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto input = builder::makeParams(ngPrc, {inputShape}); + InferenceEngine::SizeVector axes(inputShape.size() - 2); + std::iota(axes.begin(), axes.end(), 2); + std::shared_ptr mvn; + + if (mvnVersion6) { + const auto axesConst = std::make_shared(element::i64, Shape{axes.size()}, axes); + mvn = std::make_shared(input[0], axesConst, normalizeVariance, eps, epsMode); + } else { + mvn = std::make_shared(input[0], acrossChannels, normalizeVariance); + } + + auto result = std::make_shared(mvn); + function = std::make_shared(ResultVector{result}, ParameterVector{input}); + } +}; + +TEST_P(DecomposeMVNTest, CompareWithRefs) { + Run(); +} + +const std::vector netPrecisions = { + InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16 +}; + +const std::vector> configs = { + { + {"GNA_DEVICE_MODE", "GNA_SW_FP32"}, + {"GNA_SCALE_FACTOR_0", "1"} + } +}; + +const std::vector> inputs = {{1, 1, 5, 300}, {1, 6, 256}}; +const std::vector normalizeVariance = {true}; +const std::vector eps = {1.0e-09f}; +const std::vector epsMode = {op::MVNEpsMode::INSIDE_SQRT}; +const std::vector accrossChannels = {false}; + +const auto mvnParams_v6 = ::testing::Combine( + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(eps), + ::testing::ValuesIn(epsMode), + ::testing::Values(false), + ::testing::Values(true) +); + +const auto mvnParams_v1 = ::testing::Combine( + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(eps), + ::testing::ValuesIn(epsMode), + ::testing::ValuesIn(accrossChannels), + ::testing::Values(false) +); + +INSTANTIATE_TEST_SUITE_P(smoke_DecomposeMVN_v6, DecomposeMVNTest, + ::testing::Combine( + mvnParams_v6, + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GNA), + ::testing::ValuesIn(configs), + ::testing::ValuesIn(inputs)), + DecomposeMVNTest::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_DecomposeMVN_v1, DecomposeMVNTest, + ::testing::Combine( + mvnParams_v1, + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GNA), + ::testing::ValuesIn(configs), + ::testing::ValuesIn(inputs)), + DecomposeMVNTest::getTestCaseName); + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_decompose_mvn.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_decompose_mvn.cpp new file mode 100644 index 00000000000000..74de547532eff6 --- /dev/null +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_decompose_mvn.cpp @@ -0,0 +1,253 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include + +#include "transformations/op_conversions/convert_mvn1_to_mvn6.hpp" +#include "transformations/decompose_mvn.hpp" +#include "common_test_utils/ngraph_test_utils.hpp" +#include +#include +#include "backend/gna_limitations.hpp" + +namespace decomposeMVN { + +typedef std::tuple< + ngraph::Shape, // Input shape + bool, // Normalize variance + float, // Epsilon + ngraph::op::MVNEpsMode, // Epsilon mode + InferenceEngine::SizeVector, // Axes tensor + bool, // Across channels + bool // MVN version, true = v6, false = v1 +> decomposeMVNParams; + +struct MVNParams { + size_t N; + size_t C; + size_t H; + size_t W; + size_t num_parts; + float eps; + ngraph::op::MVNEpsMode eps_mode; + bool normalize_variance; +}; + +static std::shared_ptr NormalizeVariance(const MVNParams& mvn_data, const std::shared_ptr& subtract_mean, + const std::shared_ptr& avg_broadcast_const) { + // Prepare consts + auto combined_C_H = mvn_data.C * mvn_data.H; + + std::vector avg_weights(8 * mvn_data.W / mvn_data.num_parts, 1.0f / mvn_data.W); + auto avg_weights_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, avg_weights); + std::vector eps_tensor(combined_C_H * mvn_data.W, mvn_data.eps); + auto eps_tensor_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1, combined_C_H * mvn_data.W}, eps_tensor); + std::vector minus_half(combined_C_H * mvn_data.W, -0.5f); + auto minus_half_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1, combined_C_H * mvn_data.W}, minus_half); + + // Calculate square of the difference between input and its mean + auto squared_diff = std::make_shared(subtract_mean, subtract_mean); + squared_diff->set_friendly_name("MvnSqrDiff"); + + // Calculate sum of the squares + auto squared_diff_reshape = std::make_shared(squared_diff, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, + ngraph::Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false); + auto transposed_input_3 = std::make_shared(squared_diff_reshape, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_3 = std::make_shared(transposed_input_3, avg_weights_const, + ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{1, 1}, ngraph::op::PadType::VALID); + transposed_avg_conv_3->set_friendly_name("MvnAvg3"); + auto avg_conv_3 = std::make_shared(transposed_avg_conv_3, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1})); + auto reshape_avg_conv_3 = std::make_shared(avg_conv_3, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, + ngraph::Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false); + auto transposed_input_4 = std::make_shared(reshape_avg_conv_3, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_4 = std::make_shared(transposed_input_4, + avg_broadcast_const, ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, + ngraph::Strides{1, 1}, ngraph::op::PadType::VALID); + transposed_avg_conv_4->set_friendly_name("MvnAvg4"); + auto avg_conv_4 = std::make_shared(transposed_avg_conv_4, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1})); + auto reshape_avg_conv_4 = std::make_shared(avg_conv_4, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{2}, ngraph::Shape{1ull, combined_C_H * mvn_data.W}), false); + std::shared_ptr inv_stdev; + + // Create normalization part of the graph + // We ignore inside/outside epsilon position here and always use inside, to get better accuracy + // even though the built-in MVN1 to MVN6 transformation enforces outside setting + + // Add epsilon inside the square root + auto add_epsilon = std::make_shared(eps_tensor_const, reshape_avg_conv_4); + + // Calculate square root and inversion + auto log_var_eps = std::make_shared(add_epsilon); + log_var_eps->set_friendly_name("MvnLogVarEps"); + auto log_inv_stdev = std::make_shared(log_var_eps, minus_half_const); + log_inv_stdev->set_friendly_name("MvnLogInvStdev"); + inv_stdev = std::make_shared(log_inv_stdev); + inv_stdev->set_friendly_name("MvnInvStdev"); + + auto normalized_output = std::make_shared(subtract_mean, inv_stdev); + normalized_output->set_friendly_name("MvnOutput"); + + return normalized_output; +} + +static std::shared_ptr Decompose(const std::shared_ptr input_node, const MVNParams& mvn_data) { + // Prepare data + auto combined_C_H = mvn_data.C * mvn_data.H; + + std::vector neg_avg_weights(8 * mvn_data.W / mvn_data.num_parts, -1.0f / mvn_data.W); + auto neg_avg_weights_const = ngraph::opset8::Constant::create(ngraph::element::i32, + ngraph::Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, neg_avg_weights); + + std::vector avg_broadcast(8 * mvn_data.W * mvn_data.num_parts, 0.0f); + for (size_t i = 0; i < mvn_data.W * mvn_data.num_parts; i++) { + avg_broadcast[i * 8] = 1.0f; + } + auto avg_broadcast_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{mvn_data.W, 8 * mvn_data.num_parts, 1, 1}, avg_broadcast); + + // Create average calculation part of the graph + // We assume C = 1 case (combined channels) + auto reshape = std::make_shared(input_node, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, + ngraph::Shape{mvn_data.N, 1ull, combined_C_H, mvn_data.W}), false); + auto input_4d = std::make_shared(reshape, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, + ngraph::Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false); + auto input_2d = std::make_shared(reshape, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{2}, + ngraph::Shape{1ull, combined_C_H * mvn_data.W}), false); + auto transposed_input_1 = std::make_shared(input_4d, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_1 = std::make_shared(transposed_input_1, neg_avg_weights_const, + ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{1, 1}, ngraph::op::PadType::VALID); + transposed_avg_conv_1->set_friendly_name("MvnAvg1"); + auto avg_conv_1 = std::make_shared(transposed_avg_conv_1, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1})); + auto reshape_avg_conv_1 = std::make_shared(avg_conv_1, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, + ngraph::Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false); + auto transposed_input_2 = std::make_shared(reshape_avg_conv_1, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2})); + auto transposed_avg_conv_2 = std::make_shared(transposed_input_2, + avg_broadcast_const, ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, + ngraph::Strides{1, 1}, ngraph::op::PadType::VALID); + transposed_avg_conv_2->set_friendly_name("MvnAvg2"); + auto avg_conv_2 = std::make_shared(transposed_avg_conv_2, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1})); + auto avg_conv_2_2d = std::make_shared(avg_conv_2, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{2}, ngraph::Shape{1ull, combined_C_H * mvn_data.W}), false); + auto subtract_mean = std::make_shared(input_2d, avg_conv_2_2d); + subtract_mean->set_friendly_name("MvnSubMean"); + + std::shared_ptr mvn_output, pre_output = subtract_mean; + + // Normalize variance if required + if (mvn_data.normalize_variance) { + pre_output = NormalizeVariance(mvn_data, subtract_mean, avg_broadcast_const); + } + + // Reshape (combined channels) back to get the final output + if (input_node->get_output_shape(0).size() == 3) { + mvn_output = std::make_shared(pre_output, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{3}, {mvn_data.C, mvn_data.H, mvn_data.W}), false); + } else { + mvn_output = std::make_shared(pre_output, + ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {mvn_data.N, mvn_data.C, mvn_data.H, mvn_data.W}), false); + } + + return std::make_shared(mvn_output); +} + +std::shared_ptr getReferenceFunction(const ngraph::Shape& input_shape, const bool& normalize_variance, + const float& eps, const ngraph::op::MVNEpsMode& eps_mode, const InferenceEngine::SizeVector& axes) { + MVNParams mvn_data; + auto mvn_shape_size = input_shape.size(); + + if (mvn_shape_size == 4) { + mvn_data.N = input_shape[0]; + mvn_data.C = input_shape[1]; + mvn_data.H = input_shape[2]; + mvn_data.W = input_shape[3]; + } else if (mvn_shape_size == 3) { + mvn_data.N = 1; + mvn_data.C = input_shape[0]; + mvn_data.H = input_shape[1]; + mvn_data.W = input_shape[2]; + } + + mvn_data.eps = eps; + mvn_data.eps_mode = eps_mode; + mvn_data.normalize_variance = normalize_variance; + mvn_data.num_parts = 1; + + while (mvn_data.W / mvn_data.num_parts > GNAPluginNS::GNALimitations::convFilterMaxSize) { + mvn_data.num_parts *= 2; + } + + // Create decomposed reference function + auto input_params = std::make_shared(ngraph::element::i32, input_shape); + std::shared_ptr result = Decompose(input_params, mvn_data); + + return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params}); +} + +std::shared_ptr getInitialFunction(const ngraph::Shape& input_shape, const bool& normalize_variance, + const float& eps, const ngraph::op::MVNEpsMode& eps_mode, const InferenceEngine::SizeVector& axes, + const bool& across_channels, const bool& mvn_version_6) { + auto input_params = std::make_shared(ngraph::element::i32, input_shape); + std::shared_ptr mvn; + + if (mvn_version_6) { + const auto axesConst = std::make_shared(ngraph::element::i32, ngraph::Shape{axes.size()}, axes); + mvn = std::make_shared(input_params, axesConst, normalize_variance, eps, eps_mode); + } else { + mvn = std::make_shared(input_params, across_channels, normalize_variance, eps); + } + + auto result = std::make_shared(mvn); + + return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params}); +} + +} // namespace decomposeMVN + +// --------------------------------------------------------------------------------------------------------------------- + +namespace { + + void execute_test(std::shared_ptr function, std::shared_ptr reference_function) { + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(function); + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, reference_function); + ASSERT_TRUE(result.valid); +} + +} // namespace + +TEST(TransformationTests, DecomposeMVNTest) { + for (auto mvn_version_6 : {true, false}) { + for (auto normalize_variance : {true, false}) { + execute_test(decomposeMVN::getInitialFunction(ngraph::Shape{1, 1, 5, 300}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT, + InferenceEngine::SizeVector{2, 1}, false, mvn_version_6), + decomposeMVN::getReferenceFunction(ngraph::Shape{1, 1, 5, 300}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT, + InferenceEngine::SizeVector{2, 1})); + execute_test(decomposeMVN::getInitialFunction(ngraph::Shape{1, 6, 256}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT, + InferenceEngine::SizeVector{2}, false, mvn_version_6), + decomposeMVN::getReferenceFunction(ngraph::Shape{1, 6, 256}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT, + InferenceEngine::SizeVector{2})); + } + } +} +