Skip to content

Commit

Permalink
[onert-micro] Support Fill kernel (Samsung#10708)
Browse files Browse the repository at this point in the history
* [onert-micro] Support Fill kernel

This pr adds Fill kernel with unit test

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>

* add licence line

---------

Co-authored-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem and Artem Balyshev authored May 22, 2023
1 parent ed9fa9e commit 5915096
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 245 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_TEST_MODELS_FILL_KERNEL_H
#define LUCI_INTERPRETER_TEST_MODELS_FILL_KERNEL_H

#include "luci_interpreter/test_models/TestDataBase.h"

namespace luci_interpreter
{
namespace test_kernel
{
namespace fill_kernel
{
/*
* Fill Kernel:
*
* Dims(3, 2) Input(scalar)
* \ /
* Fill
* |
* |
* Output(3, 2)
*/
const unsigned char test_kernel_model_circle[] = {
0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x48, 0x00, 0x00, 0x00, 0x70, 0x01, 0x00, 0x00, 0x8c, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x34, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x8c, 0xff, 0xff, 0xff,
0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00,
0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00,
0x64, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x44, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x6c, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00,
0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00,
0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x64, 0x69, 0x6d, 0x73, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00,
0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x5e, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x5e, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69,
0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};

const std::vector<float> input_data = {1.1f};

const std::vector<float> reference_output_data = {1.1f, 1.1f, 1.1f, 1.1f, 1.1f, 1.1f};
} // namespace fill_kernel

template <typename T> class TestDataFillKernel : public TestDataBase<T>
{
public:
TestDataFillKernel()
{
_input_data = fill_kernel::input_data;
_reference_output_data = fill_kernel::reference_output_data;
_test_kernel_model_circle = fill_kernel::test_kernel_model_circle;
}

~TestDataFillKernel() override = default;

const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; }

const std::vector<T> &get_input_data_by_index(int i) override final
{
switch (i)
{
case 0:
return _input_data;
default:
assert(false && "Wrong input index");
}
}

const std::vector<T> &get_output_data_by_index(int i) override final
{
assert(i == 0);
return _reference_output_data;
}

protected:
std::vector<T> _input_data;
std::vector<T> _reference_output_data;
const unsigned char *_test_kernel_model_circle;
};

} // namespace test_kernel
} // namespace luci_interpreter

#endif // LUCI_INTERPRETER_TEST_MODELS_FILL_KERNEL_H
1 change: 1 addition & 0 deletions onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D)
REGISTER_KERNEL(LOGISTIC, Logistic)
REGISTER_KERNEL(GATHER, Gather)
REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)
REGISTER_KERNEL(FILL, Fill)
REGISTER_KERNEL(PACK, Pack)
REGISTER_KERNEL(RESHAPE, Reshape)
REGISTER_KERNEL(REDUCE_PROD, ReduceCommon)
Expand Down
115 changes: 42 additions & 73 deletions onert-micro/luci-interpreter/src/kernels/Fill.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,103 +15,71 @@
* limitations under the License.
*/

#include "kernels/Fill.h"
#include "Builders.h"
#include "TISOKernel.h"
#include "kernels/Utils.h"
#include "PALFill.h"

namespace luci_interpreter
{
namespace kernels
namespace
{

Fill::Fill(const Tensor *dims, const Tensor *value, Tensor *output)
: Kernel({dims, value}, {output})
template <typename T> void fillImpl(const size_t flat_size, const T *value_data, T *output_data)
{
}

template <typename T> void Fill::configureShape()
{
const auto dims_data = getTensorData<T>(dims());
Shape output_shape(dims()->shape().dim(0));

for (int i = 0; i < output_shape.num_dims(); ++i)
for (int i = 0; i < flat_size; ++i)
{
T data = dims_data[i];
if (data < 0)
assert(false && "Fill dimensions must be >= 0");

output_shape.dim(i) = data;
output_data[i] = *value_data;
}
// TODO: enable it only if kernel with dynamic shapes
output()->resize(output_shape);
}

void Fill::configure()
} // namespace

void configure_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
const auto dims_shape = dims()->shape();
const auto value_shape = value()->shape();
kernels::TISOKernel kernel(cur_op, runtime_graph);
// value tensor must be a scalar or has one element
LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input2()) == 0 or
Tensor::num_elements(kernel.input2()) == 1);
// value and output type must match
LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) ==
Tensor::element_type(kernel.output()));
}

// Make sure the 1st input tensor is 1-D
LUCI_INTERPRETER_CHECK(dims_shape.num_dims() == 1);
void execute_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
bool)
{
kernels::TISOKernel kernel(cur_op, runtime_graph);

// Make sure the 1st input tensor is int32 or int64
LUCI_INTERPRETER_CHECK(dims()->element_type() == DataType::S32 or
dims()->element_type() == DataType::S64);
const circle::Tensor *value = kernel.input2();
const circle::Tensor *output = kernel.output();

// Make sure the 2nd input tensor is a scalar
LUCI_INTERPRETER_CHECK(value_shape.num_dims() == 0)
kernels::TISOData tiso_data = kernel.readData();
const uint8_t *value_data = tiso_data.input2_data;
uint8_t *output_data = tiso_data.output_data;

// Check zero point and scale for S16 and S8
if (value()->element_type() == DataType::S16 or value()->element_type() == DataType::S8)
{
LUCI_INTERPRETER_CHECK(value()->scale() == output()->scale());
LUCI_INTERPRETER_CHECK(value()->zero_point() == output()->zero_point());
const size_t flat_size = Tensor::num_elements(output);

if (value()->element_type() == DataType::S16)
LUCI_INTERPRETER_CHECK(value()->zero_point() == 0);
}
// Resize output
switch (dims()->element_type())
switch (Tensor::element_type(value))
{
case DataType::S32:
configureShape<int32_t>();
break;
case DataType::S64:
configureShape<int64_t>();
break;
default:
assert(false && "Unsupported type.");
}
}

void Fill::execute() const
{
switch (output()->element_type())
{
case DataType::S8:
tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int8_t>(value()),
getTensorShape(output()), getTensorData<int8_t>(output()));
break;
case DataType::S16:
tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int16_t>(value()),
getTensorShape(output()), getTensorData<int16_t>(output()));
#ifndef DIS_FLOAT
case DataType::FLOAT32:
fillImpl<float>(flat_size, kernels::getTensorData<float>(value_data),
kernels::getTensorData<float>(output_data));
break;
#endif // DIS_FLOAT
case DataType::S32:
tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int32_t>(value()),
getTensorShape(output()), getTensorData<int32_t>(output()));
fillImpl<int32_t>(flat_size, kernels::getTensorData<int32_t>(value_data),
kernels::getTensorData<int32_t>(output_data));
break;
case DataType::S64:
tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int64_t>(value()),
getTensorShape(output()), getTensorData<int64_t>(output()));
break;
case DataType::FLOAT32:
tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<float>(value()),
getTensorShape(output()), getTensorData<float>(output()));
#ifndef DIS_QUANT
case DataType::U8:
fillImpl<uint8_t>(flat_size, kernels::getTensorData<uint8_t>(value_data),
kernels::getTensorData<uint8_t>(output_data));
break;
#endif // DIS_QUANT
default:
assert(false && "Unsupported type.");
assert(false && "Not impl yet");
}
}

} // namespace kernels
} // namespace luci_interpreter
47 changes: 0 additions & 47 deletions onert-micro/luci-interpreter/src/kernels/Fill.h

This file was deleted.

Loading

0 comments on commit 5915096

Please sign in to comment.