Skip to content

Commit

Permalink
[BYOC][ACL] Support add operation (apache#6532)
Browse files Browse the repository at this point in the history
* [BYOC][ACL] Support add operation

Added support for an "add" operation implemented via ACL
for fp32 and quantized uint8 data types

* Addressed lhutton1 comments

* linter
  • Loading branch information
d-smirnov authored Oct 11, 2020
1 parent bf21371 commit e561007
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 6 deletions.
4 changes: 4 additions & 0 deletions docs/deploy/arm_compute_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ Operator support
+----------------------+-------------------------------------------------------------------------+
| maximum | fp32 |
+----------------------+-------------------------------------------------------------------------+
| add | fp32 |
+----------------------+-------------------------------------------------------------------------+
| qnn.add | uint8 |
+----------------------+-------------------------------------------------------------------------+

.. note::
A composite operator is a series of operators that map to a single Arm Compute Library operator. You can view this
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,23 @@ def maximum(attrs, args):
type_a = args[0].checked_type
type_b = args[0].checked_type
return (type_a.dtype == "float32") and (type_b.dtype == "float32")


@tvm.ir.register_op_attr("add", "target.arm_compute_lib")
def add(attrs, args):
"""Check if the external ACL codegen for add should be used."""
for typ in [args[0].checked_type, args[1].checked_type]:
if typ.dtype != "float32":
return False

return True


@tvm.ir.register_op_attr("qnn.add", "target.arm_compute_lib")
def qnn_add(attrs, args):
"""Check if the external ACL codegen for add should be used."""
for typ in [args[0].checked_type, args[1].checked_type]:
if typ.dtype != "uint8":
return False

return True
42 changes: 36 additions & 6 deletions src/runtime/contrib/arm_compute_lib/acl_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
#include <arm_compute/core/Types.h>
#include <arm_compute/runtime/NEON/functions/NEArithmeticAddition.h>
#include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
#include <arm_compute/runtime/NEON/functions/NEElementwiseOperations.h>
#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
Expand Down Expand Up @@ -142,6 +143,8 @@ class ACLRuntime : public JSONRuntimeBase {
CreateReshapeLayer(&layer_, node);
} else if ("maximum" == op_name) {
CreateMaximumLayer(&layer_, node);
} else if ("add" == op_name || "qnn.add" == op_name) {
CreateAddLayer(&layer_, node);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand Down Expand Up @@ -417,6 +420,36 @@ class ACLRuntime : public JSONRuntimeBase {
function->configure(&layer->inputs[0], &layer->inputs[1], &layer->outputs[0]);
layer->function = function;
}
/*!
* \brief Creates an add/qnn.add layer
*
* \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
* \param node The JSON representation of the operator.
*/
void CreateAddLayer(CachedLayer* layer, const JSONGraphNode& node) {
auto op_name = node.GetOpName();
if ("add" == op_name) {
layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[1]));
layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
} else if ("qnn.add" == op_name) {
layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0], &node.GetInputs()[2],
&node.GetInputs()[3]));
layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[1], &node.GetInputs()[4],
&node.GetInputs()[5]));
layer->outputs.push_back(
MakeACLTensorFromJSONNode(node, &node.GetInputs()[6], &node.GetInputs()[7]));
} else {
throw std::runtime_error("Unsupported form of add op: " + op_name);
}

auto f = std::make_shared<arm_compute::NEArithmeticAddition>();

// SATURATE is used as add_QASYMM8_QASYMM8_QASYMM8 always saturates result
f->configure(&layer->inputs[0], &layer->inputs[1], &layer->outputs[0],
arm_compute::ConvertPolicy::SATURATE);
layer->function = f;
}

/*! \brief Allow ACL functions to request auxiliary memory from TVM. */
ACLAllocator allocator_;
Expand All @@ -437,18 +470,15 @@ class ACLRuntime : public JSONRuntimeBase {
}
#endif
};

runtime::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json,
const Array<String>& const_names) {
auto n = make_object<ACLRuntime>(symbol_name, graph_json, const_names);
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("runtime.arm_compute_lib_runtime_create").set_body_typed(ACLRuntimeCreate);

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_arm_compute_lib")
.set_body_typed(JSONRuntimeBase::LoadFromBinary<ACLRuntime>);

} // namespace contrib
} // namespace runtime
} // namespace tvm
} // namespace contrib
} // namespace runtime
} // namespace tvm
133 changes: 133 additions & 0 deletions tests/python/contrib/test_arm_compute_lib/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Arm Compute Library integration reshape tests."""

import numpy as np

import tvm
import tvm.testing
from tvm import relay

from test_arm_compute_lib.infrastructure import (
skip_runtime_test,
skip_codegen_test,
build_and_run,
verify,
verify_codegen,
)
from test_arm_compute_lib.infrastructure import Device

_qnn_params = {
"lhs_scale": relay.const(0.0156863, "float32"),
"lhs_zero_point": relay.const(127, "int32"),
"rhs_scale": relay.const(0.0117647, "float32"),
"rhs_zero_point": relay.const(85, "int32"),
"output_scale": relay.const(0.0235294, "float32"),
"output_zero_point": relay.const(128, "int32"),
}


def _get_model(shape, dtype, var_names, op, op_params):
a = relay.var(next(var_names), shape=shape, dtype=dtype)
b = relay.var(next(var_names), shape=shape, dtype=dtype)
return op(a, b, **op_params)


def _get_expected_codegen(shape, dtype, op_name, qnn_params):
input_a = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
input_b = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
input_qnn = [
{
"op": "const",
"name": "",
"attrs": {
"shape": [[list(qnn_params[_].data.shape)]],
"dtype": [[qnn_params[_].data.dtype]],
},
}
for _ in qnn_params
]
inputs = [input_a, input_b, *input_qnn]
node = {
"op": "kernel",
"name": op_name,
"inputs": [[_, 0, 0] for _ in range(len(inputs))],
"attrs": {
"num_inputs": str(len(inputs)),
"num_outputs": "1",
"shape": [[list(shape)]],
"dtype": [[dtype]],
},
}

return [*inputs, node]


def test_runtime_add():
Device.load("test_config.json")

if skip_runtime_test():
return

device = Device()
np.random.seed(0)

for dtype, low, high, atol, rtol, op, op_params in [
("float32", -127, 128, 1e-7, 1e-7, relay.add, {}),
("uint8", 0, 255, 0.0, 1.0, relay.qnn.op.add, _qnn_params),
]:
shape = (2, 2)
for inputs in [
{
"a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
"b": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
}
]:
outputs = []
func = _get_model(shape, dtype, iter(inputs), op, op_params)
for acl in [True, False]:
outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl)[0])

config = {
"shape": shape,
"dtype": dtype,
"inputs": inputs,
"operation": op,
"op_params": op_params,
}

verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=False)


def test_codegen_add():
if skip_codegen_test():
return

inputs = {"a", "b"}
for dtype, op_name, op, qnn_params in [
("float32", "add", relay.add, {}),
("uint8", "qnn.add", relay.qnn.op.add, _qnn_params),
]:
for shape in [(1, 1), (2, 2, 2), (3, 3, 3, 3)]:
func = _get_model(shape, dtype, iter(inputs), op, qnn_params)
exp_codegen = _get_expected_codegen(shape, dtype, op_name, qnn_params)
verify_codegen(func, exp_codegen, 1)


if __name__ == "__main__":
test_codegen_add()
test_runtime_add()

0 comments on commit e561007

Please sign in to comment.