Skip to content

Commit

Permalink
[BYOC][ACL] Improved pooling support (apache#6248)
Browse files Browse the repository at this point in the history
* [BYOC][ACL] Improved pooling support

Adds support in ACL for the following relay pooling operators and composite functions:
  * nn.avg_pool2d (fp32), cast + nn.avg_pool2d(uint8) + cast => AVG pool
  * nn.global_max_pool2d => Global MAX pool
  * nn.global_avg_pool2d, cast + nn.global_avg_pool2d(uint8) + cast => Global AVG pool
  * power(2) + nn.avg_pool2d + sqrt => L2 pooling (for fp32 only)

Tests updated to reflect these changes.

Change-Id: I1644b67b60ebb252344eb9695a521d2d958c724e

* Address comments

Change-Id: Ibe8a61b4c42da246ce54701c89ea985b423c8f83

* Fix not checking output saturation

Change-Id: Ia6f3d9db31cfb8c417d8556d29961210fea418b2

* Use defined set of trials

Change-Id: Ib180e3a0cbb84d6fa00c7e1994f58cb62662db15

* Rebase master

Change-Id: I5c932751cd38da06d6f2b397be5d8ab7fdeb169f
  • Loading branch information
lhutton1 authored Aug 27, 2020
1 parent 44d97ad commit c958bc1
Show file tree
Hide file tree
Showing 11 changed files with 506 additions and 107 deletions.
69 changes: 44 additions & 25 deletions docs/deploy/arm_compute_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,31 +188,50 @@ An example configuration for `test_config.json`:
Operator support
----------------
+--------------+-------------------------------------------------------------------------+
| Relay Node | Remarks |
+==============+=========================================================================+
| nn.conv2d | fp32: |
| | Simple: nn.conv2d |
| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu? |
| | |
| | (only groups = 1 supported) |
+--------------+-------------------------------------------------------------------------+
| qnn.conv2d | uint8: |
| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize |
| | |
| | (only groups = 1 supported) |
+--------------+-------------------------------------------------------------------------+
| nn.dense | fp32: |
| | Simple: nn.dense |
| | Composite: nn.dense, nn.bias_add? |
+--------------+-------------------------------------------------------------------------+
| qnn.dense | uint8: |
| | Composite: qnn.dense, nn.bias_add?, qnn.requantize |
+--------------+-------------------------------------------------------------------------+
| nn.maxpool2d | fp32, uint8 |
+--------------+-------------------------------------------------------------------------+
| reshape | fp32, uint8 |
+--------------+-------------------------------------------------------------------------+
+----------------------+-------------------------------------------------------------------------+
| Relay Node | Remarks |
+======================+=========================================================================+
| nn.conv2d | fp32: |
| | Simple: nn.conv2d |
| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu? |
| | |
| | (only groups = 1 supported) |
+----------------------+-------------------------------------------------------------------------+
| qnn.conv2d | uint8: |
| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize |
| | |
| | (only groups = 1 supported) |
+----------------------+-------------------------------------------------------------------------+
| nn.dense | fp32: |
| | Simple: nn.dense |
| | Composite: nn.dense, nn.bias_add? |
+----------------------+-------------------------------------------------------------------------+
| qnn.dense | uint8: |
| | Composite: qnn.dense, nn.bias_add?, qnn.requantize |
+----------------------+-------------------------------------------------------------------------+
| nn.max_pool2d | fp32, uint8 |
+----------------------+-------------------------------------------------------------------------+
| nn.global_max_pool2d | fp32, uint8 |
+----------------------+-------------------------------------------------------------------------+
| nn.avg_pool2d | fp32: |
| | Simple: nn.avg_pool2d |
| | |
| | uint8: |
| | Composite: cast(int32), nn.avg_pool2d, cast(uint8) |
+----------------------+-------------------------------------------------------------------------+
| nn.global_avg_pool2d | fp32: |
| | Simple: nn.global_avg_pool2d |
| | |
| | uint8: |
| | Composite: cast(int32), nn.avg_pool2d, cast(uint8) |
+----------------------+-------------------------------------------------------------------------+
| power(of 2) + | A special case for L2 pooling. |
| nn.avg_pool2d + | |
| sqrt | fp32: |
| | Composite: power(of 2), nn.avg_pool2d, sqrt |
+----------------------+-------------------------------------------------------------------------+
| reshape | fp32, uint8 |
+----------------------+-------------------------------------------------------------------------+

.. note::
A composite operator is a series of operators that map to a single Arm Compute Library operator. You can view this
Expand Down
86 changes: 84 additions & 2 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
# pylint: disable=invalid-name, unused-argument
"""Arm Compute Library supported operators."""
import tvm
from tvm.relay.expr import const
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import wildcard, is_op, is_constant
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr
from .register import register_pattern_table


Expand Down Expand Up @@ -125,6 +126,33 @@ def qnn_dense_pattern():
pattern, wildcard(), wildcard(), is_constant(), is_constant())
return pattern

