Skip to content

Commit

Permalink
[Relay]Allow dynamic batch for arm conv2d (apache#6509)
Browse files Browse the repository at this point in the history
* Allow dynamic batch for arm conv2d

* Add TODO
  • Loading branch information
kevinthesun authored Sep 18, 2020
1 parent 28ea54a commit 1d6ee60
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 5 deletions.
42 changes: 42 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,48 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
#####################


@script
def _conv2d_shape_func(dshape, kshape, strides, padding, dilation):
out = output_tensor((dshape.shape[0],), "int64")
height = dshape[2]
width = dshape[3]
kheight = kshape[2]
kwidth = kshape[3]
dilated_kh = (kheight - 1) * dilation[0] + 1
dilated_kw = (kwidth - 1) * dilation[1] + 1

oc = kshape[0]

out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1

out[0] = dshape[0]
out[1] = oc
out[2] = out_height
out[3] = out_width
return out


@reg.register_shape_func("nn.conv2d", False)
def conv2d_shape_func(attrs, inputs, _):
"""
Shape function for contrib_conv2d_NCHWc op.
"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)

return [
_conv2d_shape_func(
inputs[0],
inputs[1],
convert(strides),
convert(padding),
convert(dilation),
)
]


@script
def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
out = output_tensor((dshape.shape[0],), "int64")
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def _callback(op):

def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, tile_size):
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(N, tvm.tir.Any):
N = tvm.te.size_var("n")
if not isinstance(IH, int) or not isinstance(IW, int):
raise RuntimeError("ARM winograd conv2d doesn't support dynamic input height or width.")

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
Expand Down Expand Up @@ -154,7 +158,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
nH, nW = (H + m - 1) // m, (W + m - 1) // m
P = N * nH * nW

cfg.define_split("tile_p", cfg.axis(P), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
# TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
tile_p = P if isinstance(N, int) else nH * nW
cfg.define_split("tile_p", cfg.axis(tile_p), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
cfg.define_split("tile_k", cfg.axis(K), num_outputs=2, filter=lambda x: x.size[-1] <= 16)
VP = cfg["tile_p"].size[-1]
VK = cfg["tile_k"].size[-1]
Expand Down Expand Up @@ -236,7 +242,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
)

# we have to manually assign effective GFLOP for winograd
cfg.add_flop(2 * N * K * H * W * KH * KW * C)
if isinstance(N, int):
cfg.add_flop(2 * N * K * H * W * KH * KW * C)
return output


Expand Down
8 changes: 7 additions & 1 deletion python/tvm/topi/arm_cpu/conv2d_spatial_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, out_
"""compute define for Conv2d Spatial Pack with NCHW layout"""
out_dtype = out_dtype or data.dtype
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(N, tvm.tir.Any):
N = tvm.te.size_var("n")
if not isinstance(IH, int) or not isinstance(IW, int):
raise RuntimeError("ARM winograd conv2d doesn't support dynamic input height or width.")

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
Expand All @@ -54,7 +58,9 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, out_
data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])

# ==================== define configuration space ====================
n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
# TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
n_tuning_axis = N if isinstance(N, int) else 1
n, co, oh, ow = cfg.axis(n_tuning_axis), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)

if num_tile == 2: # for arm cpu
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/topi/arm_cpu/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def _decl_spatial_pack(
out_dtype = out_dtype or data.dtype

N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(N, tvm.tir.Any):
N = tvm.te.size_var("n")
if not isinstance(IH, int) or not isinstance(IW, int):
raise RuntimeError("ARM winograd conv2d doesn't support dynamic input height or width.")

_, CO, KH, KW = get_const_tuple(kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
opad_h, opad_w = output_padding
Expand All @@ -84,7 +89,9 @@ def _decl_spatial_pack(
data_pad = pad(dilated_input, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right])

# ==================== define configuration space ====================
n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
# TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
n_tuning_axis = N if isinstance(N, int) else 1
n, co, oh, ow = cfg.axis(n_tuning_axis), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)

if num_tile == 2: # for arm cpu
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/topi/nn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
if len(pad_after) != n:
raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (n, len(pad_before)))
ana = tvm.arith.Analyzer()
out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n))
dshape = []
for dim in data.shape:
if isinstance(dim, tvm.tir.Any):
dshape.append(tvm.te.size_var("dim"))
else:
dshape.append(dim)
out_shape = tuple(ana.simplify(dshape[i] + pad_before[i] + pad_after[i]) for i in range(n))
pad_value = (
pad_value
if isinstance(pad_value, tvm.tir.PrimExpr)
Expand Down
45 changes: 45 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,49 @@ def test_any_reshape_like():
check_result([data_np, shape_like_np], mod, shape_like_np.shape, assert_shape=True)


def verify_any_conv2d(
data_shape,
kernel_shape,
strides,
padding,
dilation,
static_data_shape,
ref_out_shape,
):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype)
y = relay.nn.conv2d(data, kernel, strides, padding, dilation, kernel_size=kernel_shape[2:4])
mod["main"] = relay.Function([data, kernel], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)


# TODO(@kevinthesun): Support dynamic input height and width.
# TODO(@kevinthesun): Support gpu to enable gpu tests.
def test_any_conv2d():
verify_any_conv2d(
(relay.Any(), 64, 224, 224),
(64, 64, 3, 3),
(1, 1),
(1, 1),
(1, 1),
(1, 64, 224, 224),
(1, 64, 224, 224),
)
verify_any_conv2d(
(relay.Any(), 64, 224, 224),
(64, 64, 3, 3),
(1, 1),
(1, 1),
(2, 2),
(2, 64, 224, 224),
(2, 64, 222, 222),
)


def verify_any_conv2d_NCHWc(
data_shape,
kernel_shape,
Expand Down Expand Up @@ -458,6 +501,7 @@ def verify_any_conv2d_NCHWc(


# TODO(@kevinthesun): Support dynamic input height and width.
# TODO(@kevinthesun): Support gpu to enable gpu tests.
def test_any_conv2d_NCHWc():
verify_any_conv2d_NCHWc(
(relay.Any(), 8, 224, 224, 8),
Expand Down Expand Up @@ -519,6 +563,7 @@ def verify_any_conv2d_transpose_nchw(


# TODO(@kevinthesun): Support dynamic input height and width.
# TODO(@kevinthesun): Support gpu to enable gpu tests.
def test_any_conv2d_transpose_nchw():
verify_any_conv2d_transpose_nchw(
(relay.Any(), 64, 224, 224),
Expand Down

0 comments on commit 1d6ee60

Please sign in to comment.