Skip to content

Commit

Permalink
Flip operator (dmlc#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
PariksheetPinjari909 authored and tqchen committed May 25, 2018
1 parent 361a228 commit a9b896f
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/nnvm/top/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
}
};

struct FlipParam : public dmlc::Parameter<FlipParam> {
int axis;
DMLC_DECLARE_PARAMETER(FlipParam) {
DMLC_DECLARE_FIELD(axis).set_default(0)
.describe("the axis to be reveresed.");
}
};

struct BroadcastToParam : public dmlc::Parameter<BroadcastToParam> {
TShape shape;

Expand Down
4 changes: 4 additions & 0 deletions python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def compute_reshape_like(attrs, inputs, out_info):
reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective)

# flip
reg.register_pattern("flip", OpPattern.INJECTIVE)
reg.register_schedule("flip", _fschedule_injective)

# reshape
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_injective)
Expand Down
49 changes: 49 additions & 0 deletions src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -830,5 +830,54 @@ Examples::
};
});

// Flip
DMLC_REGISTER_PARAMETER(FlipParam);

NNVM_REGISTER_OP(flip)
.describe(R"code(Reverse the elements of an array.
Examples::
x = [[ 1, 2],
[ 3, 4]]
flip(x) = [[ 3., 4.],
[ 1., 2.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
flip(x) = [[[ 5., 6.],
[ 7., 8.]],
[[ 1., 2.],
[ 3., 4.]]]
flip(x, axis=1) = [[[ 3., 4.],
[ 1., 2.]],
[[ 7., 8.],
[ 5., 6.]]]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Source input")
.add_arguments(FlipParam::__FIELDS__())
.set_attr_parser(ParamParser<FlipParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FlipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(4)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const FlipParam& param = nnvm::get<FlipParam>(attrs.parsed);
return Array<Tensor>{ topi::flip(inputs[0], param.axis) };
});

} // namespace top
} // namespace nnvm
23 changes: 23 additions & 0 deletions tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,28 @@ def test_reduce():
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))

def verify_flip(ishape, axis):
x = sym.Variable("x")
y = sym.flip(x, axis=axis) + 1
dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)
res = np.flip(x_np, axis) + 1

for target, ctx in ctx_list():
# set input
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(res.shape))
np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5)

def test_flip():
verify_flip((3, 4, 3), 1)
verify_flip((3, 4, 3), 0)
verify_flip((3, 4, 3), 2)
verify_flip((3, 4, 3), -1)
verify_flip((3, 4, 3), -3)
verify_flip((3, 4, 3), -2)

def verify_reshape(dshape, oshape):
x = sym.Variable("x")
Expand Down Expand Up @@ -347,4 +369,5 @@ def test_full():
test_elemwise_sum()
test_block_grad()
test_full()
test_flip()
print(nnvm.compiler.engine.dump())

0 comments on commit a9b896f

Please sign in to comment.