Skip to content

Commit

Permalink
[LIR] Teach LIR to avoid extending the BE count prior to adding one to
Browse files Browse the repository at this point in the history
it when safe.

Very often the BE count is the trip count minus one, and the plus one
here should fold with that minus one. But because the BE count might in
theory be UINT_MAX or some such, adding one before we extend could in
some cases wrap to zero and break when we scale things.

This patch checks to see if it would be safe to add one because the
specific case that would cause this is guarded for prior to entering the
preheader. This should handle essentially all of the common loop idioms
coming out of C/C++ code once canonicalized by LLVM.

Before this patch, both forms of loop in the added test cases ended up
subtracting one from the size, extending it, scaling it up by 8 and then
adding 8 back onto it. This is really silly, and it turns out made it
all the way into generated code very often, so this is a surprisingly
important cleanup to do.

Many thanks to Sanjoy for showing me how to do this with SCEV.

Differential Revision: https://reviews.llvm.org/D35758

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@308968 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
chandlerc committed Jul 25, 2017
1 parent 0000a71 commit 2dcaf78
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 18 deletions.
55 changes: 37 additions & 18 deletions lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,41 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
return SE->getMinusSCEV(Start, Index);
}

/// Compute the number of bytes as a SCEV from the backedge taken count.
///
/// This also maps the SCEV into the provided type and tries to handle the
/// computation in a way that will fold cleanly.
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
unsigned StoreSize, Loop *CurLoop,
const DataLayout *DL, ScalarEvolution *SE) {
const SCEV *NumBytesS;
// The # stored bytes is (BECount+1)*Size. Expand the trip count out to
// pointer size if it isn't already.
//
// If we're going to need to zero extend the BE count, check if we can add
// one to it prior to zero extending without overflow. Provided this is safe,
// it allows better simplification of the +1.
if (DL->getTypeSizeInBits(BECount->getType()) <
DL->getTypeSizeInBits(IntPtr) &&
SE->isLoopEntryGuardedByCond(
CurLoop, ICmpInst::ICMP_NE, BECount,
SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
NumBytesS = SE->getZeroExtendExpr(
SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
IntPtr);
} else {
NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
SE->getOne(IntPtr), SCEV::FlagNUW);
}

// And scale it based on the store size.
if (StoreSize != 1) {
NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
SCEV::FlagNUW);
}
return NumBytesS;
}

/// processLoopStridedStore - We see a strided store of some value. If we can
/// transform this into a memset or memset_pattern in the loop preheader, do so.
bool LoopIdiomRecognize::processLoopStridedStore(
Expand Down Expand Up @@ -837,16 +872,8 @@ bool LoopIdiomRecognize::processLoopStridedStore(

// Okay, everything looks good, insert the memset.

// The # stored bytes is (BECount+1)*Size. Expand the trip count out to
// pointer size if it isn't already.
BECount = SE->getTruncateOrZeroExtend(BECount, IntPtr);

const SCEV *NumBytesS =
SE->getAddExpr(BECount, SE->getOne(IntPtr), SCEV::FlagNUW);
if (StoreSize != 1) {
NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
SCEV::FlagNUW);
}
getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE);

// TODO: ideally we should still be able to generate memset if SCEV expander
// is taught to generate the dependencies at the latest point.
Expand Down Expand Up @@ -976,16 +1003,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,

// Okay, everything is safe, we can transform this!

// The # stored bytes is (BECount+1)*Size. Expand the trip count out to
// pointer size if it isn't already.
BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy);

const SCEV *NumBytesS =
SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW);

if (StoreSize != 1)
NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize),
SCEV::FlagNUW);
getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE);

Value *NumBytes =
Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator());
Expand Down
69 changes: 69 additions & 0 deletions test/Transforms/LoopIdiom/basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,75 @@ for.end6: ; preds = %for.inc4
; CHECK: ret void
}

; Handle loops where the trip count is a narrow integer that needs to be
; extended.
define void @form_memset_narrow_size(i64* %ptr, i32 %size) {
; CHECK-LABEL: @form_memset_narrow_size(
entry:
%cmp1 = icmp sgt i32 %size, 0
br i1 %cmp1, label %loop.ph, label %exit
; CHECK: entry:
; CHECK: %[[C1:.*]] = icmp sgt i32 %size, 0
; CHECK-NEXT: br i1 %[[C1]], label %loop.ph, label %exit

loop.ph:
br label %loop.body
; CHECK: loop.ph:
; CHECK-NEXT: %[[ZEXT_SIZE:.*]] = zext i32 %size to i64
; CHECK-NEXT: %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %{{.*}}, i8 0, i64 %[[SCALED_SIZE]], i32 8, i1 false)

loop.body:
%storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ]
%idxprom = sext i32 %storemerge4 to i64
%arrayidx = getelementptr inbounds i64, i64* %ptr, i64 %idxprom
store i64 0, i64* %arrayidx, align 8
%inc = add nsw i32 %storemerge4, 1
%cmp2 = icmp slt i32 %inc, %size
br i1 %cmp2, label %loop.body, label %loop.exit

loop.exit:
br label %exit

exit:
ret void
}

define void @form_memcpy_narrow_size(i64* noalias %dst, i64* noalias %src, i32 %size) {
; CHECK-LABEL: @form_memcpy_narrow_size(
entry:
%cmp1 = icmp sgt i32 %size, 0
br i1 %cmp1, label %loop.ph, label %exit
; CHECK: entry:
; CHECK: %[[C1:.*]] = icmp sgt i32 %size, 0
; CHECK-NEXT: br i1 %[[C1]], label %loop.ph, label %exit

loop.ph:
br label %loop.body
; CHECK: loop.ph:
; CHECK-NEXT: %[[ZEXT_SIZE:.*]] = zext i32 %size to i64
; CHECK-NEXT: %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 %[[SCALED_SIZE]], i32 8, i1 false)

loop.body:
%storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ]
%idxprom1 = sext i32 %storemerge4 to i64
%arrayidx1 = getelementptr inbounds i64, i64* %src, i64 %idxprom1
%v = load i64, i64* %arrayidx1, align 8
%idxprom2 = sext i32 %storemerge4 to i64
%arrayidx2 = getelementptr inbounds i64, i64* %dst, i64 %idxprom2
store i64 %v, i64* %arrayidx2, align 8
%inc = add nsw i32 %storemerge4, 1
%cmp2 = icmp slt i32 %inc, %size
br i1 %cmp2, label %loop.body, label %loop.exit

loop.exit:
br label %exit

exit:
ret void
}

; Validate that "memset_pattern" has the proper attributes.
; CHECK: declare void @memset_pattern16(i8* nocapture, i8* nocapture readonly, i64) [[ATTRS:#[0-9]+]]
; CHECK: [[ATTRS]] = { argmemonly }

0 comments on commit 2dcaf78

Please sign in to comment.