Skip to content

Commit

Permalink
[luci/import] Use direct tensors from Less to MatrixSetDiag (Samsung#…
Browse files Browse the repository at this point in the history
…7992)

This commit replaces tensors() to native_tensors() in all builders from CircleLess to CircleMatrixSetDiag.

ONE-DCO-1.0-Signed-off-by: Maksim Bronnikov <[email protected]>
  • Loading branch information
m-bronnikov authored Nov 12, 2021
1 parent 0b4e476 commit ab63ed1
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 28 deletions.
13 changes: 8 additions & 5 deletions compiler/luci/import/src/Nodes/CircleLess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);

switch (tensor->type)
switch (tensor->type())
{
case circle::TensorType_FLOAT32:
case circle::TensorType_FLOAT64:
Expand All @@ -48,12 +49,14 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
return false;
}

if (tensors[inputs.at(1)]->type != tensor->type)
assert(tensors[inputs.at(1)] != nullptr);
if (tensors[inputs.at(1)]->type() != tensor->type())
{
return false;
}

return tensors[outputs[0]]->type == circle::TensorType_BOOL;
assert(tensors[outputs[0]] != nullptr);
return tensors[outputs[0]]->type() == circle::TensorType_BOOL;
}

CircleNode *CircleLessGraphBuilder::build_node(const circle::OperatorT &,
Expand Down
8 changes: 5 additions & 3 deletions compiler/luci/import/src/Nodes/CircleLessEqual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}

return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
assert(tensors[outputs[0]] != nullptr);
return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}

CircleNode *CircleLessEqualGraphBuilder::build_node(const circle::OperatorT &,
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleLog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const
// input type check
// Must be one of bfloat16, half, float32, float64, complex64, complex128.
// Currently circle supports half(float16), float32, float64, complex64.
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
switch (tensor->type)
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);
switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const

// Only BOOL type is allowed for inputs
const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();
for (auto input : inputs)
{
const auto &tensor = tensors.at(input);
if (tensor->type != circle::TensorType::TensorType_BOOL)
const auto tensor = tensors.at(input);
assert(tensor != nullptr);
if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}

Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const

// Only BOOL type is allowed for the input
const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
if (tensor->type != circle::TensorType::TensorType_BOOL)
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));
assert(tensor != nullptr);
if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;

return true;
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const

// Only BOOL type is allowed for inputs
const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();
for (auto input : inputs)
{
const auto &tensor = tensors.at(input);
if (tensor->type != circle::TensorType::TensorType_BOOL)
const auto tensor = tensors.at(input);
assert(tensor != nullptr);
if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}

Expand Down
5 changes: 3 additions & 2 deletions compiler/luci/import/src/Nodes/CircleLogistic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
const auto tensors = args.reader.native_tensors();
assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;

return true;
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));

if (tensors[outputs[0]]->type != tensor->type)
assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
if (tensors[outputs[0]]->type() != tensor->type())
return false;

return true;
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
const auto tensors = args.reader.native_tensors();
const auto tensor = tensors.at(inputs.at(0));

if (tensors[outputs[0]]->type != tensor->type)
assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
if (tensors[outputs[0]]->type() != tensor->type())
return false;

return true;
Expand Down

0 comments on commit ab63ed1

Please sign in to comment.