Skip to content

Commit

Permalink
[luci/import] Use direct-access tensors from Abs to If (Samsung#7957)
Browse files Browse the repository at this point in the history
This commit replaces tensors() to native_tensors() in all builders from CircleAbs to CircleIf.

ONE-DCO-1.0-Signed-off-by: Maksim Bronnikov <[email protected]>
  • Loading branch information
m-bronnikov authored Nov 10, 2021
1 parent 4338bad commit 4f7c1d1
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 56 deletions.
12 changes: 7 additions & 5 deletions compiler/luci/import/src/Nodes/CircleCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
const auto *options = args.op.builtin_options.AsCastOptions();
if (options != nullptr)
{
const auto &tensors = args.reader.tensors();
const circle::TensorT &output_tensor = *tensors[outputs[0]];
const auto tensors = args.reader.native_tensors();
const auto output_tensor = tensors[outputs[0]];
assert(output_tensor != nullptr);
auto name = tensor_name(output_tensor);

const auto &tensor_in = tensors.at(inputs.at(0));
if (tensor_in->type != options->in_data_type)
const auto tensor_in = tensors.at(inputs.at(0));
assert(tensor_in != nullptr);
if (tensor_in->type() != options->in_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
Expand All @@ -57,7 +59,7 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
const auto &tensor_out = tensors.at(outputs[0]);
if (tensor_out->type != options->out_data_type)
if (tensor_out->type() != options->out_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
Expand Down
17 changes: 9 additions & 8 deletions compiler/luci/import/src/Nodes/CircleConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
namespace
{

std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect)
{
uint32_t seq = 0;
for (auto &v : vect)
for (const auto &v : vect)
{
if (seq)
os << ", ";
Expand Down Expand Up @@ -112,11 +112,12 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind

auto graph = context->graph();
auto reader = context->reader();
const auto &tensors = reader->tensors();
const circle::TensorT &const_tensor = *tensors[tensor_index];
const auto tensors = reader->native_tensors();
const auto const_tensor = tensors[tensor_index];
assert(const_tensor != nullptr);

const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor.buffer]->data;
std::vector<int32_t> const_dims = const_tensor.shape; // in NHWC
const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor->buffer()]->data;
const auto const_dims = wrap(const_tensor->shape()); // in NHWC
if (const_dims.size() == 0 && buffer.empty())
{
// unknown shape tensor and scalar tensor
Expand Down Expand Up @@ -150,7 +151,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind
<< const_dims << std::endl;
if (num_elements > 0)
{
switch (luci_datatype(const_tensor.type))
switch (luci_datatype(const_tensor->type()))
{
case loco::DataType::FLOAT32:
copy_data<loco::DataType::FLOAT32>(buffer, num_elements, const_node);
Expand Down Expand Up @@ -186,7 +187,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind

default:
throw oops::UserExn("Unsupported tensor type",
circle::EnumNameTensorType(const_tensor.type));
circle::EnumNameTensorType(const_tensor->type()));
}
}

Expand Down
5 changes: 3 additions & 2 deletions compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const
const auto &outputs = args.op.outputs;

const auto *options = args.op.builtin_options.AsDepthToSpaceOptions();
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();
assert(tensors[outputs[0]] != nullptr && tensors[inputs.at(0)] != nullptr);

if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type)
if (tensors[outputs[0]]->type() != tensors[inputs.at(0)]->type())
{
return false;
}
Expand Down
12 changes: 7 additions & 5 deletions compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,21 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.outputs.size() != 1)
return false;

const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

// input shape
const auto &input = tensors.at(args.op.inputs.at(0));
const auto &input_shape = input->shape;
const auto input = tensors.at(args.op.inputs.at(0));
assert(input != nullptr);
const auto input_shape = wrap(input->shape());

// input shape must be rank 4
if (input_shape.size() != 4)
return false;

// filter shape
const auto &filter = tensors.at(args.op.inputs.at(1));
const auto &filter_shape = filter->shape;
const auto filter = tensors.at(args.op.inputs.at(1));
assert(filter != nullptr);
const auto filter_shape = wrap(filter->shape());

// filter shape must be rank 4
if (filter_shape.size() != 4)
Expand Down
10 changes: 6 additions & 4 deletions compiler/luci/import/src/Nodes/CircleElu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;

const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);

switch (tensor->type)
switch (tensor->type())
{
case circle::TensorType_FLOAT64:
break;
Expand All @@ -48,7 +49,8 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
return false;
}

if (tensors[outputs[0]]->type != tensor->type)
assert(tensors[outputs[0]] != nullptr);
if (tensors[outputs[0]]->type() != tensor->type())
return false;

return true;
Expand Down
5 changes: 3 additions & 2 deletions compiler/luci/import/src/Nodes/CircleEqual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const
return false;

const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type;
assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
return tensors[inputs.at(0)]->type() == tensors[inputs.at(1)]->type();
}