def avg_pool2d_pattern():
"""Creates a pattern that matches either quantized
avg_pool2d or quantized global_avg_pool2d.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the convolution pattern.
"""
pattern = is_op('cast')(wildcard())
pattern = is_op('nn.avg_pool2d')(pattern) | is_op('nn.global_avg_pool2d')(pattern)
pattern = is_op('cast')(pattern)
return pattern

def l2_pool2d_pattern():
"""Create an l2 pooling pattern from equivalent relay operators.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the convolution pattern.
"""
pattern = is_op('power')(wildcard(), is_expr(const(2.0)))
pattern = is_op('nn.avg_pool2d')(pattern)
pattern = is_op('sqrt')(pattern)
return pattern

def check_conv(extract):
"""Check conv pattern is supported by ACL."""
call = extract
Expand Down Expand Up @@ -157,10 +185,27 @@ def check_qnn_dense(extract):
call = call.args[0]
return qnn_dense(call.attrs, call.args)

def check_avg_pool2d(extract):
"""Check average pool2d pattern is supported by ACL."""
if extract.attrs.dtype != "uint8":
return False
pool = extract.args[0]
if pool.args[0].attrs.dtype != "int32":
return False
return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True)

def check_l2_pool2d(extract):
"""Check l2 pool2d pattern is supported by ACL."""
pool = extract.args[0]
return avg_pool2d(pool.attrs, pool.args)

