Skip to content

Commit

Permalink
[CPU] ScatterUpdate: dynamic shapes support (openvinotoolkit#8581)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev authored Nov 25, 2021
1 parent 629cf83 commit 66b75e6
Show file tree
Hide file tree
Showing 5 changed files with 556 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ using namespace InferenceEngine;

bool MKLDNNScatterUpdateNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
if (isDynamicNgraphNode(op)) {
errorMessage = "Doesn't support op with dynamic shapes";
return false;
}
const auto scatterElemUpd = std::dynamic_pointer_cast<const ngraph::opset3::ScatterElementsUpdate>(op);
const auto scatterUpd = std::dynamic_pointer_cast<const ngraph::opset3::ScatterUpdate>(op);
const auto scatterNdUpd = std::dynamic_pointer_cast<const ngraph::opset4::ScatterNDUpdate>(op);
if (scatterElemUpd == nullptr && scatterUpd == nullptr && scatterNdUpd == nullptr) {
auto scatterElemUpd = ngraph::as_type_ptr<const ngraph::opset3::ScatterElementsUpdate>(op);
auto scatterUpd = ngraph::as_type_ptr<const ngraph::opset3::ScatterUpdate>(op);
auto scatterNdUpd = ngraph::as_type_ptr<const ngraph::opset4::ScatterNDUpdate>(op);
if (!scatterElemUpd && !scatterUpd && !scatterNdUpd) {
const std::string opType = op->get_type_name();
errorMessage = "Only opset" + opType == "ScatterNDUpdate" ? "4 " : "3 " + opType + " operation is supported";
return false;
Expand Down Expand Up @@ -81,10 +77,10 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

auto srcDataDim = getInputShapeAtPort(DATA_ID).getStaticDims();
auto indicesDim = getInputShapeAtPort(INDICES_ID).getStaticDims();
auto updateDim = getInputShapeAtPort(UPDATE_ID).getStaticDims();
auto dstDataDim = getOutputShapeAtPort(0).getStaticDims();
const auto& srcDataDim = getInputShapeAtPort(DATA_ID).getDims();
const auto& indicesDim = getInputShapeAtPort(INDICES_ID).getDims();
const auto& updateDim = getInputShapeAtPort(UPDATE_ID).getDims();
const auto& dstDataDim = getOutputShapeAtPort(0).getDims();

size_t srcRank = srcDataDim.size();
size_t indicesRank = indicesDim.size();
Expand All @@ -96,9 +92,9 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
IE_THROW() << errorPrefix << " should have same rank for input and output tensor";
} else {
for (size_t r = 0; r < srcRank; r++) {
if (srcDataDim[r] != dstDataDim[r]) {
if (!dimsEqualWeak(srcDataDim[r], dstDataDim[r])) {
IE_THROW() << errorPrefix << " should have same shape for input and output tensor. The input shape is "
<< srcDataDim[r] << ", while output shape is " << dstDataDim[r] << " for " << r << "th dimension";
<< srcDataDim[r] << ", while output shape is " << dstDataDim[r] << " for " << r << "th dimension";
}
}
}
Expand All @@ -111,26 +107,28 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
break;
}
case ScatterUpdateMode::ScatterNDUpdate: {
size_t k = indicesDim[indicesRank - 1];
if (k > srcRank) {
IE_THROW() << errorPrefix << "' do not have an correct indices' last dimension value, "
<< "which should be smaller than or equal to input tensor rank";
}
if (indicesDim[indicesRank - 1] != Shape::UNDEFINED_DIM) {
size_t k = indicesDim[indicesRank - 1];
if (k > srcRank) {
IE_THROW() << errorPrefix << "' do not have an correct indices' last dimension value, "
<< "which should be smaller than or equal to input tensor rank";
}

SizeVector expectUpdateShape = {};
size_t tupleRank = indicesRank - 1;
for (size_t ri = 0; ri < tupleRank; ri++) {
expectUpdateShape.push_back(indicesDim[ri]);
}
for (size_t rd = k; rd < srcRank; rd++) {
expectUpdateShape.push_back(srcDataDim[rd]);
}
if (expectUpdateShape.size() != updateRank) {
IE_THROW() << errorPrefix << " do not have matched tensor rank relationship for input, indices and update";
}
for (size_t ru = 0; ru < updateRank; ru++) {
if (updateDim[ru] != expectUpdateShape[ru]) {
IE_THROW() << errorPrefix << " do not have matched tensor shape relationship for input, indices and update";
SizeVector expectUpdateShape = {};
size_t tupleRank = indicesRank - 1;
for (size_t ri = 0; ri < tupleRank; ri++) {
expectUpdateShape.push_back(indicesDim[ri]);
}
for (size_t rd = k; rd < srcRank; rd++) {
expectUpdateShape.push_back(srcDataDim[rd]);
}
if (expectUpdateShape.size() != updateRank) {
IE_THROW() << errorPrefix << " do not have matched tensor rank relationship for input, indices and update";
}
for (size_t ru = 0; ru < updateRank; ru++) {
if (!dimsEqualWeak(updateDim[ru], expectUpdateShape[ru])) {
IE_THROW() << errorPrefix << " do not have matched tensor shape relationship for input, indices and update";
}
}
}
break;
Expand All @@ -140,7 +138,7 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
IE_THROW() << errorPrefix << " do not have the same tensor rank for input, indices and update";
}
for (size_t ri = 0; ri < indicesRank; ri++) {
if (indicesDim[ri] != updateDim[ri]) {
if (!dimsEqualWeak(indicesDim[ri], updateDim[ri])) {
IE_THROW() << errorPrefix << " do not have the same tensor shape for indices and update";
}
}
Expand Down Expand Up @@ -178,8 +176,8 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
dataPrec = getOriginalInputPrecisionAtPort(DATA_ID);
dataSize = dataPrec.size();

bool canBeInplace = getParentEdgeAt(DATA_ID)->getParent()->getChildEdges().size() == 1 &&
!getParentEdgeAt(DATA_ID)->getParent()->isConstant();
bool canBeInplace = !isDynamicNode() && getParentEdgeAt(DATA_ID)->getParent()->getChildEdges().size() == 1 &&
!getParentEdgeAt(DATA_ID)->getParent()->isConstant();

NodeConfig config;
config.dynBatchSupport = false;
Expand Down Expand Up @@ -226,6 +224,18 @@ void MKLDNNScatterUpdateNode::createPrimitive() {
IE_THROW() << errorPrefix << " did not allocate update memory";
if (getSelectedPrimitiveDescriptor() == nullptr)
IE_THROW() << errorPrefix << " did not set preferable primitive descriptor";

if (inputShapesDefined()) {
updateLastInputDims();
}
}

bool MKLDNNScatterUpdateNode::needPrepareParams() const {
return false;
}

void MKLDNNScatterUpdateNode::executeDynamicImpl(mkldnn::stream strm) {
return execute(strm);
}

int64_t MKLDNNScatterUpdateNode::getIndicesValue(uint8_t *indices, size_t offset) {
Expand All @@ -245,7 +255,7 @@ int64_t MKLDNNScatterUpdateNode::getIndicesValue(uint8_t *indices, size_t offset
// shapeND: n c d h w
// blockND: ncdhw cdhw dhw hw w 1
// index : 0 1 2 3 4 5
static std::vector<size_t> getBlockND(const SizeVector& shape) {
static std::vector<size_t> getBlockND(const VectorDims& shape) {
size_t shapeRank = shape.size();
std::vector<size_t> blockND(shapeRank + 1, 1);
for (int i = shapeRank - 1; i >= 0; i--) {
Expand All @@ -265,8 +275,8 @@ void MKLDNNScatterUpdateNode::execute(mkldnn::stream strm) {
uint8_t *indicesPtr = reinterpret_cast<uint8_t*>(indicesMemPtr->GetPtr());
uint8_t *updatePtr = reinterpret_cast<uint8_t*>(updateMemPtr->GetPtr());

SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getMemory().getStaticDims();
const auto& srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
const auto& indicesDim = getParentEdgeAt(INDICES_ID)->getMemory().getStaticDims();
size_t srcRank = srcDataDim.size();
int axis = 0;
if (axisRelaxed) {
Expand Down Expand Up @@ -362,9 +372,9 @@ void MKLDNNScatterUpdateNode::execute(mkldnn::stream strm) {
// and indices tensor of shape [i_0, i_1, ..., i_k].
// Updates tensor shape should be [d_0, d_1, ... d_(axis - 1), i_0, i_1, ..., i_k, d_(axis + 1), ..., d_n].
void MKLDNNScatterUpdateNode::scatterUpdate(uint8_t *indices, uint8_t *update, int axis, uint8_t *dstData) {
SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getMemory().getStaticDims();
SizeVector updateDim = getParentEdgeAt(UPDATE_ID)->getMemory().getStaticDims();
const auto& srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
const auto& indicesDim = getParentEdgeAt(INDICES_ID)->getMemory().getStaticDims();
const auto& updateDim = getParentEdgeAt(UPDATE_ID)->getMemory().getStaticDims();
size_t indicesRank = indicesDim.size();

std::vector<size_t> srcBlockND = getBlockND(srcDataDim);
Expand Down Expand Up @@ -395,8 +405,8 @@ void MKLDNNScatterUpdateNode::scatterUpdate(uint8_t *indices, uint8_t *update, i
// k is indices.shape[-1] and should not be greater than rank of input, q is rank of indicies.
// updates is a (q-1)-dimension tensor of replacement-slice-values
void MKLDNNScatterUpdateNode::scatterNDUpdate(uint8_t *indices, uint8_t *update, uint8_t *dstData) {
SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getMemory().getStaticDims();
const auto& srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
const auto& indicesDim = getParentEdgeAt(INDICES_ID)->getMemory().getStaticDims();
size_t indicesRank = indicesDim.size();

std::vector<size_t> srcBlockND = getBlockND(srcDataDim);
Expand Down Expand Up @@ -425,9 +435,8 @@ void MKLDNNScatterUpdateNode::scatterNDUpdate(uint8_t *indices, uint8_t *update,
// output[i][indices[i][j][k]][k] = updates[i][j][k] if axis = 1,
// output[i][j][indices[i][j][k]] = updates[i][j][k] if axis = 2.
void MKLDNNScatterUpdateNode::scatterElementsUpdate(uint8_t *indices, uint8_t *update, int axis, uint8_t *dstData) {
SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
SizeVector updateDim = getParentEdgeAt(UPDATE_ID)->getMemory().getStaticDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getMemory().getStaticDims();
const auto& srcDataDim = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims();
const auto& updateDim = getParentEdgeAt(UPDATE_ID)->getMemory().getStaticDims();
size_t updateRank = updateDim.size();

std::vector<size_t> srcBlockND = getBlockND(srcDataDim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class MKLDNNScatterUpdateNode : public MKLDNNNode {
return false;
}

bool needPrepareParams() const override;
void executeDynamicImpl(mkldnn::stream strm) override;

static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;

private:
Expand All @@ -40,10 +43,7 @@ class MKLDNNScatterUpdateNode : public MKLDNNNode {
inline int64_t getIndicesValue(uint8_t *indices, size_t offset);

ScatterUpdateMode scatterUpdateMode = ScatterUpdateMode::ScatterUpdate;
const size_t DATA_ID = 0;
const size_t INDICES_ID = 1;
const size_t UPDATE_ID = 2;
const size_t AXIS_ID = 3;
enum { DATA_ID, INDICES_ID, UPDATE_ID, AXIS_ID };

// if axis can be set other than default 0.
bool axisRelaxed = false;
Expand All @@ -53,4 +53,4 @@ class MKLDNNScatterUpdateNode : public MKLDNNNode {
std::string errorPrefix;
};

} // namespace MKLDNNPlugin
} // namespace MKLDNNPlugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "functional_test_utils/ov_tensor_utils.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ngraph_functions/builders.hpp"

using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov::test;

namespace CPULayerTestsDefinitions {
using ScatterNDUpdateShapes = std::vector<InputShape>;
using IndicesValues = std::vector<std::int64_t>;

struct ScatterNDUpdateLayerParams {
ScatterNDUpdateShapes inputShapes;
IndicesValues indicesValues;
};

using scatterUpdateParams = std::tuple<
ScatterNDUpdateLayerParams,
ElementType, // input precision
ElementType>; // indices precision

class ScatterNDUpdateLayerCPUTest : public testing::WithParamInterface<scatterUpdateParams>, public SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(testing::TestParamInfo<scatterUpdateParams> obj) {
ScatterNDUpdateLayerParams scatterParams;
ElementType inputPrecision;
ElementType idxPrecision;
std::tie(scatterParams, inputPrecision, idxPrecision) = obj.param;
const auto inputShapes = scatterParams.inputShapes;
const auto indicesValues = scatterParams.indicesValues;

std::ostringstream result;
result << inputPrecision << "_IS=";
for (const auto& shape : inputShapes) {
result << CommonTestUtils::partialShape2str({ shape.first }) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
for (const auto& targetShape : shape.second) {
result << CommonTestUtils::vec2str(targetShape) << "_";
}
result << ")_";
}
result << "indices_values=" << CommonTestUtils::vec2str(indicesValues) << "_idx_precision=" << idxPrecision;
return result.str();
}

protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (int i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
const auto& inputPrecision = funcInput.get_element_type();
const auto& targetShape = targetInputStaticShapes[i];
ov::runtime::Tensor tensor;
if (i == 1) {
tensor = ov::runtime::Tensor{ inputPrecision, targetShape };
const auto indicesVals = std::get<0>(this->GetParam()).indicesValues;
if (inputPrecision == ElementType::i32) {
auto data = tensor.data<std::int32_t>();
for (size_t i = 0; i < tensor.get_size(); ++i) {
data[i] = static_cast<std::int32_t>(indicesVals[i]);
}
} else if (inputPrecision == ElementType::i64) {
auto data = tensor.data<std::int64_t>();
for (size_t i = 0; i < tensor.get_size(); ++i) {
data[i] = indicesVals[i];
}
} else {
IE_THROW() << "GatherNDUpdate. Unsupported indices precision: " << inputPrecision;
}
} else {
if (inputPrecision.is_real()) {
tensor = ov::test::utils::create_and_fill_tensor(inputPrecision, targetShape, 10, 0, 1000);
} else {
tensor = ov::test::utils::create_and_fill_tensor(inputPrecision, targetShape);
}
}
inputs.insert({ funcInput.get_node_shared_ptr(), tensor });
}
}

void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
ScatterNDUpdateLayerParams scatterParams;
ElementType inputPrecision;
ElementType idxPrecision;
std::tie(scatterParams, inputPrecision, idxPrecision) = this->GetParam();
const auto inputShapes = scatterParams.inputShapes;
const auto indicesValues = scatterParams.indicesValues;

init_input_shapes(inputShapes);
selectedType = makeSelectedTypeStr("unknown", inputPrecision);

auto dataParams = ngraph::builder::makeDynamicParams(inputPrecision, { inputDynamicShapes[0], inputDynamicShapes[2] });
auto indicesParam = ngraph::builder::makeDynamicParams(idxPrecision, { inputDynamicShapes[1] });
dataParams[0]->set_friendly_name("Param_1");
indicesParam[0]->set_friendly_name("Param_2");
dataParams[1]->set_friendly_name("Param_3");

auto scatter = std::make_shared<ngraph::opset4::ScatterNDUpdate>(dataParams[0], indicesParam[0], dataParams[1]);

ngraph::ParameterVector allParams{ dataParams[0], indicesParam[0], dataParams[1] };
function = makeNgraphFunction(inputPrecision, allParams, scatter, "ScatterNDUpdateLayerCPUTest");
}
};

TEST_P(ScatterNDUpdateLayerCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
CheckPluginRelatedResults(executableNetwork, "ScatterUpdate");
}

const std::vector<ScatterNDUpdateLayerParams> scatterParams = {
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{-1, -1, -1, -1, -1}, {{10, 9, 10, 9, 10}, {10, 1, 11, 2, 5}, {10, 15, 8, 1, 7}}},
{{2, 2, 1}, {{2, 2, 1}, {2, 2, 1}, {2, 2, 1}}},
{{-1, -1, -1, -1, -1, -1}, {{2, 2, 9, 10, 9, 10}, {2, 2, 1, 11, 2, 5}, {2, 2, 15, 8, 1, 7}}},
},
IndicesValues{ 5, 6, 2, 8 }
},
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{-1, -1, -1, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{2, 3}, {{2, 3}, {2, 3}, {2, 3}}},
{{-1, -1}, {{2, 11}, {2, 12}, {2, 8}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2 }
},
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{{3, 10}, -1, {3, 9}, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{2, 3}, {{2, 3}, {2, 3}, {2, 3}}},
{{{2, 4}, -1}, {{2, 11}, {2, 12}, {2, 8}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2 }
},
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{{3, 10}, {4, 11}, {3, 9}, {8, 15}}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{2, 3}, {{2, 3}, {2, 3}, {2, 3}}},
{{{2, 4}, -1}, {{2, 11}, {2, 12}, {2, 8}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2 }
},
};

const std::vector<ElementType> inputPrecisions = {
ElementType::f32,
ElementType::i32,
};

const std::vector<ElementType> constantPrecisions = {
ElementType::i32,
ElementType::i64,
};

INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs, ScatterNDUpdateLayerCPUTest,
::testing::Combine(
::testing::ValuesIn(scatterParams),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(constantPrecisions)),
ScatterNDUpdateLayerCPUTest::getTestCaseName);
} // namespace CPULayerTestsDefinitions
Loading

0 comments on commit 66b75e6

Please sign in to comment.