Skip to content

Commit

Permalink
[GNA] Allow 2d reshape of the first diagonal layer (openvinotoolkit#6115
Browse files Browse the repository at this point in the history
)
  • Loading branch information
elilobanova authored Jun 16, 2021
1 parent 2c775d4 commit 5c55d39
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 12 deletions.
7 changes: 0 additions & 7 deletions inference-engine/src/gna_plugin/gna_groups.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ inline bool HasTo2DReshapeData(InferenceEngine::CNNLayerPtr layer) {
if (!GNAPluginNS::LayerInfo(layer).isSyntheticScaleShift())
return false;

// Don't reshape the first dnn layer since it breaks groups recognition
auto prevLayer = InferenceEngine::CNNNetPrevLayerSkipCertain(layer, 0, [](InferenceEngine::CNNLayerPtr ptr) {
return LayerInfo(ptr).isNonValuesChangable();
});
IE_ASSERT(prevLayer != nullptr);
if (LayerInfo(prevLayer).isInput()) return false;

// Don't reshape diagonallayers with bias connection
return !GNAPluginNS::LayerInfo(getCreatorLayer(layer->insData.front().lock()).lock()).has32BOutput();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ static void insertDiagonalLayerBetween(InferenceEngine::CNNLayerPtr prevLayer,
return LayerInfo(ptr).isNonValuesChangable();
});
IE_ASSERT(inputLayer != nullptr);
size_t weightsSize = (LayerInfo(prevLayer).has32BOutput() || LayerInfo(inputLayer).isInput()) ?
nextLayer->outData[0]->getDims().back() :
Get2DReshapedData(nextLayer->outData[0], 8)->getDims()[1];
size_t weightsSize = LayerInfo(prevLayer).has32BOutput() ? nextLayer->outData[0]->getDims().back() :
Get2DReshapedData(nextLayer->outData[0], 8)->getDims()[1];
std::vector<float> weightsValues(weightsSize, fillValue);
IE_ASSERT(diagLayer != nullptr);
diagLayer->_weights = make_shared_blob<float>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,7 @@ INSTANTIATE_TEST_CASE_P(smoke_ConvertMatmulToPointwiseConvTest, ConvertMatmulToP
::testing::ValuesIn(inputShape)),
ConvertMatmulToPointwiseConv::getTestCaseName);

// Issue 55662
INSTANTIATE_TEST_CASE_P(DISABLED_smoke_ConvertMatmulToPointwiseConvTest, ConvertMatmulToPointwiseConvWithFq,
INSTANTIATE_TEST_CASE_P(smoke_ConvertMatmulToPointwiseConvTest, ConvertMatmulToPointwiseConvWithFq,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
Expand Down

0 comments on commit 5c55d39

Please sign in to comment.