return [('arm_compute_lib.conv2d', conv_pattern(), check_conv),
('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
('arm_compute_lib.dense', dense_pattern(), check_dense),
('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense)]
('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense),
('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
('arm_compute_lib.avg_pool2d', avg_pool2d_pattern(), check_avg_pool2d),
('arm_compute_lib.l2_pool2d', l2_pool2d_pattern(), check_l2_pool2d)]


def _register_external_op_helper(op_name, supported=True):
Expand Down Expand Up @@ -245,3 +290,40 @@ def max_pool2d(attrs, args):
if typ.dtype not in ["float32", "uint8"]:
return False
return True


@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib")
def avg_pool2d(attrs, args, from_quantized_composite=False):
"""Check if the external ACL codegen for avgpool2d should be used."""
typ = args[0].checked_type
if from_quantized_composite:
if typ.dtype != "int32":
return False
else:
if typ.dtype not in ["float32"]:
return False
if attrs.layout != "NHWC":
return False
return True


@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib")
def global_max_pool2d(attrs, args):
"""Check if the external ACL codegen for gloval_maxpool2d should be used."""
typ = args[0].checked_type
if typ.dtype not in ["float32", "uint8"]:
return False
if attrs.layout != "NHWC":
return False
return True


@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib")
def global_avg_pool2d(attrs, args):
"""Check if the external ACL codegen for global_avgpool2d should be used."""
typ = args[0].checked_type
if typ.dtype not in ["float32"]:
return False
if attrs.layout != "NHWC":
return False
return True
60 changes: 60 additions & 0 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
json_node = CreateCompositeConvJSONNode(cn);
} else if (name == "arm_compute_lib.dense" || name == "arm_compute_lib.qnn_dense") {
json_node = CreateCompositeDenseJSONNode(cn);
} else if (name == "arm_compute_lib.avg_pool2d") {
json_node = CreateCompositeAvgPool2DJSONNode(cn);
} else if (name == "arm_compute_lib.l2_pool2d") {
json_node = CreateCompositeL2Pool2DJSONNode(cn);
} else {
LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
}
Expand Down Expand Up @@ -267,6 +271,62 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
SetCallNodeAttribute(json_node, nodes.dense);
return json_node;
}

/*!
* \brief Create a JSON representation of a composite (global) average pooling operator.
*
* A composite function is only created when using the uint8 datatype for these operators.
*
* \param cn The call to be represented.
* \return A JSON representation of a specific operator.
*/
std::shared_ptr<JSONGraphNode> CreateCompositeAvgPool2DJSONNode(const CallNode* cn) {
const auto* fn = cn->op.as<FunctionNode>();
CHECK(fn);
const auto* cast = fn->body.as<CallNode>();
CHECK(cast);
const auto* avg_pool = cast->args[0].as<CallNode>();
CHECK(avg_pool);
const auto* avg_pool_op = avg_pool->op.as<OpNode>();
CHECK(avg_pool_op);
const std::string name = avg_pool_op->name;

std::vector<JSONGraphNodeEntry> inputs;
inputs.push_back(VisitExpr(cn->args[0])[0]);
auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
SetCallNodeAttribute(json_node, avg_pool);
return json_node;
}

/*!
* \brief Create a JSON representation of a composite L2 pooling operator.
*
* \note Relay does not have an operator for L2 pooling, instead we can create
* an equivalent from power(2) + nn.avg_pool2d + sqrt.
*
* \param cn The call to be represented.
* \return A JSON representation of a specific operator.
*/
std::shared_ptr<JSONGraphNode> CreateCompositeL2Pool2DJSONNode(const CallNode* cn) {
const std::string name = "nn.l2_pool2d";
const auto* fn = cn->op.as<FunctionNode>();
CHECK(fn);
const auto* sqrt = fn->body.as<CallNode>();
CHECK(sqrt);
const auto* avg_pool = sqrt->args[0].as<CallNode>();
CHECK(avg_pool);
const auto* pow = avg_pool->args[0].as<CallNode>();
CHECK(pow);
const auto* exponent = pow->args[1].as<ConstantNode>();
CHECK(exponent);
CHECK_EQ(*static_cast<float*>(exponent->data->data), 2) << "Exponent must be 2 for L2 pooling";

std::vector<JSONGraphNodeEntry> inputs;
inputs.push_back(VisitExpr(cn->args[0])[0]);
auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
SetCallNodeAttribute(json_node, avg_pool);
return json_node;
}
};

/*!
Expand Down
54 changes: 50 additions & 4 deletions src/runtime/contrib/arm_compute_lib/acl_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ class ACLRuntime : public JSONRuntimeBase {
} else if ("nn.dense" == op_name || "qnn.dense" == op_name) {
CreateFullyConnectedLayer(&layer_, node, mm);
num_pools++;
} else if ("nn.max_pool2d" == op_name) {
} else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name ||
"nn.l2_pool2d" == op_name) {
CreatePoolingLayer(&layer_, node);
} else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" == op_name) {
CreateGlobalPoolingLayer(&layer_, node);
} else if ("reshape" == op_name) {
CreateReshapeLayer(&layer_, node);
} else {
Expand Down Expand Up @@ -308,30 +311,73 @@ class ACLRuntime : public JSONRuntimeBase {
/*!
* \brief Create a pooling layer.
*
* \note Currently only maxpool is supported.
* \note Currently max_pool2d, avg_pool2d and L2 pooling are supported.
*
* \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
* \param node The JSON representation of the operator.
*/
void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) {
std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("padding");
std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides);
bool ceil_mode = std::stoi(node.GetAttr<std::vector<std::string>>("ceil_mode")[0]);
arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides, ceil_mode);

