Skip to content

Commit

Permalink
[Operator]Delete XShape for squeeze output (PaddlePaddle#67355)
Browse files Browse the repository at this point in the history
* [Operator]Delete XShape for squeeze output

* fix kernel output_def is not match op defination

* fix spmd unittest

* fix UT

* fix squeeze_grad translator

* fix unused var
  • Loading branch information
Aurelius84 authored Aug 14, 2024
1 parent 47aa76e commit 413d127
Show file tree
Hide file tree
Showing 26 changed files with 301 additions and 193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,9 @@ class SqueezeOpPattern
in_shape[i]));
}
}

ReplaceWithCinnReshapeOp(op, rewriter, output_shape);
auto cinn_reshape = rewriter.Build<cinn::dialect::ReshapeOp>(
op->operand_source(0), output_shape);
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
rewriter.EraseOp(op);

return true;
Expand Down Expand Up @@ -956,7 +957,6 @@ class UnsqueezeOpPattern
output_shape.push_back(1);
}
}

ReplaceWithCinnReshapeOp(op, rewriter, output_shape);
rewriter.EraseOp(op);

Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,8 @@
"imag",
"diagonal",
"flatten",
"flatten_infer",
"reshape",
"slice",
"squeeze_infer",
"squeeze",
"strided_slice",
"strided_slice_raw",
Expand All @@ -164,9 +162,7 @@
"real_",
"imag_",
"diagonal_",
"flatten_infer_",
"slice_",
"squeeze_infer_",
"strided_slice_",
"strided_slice_raw_",
"tensor_unfold_",
Expand Down
159 changes: 112 additions & 47 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3591,6 +3591,92 @@ struct QuantizeLinearOpTranscriber : public OpTranscriber {
}
};

// NOTE(Dev): heleper funtions for WithXShapeGradOpTranscriber
static std::pair<pir::Value, pir::Value> ParseXAndOutGradValue(
const OpDesc& op_desc,
pir::IrContext* ctx,
pir::Builder* builder,
TranslationContext* param_map,
pir::Block* block) {
auto& input_xshape_name = op_desc.Input("XShape")[0];
auto& input_outgrad_name = op_desc.Input("Out@GRAD")[0];
pir::Value xshape_value;
VLOG(10) << "create data op for " << input_xshape_name;
auto var_desc = op_desc.Block()->FindVarRecursive(input_xshape_name);
auto dtype = ::phi::TransToPhiDataType(var_desc->GetDataType());
auto shape_vec = var_desc->GetShape();
// NOTE(dev): GrapOp depends on X instead of XShape, so we need
// earse fisrt element in xshape.
shape_vec.erase(shape_vec.begin());
xshape_value = builder
->Build<paddle::dialect::DataOp>(
input_xshape_name, shape_vec, dtype, phi::Place())
.result(0);

VLOG(10) << "create data op for " << input_xshape_name << " done";

if (param_map->Has(input_xshape_name)) {
auto value =
param_map->at(input_xshape_name).value.dyn_cast<pir::OpResult>();
auto* defining_op = value.owner();
value.ReplaceAllUsesWith(xshape_value);
param_map->PopValue(input_xshape_name);
defining_op->Erase();
}

param_map->PushValue(input_xshape_name, xshape_value);
PADDLE_ENFORCE_EQ(param_map->Has(input_outgrad_name),
true,
common::errors::InvalidArgument(
"Reshape2_Grad op does not have input Out@GRAD"));
auto input_outgrad_value_info = param_map->at(input_outgrad_name);
if (input_outgrad_value_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, block, input_outgrad_value_info, input_outgrad_name);
input_outgrad_value_info = param_map->at(input_outgrad_name);
}
pir::Value input_outgrad_value = input_outgrad_value_info.value;

PADDLE_ENFORCE_EQ(
input_outgrad_value.type().isa<paddle::dialect::DenseTensorType>(),
true,
::common::errors::InvalidArgument(
"input type must be DenseTensorType, but received: %s.",
input_outgrad_value.type()));

return std::make_pair(xshape_value, input_outgrad_value);
}

static pir::Value ParseAxis(const OpDesc& op_desc,
TranslationContext* param_map,
pir::IrContext* ctx,
pir::Block* block) {
// process axes
if (op_desc.HasInput("AxesTensor") && !op_desc.Input("AxesTensor").empty()) {
// get axis from input
auto axis_var_list = op_desc.Input("AxesTensor");
PADDLE_ENFORCE_EQ(
axis_var_list.size(),
1UL,
common::errors::InvalidArgument(
"axis tensor input of %s MUST be a tensor", op_desc.Type()));
auto axis_defining_info = (*param_map)[axis_var_list[0]];
return axis_defining_info.value;
} else if (op_desc.HasInput("AxesTensorList") &&
!op_desc.Input("AxesTensorList").empty()) {
auto* combine_op = InsertCombineOperationForTarget(
ctx, param_map, block, op_desc.Input("AxesTensorList"));
return combine_op->result(0);
} else {
auto& attribute_translator = AttributeTranslator::instance();
pir::Attribute new_attr = attribute_translator(
"paddle::dialect::IntArrayAttribute", op_desc.GetAttr("axes"));
auto full_array_op =
InsertFullArrayOperationForAttributeInput(ctx, block, new_attr);
return full_array_op->result(0);
}
}

