Skip to content

Commit

Permalink
[luci/import] Use direct tensors from NonMaxSuppression to OneHot (Sa…
Browse files Browse the repository at this point in the history
…msung#7993)

This commit replaces tensors() to native_tensors() in all builders from CircleNonMaxSuppressionV4 to CircleOneHot.

ONE-DCO-1.0-Signed-off-by: Maksim Bronnikov <[email protected]>
  • Loading branch information
m-bronnikov authored Nov 12, 2021
1 parent ab63ed1 commit 5e55b73
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 30 deletions.
22 changes: 14 additions & 8 deletions compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,26 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 2)
return false;

const auto &tensors = args.reader.tensors();
const auto &boxes_tensor = tensors.at(inputs[0]);
if (boxes_tensor->shape.size() != 2)
const auto tensors = args.reader.native_tensors();
const auto boxes_tensor = tensors.at(inputs[0]);
assert(boxes_tensor != nullptr);
const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
if (boxes_tensor_shape.size() != 2)
return false;
if (boxes_tensor->shape.at(1) != 4)
if (boxes_tensor_shape.at(1) != 4)
return false;
if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
assert(tensors.at(inputs[1]) != nullptr);
if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;

if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
assert(tensors.at(inputs[2]) != nullptr);
if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
assert(tensors.at(inputs[3]) != nullptr);
if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
assert(tensors.at(inputs[4]) != nullptr);
if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;

return true;
Expand Down
25 changes: 16 additions & 9 deletions compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,29 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 3)
return false;

const auto &tensors = args.reader.tensors();
const auto &boxes_tensor = tensors.at(inputs[0]);
if (boxes_tensor->shape.size() != 2)
const auto tensors = args.reader.native_tensors();
const auto boxes_tensor = tensors.at(inputs[0]);
assert(boxes_tensor != nullptr);
const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
if (boxes_tensor_shape.size() != 2)
return false;
if (boxes_tensor->shape.at(1) != 4)
if (boxes_tensor_shape.at(1) != 4)
return false;
if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
assert(tensors.at(inputs[1]) != nullptr);
if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;

if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
assert(tensors.at(inputs[2]) != nullptr);
if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
assert(tensors.at(inputs[3]) != nullptr);
if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
assert(tensors.at(inputs[4]) != nullptr);
if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;
if (tensors.at(inputs[5])->type != circle::TensorType_FLOAT32)
assert(tensors.at(inputs[5]) != nullptr);
if (tensors.at(inputs[5])->type() != circle::TensorType_FLOAT32)
return false;

return true;
Expand Down
8 changes: 5 additions & 3 deletions compiler/luci/import/src/Nodes/CircleNotEqual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto tensors = args.reader.native_tensors();

if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}

return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
assert(tensors[outputs[0]] != nullptr);
return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}

CircleNode *CircleNotEqualGraphBuilder::build_node(const circle::OperatorT &,
Expand Down
24 changes: 14 additions & 10 deletions compiler/luci/import/src/Nodes/CircleOneHot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,25 @@ bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const

const auto &inputs = args.op.inputs;
const auto *options = args.op.builtin_options.AsOneHotOptions();
const auto &tensors = args.reader.tensors();
const auto &indices = tensors.at(inputs.at(0));
const auto &depth = tensors.at(inputs.at(1));
const auto &on_value = tensors.at(inputs.at(2));
const auto &off_value = tensors.at(inputs.at(3));
const auto tensors = args.reader.native_tensors();
const auto indices = tensors.at(inputs.at(0));
const auto depth = tensors.at(inputs.at(1));
const auto on_value = tensors.at(inputs.at(2));
const auto off_value = tensors.at(inputs.at(3));
assert(indices != nullptr);
assert(depth != nullptr);
assert(on_value != nullptr);
assert(off_value != nullptr);

if (options->axis < -1 || options->axis > static_cast<int32_t>(indices->shape.size()))
if (options->axis < -1 || options->axis > static_cast<int32_t>(wrap(indices->shape()).size()))
return false;
if (depth->shape.size() != 0)
if (wrap(depth->shape()).size() != 0)
return false;
if (on_value->shape.size() != 0)
if (wrap(on_value->shape()).size() != 0)
return false;
if (off_value->shape.size() != 0)
if (wrap(off_value->shape()).size() != 0)
return false;
if (on_value->type != off_value->type)
if (on_value->type() != off_value->type())
return false;

return true;
Expand Down

0 comments on commit 5e55b73

Please sign in to comment.