Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add Div to replace Recipical in DimExpr #70376

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
22 changes: 18 additions & 4 deletions paddle/cinn/adt/dim_expr_match_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ struct UnaryDimExprMatchTrait {
}
};

template <template <typename> class Op, typename T0>
struct BinaryDimExprMatchTrait {
using base_type = Op<DimExpr>;

static constexpr int is_template = true;

template <template <typename, typename> class Matcher>
static bool MatchChildren(const base_type& value) {
const auto& lhs = value->lhs;
const auto& rhs = value->rhs;
return Matcher<T0, DimExpr>::Call(lhs) && Matcher<T0, DimExpr>::Call(rhs);
}
};

template <template <typename> class Op, typename T0>
struct ListDimExprMatchTrait {
using base_type = Op<DimExpr>;
Expand Down Expand Up @@ -65,10 +79,6 @@ template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Negative<T0>> final
: public UnaryDimExprMatchTrait<::symbol::Negative, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Reciprocal<T0>> final
: public UnaryDimExprMatchTrait<::symbol::Reciprocal, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Add<T0>> final
: public ListDimExprMatchTrait<::symbol::Add, T0> {};
Expand All @@ -77,6 +87,10 @@ template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Mul<T0>> final
: public ListDimExprMatchTrait<::symbol::Mul, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Div<T0>> final
: public BinaryDimExprMatchTrait<::symbol::Div, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Broadcast<T0>> final
: public ListDimExprMatchTrait<::symbol::Broadcast, T0> {};
Expand Down
19 changes: 14 additions & 5 deletions paddle/cinn/common/broadcast_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ bool SearchBroadcastImplForUnary(const T& unary, const DoEachT& DoEach) {
return SearchBroadcast(operand, DoEach);
}

template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Negative<symbol::DimExpr>& unary,
const DoEachT& DoEach) {
return SearchBroadcastImplForUnary(unary, DoEach);
template <typename T, typename DoEachT>
bool SearchBroadcastImplForBinary(const T& binary, const DoEachT& DoEach) {
const auto& lhs = binary->lhs;
const auto& rhs = binary->rhs;
if (SearchBroadcast(lhs, DoEach)) return true;
if (SearchBroadcast(rhs, DoEach)) return true;
return false;
}

template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Reciprocal<symbol::DimExpr>& unary,
bool SearchBroadcastImpl(const symbol::Negative<symbol::DimExpr>& unary,
const DoEachT& DoEach) {
return SearchBroadcastImplForUnary(unary, DoEach);
}
Expand All @@ -76,6 +79,12 @@ bool SearchBroadcastImpl(const symbol::Mul<symbol::DimExpr>& variadic,
return SearchBroadcastImplForVariadic(variadic, DoEach);
}

template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Div<symbol::DimExpr>& binary,
const DoEachT& DoEach) {
return SearchBroadcastImplForBinary(binary, DoEach);
}

