Skip to content

Commit

Permalink
[InstCombine] fold signed absolute diff patterns
Browse files Browse the repository at this point in the history
This overlaps partially with the codegen patch D144789. This needs no-wrap
for correctness, and I'm not sure if there's an unsigned equivalent:
https://alive2.llvm.org/ce/z/ErmQ-9
https://alive2.llvm.org/ce/z/mr-c_A

This is obviously an improvement in IR, and it looks like a codegen win
for all targets and data types that I sampled.

The 'nabs' case is left as a potential follow-up (and seems less likely
to occur in real code).

Differential Revision: https://reviews.llvm.org/D145073
  • Loading branch information
rotateright committed Mar 6, 2023
1 parent 870e6b6 commit 74a5849
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 35 deletions.
44 changes: 44 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,47 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}

/// Try to match patterns with select and subtract as absolute difference.
static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
InstCombiner::BuilderTy &Builder) {
auto *TI = dyn_cast<Instruction>(TVal);
auto *FI = dyn_cast<Instruction>(FVal);
if (!TI || !FI)
return nullptr;

// Normalize predicate to gt/lt rather than ge/le.
ICmpInst::Predicate Pred = Cmp->getStrictPredicate();
Value *A = Cmp->getOperand(0);
Value *B = Cmp->getOperand(1);

// Normalize "A - B" as the true value of the select.
if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) {
std::swap(FI, TI);
Pred = ICmpInst::getSwappedPredicate(Pred);
}

// With any pair of no-wrap subtracts:
// (A > B) ? (A - B) : (B - A) --> abs(A - B)
if (Pred == CmpInst::ICMP_SGT &&
match(TI, m_Sub(m_Specific(A), m_Specific(B))) &&
match(FI, m_Sub(m_Specific(B), m_Specific(A))) &&
(TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) &&
(FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) {
// The remaining subtract is not "nuw" any more.
// If there's one use of the subtract (no other use than the use we are
// about to replace), then we know that the sub is "nsw" in this context
// even if it was only "nuw" before. If there's another use, then we can't
// add "nsw" to the existing instruction because it may not be safe in the
// other user's context.
TI->setHasNoUnsignedWrap(false);
if (!TI->hasNoSignedWrap())
TI->setHasNoSignedWrap(TI->hasOneUse());
return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue());
}

return nullptr;
}

/// Fold the following code sequence:
/// \code
/// int a = ctlz(x & -x);
Expand Down Expand Up @@ -1790,6 +1831,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);

if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);

return Changed ? &SI : nullptr;
}

Expand Down
83 changes: 48 additions & 35 deletions llvm/test/Transforms/InstCombine/abs-1.ll
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ define i8 @nabs_extra_use_icmp_sub(i8 %x) {
ret i8 %s
}

; TODO: negate-of-abs-diff

