Skip to content

Commit

Permalink
Reland "[mlir][arith] Canonicalization patterns for arith.select (l…
Browse files Browse the repository at this point in the history
…lvm#67809)" (llvm#68941)

This cherry-picks the changes in
llvm-project/5bf701a6687a46fd898621f5077959ff202d716b and extends the
pattern to handle vector types.

To reuse `getBoolAttribute` method, it moves the static method above the
include of generated file.
  • Loading branch information
hanhanW authored Oct 13, 2023
1 parent b1115f8 commit 6dbc6df
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 10 deletions.
49 changes: 49 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,55 @@ def CmpIExtUI :
CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
"$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

def GetScalarOrVectorTrueAttribute :
NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;

// select(not(pred), a, b) => select(pred, b, a)
def SelectNotCond :
Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
(SelectOp $pred, $b, $a),
[(IsScalarOrSplatNegativeOne $ones)]>;

// select(pred, select(pred, a, b), c) => select(pred, a, c)
def RedundantSelectTrue :
Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
(SelectOp $pred, $a, $c)>;

// select(pred, a, select(pred, b, c)) => select(pred, a, c)
def RedundantSelectFalse :
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
(SelectOp $pred, $a, $c)>;

// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
def SelectAndCond :
Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;

// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
def SelectAndNotCond :
Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
(SelectOp (Arith_AndIOp $predA,
(Arith_XOrIOp $predB,
(Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
$x, $y)>;

// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
def SelectOrCond :
Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
(SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;

// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
def SelectOrNotCond :
Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
(SelectOp (Arith_OrIOp $predA,
(Arith_XOrIOp $predB,
(Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
$x, $y)>;

//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 12 additions & 10 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
return failure();
}

static Attribute getBoolAttribute(Type type, bool value) {
auto boolAttr = BoolAttr::get(type.getContext(), value);
ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
if (!shapedType)
return boolAttr;
return DenseElementsAttr::get(shapedType, boolAttr);
}

//===----------------------------------------------------------------------===//
// TableGen'd canonicalization patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1696,14 +1704,6 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
llvm_unreachable("unknown cmpi predicate kind");
}

static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
auto boolAttr = BoolAttr::get(ctx, value);
ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
if (!shapedType)
return boolAttr;
return DenseElementsAttr::get(shapedType, boolAttr);
}

static std::optional<int64_t> getIntegerWidth(Type t) {
if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
return intType.getWidth();
Expand All @@ -1718,7 +1718,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
// cmpi(pred, x, x)
if (getLhs() == getRhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
return getBoolAttribute(getType(), getContext(), val);
return getBoolAttribute(getType(), val);
}

if (matchPattern(adaptor.getRhs(), m_Zero())) {
Expand Down Expand Up @@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {

void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SelectI1Simplify, SelectToExtUI>(context);
results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
SelectNotCond, SelectToExtUI>(context);
}

OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
Expand Down
100 changes: 100 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,106 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
return %res : i1
}

// CHECK-LABEL: @redundantSelectTrue
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
// CHECK-NEXT: return %[[res]]
func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
%0 = arith.select %arg0, %arg1, %arg2 : i32
%res = arith.select %arg0, %0, %arg3 : i32
return %res : i32
}

// CHECK-LABEL: @redundantSelectFalse
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
// CHECK-NEXT: return %[[res]]
func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
%0 = arith.select %arg0, %arg1, %arg2 : i32
%res = arith.select %arg0, %arg3, %0 : i32
return %res : i32
}

// CHECK-LABEL: @selNotCond
// CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1
// CHECK-NEXT: %[[res2:.+]] = arith.select %arg0, %arg4, %arg3
// CHECK-NEXT: return %[[res1]], %[[res2]]
func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
%one = arith.constant 1 : i1
%cond1 = arith.xori %arg0, %one : i1
%cond2 = arith.xori %one, %arg0 : i1

%res1 = arith.select %cond1, %arg1, %arg2 : i32
%res2 = arith.select %cond2, %arg3, %arg4 : i32
return %res1, %res2 : i32, i32
}

// CHECK-LABEL: @selAndCond
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
// CHECK-NEXT: return %[[res]]
func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
%sel = arith.select %arg0, %arg2, %arg3 : i32
%res = arith.select %arg1, %sel, %arg3 : i32
return %res : i32
}

// CHECK-LABEL: @selAndNotCond
// CHECK-NEXT: %[[one:.+]] = arith.constant true
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
// CHECK-NEXT: return %[[res]]
func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
%sel = arith.select %arg0, %arg2, %arg3 : i32
%res = arith.select %arg1, %sel, %arg2 : i32
return %res : i32
}

// CHECK-LABEL: @selAndNotCondVec
// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
// CHECK-NEXT: return %[[res]]
func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
%sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
%res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
return %res : vector<4xi32>
}

// CHECK-LABEL: @selOrCond
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
// CHECK-NEXT: return %[[res]]
func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
%sel = arith.select %arg0, %arg2, %arg3 : i32
%res = arith.select %arg1, %arg2, %sel : i32
return %res : i32
}

// CHECK-LABEL: @selOrNotCond
// CHECK-NEXT: %[[one:.+]] = arith.constant true
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
// CHECK-NEXT: return %[[res]]
func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
%sel = arith.select %arg0, %arg2, %arg3 : i32
%res = arith.select %arg1, %arg3, %sel : i32
return %res : i32
}

// CHECK-LABEL: @selOrNotCondVec
// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
// CHECK-NEXT: return %[[res]]
func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
%sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
%res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
return %res : vector<4xi32>
}

// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
Expand Down

0 comments on commit 6dbc6df

Please sign in to comment.