From ecba2f353237a7edcc634c9fe00c25ce5a77bcd4 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 10 Sep 2020 17:42:25 -0700 Subject: [PATCH] [Relay][Op] Fix Reshape Compute (#6396) * Fix Reshape Compute * Fix test * Fix lint * Fix lint * Fix * Fix lint * Fix test * Rebase test --- src/relay/op/tensor/transform.cc | 85 +++++++++++++++++++++----------- src/relay/op/tensor/transform.h | 9 ++++ tests/python/relay/test_any.py | 27 ++++++++++ 3 files changed, 91 insertions(+), 30 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bf6ce4d27f99..88179b796888 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -451,27 +451,17 @@ RELAY_REGISTER_OP("transpose") /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); -bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +Array infer_newshape(const Array& data_shape, const Attrs& attrs) { const auto* param = attrs.as(); - // types: [data, result] - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) { - CHECK(types[0].as()) - << "reshape: expect input type to be TensorType but get " << types[0]; - return false; - } - Array oshape; - Array data_shape; + Array ishape; Array newshape; if (param->reverse) { - data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + ishape.Assign(data_shape.rbegin(), data_shape.rend()); newshape.Assign(param->newshape.rbegin(), param->newshape.rend()); } else { - data_shape = data->shape; + ishape = data_shape; newshape = param->newshape; } @@ -488,10 +478,10 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, ++src_idx; } else if (svalue == 0) { // keep same - CHECK_LT(src_idx, data_shape.size()); + CHECK_LT(src_idx, ishape.size()); used_input_dims.insert(src_idx); used_output_dims.insert(oshape.size()); - oshape.push_back(data_shape[src_idx++]); + oshape.push_back(ishape[src_idx++]); } else if (svalue == -1) { // inference based on rest CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred"; @@ -500,18 +490,18 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, ++src_idx; } else if (svalue == -2) { // copy all remaining dims from source - while (src_idx < data_shape.size()) { + while (src_idx < ishape.size()) { used_input_dims.insert(src_idx); used_output_dims.insert(oshape.size()); - oshape.push_back(data_shape[src_idx++]); + oshape.push_back(ishape[src_idx++]); } } else if (svalue == -3) { // merge two dims from source - CHECK_LT(src_idx + 1, data_shape.size()); + CHECK_LT(src_idx + 1, ishape.size()); used_input_dims.insert(src_idx); - IndexExpr d1 = data_shape[src_idx++]; + IndexExpr d1 = ishape[src_idx++]; used_input_dims.insert(src_idx); - IndexExpr d2 = data_shape[src_idx++]; + IndexExpr d2 = ishape[src_idx++]; used_output_dims.insert(oshape.size()); if (d1.as() || d2.as()) { oshape.push_back(Any()); @@ -522,13 +512,13 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, // split the source dim s into two dims // read the left dim and then the right dim (either can be -1) CHECK_LT(i + 2, newshape.size()); - CHECK_LT(src_idx, data_shape.size()); + CHECK_LT(src_idx, ishape.size()); used_input_dims.insert(src_idx); - IndexExpr d0 = data_shape[src_idx++]; + IndexExpr d0 = ishape[src_idx++]; Integer d1 = newshape[++i]; Integer d2 = newshape[++i]; if (d1->value == -1) { - CHECK(d2->value != -1) << "Split dims cannot both be -1."; + CHECK_NE(d2->value, -1) << "Split dims cannot both be -1."; used_output_dims.insert(oshape.size()); if (d0.as()) { oshape.push_back(Any()); @@ -552,21 +542,21 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, } } } else { - CHECK(false) << "Unsupported special value: " << svalue; + LOG(FATAL) << "Unsupported special value: " << svalue; } } if (infer_idx >= 0) { IndexExpr infer_dim = 1; - for (size_t i = 0; i < data_shape.size(); ++i) { + for (size_t i = 0; i < ishape.size(); ++i) { if (used_input_dims.count(i) != 0) { continue; } - if (data_shape[i].as()) { + if (ishape[i].as()) { infer_dim = Any(); break; } - infer_dim *= data_shape[i]; + infer_dim *= ishape[i]; } if (!infer_dim.as()) { for (size_t i = 0; i < oshape.size(); ++i) { @@ -585,8 +575,32 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.Set(infer_idx, infer_dim); } + return oshape; +} + +bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const auto* param = attrs.as(); + // types: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "reshape: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto& oshape = infer_newshape(data->shape, attrs); + // Verify that the sum of dimensions in the output shape is the sum of // dimensions in the input shape + Array data_shape; + if (param->reverse) { + data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + } else { + data_shape = data->shape; + } + bool found_dynamic = false; int64_t oshape_sum = 1; for (auto& x : oshape) { @@ -626,16 +640,27 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, Array ReshapeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { + // Quick path for reshape_like + if (!attrs.as()) { + return {topi::reshape(inputs[0], inputs[1]->shape)}; + } + const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); Array newshape; + bool newshape_has_any = false; for (auto val : out_ttype->shape) { - if (val->IsInstance()) { - newshape.push_back(val.as()->ToVar()); + if (val->IsInstance() || val->IsInstance()) { + newshape_has_any = true; + break; } else { newshape.push_back(val); } } + + if (newshape_has_any) { + newshape = infer_newshape(inputs[0]->shape, attrs); + } return {topi::reshape(inputs[0], newshape)}; } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 4e5677a1af6d..0fe4734fe883 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -180,6 +180,15 @@ static inline Array> ConcatenateLayout(const Attrs& attrs, return Array>{Array(old_in_layouts.size(), ret), {ret}}; } +/*! + * \brief Infer output shape for reshape. + * + * \param data_shape The input data shape. + * \param attrs The attributes. + * \return Output shape. + */ +Array infer_newshape(const Array& data_shape, const Attrs& attrs); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_ diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 1fa50396409e..6bb34d350253 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -865,5 +865,32 @@ def test_any_consecutive_broadcast(): ((np_data2 + np_data3) - (np_data2 * np_data3)) check_result([np_data0, np_data1, np_data2, np_data3], mod, ref_res) +def test_reshape_concat(): + dtype = "float32" + d0 = relay.var("d0", shape=any_dims(2), dtype=dtype) + d1 = relay.var("d1", shape=any_dims(3), dtype=dtype) + out = relay.op.concatenate([relay.op.reshape(d0, [-1]), relay.op.reshape(d1, [-1])], axis=0) + mod = tvm.IRModule() + mod['main'] = relay.Function([d0, d1], out) + np_data0 = np.random.uniform(size=(4, 5)).astype(dtype) + np_data1 = np.random.uniform(size=(2, 5, 2)).astype(dtype) + ref_res = np.concatenate([np.reshape(np_data0, [-1]), np.reshape(np_data1, [-1])], axis=0) + check_result([np_data0, np_data1], mod, ref_res) + + d0 = relay.var("d0", shape=any_dims(2), dtype=dtype) + d1 = relay.var("d1", shape=any_dims(2), dtype=dtype) + s0 = relay.var("s0", shape=any_dims(3), dtype=dtype) + s1 = relay.var("s1", shape=any_dims(3), dtype=dtype) + out = relay.op.concatenate([relay.op.reshape_like(d0, s0), relay.op.reshape_like(d1, s1)], axis=0) + mod = tvm.IRModule() + mod['main'] = relay.Function([d0, d1, s0, s1], out) + np_data0 = np.random.uniform(size=(4, 5)).astype(dtype) + np_data1 = np.random.uniform(size=(8, 5)).astype(dtype) + np_shape_like0 = np.random.uniform(size=(2, 2, 5)).astype(dtype) + np_shape_like1 = np.random.uniform(size=(4, 2, 5)).astype(dtype) + ref_res = np.concatenate([np.reshape(np_data0, np_shape_like0.shape), + np.reshape(np_data1, np_shape_like1.shape)], axis=0) + check_result([np_data0, np_data1, np_shape_like0, np_shape_like1], mod, ref_res) + if __name__ == "__main__": pytest.main([__file__])