Skip to content

Commit

Permalink
[onert] Add Shape inference for Gather (Samsung#1838)
Browse files Browse the repository at this point in the history
* [onert] Add Shape inference for Gather

This adds shape inference for Gather op

* Add condition when indices shape is dynamic
* Apply function name rule. GatherShapes -> gatherShapes

Signed-off-by: YiHyunjin <[email protected]>
  • Loading branch information
YiHyunJin authored Jun 5, 2020
1 parent 9fa5f0d commit 57fdecb
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 2 deletions.
4 changes: 2 additions & 2 deletions runtime/onert/core/include/util/ShapeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class StaticInferer : public ir::OperationVisitor
void visit(const ir::operation::Exp &op);
void visit(const ir::operation::ExpandDims &op);
void visit(const ir::operation::FullyConnected &op);
// TODO write op starting from G
void visit(const ir::operation::Gather &op);
void visit(const ir::operation::If &op);
void visit(const ir::operation::Log &op);
void visit(const ir::operation::Logistic &op);
Expand Down Expand Up @@ -195,7 +195,7 @@ class DynamicInferer : public ir::OperationVisitor
void visit(const ir::operation::Exp &op);
void visit(const ir::operation::ExpandDims &op);
void visit(const ir::operation::FullyConnected &op);
// TODO write op starting from G
void visit(const ir::operation::Gather &op);
void visit(const ir::operation::Log &op);
void visit(const ir::operation::Logistic &op);
void visit(const ir::operation::Mul &op);
Expand Down
104 changes: 104 additions & 0 deletions runtime/onert/core/src/util/shapeinf/Gather.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "util/ShapeInference.h"

namespace onert
{
namespace shape_inference
{

ir::Shape gatherShapes(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
int rank)
{
ir::Shape out_shape;
const int indices_rank = indices_shape.rank();
for (int idx = 0; idx < rank; ++idx)
{
if (idx == axis)
{
for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
{
out_shape.append(indices_shape.dim(indices_idx));
}
}
else
{
out_shape.append(input_shape.dim(idx));
}
}

return out_shape;
}

void StaticInferer::visit(const ir::operation::Gather &op)
{
const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
const auto &input = _operands.at(input_idx);

// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
ir::Operand &output = _operands.at(output_idx);

const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
const auto &indices = _operands.at(indices_idx);

// if input is dynamic, output also becomes dynamic
if (input.info().isDynamic() || indices.info().isDynamic())
{
output.info().setDynamic();
return;
}

const auto rank = input.info().shape().rank();
const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);

assert(0 <= axis && axis < rank);

// re-sizing output shape
ir::Shape new_shape = gatherShapes(input.info().shape(), indices.info().shape(), axis, rank);
output.info().shape(new_shape);
}

void DynamicInferer::visit(const ir::operation::Gather &op)
{
const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
const auto &input = _tensor_registry->getITensor(input_idx);
auto input_shape = getShape(input.get());

const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
const auto &indices = _tensor_registry->getITensor(indices_idx);
auto indices_shape = getShape(indices.get());

if (!(input->is_dynamic()) && !(indices->is_dynamic()))
return;

const auto rank = input_shape.rank();
const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);

assert(0 <= axis && axis < rank);

ir::Shape new_shape = gatherShapes(input_shape, indices_shape, axis, rank);

auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);

_dynamic_tensor_manager->applyShape(output_ind, new_shape);
assert(output->buffer() != nullptr);
}

} // namespace shape_inference
} // namespace onert
1 change: 1 addition & 0 deletions tests/nnapi/nnapi_gtest.skip.aarch64-linux.acl_cl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ GeneratedTests.fill_ex_1D_float
GeneratedTests.fill_ex_4D_float
GeneratedTests.fully_connected_dynamic_nnfw
GeneratedTests.fully_connected_float_2_weights_as_inputs
GeneratedTests.gather_dynamic_nnfw
GeneratedTests.gather_float16
GeneratedTests.gather_float16_2
GeneratedTests.gather_float16_3
Expand Down
1 change: 1 addition & 0 deletions tests/nnapi/nnapi_gtest.skip.aarch64-linux.acl_neon
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ GeneratedTests.fully_connected_float_4d_simple_relaxed
GeneratedTests.fully_connected_float_large_relaxed
GeneratedTests.fully_connected_float_relaxed
GeneratedTests.fully_connected_hybrid_1_nnfw
GeneratedTests.gather_dynamic_nnfw
GeneratedTests.gather_float16
GeneratedTests.gather_float16_2
GeneratedTests.gather_float16_3
Expand Down
1 change: 1 addition & 0 deletions tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_cl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ GeneratedTests.fill_ex_1D_float
GeneratedTests.fill_ex_4D_float
GeneratedTests.fully_connected_dynamic_nnfw
GeneratedTests.fully_connected_float_2_weights_as_inputs
GeneratedTests.gather_dynamic_nnfw
GeneratedTests.gather_float16
GeneratedTests.gather_float16_2
GeneratedTests.gather_float16_3
Expand Down
1 change: 1 addition & 0 deletions tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ GeneratedTests.fully_connected_float_2_relaxed
GeneratedTests.fully_connected_float_4d_simple_relaxed
GeneratedTests.fully_connected_float_large_relaxed
GeneratedTests.fully_connected_float_relaxed
GeneratedTests.gather_dynamic_nnfw
GeneratedTests.gather_float16
GeneratedTests.gather_float16_2
GeneratedTests.gather_float16_3
Expand Down
1 change: 1 addition & 0 deletions tests/nnapi/nnapi_gtest.skip.noarch.interp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ GeneratedTests.fully_connected_quant8_2
GeneratedTests.fully_connected_quant8_large
GeneratedTests.fully_connected_quant8_large_weights_as_inputs
GeneratedTests.fully_connected_quant8_weights_as_inputs
GeneratedTests.gather_dynamic_nnfw
GeneratedTests.gather_float16
GeneratedTests.gather_float16_2
GeneratedTests.gather_float16_3
Expand Down
89 changes: 89 additions & 0 deletions tests/nnapi/specs/V1_2/gather_dynamic_nnfw.mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#
# Copyright (C) 2018 The Android Open Source Project
# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# refer to tanh_v1_dynamic.mod.py about the structore

