Skip to content

Commit

Permalink
[onert] Frontend for ArgMinMax (Samsung#5146)
Browse files Browse the repository at this point in the history
* [onert] Frontend for ArgMinMax

- Support ArgMin on loader and nnapi frontend
- Set is_arg_max param
- Remove operand type duplication in loader: handle on operation validator

Signed-off-by: Hyeongseok Oh <[email protected]>
  • Loading branch information
hseok-oh authored Nov 25, 2020
1 parent 9b94637 commit 9481500
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
29 changes: 11 additions & 18 deletions runtime/onert/frontend/base_loader/include/base_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ template <typename LoaderDomain> class BaseLoader
void loadOneHot(const Operator *op, ir::Graph &subg);
void loadIf(const Operator *op, ir::Graph &subg);
void loadWhile(const Operator *op, ir::Graph &subg);
void loadArgMax(const Operator *op, ir::Graph &subg);
void loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax);
void loadFusedBatchNorm(const Operator *op, ir::Graph &subg);
void loadLogSoftmax(const Operator *op, ir::Graph &subg);
void loadSpaceToDepth(const Operator *op, ir::Graph &subg);
Expand Down Expand Up @@ -1231,25 +1231,15 @@ void BaseLoader<LoaderDomain>::loadWhile(const Operator *op, ir::Graph &subg)
}

template <typename LoaderDomain>
void BaseLoader<LoaderDomain>::loadArgMax(const Operator *op, ir::Graph &subg)
void BaseLoader<LoaderDomain>::loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax)
{
ir::operation::ArgMinMax::Param param;
const auto output_type = op->builtin_options_as_ArgMaxOptions()->output_type();
switch (output_type)
{
case TensorType::TensorType_INT32:
case TensorType::TensorType_INT64:
param.output_type = tensorTypeToDataType(output_type);
break;
default:
throw std::runtime_error("ArgMax: `output_type` must be either int32 or int64.");
}
auto am = loadOperationTo<ir::operation::ArgMinMax>(op, subg, param);
const auto output_type = is_argmax ? op->builtin_options_as_ArgMaxOptions()->output_type()
: op->builtin_options_as_ArgMinOptions()->output_type();
param.output_type = tensorTypeToDataType(output_type);
param.is_arg_max = is_argmax;

auto &axisOperand = subg.operands().at(am->getInputs().at(ir::operation::ArgMinMax::Input::AXIS));
if (!(axisOperand.operandSize() == 4 && (axisOperand.typeInfo().type() == ir::DataType::INT32 ||
axisOperand.typeInfo().type() == ir::DataType::INT64)))
throw std::runtime_error("ArgMax: `axis` with an int32 or int64 element is only supported.");
loadOperationTo<ir::operation::ArgMinMax>(op, subg, param);
}

template <typename LoaderDomain>
Expand Down Expand Up @@ -1509,7 +1499,10 @@ void BaseLoader<LoaderDomain>::loadOperation(const Operator *op, ir::Graph &subg
loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::NEG);
return;
case BuiltinOperator::BuiltinOperator_ARG_MAX:
loadArgMax(op, subg);
loadArgMinMax(op, subg, true);
return;
case BuiltinOperator::BuiltinOperator_ARG_MIN:
loadArgMinMax(op, subg, false);
return;
case BuiltinOperator::BuiltinOperator_LOG:
loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOG);
Expand Down
20 changes: 20 additions & 0 deletions runtime/onert/frontend/nnapi/wrapper/OperationFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,7 @@ OperationFactory::OperationFactory()
operation::ArgMinMax::Param param;
// NNAPI ARGMAX output type is always int32
param.output_type = DataType::INT32;
param.is_arg_max = true;

return new operation::ArgMinMax{inputs, outputs, param};
};
Expand All @@ -1530,6 +1531,25 @@ OperationFactory::OperationFactory()
// TODO Remove ANEURALNETWORKS_ARGMAX_EX
_map[ANEURALNETWORKS_ARGMAX_EX] = _map[ANEURALNETWORKS_ARGMAX];

_map[ANEURALNETWORKS_ARGMIN] = [](const OperationFactory::Param &init_param, Operands &) {
assert(init_param.input_count == 2 && init_param.output_count == 1);

OperandIndexSequence outputs{init_param.outputs[0]};

// Each input should be interpreted as follows:
//
// 0 -> Input Tensor Index
// 1 -> Axis Tensor Index
OperandIndexSequence inputs{init_param.inputs[0], init_param.inputs[1]};

operation::ArgMinMax::Param param;
// NNAPI ARGMIN output type is always int32
param.output_type = DataType::INT32;
param.is_arg_max = false;

return new operation::ArgMinMax{inputs, outputs, param};
};

_map[ANEURALNETWORKS_DEQUANTIZE] =
getElementwiseUnaryGenerator(operation::ElementwiseUnary::Type::DEQUANTIZE);

Expand Down

0 comments on commit 9481500

Please sign in to comment.