auto attr_pool_size = node.GetAttr<std::vector<std::string>>("pool_size");
int pool_size_h = std::stoi(attr_pool_size[0]);
int pool_size_w = std::stoi(attr_pool_size[1]);

// Only applies to average pool and l2 pool.
// ACL exclude pad option is inverse to Relays include pad option.
bool exclude_pad = false;
if (node.HasAttr("count_include_pad")) {
int count_include_pad =
std::stoi(node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
exclude_pad = !count_include_pad;
}

arm_compute::PoolingType pool_type;
if (node.GetOpName() == "nn.max_pool2d") {
pool_type = arm_compute::PoolingType::MAX;
} else if (node.GetOpName() == "nn.avg_pool2d") {
pool_type = arm_compute::PoolingType::AVG;
} else if (node.GetOpName() == "nn.l2_pool2d") {
pool_type = arm_compute::PoolingType::L2;
} else {
LOG(FATAL) << "Pooling type not supported";
}

arm_compute::PoolingLayerInfo pool_info =
arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w),
arm_compute::DataLayout::NHWC, pad_stride_info);
arm_compute::DataLayout::NHWC, pad_stride_info, exclude_pad);

layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
layer->outputs.push_back(MakeACLTensorFromJSONNode(node));

auto function = std::make_shared<arm_compute::NEPoolingLayer>();
function->configure(&layer->inputs[0], &layer->outputs[0], pool_info);
layer->function = function;
}

/*!
* \brief Create a global pooling layer.
*
* \note Currently global_max_pool2d and global_avg_pool2d are supported.
*
* \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
* \param node The JSON representation of the operator.
*/
void CreateGlobalPoolingLayer(CachedLayer* layer, const JSONGraphNode& node) {
arm_compute::PoolingType pool_type;
if (node.GetOpName() == "nn.global_max_pool2d") {
pool_type = arm_compute::PoolingType::MAX;
} else if (node.GetOpName() == "nn.global_avg_pool2d") {
pool_type = arm_compute::PoolingType::AVG;
} else {
LOG(FATAL) << "Pooling type not supported";
}

arm_compute::PoolingLayerInfo pool_info =
arm_compute::PoolingLayerInfo(pool_type, arm_compute::DataLayout::NHWC);

layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
Expand Down
10 changes: 8 additions & 2 deletions src/runtime/contrib/arm_compute_lib/acl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeACLMemoryManager() {
}

arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
const std::vector<std::string>& stride) {
const std::vector<std::string>& stride,
bool ceil_mode) {
int pad_0 = 0, pad_1 = 0, pad_2 = 0, pad_3 = 0;
int stride_0 = std::stoi(stride[0]), stride_1 = std::stoi(stride[1]);
auto dimensions_rounding = arm_compute::DimensionRoundingType::FLOOR;
size_t size = pad.size();
if (size == 1) {
int pad_v = std::stoi(pad[0]);
Expand All @@ -109,8 +111,12 @@ arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
LOG(FATAL) << "Unsupported padding dimensions";
}

if (ceil_mode) {
dimensions_rounding = arm_compute::DimensionRoundingType::CEIL;
}

return arm_compute::PadStrideInfo(stride_0, stride_1, pad_0, pad_1, pad_2, pad_3,
arm_compute::DimensionRoundingType::FLOOR);
dimensions_rounding);
}

arm_compute::DataType MakeACLDataType(const DLDataType& data_type) {
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/contrib/arm_compute_lib/acl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeACLMemoryManager();
*
* \param pad The pad vector.
* \param stride The stride vector.
* \param ceil_mode Dimensions rounding.
* \return arm_compute::PadStrideInfo
*/
arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
const std::vector<std::string>& stride);
const std::vector<std::string>& stride,
bool ceil_mode = false);

/*!
* \brief Convert DLDataType to arm_compute::DataType.
Expand Down
Loading

0 comments on commit c958bc1

Please sign in to comment.