Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add Div to replace Recipical in DimExpr #70376

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
add BinaryExprMatchTrait
  • Loading branch information
gongshaotian committed Dec 24, 2024
commit 6d66e14f7bc49cd75cb2f57778e25c918a90d753
16 changes: 15 additions & 1 deletion paddle/cinn/adt/dim_expr_match_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ struct UnaryDimExprMatchTrait {
}
};

template <template <typename> class Op, typename t0>
struct BinaryDimExprMatchTrait {
using base_type = Op<DimExpr>;

static constexpr int is_template = true;

template <template <typename, typename> class Matcher>
static bool MatchChildren(const base_type& value) {
const auto& lhs = std::get<0>(value.tuple());
const auto& rhs = std::get<1>(value.tuple());
return Matcher<T0, DimExpr>::Call(lhs) && Matcher<T0, DimExpr>::Call(rhs);
}
};

template <template <typename> class Op, typename T0>
struct ListDimExprMatchTrait {
using base_type = Op<DimExpr>;
Expand Down Expand Up @@ -75,7 +89,7 @@ struct MatchTrait<DimExpr, ::symbol::Mul<T0>> final

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Div<T0>> final
: public ListDimExprMatchTrait<::symbol::Div, T0> {};
: public BinaryDimExprMatchTrait<::symbol::Div, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Broadcast<T0>> final
Expand Down