Skip to content

Commit

Permalink
[luci/import] Use direct tensors from Round to Sin (Samsung#7995)
Browse files Browse the repository at this point in the history
This commit replaces tensors() to native_tensors() in all builders from CircleRound to CircleSin.

ONE-DCO-1.0-Signed-off-by: Maksim Bronnikov <[email protected]>
  • Loading branch information
m-bronnikov authored Nov 14, 2021
1 parent 970fe4e commit bf0efe8
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 30 deletions.
12 changes: 7 additions & 5 deletions compiler/luci/import/src/Nodes/CircleRound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
const auto &tensors = args.reader.tensors();
const auto &tensor_in = tensors.at(inputs.at(0));
const auto &tensor_out = tensors.at(outputs[0]);
const auto tensors = args.reader.native_tensors();
const auto tensor_in = tensors.at(inputs.at(0));
const auto tensor_out = tensors.at(outputs[0]);
assert(tensor_in != nullptr);
assert(tensor_out != nullptr);

switch (tensor_in->type)
switch (tensor_in->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
Expand All @@ -49,7 +51,7 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
return false;
}

if (tensor_out->type != tensor_in->type)
if (tensor_out->type() != tensor_in->type())
return false;

return true;
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleRsqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
switch (tensor->type)
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);
switch (tensor->type())
{
case circle::TensorType_UINT8:
case circle::TensorType_INT16:
Expand Down
9 changes: 5 additions & 4 deletions compiler/luci/import/src/Nodes/CircleScatterNd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
// indices must have the same type as shape
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

if (tensors[inputs.at(0)]->type != tensors[inputs.at(2)]->type)
assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(2)] != nullptr);
if (tensors[inputs.at(0)]->type() != tensors[inputs.at(2)]->type())
return false;

// indices must be either int32 or int64
if (tensors[inputs.at(0)]->type != circle::TensorType_INT32 &&
tensors[inputs.at(0)]->type != circle::TensorType_INT64)
if (tensors[inputs.at(0)]->type() != circle::TensorType_INT32 &&
tensors[inputs.at(0)]->type() != circle::TensorType_INT64)
return false;

return true;
Expand Down
15 changes: 9 additions & 6 deletions compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in = tensors.at(inputs.at(0));
const auto &tensor_out = tensors.at(outputs[0]);
const auto &tensor_ids = tensors.at(inputs.at(1));
const auto tensors = args.reader.native_tensors();
const auto tensor_in = tensors.at(inputs.at(0));
const auto tensor_out = tensors.at(outputs[0]);
const auto tensor_ids = tensors.at(inputs.at(1));
assert(tensor_in != nullptr);
assert(tensor_out != nullptr);
assert(tensor_ids != nullptr);

switch (tensor_ids->type)
switch (tensor_ids->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
Expand All @@ -44,7 +47,7 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
return false;
}

if (tensor_out->type != tensor_in->type)
if (tensor_out->type() != tensor_in->type())
{
return false;
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const
return false;

const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
if (tensor->type != circle::TensorType_BOOL)
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);
if (tensor->type() != circle::TensorType_BOOL)
return false;
// TODO check dtypes for input 1, 2

Expand Down
14 changes: 8 additions & 6 deletions compiler/luci/import/src/Nodes/CircleSelectV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const
return false;

const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &condition = tensors.at(inputs.at(0));
if (condition->type != circle::TensorType_BOOL)
const auto tensors = args.reader.native_tensors();
const auto condition = tensors.at(inputs.at(0));
assert(condition != nullptr);
if (condition->type() != circle::TensorType_BOOL)
return false;

const auto &t = tensors.at(inputs.at(1));
const auto &e = tensors.at(inputs.at(2));
if (t->type != e->type)
const auto t = tensors.at(inputs.at(1));
const auto e = tensors.at(inputs.at(2));
assert(t != nullptr && e != nullptr);
if (t->type() != e->type())
return false;

return true;
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleSin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
// input type check
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
switch (tensor->type)
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);
switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
Expand Down

0 comments on commit bf0efe8

Please sign in to comment.