template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Max<symbol::DimExpr>& variadic,
const DoEachT& DoEach) {
Expand Down
23 changes: 7 additions & 16 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ struct DimExprToIrExprVisitor {
return ir::Sub::Make(ir::Expr(std::int64_t(0)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Reciprocal<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Div::Make(ir::Expr(std::int64_t(1)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Add<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
Expand All @@ -69,21 +64,17 @@ struct DimExprToIrExprVisitor {
}
ir::Expr product = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
// Convert Reciprocal<DimExpr>(S0) to (1 / S0) will result in precision
// error. For example, (S0 * S1 / S2) != (S0 * S1 * (1 / S2)). So we
// should use Div instead of Reciprocal here.
if (operands->at(i).isa<Reciprocal<DimExpr>>()) {
product = ir::Div::Make(
product,
ConvertToIrExpr(
operands->at(i).dyn_cast<Reciprocal<DimExpr>>()->data));
} else {
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
}
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
}
return product;
}

ir::Expr operator()(const Div<DimExpr>& dim_expr) {
const auto& lhs = ConvertToIrExpr(dim_expr->lhs);
const auto& rhs = ConvertToIrExpr(dim_expr->rhs);
return ir::Div::Make(lhs, rhs);
}

ir::Expr operator()(const Max<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
PADDLE_ENFORCE_EQ(
Expand Down
91 changes: 72 additions & 19 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ std::string GetSerializedTag<Negative<DimExpr>>() {
return "Negative";
}

template <>
std::string GetSerializedTag<Reciprocal<DimExpr>>() {
return "Reciprocal";
}

template <>
std::string GetSerializedTag<Add<DimExpr>>() {
return "Add";
Expand All @@ -45,6 +40,11 @@ std::string GetSerializedTag<Mul<DimExpr>>() {
return "Mul";
}

template <>
std::string GetSerializedTag<Div<DimExpr>>() {
return "Div";
}

template <>
std::string GetSerializedTag<Max<DimExpr>>() {
return "Max";
Expand Down Expand Up @@ -80,13 +80,20 @@ ::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::IrContext* ctx,
return pir::ArrayAttribute::get(ctx, attr_vecs);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::IrContext* ctx, const Negative<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
template <typename T>
::pir::Attribute ConvertBinaryDimExprToAttributeImpl(::pir::IrContext* ctx,
const T& dim_expr) {
std::vector<::pir::Attribute> attr_vecs{};
attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag<T>()));
const auto& lhs = dim_expr->lhs;
const auto& rhs = dim_expr->rhs;
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, lhs));
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, rhs));
return pir::ArrayAttribute::get(ctx, attr_vecs);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::IrContext* ctx, const Reciprocal<DimExpr>& dim_expr) {
::pir::IrContext* ctx, const Negative<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
}

Expand All @@ -112,6 +119,11 @@ ::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Div<DimExpr>& dim_expr) {
return ConvertBinaryDimExprToAttributeImpl(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Max<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
Expand Down Expand Up @@ -150,6 +162,23 @@ std::optional<DimExpr> ConvertArrayAttributeToUnaryDimExpr(
return T{operand.value()};
}

template <typename T>
std::optional<DimExpr> ConvertArrayAttributeToBinaryDimExpr(
const ::pir::ArrayAttribute& attribute) {
if (attribute.size() != 3) {
return std::nullopt;
}
std::optional<DimExpr> lhs = ConvertAttributeToDimExpr(attribute.at(1));
if (!lhs.has_value()) {
return std::nullopt;
}
std::optional<DimExpr> rhs = ConvertAttributeToDimExpr(attribute.at(2));
if (!rhs.has_value()) {
return std::nullopt;
}
return T{lhs.value(), rhs.value()};
}

template <typename T>
std::optional<DimExpr> ConvertArrayAttributeToVariadicDimExpr(
const ::pir::ArrayAttribute& attribute) {
Expand All @@ -175,12 +204,12 @@ std::optional<ArrayAttributeConverterT> GetArrayAttributeConverter(
static std::unordered_map<std::string, ArrayAttributeConverterT> map{
{GetSerializedTag<Negative<DimExpr>>(),
&ConvertArrayAttributeToUnaryDimExpr<Negative<DimExpr>>},
{GetSerializedTag<Reciprocal<DimExpr>>(),
&ConvertArrayAttributeToUnaryDimExpr<Reciprocal<DimExpr>>},
{GetSerializedTag<Add<DimExpr>>(),
&ConvertArrayAttributeToVariadicDimExpr<Add<DimExpr>>},
{GetSerializedTag<Mul<DimExpr>>(),
&ConvertArrayAttributeToVariadicDimExpr<Mul<DimExpr>>},
{GetSerializedTag<Div<DimExpr>>(),
&ConvertArrayAttributeToBinaryDimExpr<Div<DimExpr>>},
{GetSerializedTag<Max<DimExpr>>(),
&ConvertArrayAttributeToVariadicDimExpr<Max<DimExpr>>},
{GetSerializedTag<Min<DimExpr>>(),
Expand Down Expand Up @@ -276,9 +305,6 @@ class SubstituteDimExprHelper final {
std::optional<DimExpr> SubstituteImpl(const Negative<DimExpr>& dim_expr) {
return SubstituteUnary(dim_expr);
}
std::optional<DimExpr> SubstituteImpl(const Reciprocal<DimExpr>& dim_expr) {
return SubstituteUnary(dim_expr);
}

template <typename T>
std::optional<DimExpr> SubstituteUnary(const T& dim_expr) {
Expand All @@ -298,6 +324,25 @@ class SubstituteDimExprHelper final {
return SubstituteVariadic(dim_expr);
}

std::optional<DimExpr> SubstituteImpl(const Div<DimExpr>& dim_expr) {
return SubstituteBinary(dim_expr);
}

template <typename T>
std::optional<DimExpr> SubstituteBinary(const T& dim_expr) {
const auto& lhs = dim_expr->lhs;
const auto& rhs = dim_expr->rhs;
const auto& substituted_lhs = Substitute(lhs);
if (!substituted_lhs.has_value()) {
return std::nullopt;
}
const auto& substituted_rhs = Substitute(rhs);
if (!substituted_rhs.has_value()) {
return std::nullopt;
}
return T{substituted_lhs.value(), substituted_rhs.value()};
}

std::optional<DimExpr> SubstituteImpl(const Max<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}
Expand Down Expand Up @@ -412,12 +457,12 @@ bool IsAtomicImpl(const std::string&) { return true; }

bool IsAtomicImpl(const symbol::Negative<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Reciprocal<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Add<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Mul<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Div<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Max<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Min<symbol::DimExpr>&) { return false; }
Expand Down Expand Up @@ -484,9 +529,12 @@ void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr,
CollectSymbolNamesImplForUnary(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForUnary(dim_expr, ret);
template <typename T>
void CollectSymbolNamesImplForBinary(const T& dim_expr,
std::set<std::string>* ret) {
const auto& [lhs, rhs] = *dim_expr;
CollectSymbolNames(lhs, ret);
CollectSymbolNames(rhs, ret);
}

template <typename T>
Expand All @@ -508,6 +556,11 @@ void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr,
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Div<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForBinary(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Max<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ struct ShapeSignatureGenerator {
[&](const symbol::Negative<symbol::DimExpr>& negative) {
GetSymbolsForOneDimExpr(negative->data, symbols);
},
[&](const symbol::Reciprocal<symbol::DimExpr>& reciprocal) {
GetSymbolsForOneDimExpr(reciprocal->data, symbols);
},
[&](const symbol::Add<symbol::DimExpr>& add) {
for (const auto& dim_expr : *add.operands) {
GetSymbolsForOneDimExpr(dim_expr, symbols);
Expand All @@ -150,6 +147,10 @@ struct ShapeSignatureGenerator {
GetSymbolsForOneDimExpr(dim_expr, symbols);
}
},
[&](const symbol::Div<symbol::DimExpr>& div) {
GetSymbolsForOneDimExpr(div->lhs, symbols);
GetSymbolsForOneDimExpr(div->rhs, symbols);
},
[&](const symbol::Max<symbol::DimExpr>& max) {
for (const auto& dim_expr : *max.operands) {
GetSymbolsForOneDimExpr(dim_expr, symbols);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,16 @@ struct StaticDimToDynamicConverter {
return AppliedOnceUnaryImpl(dim_expr, symbol);
}

bool AppliedOnceImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
template <typename T>
bool AppliedOnceBinaryImpl(const T& dim_expr, const std::string& symbol) {
const auto& lhs = dim_expr->lhs;
const auto& rhs = dim_expr->rhs;
return AppliedOnce(lhs, symbol) || AppliedOnce(rhs, symbol);
}

bool AppliedOnceImpl(const symbol::Div<symbol::DimExpr>& dim_expr,
const std::string& symbol) {
return AppliedOnceUnaryImpl(dim_expr, symbol);
return AppliedOnceBinaryImpl(dim_expr, symbol);
}

template <typename T>
Expand Down Expand Up @@ -272,6 +279,24 @@ struct StaticDimToDynamicConverter {
return T{converted_operand.value()};
}

template <typename T>
std::optional<symbol::DimExpr> ConvertBinaryDimExprImpl(
const T& dim_expr, int64_t c, const std::string& symbol) {
const auto& lhs = dim_expr->lhs;
const auto& rhs = dim_expr->rhs;
const auto& converted_lhs = ConvertDimExpr(lhs, c, symbol);
const auto& converted_rhs = ConvertDimExpr(rhs, c, symbol);
if (!converted_lhs.has_value() && !converted_rhs.has_value())
return std::nullopt;
if (converted_lhs.has_value() && converted_rhs.has_value()) {
return T{converted_lhs.value(), converted_rhs.value()};
}
if (converted_lhs.has_value()) {
return T{converted_lhs.value(), rhs};
}
return T{lhs, converted_rhs.value()};
}

template <typename T>
std::optional<symbol::DimExpr> ConvertListDimExprImpl(
const T& dim_expr, int64_t c, const std::string& symbol) {
Expand All @@ -297,24 +322,24 @@ struct StaticDimToDynamicConverter {
}

std::optional<symbol::DimExpr> ConvertDimExprImpl(
const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
const symbol::Add<symbol::DimExpr>& dim_expr,
int64_t c,
const std::string& symbol) {
return ConvertUnaryDimExprImpl(dim_expr, c, symbol);
return ConvertListDimExprImpl(dim_expr, c, symbol);
}

std::optional<symbol::DimExpr> ConvertDimExprImpl(
const symbol::Add<symbol::DimExpr>& dim_expr,
const symbol::Mul<symbol::DimExpr>& dim_expr,
int64_t c,
const std::string& symbol) {
return ConvertListDimExprImpl(dim_expr, c, symbol);
}

std::optional<symbol::DimExpr> ConvertDimExprImpl(
const symbol::Mul<symbol::DimExpr>& dim_expr,
const symbol::Div<symbol::DimExpr>& dim_expr,
int64_t c,
const std::string& symbol) {
return ConvertListDimExprImpl(dim_expr, c, symbol);
return ConvertBinaryDimExprImpl(dim_expr, c, symbol);
}

std::optional<symbol::DimExpr> ConvertDimExprImpl(
Expand Down
Loading
Loading