Skip to content

Commit

Permalink
[onert] OperationValidator for ArgMax (Samsung#5116)
Browse files Browse the repository at this point in the history
* [onert] OperationValidator for ArgMax

- Enable OperationValidator for ArgMax operation
- Add negative test for ArgMax
- Enable ArgMax test for int64 output

Signed-off-by: Hyeongseok Oh <[email protected]>

* Check axis type and fix test axis type
  • Loading branch information
hseok-oh authored Nov 24, 2020
1 parent a71ecb8 commit 17d2995
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
14 changes: 14 additions & 0 deletions runtime/onert/core/src/ir/OperationValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ void OperationValidator::visit(const operation::AddN &node)
}
}

void OperationValidator::visit(const operation::ArgMax &node)
{
const auto input_index(node.getInputs().at(operation::ArgMax::Input::INPUT));
const auto axis_index(node.getInputs().at(operation::ArgMax::Input::AXIS));
const auto output_index(node.getOutputs().at(0));
const auto output_type = node.param().output_type;

OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8,
DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64}));
OP_REQUIRES(isValidType(output_index, output_type));
}

void OperationValidator::visit(const operation::BatchMatMul &node)
{
const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/src/ir/OperationValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class OperationValidator : public OperationVisitor

public:
void visit(const operation::AddN &node) override;
void visit(const operation::ArgMax &node) override;
void visit(const operation::BatchMatMul &node) override;
void visit(const operation::BatchToSpaceND &node) override;
void visit(const operation::BinaryArithmetic &node) override;
Expand Down
74 changes: 73 additions & 1 deletion tests/nnfw_api/src/one_op_tests/ArgMax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TEST_F(GenModelTest, OneOp_ArgMax_Int64_AxisToConst)

_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->addTestCase(TestCaseData{}.addInput<float>({1, 4, 2, 3}).addOutput<int64_t>({1, 0}));
_context->setBackends({"acl_cl"});
_context->setBackends({"acl_cl", "cpu"});

SUCCEED();
}
Expand Down Expand Up @@ -113,3 +113,75 @@ TEST_F(GenModelTest, neg_OneOp_ArgMax_InvalidAxis1)

SUCCEED();
}

TEST_F(GenModelTest, neg_OneOp_ArgMax_InType)
{
CircleGen cgen;
const auto output_type = circle::TensorType::TensorType_INT32;
std::vector<int32_t> axis_data{4};
uint32_t axis_buf = cgen.addBuffer(axis_data);
int axis = cgen.addTensor({{1}, circle::TensorType::TensorType_INT32, axis_buf});
int in = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_BOOL});
int out = cgen.addTensor({{1, 2, 1}, output_type});
cgen.addOperatorArgMax({{in, axis}, {out}}, output_type);
cgen.setInputsAndOutputs({in}, {out});

_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->expectFailModelLoad();

SUCCEED();
}

TEST_F(GenModelTest, neg_OneOp_ArgMax_AxisType)
{
CircleGen cgen;
const auto output_type = circle::TensorType::TensorType_FLOAT32;
std::vector<float> axis_data{4};
uint32_t axis_buf = cgen.addBuffer(axis_data);
int axis = cgen.addTensor({{1}, circle::TensorType::TensorType_FLOAT32, axis_buf});
int in = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
int out = cgen.addTensor({{1, 2, 1}, output_type});
cgen.addOperatorArgMax({{in, axis}, {out}}, output_type);
cgen.setInputsAndOutputs({in}, {out});

_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->expectFailModelLoad();

SUCCEED();
}

TEST_F(GenModelTest, neg_OneOp_ArgMax_OutType)
{
CircleGen cgen;
const auto output_type = circle::TensorType::TensorType_FLOAT32;
std::vector<int32_t> axis_data{4};
uint32_t axis_buf = cgen.addBuffer(axis_data);
int axis = cgen.addTensor({{1}, circle::TensorType::TensorType_INT32, axis_buf});
int in = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
int out = cgen.addTensor({{1, 2, 1}, output_type});
cgen.addOperatorArgMax({{in, axis}, {out}}, output_type);
cgen.setInputsAndOutputs({in}, {out});

_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->expectFailModelLoad();

SUCCEED();
}

TEST_F(GenModelTest, neg_OneOp_ArgMax_paramType)
{
CircleGen cgen;
const auto output_type = circle::TensorType::TensorType_INT32;
std::vector<int32_t> axis_data{4};
uint32_t axis_buf = cgen.addBuffer(axis_data);
int axis = cgen.addTensor({{1}, circle::TensorType::TensorType_INT32, axis_buf});
int in = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
int out = cgen.addTensor({{1, 2, 1}, output_type});
cgen.addOperatorArgMax({{in, axis}, {out}}, circle::TensorType::TensorType_INT64);
cgen.setInputsAndOutputs({in}, {out});

_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->expectFailModelLoad();

SUCCEED();
}

0 comments on commit 17d2995

Please sign in to comment.