define i32 @nabs_diff_signed_slt(i32 %a, i32 %b) {
; CHECK-LABEL: @nabs_diff_signed_slt(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
Expand All @@ -694,6 +696,8 @@ define i32 @nabs_diff_signed_slt(i32 %a, i32 %b) {
ret i32 %cond
}

; TODO: negate-of-abs-diff

define <2 x i8> @nabs_diff_signed_sle(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @nabs_diff_signed_sle(
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp sgt <2 x i8> [[A:%.*]], [[B:%.*]]
Expand All @@ -711,11 +715,9 @@ define <2 x i8> @nabs_diff_signed_sle(<2 x i8> %a, <2 x i8> %b) {

define i8 @abs_diff_signed_sgt(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_diff_signed_sgt(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B]], [[A]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
Expand All @@ -728,12 +730,11 @@ define i8 @abs_diff_signed_sgt(i8 %a, i8 %b) {

define i8 @abs_diff_signed_sge(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_diff_signed_sge(
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B]], [[A]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i8 [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]])
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP_NOT]], i8 [[SUB_BA]], i8 [[SUB_AB]]
; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sge i8 %a, %b
Expand All @@ -745,6 +746,8 @@ define i8 @abs_diff_signed_sge(i8 %a, i8 %b) {
ret i8 %cond
}

; negative test - need nsw

define i32 @abs_diff_signed_slt_no_nsw(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_diff_signed_slt_no_nsw(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
Expand All @@ -760,12 +763,12 @@ define i32 @abs_diff_signed_slt_no_nsw(i32 %a, i32 %b) {
ret i32 %cond
}

; bonus nuw - it's fine to match the pattern, but nuw can't propagate

define i8 @abs_diff_signed_sgt_nsw_nuw(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_diff_signed_sgt_nsw_nuw(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw nsw i8 [[B]], [[A]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw nsw i8 [[A]], [[B]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
Expand All @@ -775,12 +778,12 @@ define i8 @abs_diff_signed_sgt_nsw_nuw(i8 %a, i8 %b) {
ret i8 %cond
}

; this is absolute diff, but nuw can't propagate and nsw can be set.

define i8 @abs_diff_signed_sgt_nuw(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_diff_signed_sgt_nuw(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
Expand All @@ -790,13 +793,14 @@ define i8 @abs_diff_signed_sgt_nuw(i8 %a, i8 %b) {
ret i8 %cond
}

; same as above

define i8 @abs_diff_signed_sgt_nuw_extra_use1(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_diff_signed_sgt_nuw_extra_use1(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]])
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i8 [[A]], [[B]]
; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
Expand All @@ -807,13 +811,13 @@ define i8 @abs_diff_signed_sgt_nuw_extra_use1(i8 %a, i8 %b) {
ret i8 %cond
}

; nuw can't propagate, and the extra use prevents applying nsw

define i8 @abs_diff_signed_sgt_nuw_extra_use2(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_diff_signed_sgt_nuw_extra_use2(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
Expand All @@ -824,14 +828,15 @@ define i8 @abs_diff_signed_sgt_nuw_extra_use2(i8 %a, i8 %b) {
ret i8 %cond
}

; same as above

define i8 @abs_diff_signed_sgt_nuw_extra_use3(i8 %a, i8 %b) {
; CHECK-LABEL: @abs_diff_signed_sgt_nuw_extra_use3(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B]], [[A]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nuw i8 [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_BA]])
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nuw i8 [[A]], [[B]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub i8 [[A]], [[B]]
; CHECK-NEXT: call void @extra_use(i8 [[SUB_AB]])
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[SUB_AB]], i8 [[SUB_BA]]
; CHECK-NEXT: [[COND:%.*]] = call i8 @llvm.abs.i8(i8 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i8 [[COND]]
;
%cmp = icmp sgt i8 %a, %b
Expand All @@ -843,6 +848,8 @@ define i8 @abs_diff_signed_sgt_nuw_extra_use3(i8 %a, i8 %b) {
ret i8 %cond
}

; negative test - wrong predicate

define i32 @abs_diff_signed_slt_swap_wrong_pred1(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_diff_signed_slt_swap_wrong_pred1(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], [[B:%.*]]
Expand All @@ -858,6 +865,8 @@ define i32 @abs_diff_signed_slt_swap_wrong_pred1(i32 %a, i32 %b) {
ret i32 %cond
}

; negative test - wrong predicate

define i32 @abs_diff_signed_slt_swap_wrong_pred2(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_diff_signed_slt_swap_wrong_pred2(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
Expand All @@ -873,6 +882,8 @@ define i32 @abs_diff_signed_slt_swap_wrong_pred2(i32 %a, i32 %b) {
ret i32 %cond
}

; negative test - need common operands

define i32 @abs_diff_signed_slt_swap_wrong_op(i32 %a, i32 %b, i32 %z) {
; CHECK-LABEL: @abs_diff_signed_slt_swap_wrong_op(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], [[B:%.*]]
Expand All @@ -890,10 +901,8 @@ define i32 @abs_diff_signed_slt_swap_wrong_op(i32 %a, i32 %b, i32 %z) {

define i32 @abs_diff_signed_slt_swap(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_diff_signed_slt_swap(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i32 [[B]], [[A]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A]], [[B]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[SUB_BA]], i32 [[SUB_AB]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.abs.i32(i32 [[SUB_AB]], i1 true)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp slt i32 %a, %b
Expand All @@ -905,10 +914,8 @@ define i32 @abs_diff_signed_slt_swap(i32 %a, i32 %b) {

define <2 x i8> @abs_diff_signed_sle_swap(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @abs_diff_signed_sle_swap(
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp sgt <2 x i8> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw <2 x i8> [[B]], [[A]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw <2 x i8> [[A]], [[B]]
; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[CMP_NOT]], <2 x i8> [[SUB_AB]], <2 x i8> [[SUB_BA]]
; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw <2 x i8> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[COND:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[SUB_AB]], i1 true)
; CHECK-NEXT: ret <2 x i8> [[COND]]
;
%cmp = icmp sle <2 x i8> %a, %b
Expand All @@ -918,6 +925,8 @@ define <2 x i8> @abs_diff_signed_sle_swap(<2 x i8> %a, <2 x i8> %b) {
ret <2 x i8> %cond
}

; TODO: negate-of-abs-diff

define i8 @nabs_diff_signed_sgt_swap(i8 %a, i8 %b) {
; CHECK-LABEL: @nabs_diff_signed_sgt_swap(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
Expand All @@ -935,6 +944,8 @@ define i8 @nabs_diff_signed_sgt_swap(i8 %a, i8 %b) {
ret i8 %cond
}

; TODO: negate-of-abs-diff, but too many uses?

define i8 @nabs_diff_signed_sge_swap(i8 %a, i8 %b) {
; CHECK-LABEL: @nabs_diff_signed_sge_swap(
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]]
Expand All @@ -954,6 +965,8 @@ define i8 @nabs_diff_signed_sge_swap(i8 %a, i8 %b) {
ret i8 %cond
}

; negative test - need nsw

define i32 @abs_diff_signed_slt_no_nsw_swap(i32 %a, i32 %b) {
; CHECK-LABEL: @abs_diff_signed_slt_no_nsw_swap(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
Expand Down

0 comments on commit 74a5849

Please sign in to comment.