Skip to content

Commit

Permalink
[Relay/TOPI] Added dilation_value attribute to dilate operator. (apac…
Browse files Browse the repository at this point in the history
…he#6550)

* Added dilation_value attribute to dilate operator of Relay/TOPI.
  (Enables custom value for dilation, instead of always 0)
* Added tests for dilation_value of dilate operator in Relay and TOPI.
  • Loading branch information
jainris authored Sep 25, 2020
1 parent 0c3efc2 commit 63d203c
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 23 deletions.
2 changes: 2 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
/*! \brief Attributes used in dilate operator */
struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
Array<IndexExpr> strides;
double dilation_value;

TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") {
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1}))
.describe("Dilation stride on each dimension, 1 means no dilation.");
TVM_ATTR_FIELD(dilation_value).set_default(0.0).describe("Value used to dilate the input.");
}
};

Expand Down
10 changes: 6 additions & 4 deletions include/tvm/topi/nn/dilate.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,20 @@ PrimExpr all(Array<PrimExpr> args) {
}

/*!
* \brief Dilate data with zeros
* \brief Dilate data with given dilation value (0 by default).
*
* \param x The input tensor, this can have any number of
* dimensions and any layout.
* \param strides Dilation stride for each dimension. Stride 1
* means no dilation.
* \param dilation_value Value used to dilate the input.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The output tensor.
*/
inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, std::string name = "tensor",
std::string tag = kInjective) {
inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, double dilation_value,
std::string name = "tensor", std::string tag = kInjective) {
auto n = x->shape.size();
CHECK_EQ(n, strides.size()) << "strides size (" << strides.size()
<< ") must match dimension of x (" << n << ")";
Expand All @@ -94,7 +95,8 @@ inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, std::string name
}
if (not_zero.size() > 0) {
auto all_not_zero = all(not_zero);
return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0));
return tvm::if_then_else(all_not_zero, x(index_tuple),
make_const(x->dtype, dilation_value));
}
return x(index_tuple);
},
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def compute_cross_entropy(attrs, inputs, out_dtype):
# dilate
@reg.register_compute("nn.dilate")
def compute_dilate(attrs, inputs, out_dtype):
return [topi.nn.dilate(inputs[0], attrs.strides)]
return [topi.nn.dilate(inputs[0], attrs.strides, attrs.dilation_value)]


reg.register_broadcast_schedule("nn.dilate")
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,23 +1549,26 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
return _make.pad(data, pad_width, pad_value, pad_mode)


def dilate(data, strides):
"""Dilate data with zeros.
def dilate(data, strides, dilation_value=0.0):
"""Dilate data with given dilation value (0 by default).
Parameters
----------
data : tvm.relay.Expr
n-D, can be any layout.
strides : <tuple of <int>
strides : tuple of <int>
Dilation stride on each dimension, 1 means no dilation.
dilation_value : int/float, optional
Value used to dilate the input.
Returns
-------
Output : tvm.relay.Expr
The computed result
"""
return _make.dilate(data, strides)
return _make.dilate(data, strides, dilation_value)


def mirror_pad(data, pad_width, mode="SYMMETRIC"):
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/nn/dilate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@


@te.tag_scope(tag=tag.INJECTIVE + ",dilate")
def dilate(data, strides, name="DilatedInput"):
"""Dilate data with zeros.
def dilate(data, strides, dilation_value=0.0, name="DilatedInput"):
"""Dilate data with given dilation value (0 by default).
Parameters
----------
Expand All @@ -34,6 +34,9 @@ def dilate(data, strides, name="DilatedInput"):
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.
dilation_value : int/float, optional
Value used to dilate the input.
name : str, optional
The name prefix operators generated
Expand Down Expand Up @@ -62,7 +65,7 @@ def _dilate(*indices):
if not_zero:
not_zero = tvm.tir.all(*not_zero)
return tvm.tir.if_then_else(
not_zero, data(*index_tuple), tvm.tir.const(0.0, data.dtype)
not_zero, data(*index_tuple), tvm.tir.const(dilation_value, data.dtype)
)
return data(*index_tuple)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/topi/testing/dilate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def dilate_python(input_np, strides):
def dilate_python(input_np, strides, dilation_value=0.0):
"""Dilate operation.
Parameters
Expand All @@ -30,6 +30,9 @@ def dilate_python(input_np, strides):
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.
dilation_value : int/float, optional
Value used to dilate the input.
Returns
-------
output_np : numpy.ndarray
Expand All @@ -45,7 +48,8 @@ def dilate_python(input_np, strides):
for i in range(n):
output_size += ((input_np.shape[i] - 1) * strides[i] + 1,)
no_zero += ((range(0, output_size[i], strides[i])),)
output_np = np.zeros(shape=output_size)
output_np = np.ones(shape=output_size)
output_np = dilation_value * output_np
output_np[np.ix_(*no_zero)] = input_np

