Skip to content

Commit

Permalink
Upsampling op support (dmlc#298)
Browse files Browse the repository at this point in the history
* add nnvm upsampling symbol

* add upsampling mxnet frontend

* add doc for upsampling op

* cleanup upsampling test

* minor fix

* use schedule_injective for upsampling

* upgrade tvm
  • Loading branch information
masahi authored and tqchen committed Jan 11, 2018
1 parent 0c08e4e commit a0f8612
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 1 deletion.
9 changes: 9 additions & 0 deletions include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,15 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
}
};

struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale;

DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.describe("upsampling scaling factor");
}
};

} // namespace top
} // namespace nnvm

Expand Down
7 changes: 7 additions & 0 deletions python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def _softmax_output(inputs, attrs):
new_attrs['axis'] = 1
return _get_nnvm_op(op_name)(inputs[0], **new_attrs)

def _upsampling(inputs, attrs):
scale = attrs.get('scale')
new_attrs = {'scale':int(scale)}
return _get_nnvm_op('upsampling')(inputs[0], **new_attrs)


_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
Expand Down Expand Up @@ -231,6 +237,7 @@ def _softmax_output(inputs, attrs):
'min_axis' : _rename('min'),
'reshape' : _reshape,
'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling
}

def _convert_symbol(op_name, inputs, attrs,
Expand Down
15 changes: 15 additions & 0 deletions python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,18 @@ def schedule_global_avg_pool2d(_, outs, target):
return topi.generic.schedule_global_pool(outs)

reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("upsampling")
def compute_upsampling(attrs, inputs, _):
"""Compute definition of upsampling"""
scale = attrs.get_int("scale")
return topi.nn.upsampling(inputs[0], scale)

@reg.register_schedule("upsampling")
def schedule_upsampling(_, outs, target):
"""Compute definition of upsampling"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)

reg.register_pattern("upsampling", OpPattern.OUT_ELEMWISE_FUSABLE)
52 changes: 52 additions & 0 deletions src/top/nn/upsampling.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*!
* Copyright (c) 2017 by Contributors
* \file pooling.cc
* \brief Property def of pooling operators.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "./nn_common.h"
#include "../op_common.h"
#include "../elemwise_op_common.h"

namespace nnvm {
namespace top {

DMLC_REGISTER_PARAMETER(UpSamplingParam);

inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false;
TShape oshape = dshape;
oshape[2] = oshape[2] * param.scale;
oshape[3] = oshape[3] * param.scale;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}

NNVM_REGISTER_OP(upsampling)
.describe(R"(Perform nearest neighbor upsampling to input array.
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
- **out**: Output is 4D array of shape (batch_size, channels, in_height*scale, in_width*scale).
)" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.add_arguments(UpSamplingParam::__FIELDS__())
.set_attr_parser(ParamParser<UpSamplingParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>)
.set_attr<FInferShape>("FInferShape", UpSamplingInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2);

} // namespace top
} // namespace nnvm
20 changes: 20 additions & 0 deletions tests/python/compiler/test_top_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@ def test_global_avg_pool2d():
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)


def test_upsampling():
x = sym.Variable("x")
scale = 2
y = sym.upsampling(x, scale=scale, name="y")
dtype = "float32"
dshape = (1, 16, 32, 32)
oshape = (1, 16, 32*scale, 32*scale)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = graph_runtime.create(graph, lib, ctx)
a_np = np.random.uniform(size=dshape).astype(dtype)
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.upsampling_python(a_np, scale)
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)


if __name__ == "__main__":
test_conv2d()
test_grouped_conv2d()
Expand All @@ -156,3 +175,4 @@ def test_global_avg_pool2d():
test_avg_pool2d()
test_global_max_pool2d()
test_global_avg_pool2d()
test_upsampling()
2 changes: 1 addition & 1 deletion tvm
Submodule tvm updated from 528715 to 9800fe

0 comments on commit a0f8612

Please sign in to comment.