# This adds reshape as the first op in a model and
# returns output of reshape, which is dynamic tensor

'''
Testing Gather op when the input1 is dynamic.
input1 [1, 2, 3, 4] shape [4] (value of shape will be [1, 2, 3, 4])
| |
+-------------+
|
Reshape (added by DynamicInputGenerator since it generates its output to be dynamic)
|
| axis = 0 input2 [2]
| | |
+-------------+-------------+
|
|
| dynamic tensor at compilation time but the shape will be [2, 2, 3, 4] at execution time
|
Gather
|
output (dynamic tensor, [2, 2, 3, 4] at execution time)
'''
import dynamic_tensor

model = Model()

input1_shape = [1, 2, 3, 4]

dynamic_layer = dynamic_tensor.DynamicInputGenerator(model, input1_shape, "TENSOR_FLOAT32")

node_input = dynamic_layer.getTestNodeInput()

input2 = Input("intput2", "TENSOR_INT32", "{2}")
axis = Int32Scalar("axis", 0)
output = Output("output", "TENSOR_FLOAT32", "{2,2,3,4}")
model = model.Operation("GATHER", node_input, axis, input2).To(output)

input1_data = [1.123456789123456789, 2.123456789123456789, 3.123456789123456789, 4.123456789123456789,
5.123456789123456789, 6.123456789123456789, 7.123456789123456789, 8.123456789123456789,
9.123456789123456789, 10.123456789123456789, 11.123456789123456789, 12.123456789123456789,
13.123456789123456789, 14.123456789123456789, 15.123456789123456789, 16.123456789123456789,
17.123456789123456789, 18.123456789123456789, 19.123456789123456789, 20.123456789123456789,
21.123456789123456789, 22.123456789123456789, 23.123456789123456789, 24.123456789123456789
]

input0 = {
dynamic_layer.getModelInput() : input1_data, # input 1
dynamic_layer.getShapeInput() : input1_shape,

input2 : [0, 0] # input 2
}

output0 = {
output: # output
[1.123456789123456789, 2.123456789123456789, 3.123456789123456789, 4.123456789123456789,
5.123456789123456789, 6.123456789123456789, 7.123456789123456789, 8.123456789123456789,
9.123456789123456789, 10.123456789123456789, 11.123456789123456789, 12.123456789123456789,
13.123456789123456789, 14.123456789123456789, 15.123456789123456789, 16.123456789123456789,
17.123456789123456789, 18.123456789123456789, 19.123456789123456789, 20.123456789123456789,
21.123456789123456789, 22.123456789123456789, 23.123456789123456789, 24.123456789123456789,
1.123456789123456789, 2.123456789123456789, 3.123456789123456789, 4.123456789123456789,
5.123456789123456789, 6.123456789123456789, 7.123456789123456789, 8.123456789123456789,
9.123456789123456789, 10.123456789123456789, 11.123456789123456789, 12.123456789123456789,
13.123456789123456789, 14.123456789123456789, 15.123456789123456789, 16.123456789123456789,
17.123456789123456789, 18.123456789123456789, 19.123456789123456789, 20.123456789123456789,
21.123456789123456789, 22.123456789123456789, 23.123456789123456789, 24.123456789123456789]
}

# Instantiate an example
Example((input0, output0))

0 comments on commit 57fdecb

Please sign in to comment.