Skip to content

Commit

Permalink
[Relay][Op] Fix Reshape Compute (apache#6396)
Browse files Browse the repository at this point in the history
* Fix Reshape Compute

* Fix test

* Fix lint

* Fix lint

* Fix

* Fix lint

* Fix test

* Rebase test
  • Loading branch information
kevinthesun authored Sep 11, 2020
1 parent aeef16d commit ecba2f3
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 30 deletions.
85 changes: 55 additions & 30 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,27 +451,17 @@ RELAY_REGISTER_OP("transpose")
/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);

bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs) {
const auto* param = attrs.as<ReshapeAttrs>();
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "reshape: expect input type to be TensorType but get " << types[0];
return false;
}

Array<IndexExpr> oshape;
Array<IndexExpr> data_shape;
Array<IndexExpr> ishape;
Array<Integer> newshape;

if (param->reverse) {
data_shape.Assign(data->shape.rbegin(), data->shape.rend());
ishape.Assign(data_shape.rbegin(), data_shape.rend());
newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
} else {
data_shape = data->shape;
ishape = data_shape;
newshape = param->newshape;
}

Expand All @@ -488,10 +478,10 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
++src_idx;
} else if (svalue == 0) {
// keep same
CHECK_LT(src_idx, data_shape.size());
CHECK_LT(src_idx, ishape.size());
used_input_dims.insert(src_idx);
used_output_dims.insert(oshape.size());
oshape.push_back(data_shape[src_idx++]);
oshape.push_back(ishape[src_idx++]);
} else if (svalue == -1) {
// inference based on rest
CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred";
Expand All @@ -500,18 +490,18 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
++src_idx;
} else if (svalue == -2) {
// copy all remaining dims from source
while (src_idx < data_shape.size()) {
while (src_idx < ishape.size()) {
used_input_dims.insert(src_idx);
used_output_dims.insert(oshape.size());
oshape.push_back(data_shape[src_idx++]);
oshape.push_back(ishape[src_idx++]);
}
} else if (svalue == -3) {
// merge two dims from source
CHECK_LT(src_idx + 1, data_shape.size());
CHECK_LT(src_idx + 1, ishape.size());
used_input_dims.insert(src_idx);
IndexExpr d1 = data_shape[src_idx++];
IndexExpr d1 = ishape[src_idx++];
used_input_dims.insert(src_idx);
IndexExpr d2 = data_shape[src_idx++];
IndexExpr d2 = ishape[src_idx++];
used_output_dims.insert(oshape.size());
if (d1.as<AnyNode>() || d2.as<AnyNode>()) {
oshape.push_back(Any());
Expand All @@ -522,13 +512,13 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// split the source dim s into two dims
// read the left dim and then the right dim (either can be -1)
CHECK_LT(i + 2, newshape.size());
CHECK_LT(src_idx, data_shape.size());
CHECK_LT(src_idx, ishape.size());
used_input_dims.insert(src_idx);
IndexExpr d0 = data_shape[src_idx++];
IndexExpr d0 = ishape[src_idx++];
Integer d1 = newshape[++i];
Integer d2 = newshape[++i];
if (d1->value == -1) {
CHECK(d2->value != -1) << "Split dims cannot both be -1.";
CHECK_NE(d2->value, -1) << "Split dims cannot both be -1.";
used_output_dims.insert(oshape.size());
if (d0.as<AnyNode>()) {
oshape.push_back(Any());
Expand All @@ -552,21 +542,21 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
}
} else {
CHECK(false) << "Unsupported special value: " << svalue;
LOG(FATAL) << "Unsupported special value: " << svalue;
}
}

if (infer_idx >= 0) {
IndexExpr infer_dim = 1;
for (size_t i = 0; i < data_shape.size(); ++i) {
for (size_t i = 0; i < ishape.size(); ++i) {
if (used_input_dims.count(i) != 0) {
continue;
}
if (data_shape[i].as<AnyNode>()) {
if (ishape[i].as<AnyNode>()) {
infer_dim = Any();
break;
}
infer_dim *= data_shape[i];
infer_dim *= ishape[i];
}
if (!infer_dim.as<AnyNode>()) {
for (size_t i = 0; i < oshape.size(); ++i) {
Expand All @@ -585,8 +575,32 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
oshape.Set(infer_idx, infer_dim);
}

return oshape;
}

bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const auto* param = attrs.as<ReshapeAttrs>();
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "reshape: expect input type to be TensorType but get " << types[0];
return false;
}

const auto& oshape = infer_newshape(data->shape, attrs);

// Verify that the sum of dimensions in the output shape is the sum of
// dimensions in the input shape
Array<IndexExpr> data_shape;
if (param->reverse) {
data_shape.Assign(data->shape.rbegin(), data->shape.rend());
} else {
data_shape = data->shape;
}

bool found_dynamic = false;
int64_t oshape_sum = 1;
for (auto& x : oshape) {
Expand Down Expand Up @@ -626,16 +640,27 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
// Quick path for reshape_like
if (!attrs.as<ReshapeAttrs>()) {
return {topi::reshape(inputs[0], inputs[1]->shape)};
}

const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
Array<IndexExpr> newshape;
bool newshape_has_any = false;
for (auto val : out_ttype->shape) {
if (val->IsInstance<tir::AnyNode>()) {
newshape.push_back(val.as<tir::AnyNode>()->ToVar());
if (val->IsInstance<tir::AnyNode>() || val->IsInstance<tir::VarNode>()) {
newshape_has_any = true;
break;
} else {
newshape.push_back(val);
}
}

if (newshape_has_any) {
newshape = infer_newshape(inputs[0]->shape, attrs);
}
return {topi::reshape(inputs[0], newshape)};
}

Expand Down
9 changes: 9 additions & 0 deletions src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ static inline Array<Array<Layout>> ConcatenateLayout(const Attrs& attrs,
return Array<Array<Layout>>{Array<Layout>(old_in_layouts.size(), ret), {ret}};
}

/*!
* \brief Infer output shape for reshape.
*
* \param data_shape The input data shape.
* \param attrs The attributes.
* \return Output shape.
*/
Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_
27 changes: 27 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,5 +865,32 @@ def test_any_consecutive_broadcast():
((np_data2 + np_data3) - (np_data2 * np_data3))
check_result([np_data0, np_data1, np_data2, np_data3], mod, ref_res)

def test_reshape_concat():
dtype = "float32"
d0 = relay.var("d0", shape=any_dims(2), dtype=dtype)
d1 = relay.var("d1", shape=any_dims(3), dtype=dtype)
out = relay.op.concatenate([relay.op.reshape(d0, [-1]), relay.op.reshape(d1, [-1])], axis=0)
mod = tvm.IRModule()
mod['main'] = relay.Function([d0, d1], out)
np_data0 = np.random.uniform(size=(4, 5)).astype(dtype)
np_data1 = np.random.uniform(size=(2, 5, 2)).astype(dtype)
ref_res = np.concatenate([np.reshape(np_data0, [-1]), np.reshape(np_data1, [-1])], axis=0)
check_result([np_data0, np_data1], mod, ref_res)

d0 = relay.var("d0", shape=any_dims(2), dtype=dtype)
d1 = relay.var("d1", shape=any_dims(2), dtype=dtype)
s0 = relay.var("s0", shape=any_dims(3), dtype=dtype)
s1 = relay.var("s1", shape=any_dims(3), dtype=dtype)
out = relay.op.concatenate([relay.op.reshape_like(d0, s0), relay.op.reshape_like(d1, s1)], axis=0)
mod = tvm.IRModule()
mod['main'] = relay.Function([d0, d1, s0, s1], out)
np_data0 = np.random.uniform(size=(4, 5)).astype(dtype)
np_data1 = np.random.uniform(size=(8, 5)).astype(dtype)
np_shape_like0 = np.random.uniform(size=(2, 2, 5)).astype(dtype)
np_shape_like1 = np.random.uniform(size=(4, 2, 5)).astype(dtype)
ref_res = np.concatenate([np.reshape(np_data0, np_shape_like0.shape),
np.reshape(np_data1, np_shape_like1.shape)], axis=0)
check_result([np_data0, np_data1, np_shape_like0, np_shape_like1], mod, ref_res)

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit ecba2f3

Please sign in to comment.