Skip to content

Commit

Permalink
Lowing shape_optimization_pass to paddle/pir/ directory (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#63572)

* update

* to trigger CI

* fix paddle::dialect namespace

* fix

* change dir

* fix
  • Loading branch information
chen2016013 authored Apr 18, 2024
1 parent 7c6d7d7 commit b631ac7
Show file tree
Hide file tree
Showing 17 changed files with 98 additions and 76 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/op_base.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h"

namespace cinn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"

#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
Expand All @@ -45,7 +46,6 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"

COMMON_DECLARE_bool(print_ir);
COMMON_DECLARE_bool(disable_dyshape_in_train);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"

COMMON_DECLARE_bool(check_infer_symbolic);
PD_DECLARE_bool(prim_all);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
#include "paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h"
#include "paddle/fluid/pir/transforms/passes.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "paddle/pir/include/pass/pass_registry.h"

Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ set(op_dialect_srcs
${pir_op_source_file}
${pir_bwd_op_source_file}
${pir_update_op_source_file}
${api_source_file}
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/transforms/shape_optimization_pass.cc)
${api_source_file})

if(WITH_ONEDNN)
set(op_dialect_srcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,11 @@
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h"
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"

// Type inference is currently modelled executionally for operation creation
// using the `InferMetaInterface`. While `InferSymbolicShapeInterface` is used
// to implement the shape and element type inference. The return type can often
// be deduced from the deduced return shape and elemental type (queryable from
// `InferSymbolicShapeInterface`) and so type inference for tensor types can be
// implemented with `InferSymbolicShapeInterface`.
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"

namespace paddle::dialect {

class InferSymbolicShapeInterface
: public pir::OpInterfaceBase<InferSymbolicShapeInterface> {
public:
/// Defined these methods with the interface.
struct Concept {
explicit Concept(bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis))
: infer_symbolic_shapes(infer_symbolic_shapes) {}
bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
};

template <class ConcreteOp>
struct Model : public Concept {
static inline bool InferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return op->dyn_cast<ConcreteOp>().InferSymbolicShape(shape_analysis);
}

Model() : Concept(InferSymbolicShape) {}
};

/// Constructor
InferSymbolicShapeInterface(pir::Operation *op, Concept *impl)
: pir::OpInterfaceBase<InferSymbolicShapeInterface>(op), impl_(impl) {}

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);

private:
Concept *impl_;
};
using InferSymbolicShapeInterface = pir::InferSymbolicShapeInterface;

} // namespace paddle::dialect

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp, paddle::dialect::HasElementsOp,
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"

#include "paddle/phi/core/enforce.h"
#include "paddle/pir/include/core/builder.h"
#include "paddle/pir/include/core/builtin_attribute.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/transforms/passes.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/fluid/pybind/control_flow_api.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/pybind_variant_caster.h"
Expand All @@ -63,6 +62,7 @@
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "paddle/pir/include/pass/pass_registry.h"
Expand Down
11 changes: 5 additions & 6 deletions paddle/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
add_definitions(-DIR_LIBRARY)
set_property(GLOBAL PROPERTY IR_TARGETS "")

file(
GLOB_RECURSE
PIR_CPP_SOURCES
"*.cc"
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.cc
)
file(GLOB_RECURSE PIR_CPP_SOURCES "*.cc")

if(WIN32)
if(WITH_SHARED_IR)
Expand Down Expand Up @@ -56,3 +51,7 @@ else()
set(ir_targets pir)
set_property(GLOBAL PROPERTY IR_TARGETS "${ir_targets}")
endif()

if((CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
set_target_properties(pir PROPERTIES COMPILE_FLAGS "-Wno-maybe-uninitialized")
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"

// Type inference is currently modelled executionally for operation creation
// using the `InferMetaInterface`. While `InferSymbolicShapeInterface` is used
// to implement the shape and element type inference. The return type can often
// be deduced from the deduced return shape and elemental type (queryable from
// `InferSymbolicShapeInterface`) and so type inference for tensor types can be
// implemented with `InferSymbolicShapeInterface`.

namespace pir {

class InferSymbolicShapeInterface
: public pir::OpInterfaceBase<InferSymbolicShapeInterface> {
public:
/// Defined these methods with the interface.
struct Concept {
explicit Concept(bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis))
: infer_symbolic_shapes(infer_symbolic_shapes) {}
bool (*infer_symbolic_shapes)(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
};

template <class ConcreteOp>
struct Model : public Concept {
static inline bool InferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
return op->dyn_cast<ConcreteOp>().InferSymbolicShape(shape_analysis);
}

Model() : Concept(InferSymbolicShape) {}
};

/// Constructor
InferSymbolicShapeInterface(pir::Operation *op, Concept *impl)
: pir::OpInterfaceBase<InferSymbolicShapeInterface>(op), impl_(impl) {}

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);

private:
Concept *impl_;
};

} // namespace pir