CircleNode *CircleEqualGraphBuilder::build_node(const circle::OperatorT &,
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleExp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
// input type check
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
switch (tensor->type)
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);
switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
Expand Down
5 changes: 3 additions & 2 deletions compiler/luci/import/src/Nodes/CircleExpandDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const
return false;

const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

return tensors[inputs.at(1)]->type == circle::TensorType_INT32;
assert(tensors[inputs.at(1)] != nullptr);
return tensors[inputs.at(1)]->type() == circle::TensorType_INT32;
}

CircleNode *CircleExpandDimsGraphBuilder::build_node(const circle::OperatorT &,
Expand Down
17 changes: 10 additions & 7 deletions compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@ bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in_0 = tensors.at(inputs.at(0));
const auto &tensor_in_1 = tensors.at(inputs.at(1));
const auto &tensor_out = tensors.at(outputs[0]);

if (tensor_in_0->type != tensor_in_1->type)
const auto tensors = args.reader.native_tensors();
const auto tensor_in_0 = tensors.at(inputs.at(0));
const auto tensor_in_1 = tensors.at(inputs.at(1));
const auto tensor_out = tensors.at(outputs[0]);
assert(tensor_in_0 != nullptr);
assert(tensor_in_1 != nullptr);
assert(tensor_out != nullptr);

if (tensor_in_0->type() != tensor_in_1->type())
return false;

if (tensor_out->type != tensor_in_1->type)
if (tensor_out->type() != tensor_in_1->type())
{
return false;
}
Expand Down
9 changes: 5 additions & 4 deletions compiler/luci/import/src/Nodes/CircleFloorMod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const
return false;

const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in_0 = tensors.at(inputs.at(0));
const auto &tensor_in_1 = tensors.at(inputs.at(1));
if (tensor_in_0->type != tensor_in_1->type)
const auto tensors = args.reader.native_tensors();
const auto tensor_in_0 = tensors.at(inputs.at(0));
const auto tensor_in_1 = tensors.at(inputs.at(1));
assert(tensor_in_0 != nullptr && tensor_in_1 != nullptr);
if (tensor_in_0->type() != tensor_in_1->type())
return false;

// TODO dtype check
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleGatherNd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const
return false;

const auto &inputs = args.op.inputs;
auto &indices_tensor = args.reader.tensors()[inputs.at(1)];
auto indices_tensor = args.reader.native_tensors()[inputs.at(1)];
assert(indices_tensor != nullptr);

if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 ||
indices_tensor->type == circle::TensorType::TensorType_INT64))
if (!(indices_tensor->type() == circle::TensorType::TensorType_INT32 ||
indices_tensor->type() == circle::TensorType::TensorType_INT64))
{
return false;
}
Expand Down
10 changes: 6 additions & 4 deletions compiler/luci/import/src/Nodes/CircleGreater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,19 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
return false;

// NOTE: real models do have output dtype NOT BOOL
if (tensors[outputs[0]]->type != circle::TensorType_BOOL)
assert(tensors[outputs[0]] != nullptr);
if (tensors[outputs[0]]->type() != circle::TensorType_BOOL)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
const circle::TensorT &output_tensor = *tensors[outputs[0]];
const auto output_tensor = tensors[outputs[0]];
auto name = tensor_name(output_tensor);
WARN(l) << "Warning: import Greater(" << name << ") output dtype is not boolean";
}
Expand Down
8 changes: 5 additions & 3 deletions compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}

return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
assert(tensors[outputs[0]] != nullptr);
return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}

CircleNode *CircleGreaterEqualGraphBuilder::build_node(const circle::OperatorT &,
Expand Down
9 changes: 5 additions & 4 deletions compiler/luci/import/src/Nodes/CircleIf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const
return false;

// input 0 should be BOOL type
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
if (tensor->type != circle::TensorType_BOOL)
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);
if (tensor->type() != circle::TensorType_BOOL)
return false;

const auto &shape = tensor->shape;
const auto shape = wrap(tensor->shape());
if (shape.size() != 1 && shape.size() != 0)
return false;

Expand Down

0 comments on commit 4f7c1d1

Please sign in to comment.