return output_np
5 changes: 3 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,10 @@ bool DilateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

// Positional relay function to create dilate operator used by frontend FFI.
Expr MakeDilate(Expr data, Array<IndexExpr> strides) {
Expr MakeDilate(Expr data, Array<IndexExpr> strides, double dilation_value = 0.0) {
auto attrs = make_object<DilateAttrs>();
attrs->strides = std::move(strides);
attrs->dilation_value = std::move(dilation_value);
static const Op& op = Op::Get("nn.dilate");
return Call(op, {data}, Attrs(attrs), {});
}
Expand All @@ -972,7 +973,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate);

RELAY_REGISTER_OP("nn.dilate")
.describe(R"code(
Dilate data with zeros.
Dilate data with given dilation value (0 by default).
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("x", "1D Tensor", "Data to dilate.")
Expand Down
2 changes: 1 addition & 1 deletion src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValu

/* Ops from nn/dilate.h */
TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::dilate(args[0], args[1]);
*rv = nn::dilate(args[0], args[1], args[2]);
});

/* Ops from nn/flatten.h */
Expand Down
13 changes: 10 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,18 +740,24 @@ def test_any_pad():
verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))


def verify_any_dilate(data_shape, strides, static_data_shape):
def verify_any_dilate(data_shape, strides, static_data_shape, dilation_value=None):
assert len(data_shape) == len(strides)
mod = tvm.IRModule()
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
y = relay.nn.dilate(data, strides)
if dilation_value is None:
y = relay.nn.dilate(data, strides)
else:
y = relay.nn.dilate(data, strides, dilation_value)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
ref_shape = tuple(
(static_data_shape[i] - 1) * strides[i] + 1 for i in range(len(static_data_shape))
)
ref_out = np.zeros(shape=ref_shape, dtype=dtype)
if dilation_value is None:
dilation_value = 0.0
ref_out = np.ones(shape=ref_shape, dtype=dtype)
ref_out = dilation_value * ref_out
ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np
check_result([data_np], mod, ref_out)

Expand All @@ -766,6 +772,7 @@ def test_any_dilate():
verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3))
verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3))
verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4))
verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4), 1.0)


def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
Expand Down
13 changes: 10 additions & 3 deletions tests/python/topi/python/test_topi_dilate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@ def test_dilate():
target = "llvm"
ctx = tvm.cpu(0)

def _test_dilate(input_size, strides):
def _test_dilate(input_size, strides, dilation_value=None):
Input = te.placeholder((input_size))
Output = topi.nn.dilate(Input, strides)
if dilation_value is None:
Output = topi.nn.dilate(Input, strides)
else:
Output = topi.nn.dilate(Input, strides, dilation_value)
schedule = te.create_schedule(Output.op)
input_np = np.random.uniform(size=input_size).astype(Input.dtype)
output_np = tvm.topi.testing.dilate_python(input_np, strides)
if dilation_value is None:
output_np = tvm.topi.testing.dilate_python(input_np, strides)
else:
output_np = tvm.topi.testing.dilate_python(input_np, strides, dilation_value)
input_tvm = tvm.nd.array(input_np, ctx=ctx)
output_size = topi.util.get_const_tuple(Output.shape)
output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx)
Expand All @@ -47,6 +53,7 @@ def _test_dilate(input_size, strides):
_test_dilate((1, 32, 32, 3, 3), (2, 2, 2, 2, 2))
_test_dilate((1, 32, 32, 32, 3, 3), (1, 1, 1, 2, 2, 2))
_test_dilate((1, 32, 32, 32, 3, 3), (2, 2, 2, 1, 1, 1))
_test_dilate((1, 32, 32, 32, 3, 3), (2, 2, 2, 1, 1, 1), 1.0)


if __name__ == "__main__":
Expand Down

0 comments on commit 63d203c

Please sign in to comment.