Skip to content

Commit

Permalink
[CINN]fix bug in cinn stack op_mapper (PaddlePaddle#57307)
Browse files Browse the repository at this point in the history
fix bug in cinn stack op_mapper, the output shape is wrong when axis < 0
  • Loading branch information
lanxianghit authored Sep 18, 2023
1 parent 169b785 commit 1095f9e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 32 deletions.
47 changes: 15 additions & 32 deletions paddle/cinn/frontend/op_mappers/paddle/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,23 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc,
"but here cannot found! Please check.";
}

cinn::utils::ShapeType input_shape(ctx.GetVar(x_names.front())->shape);
auto axis = utils::GetAttrOrDefault<int>(op_desc, "axis", 0);
axis = axis >= 0 ? axis : axis + input_shape.size() + 1;
cinn::utils::ShapeType output_shape(input_shape);
output_shape.insert(output_shape.begin() + axis, 1);

std::vector<Variable> xs;
for (const auto& name : x_names) {
xs.emplace_back(ctx.GetVar(name));
auto x = ctx.GetVar(name);
CHECK(x->shape == input_shape)
<< "All input shape of [stack] should be the same, be the input "
<< x->id << "'s shape [" << cinn::utils::Join(x->shape, ", ")
<< "] not equal to "
<< "the first input " << ctx.GetVar(x_names.front())->id << "'s shape ["
<< cinn::utils::Join(input_shape, ", ") << "]";

xs.emplace_back(ctx.Builder()->Reshape(x, output_shape));
}

auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) {
Expand All @@ -83,39 +95,10 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc,
<< "] not equal to the first input " << xs.front()->id << "'s dtype ["
<< xs.front()->type << "]";

err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) {
return x->shape != xs.front()->shape;
});
CHECK(err_x == xs.end())
<< "All input shape of [stack] should be the same, be the input "
<< (*err_x)->id << "'s shape ["
<< cinn::utils::Join((*err_x)->shape, ", ") << "] not equal to "
<< "the first input " << xs.front()->id << "'s shape ["
<< cinn::utils::Join(xs.front()->shape, ", ") << "]";

auto concat_out = ctx.Builder()->Concat(xs, axis);

int rank = concat_out->shape.size();
axis = axis >= 0 ? axis : axis + rank;
CHECK(axis >= 0 && axis < rank)
<< "The axis of stack should >=0 and <rank(x)! Please check.";

// N * [A, B] with axis=0 --> [N, A, B]; N * [A, B] with axis=1 --> [A, N, B];
cinn::utils::ShapeType new_shape;
for (int i = 0; i < rank; ++i) {
auto dim = concat_out->shape[i];
if (i != axis) {
new_shape.emplace_back(dim);
} else {
new_shape.emplace_back(xs.size());
// the shape same ensure `dim % xs.size() == 0`
new_shape.emplace_back(dim / xs.size());
}
}
auto out = ctx.Builder()->Reshape(concat_out, new_shape);

ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
ctx.AddVar(out_name, concat_out);
ctx.AddVarModelToProgram(out_name, concat_out->id);
}

void SplitOpMapper(const paddle::cpp::OpDesc& op_desc,
Expand Down
5 changes: 5 additions & 0 deletions test/cinn/op_mappers/test_stack_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,10 @@ def test_check_results(self):
self.check_outputs_and_grads()


class TestStackOpAxisNegative(TestStackOp):
def set_op_attrs(self):
return {"axis": -1}


if __name__ == "__main__":
unittest.main()

0 comments on commit 1095f9e

Please sign in to comment.