IR_DECLARE_EXPLICIT_TYPE_ID(pir::InferSymbolicShapeInterface)
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"

// This file implements the infer_symbolic_shape interface for both paddle and
// cinn operators.

// Add `interfaces : paddle::dialect::InferSymbolicShapeInterface` in relative
// Add `interfaces : pir::InferSymbolicShapeInterface` in relative
// yaml file to conresponding op.

// Since necessary checks have been done in the Op's `InferMeta` and `VeriySig`,
// no more repetitive work here.

namespace paddle::dialect {
namespace pir {

bool InferSymbolicShapeInterface::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return impl_->infer_symbolic_shapes(operation(), shape_analysis);
}
} // namespace paddle::dialect

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)
} // namespace pir

IR_DEFINE_EXPLICIT_TYPE_ID(pir::InferSymbolicShapeInterface)
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"

#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/dialect.h"
#include "paddle/pir/include/core/ir_printer.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/pass/pass_manager.h"
Expand Down Expand Up @@ -173,9 +175,9 @@ void CheckInferSymWithInferMeta(
// InferMeta funcs of some Ops are not corrrect now, we don't check them.
if (!NeedCheckInferSymbolicWithInferMeta(op->name(), i)) continue;

if (res.type().isa<paddle::dialect::DenseTensorType>()) {
const std::vector<int64_t>& infer_meta_shape = common::vectorize(
res.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
if (res.type().isa<pir::DenseTensorType>()) {
const std::vector<int64_t>& infer_meta_shape =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
const std::vector<symbol::DimExpr>& infer_sym_shape =
shape_analysis->GetShapeOrDataForValue(res).shape();

Expand Down Expand Up @@ -272,12 +274,11 @@ class ShapeOptimizationPass : public pir::Pass {

static inline bool IsStaticShape(const Value& value) {
const auto& value_type = value.type();
if (!value || !value_type ||
!value_type.isa<paddle::dialect::DenseTensorType>()) {
if (!value || !value_type || !value_type.isa<pir::DenseTensorType>()) {
return false;
}
return !::common::contain_unknown_dim(
value_type.dyn_cast<paddle::dialect::DenseTensorType>().dims());
value_type.dyn_cast<pir::DenseTensorType>().dims());
}

symbol::ShapeOrDataDimExprs CreateShapeOrDataByDDim(const pir::DDim& dims) {
Expand All @@ -292,7 +293,7 @@ void InferSymExprForBlock(const Block& block,
ShapeConstraintIRAnalysis* shape_analysis) {
for (auto& op : block) {
auto infer_symbolic_shape_interface =
op.dyn_cast<paddle::dialect::InferSymbolicShapeInterface>();
op.dyn_cast<pir::InferSymbolicShapeInterface>();
if (infer_symbolic_shape_interface) {
PrintOpInfo(&op);
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -326,10 +327,7 @@ void InferSymExprForBlock(const Block& block,
shape_analysis->SetShapeOrDataForValue(
op.result(i),
CreateShapeOrDataByDDim(
op.result(i)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()));
op.result(i).type().dyn_cast<pir::DenseTensorType>().dims()));
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/cinn/adt/map_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/ir/shape_op.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/shape_dialect/infer_symbolic_shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include <gtest/gtest.h>
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type_interfaces.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/shape_dialect/shape_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include <gtest/gtest.h>
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type_interfaces.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/shape_dialect/shape_optimization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include <gtest/gtest.h>
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/core/builtin_type_interfaces.h"
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

Expand Down

0 comments on commit b631ac7

Please sign in to comment.