template <typename OpT>
struct WithXShapeGradOpTranscriber : public OpTranscriber {
pir::Operation* operator()(pir::IrContext* ctx,
Expand All @@ -3599,53 +3685,9 @@ struct WithXShapeGradOpTranscriber : public OpTranscriber {
pir::Block* block) override {
VLOG(4) << "Translate " << op_desc.Type() << ".....";
pir::Builder builder(ctx, block);
auto& input_xshape_name = op_desc.Input("XShape")[0];
auto& input_outgrad_name = op_desc.Input("Out@GRAD")[0];
auto [xshape_value, input_outgrad_value] =
ParseXAndOutGradValue(op_desc, ctx, &builder, param_map, block);
auto& out_name = op_desc.Output("X@GRAD")[0];
pir::Value xshape_value;
VLOG(10) << "create data op for " << input_xshape_name;
auto var_desc = op_desc.Block()->FindVarRecursive(input_xshape_name);
auto dtype = ::phi::TransToPhiDataType(var_desc->GetDataType());
auto shape_vec = var_desc->GetShape();
shape_vec.erase(shape_vec.begin());
xshape_value = builder
.Build<paddle::dialect::DataOp>(
input_xshape_name, shape_vec, dtype, phi::Place())
.result(0);

VLOG(10) << "create data op for " << input_xshape_name << " done";

if (param_map->Has(input_xshape_name)) {
auto value =
param_map->at(input_xshape_name).value.dyn_cast<pir::OpResult>();
auto* defining_op = value.owner();
value.ReplaceAllUsesWith(xshape_value);
param_map->PopValue(input_xshape_name);
defining_op->Erase();
}

param_map->PushValue(input_xshape_name, xshape_value);
auto* defining_op = xshape_value.dyn_cast<pir::OpResult>().owner();
auto attr_map = defining_op->attributes();

PADDLE_ENFORCE_EQ(param_map->Has(input_outgrad_name),
true,
common::errors::InvalidArgument(
"Reshape2_Grad op does not have input Out@GRAD"));
auto input_outgrad_value_info = param_map->at(input_outgrad_name);
if (input_outgrad_value_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, block, input_outgrad_value_info, input_outgrad_name);
input_outgrad_value_info = param_map->at(input_outgrad_name);
}
pir::Value input_outgrad_value = input_outgrad_value_info.value;

PADDLE_ENFORCE_EQ(
input_outgrad_value.type().isa<paddle::dialect::DenseTensorType>(),
true,
::common::errors::InvalidArgument(
"input type must be DenseTensorType, but received: %s.",
input_outgrad_value.type()));
// NOTE(Aurelius84): Even though we use xshape to construct grad op,
// but in GradKernel we still use dx->dims by default.
OpT grad_op = builder.Build<OpT>(xshape_value, input_outgrad_value);
Expand All @@ -3655,6 +3697,28 @@ struct WithXShapeGradOpTranscriber : public OpTranscriber {
}
};

// NOTE(dev): In case of squeeze_grad and unsqueeze_grad
template <typename OpT>
struct WithXShapeAndAxisGradOpTranscriber : public OpTranscriber {
pir::Operation* operator()(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
pir::Block* block) override {
VLOG(4) << "Translate " << op_desc.Type() << ".....";
pir::Builder builder(ctx, block);
auto [x_value, input_outgrad_value] =
ParseXAndOutGradValue(op_desc, ctx, &builder, param_map, block);
auto& out_name = op_desc.Output("X@GRAD")[0];
// NOTE(Aurelius84): Even though we use xshape to construct grad op,
// but in GradKernel we still use dx->dims by default.
pir::Value axis = ParseAxis(op_desc, param_map, ctx, block);
OpT grad_op = builder.Build<OpT>(x_value, input_outgrad_value, axis);
param_map->PushValue(out_name, grad_op.result(0));

return grad_op.operation();
}
};

OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down Expand Up @@ -3752,7 +3816,8 @@ OpTranslator::OpTranslator() {
WithXShapeGradOpTranscriber<dialect::ReshapeGradOp>();
special_handlers["flatten_contiguous_range_grad"] =
WithXShapeGradOpTranscriber<dialect::FlattenGradOp>();
special_handlers["squeeze2_grad"] =
WithXShapeAndAxisGradOpTranscriber<dialect::SqueezeGradOp>();
}

} // namespace translator
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -2366,8 +2366,6 @@ bool SqueezeOpInferSymbolicShape(

pir::Value res = op->result(0);
infer_context->SetShapeOrDataForValue(res, shape_data);
infer_context->SetShapeOrDataForValue(
op->result(1), CreateShapeOrDataForXShape(x_shape_or_data));

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,10 @@ class FusedRotaryPositionEmbeddingPattern : public paddle::drr::DrrPatternBase {
const auto &concat_op_k = pat.Op(paddle::dialect::ConcatOp::name());
const auto &combine_k = pat.Op(pir::CombineOp::name());

squeeze({&pat.Tensor("cos"), &full_13()},
{&pat.Tensor("squeeze_out_cos"), &pat.Tensor("xshape")});
squeeze({&pat.Tensor("cos"), &full_13()}, {&pat.Tensor("squeeze_out_cos")});

squeeze_1({&pat.Tensor("sin"), &full_12()},
{&pat.Tensor("squeeze_out_sin"), &pat.Tensor("xshape")});
{&pat.Tensor("squeeze_out_sin")});

unsqueeze({&pat.Tensor("position_ids"), &full_11()},
{&pat.Tensor("unsqueeze_s_out_cos"), &pat.Tensor("xshape")});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class SqueezeTransposePattern : public paddle::drr::DrrPatternBase {
const auto &full_1 = pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_1_value")}});

squeeze({&pat.Tensor("x"), &full_1()},
{&pat.Tensor("squeeze_out"), &pat.Tensor("xshape")});
squeeze({&pat.Tensor("x"), &full_1()}, {&pat.Tensor("squeeze_out")});

const auto &transpose = pat.Op(paddle::dialect::TransposeOp::name(),
{{"perm", pat.Attr("perm")}});
Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,11 @@ Tensor relu6_decomp(const Tensor& x) {
}

template <typename T>
std::tuple<Tensor, Tensor> squeeze_decomp(const Tensor& x,
const IntArray& axis) {
Tensor squeeze_decomp(const Tensor& x, const IntArray& axis) {
auto axis_ = process_dims(x, axis.GetData());
auto out_shape = get_squeeze_dims(x, axis_);
Tensor out = reshape<T>(x, out_shape);
Tensor xshape;
return std::make_tuple(out, xshape);
return out;
}

template <typename T>
Expand Down Expand Up @@ -1460,7 +1458,7 @@ Tensor embedding_decomp(const Tensor& x,
if (x.dims().size() <= 1) {
res = gather<T>(weight_tmp, x);
if (x.dims().size() == 0) {
res = std::get<0>(squeeze_decomp<T>(res, {0}));
res = squeeze_decomp<T>(res, {0});
}
} else {
std::vector<int64_t> tar_shape{-1};
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ void softmax_grad(const Tensor& out,
}

template <typename T>
void squeeze_grad(const Tensor& xshape,
void squeeze_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
Tensor* x_grad) {
Expand Down
18 changes: 4 additions & 14 deletions paddle/phi/infermeta/spmd_rules/squeeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ namespace distributed {

using phi::distributed::auto_parallel::str_join;

TensorDistAttr CreateSqueezeXshape(const TensorDistAttr& x) {
TensorDistAttr out(x);
auto dims_mapping = x.dims_mapping();
dims_mapping.insert(dims_mapping.begin(), -1);
out.set_dims_mapping(dims_mapping);
return out;
}

void MakeSqueezeDimTransWithoutAxis(
const std::vector<int64_t>& x_shape,
std::vector<int64_t>* out_shape,
Expand Down Expand Up @@ -168,8 +160,7 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "]\n\n";

return {{x_dist_attr_dst},
{out_dist_attr, CreateSqueezeXshape(x_dist_attr_dst)}};
return {{x_dist_attr_dst}, {out_dist_attr}};
}

SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
Expand Down Expand Up @@ -246,13 +237,12 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
return {{x_dist_attr}, {out_dist_attr_dst}};
}

SpmdInfo SqueezeGradInferSpmd(const DistMetaTensor& xshape,
SpmdInfo SqueezeGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad,
const IntArray& axis) {
auto shape = phi::vectorize(xshape.dims());
shape = std::vector<int64_t>(shape.begin() + 1, shape.end());
auto shape = phi::vectorize(x.dims());
const auto& spmd = ReshapeInferSpmd(out_grad, shape);
return {{xshape.dist_attr(), spmd.first[0]}, {spmd.second[0]}};
return {{x.dist_attr(), spmd.first[0]}, {spmd.second[0]}};
}

} // namespace distributed
Expand Down
30 changes: 15 additions & 15 deletions paddle/phi/kernels/onednn/squeeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ void ExecuteSqueeze(const Context& dev_ctx,
}

template <typename T, typename Context>
void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
auto x_dims = x.dims();
auto x_dims_tz = x_dims.size();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
Expand Down Expand Up @@ -87,13 +87,13 @@ void SqueezeInferKernel(const Context& dev_ctx,
}

template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
if (xshape == nullptr) {
SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
} else {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims();
Expand All @@ -102,12 +102,12 @@ void SqueezeKernel(const Context& dev_ctx,
}
} // namespace phi

PD_REGISTER_KERNEL(squeeze_infer,
PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(squeeze_with_xshape,
OneDNN,
ONEDNN,
phi::SqueezeInferKernel,
phi::SqueezeWithXShapeKernel,
float,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}
Loading

0 comments on commit 413d127

Please sign in to comment.