Skip to content

Commit

Permalink
[LoopUnroll] Keep the loop test only on the first iteration of max-or…
Browse files Browse the repository at this point in the history
…-zero loops

When we have a loop with a known upper bound on the number of iterations, and
furthermore know that either the number of iterations will be either exactly
that upper bound or zero, then we can fully unroll up to that upper bound
keeping only the first loop test to check for the zero iteration case.

Most of the work here is in plumbing this 'max-or-zero' information from the
part of scalar evolution where it's detected through to loop unrolling. I've
also gone for the safe default of 'false' everywhere but howManyLessThans which
could probably be improved.

Differential Revision: https://reviews.llvm.org/D25682


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@284818 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
john-brawn-arm committed Oct 21, 2016
1 parent 4b5784c commit 9e0c61c
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 56 deletions.
32 changes: 23 additions & 9 deletions include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,9 @@ class ScalarEvolution {
/// pair of exact and max expressions that are eventually summarized in
/// ExitNotTakenInfo and BackedgeTakenInfo.
struct ExitLimit {
const SCEV *ExactNotTaken;
const SCEV *MaxNotTaken;
const SCEV *ExactNotTaken; //< The exit is not taken exactly this many times
const SCEV *MaxNotTaken; //< The exit is not taken at most this many times
bool MaxOrZero; //< Not taken either exactly MaxNotTaken or zero times

/// A set of predicate guards for this ExitLimit. The result is only valid
/// if all of the predicates in \c Predicates evaluate to 'true' at
Expand All @@ -561,12 +562,13 @@ class ScalarEvolution {
Predicates.insert(P);
}

/*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {}
/*implicit*/ ExitLimit(const SCEV *E)
: ExactNotTaken(E), MaxNotTaken(E), MaxOrZero(false) {}

ExitLimit(
const SCEV *E, const SCEV *M,
const SCEV *E, const SCEV *M, bool MaxOrZero,
ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
: ExactNotTaken(E), MaxNotTaken(M) {
: ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
!isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
"Exact is not allowed to be less precise than Max");
Expand All @@ -575,11 +577,12 @@ class ScalarEvolution {
addPredicate(P);
}

ExitLimit(const SCEV *E, const SCEV *M,
ExitLimit(const SCEV *E, const SCEV *M, bool MaxOrZero,
const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
: ExitLimit(E, M, {&PredSet}) {}
: ExitLimit(E, M, MaxOrZero, {&PredSet}) {}

ExitLimit(const SCEV *E, const SCEV *M) : ExitLimit(E, M, None) {}
ExitLimit(const SCEV *E, const SCEV *M, bool MaxOrZero)
: ExitLimit(E, M, MaxOrZero, None) {}

/// Test whether this ExitLimit contains any computed information, or
/// whether it's all SCEVCouldNotCompute values.
Expand Down Expand Up @@ -628,6 +631,9 @@ class ScalarEvolution {
/// ExitNotTaken has an element for every exiting block in the loop.
PointerIntPair<const SCEV *, 1> MaxAndComplete;

/// True iff the backedge is taken either exactly Max or zero times.
bool MaxOrZero;

/// \name Helper projection functions on \c MaxAndComplete.
/// @{
bool isComplete() const { return MaxAndComplete.getInt(); }
Expand All @@ -644,7 +650,7 @@ class ScalarEvolution {

/// Initialize BackedgeTakenInfo from a list of exact exit counts.
BackedgeTakenInfo(SmallVectorImpl<EdgeExitInfo> &&ExitCounts, bool Complete,
const SCEV *MaxCount);
const SCEV *MaxCount, bool MaxOrZero);

/// Test whether this BackedgeTakenInfo contains any computed information,
/// or whether it's all SCEVCouldNotCompute values.
Expand Down Expand Up @@ -683,6 +689,10 @@ class ScalarEvolution {
/// Get the max backedge taken count for the loop.
const SCEV *getMax(ScalarEvolution *SE) const;

/// Return true if the number of times this backedge is taken is either the
/// value returned by getMax or zero.
bool isMaxOrZero(ScalarEvolution *SE) const;

/// Return true if any backedge taken count expressions refer to the given
/// subexpression.
bool hasOperand(const SCEV *S, ScalarEvolution *SE) const;
Expand Down Expand Up @@ -1354,6 +1364,10 @@ class ScalarEvolution {
/// that is known never to be less than the actual backedge taken count.
const SCEV *getMaxBackedgeTakenCount(const Loop *L);

/// Return true if the backedge taken count is either the value returned by
/// getMaxBackedgeTakenCount or zero.
bool isBackedgeTakenCountMaxOrZero(const Loop *L);

/// Return true if the specified loop has an analyzable loop-invariant
/// backedge-taken count.
bool hasLoopInvariantBackedgeTakenCount(const Loop *L);
Expand Down
5 changes: 3 additions & 2 deletions include/llvm/Transforms/Utils/UnrollLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class ScalarEvolution;

bool UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
bool AllowRuntime, bool AllowExpensiveTripCount,
bool UseUpperBound, unsigned TripMultiple, LoopInfo *LI,
ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
bool PreserveCondBr, bool PreserveOnlyFirst,
unsigned TripMultiple, LoopInfo *LI, ScalarEvolution *SE,
DominatorTree *DT, AssumptionCache *AC,
OptimizationRemarkEmitter *ORE, bool PreserveLCSSA);

bool UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,
Expand Down
57 changes: 40 additions & 17 deletions lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5424,6 +5424,10 @@ const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
return getBackedgeTakenInfo(L).getMax(this);
}

bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
return getBackedgeTakenInfo(L).isMaxOrZero(this);
}

/// Push PHI nodes in the header of the given loop onto the given Worklist.
static void
PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
Expand Down Expand Up @@ -5656,6 +5660,13 @@ ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
return getMax();
}

bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const {
auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
return !ENT.hasAlwaysTruePredicate();
};
return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
}

bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
ScalarEvolution *SE) const {
if (getMax() && getMax() != SE->getCouldNotCompute() &&
Expand All @@ -5675,8 +5686,8 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo>
&&ExitCounts,
bool Complete, const SCEV *MaxCount)
: MaxAndComplete(MaxCount, Complete) {
bool Complete, const SCEV *MaxCount, bool MaxOrZero)
: MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
ExitNotTaken.reserve(ExitCounts.size());
std::transform(
Expand Down Expand Up @@ -5714,6 +5725,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
const SCEV *MustExitMaxBECount = nullptr;
const SCEV *MayExitMaxBECount = nullptr;
bool MustExitMaxOrZero = false;

// Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
// and compute maxBECount.
Expand Down Expand Up @@ -5746,9 +5758,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
// computable EL.MaxNotTaken.
if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
DT.dominates(ExitBB, Latch)) {
if (!MustExitMaxBECount)
if (!MustExitMaxBECount) {
MustExitMaxBECount = EL.MaxNotTaken;
else {
MustExitMaxOrZero = EL.MaxOrZero;
} else {
MustExitMaxBECount =
getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
}
Expand All @@ -5763,8 +5776,11 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
}
const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
(MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
// The loop backedge will be taken the maximum or zero times if there's
// a single exit that must be taken the maximum or zero times.
bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
MaxBECount);
MaxBECount, MaxOrZero);
}

ScalarEvolution::ExitLimit
Expand Down Expand Up @@ -5901,7 +5917,8 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
!isa<SCEVCouldNotCompute>(BECount))
MaxBECount = BECount;

return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
return ExitLimit(BECount, MaxBECount, false,
{&EL0.Predicates, &EL1.Predicates});
}
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
Expand Down Expand Up @@ -5940,7 +5957,8 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
BECount = EL0.ExactNotTaken;
}

return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
return ExitLimit(BECount, MaxBECount, false,
{&EL0.Predicates, &EL1.Predicates});
}
}

Expand Down Expand Up @@ -6325,7 +6343,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *UpperBound =
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
return ExitLimit(getCouldNotCompute(), UpperBound);
return ExitLimit(getCouldNotCompute(), UpperBound, false);
}

return getCouldNotCompute();
Expand Down Expand Up @@ -7121,7 +7139,8 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// should not accept a root of 2.
const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
if (Val->isZero())
return ExitLimit(R1, R1, Predicates); // We found a quadratic root!
// We found a quadratic root!
return ExitLimit(R1, R1, false, Predicates);
}
}
return getCouldNotCompute();
Expand Down Expand Up @@ -7178,7 +7197,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
else
MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
: -CR.getUnsignedMin());
return ExitLimit(Distance, MaxBECount, Predicates);
return ExitLimit(Distance, MaxBECount, false, Predicates);
}

// As a special case, handle the instance where Step is a positive power of
Expand Down Expand Up @@ -7233,7 +7252,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,

const SCEV *Limit =
getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
return ExitLimit(Limit, Limit, Predicates);
return ExitLimit(Limit, Limit, false, Predicates);
}
}

Expand All @@ -7246,14 +7265,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
loopHasNoAbnormalExits(AddRec->getLoop())) {
const SCEV *Exact =
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
return ExitLimit(Exact, Exact, Predicates);
return ExitLimit(Exact, Exact, false, Predicates);
}

// Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
const SCEV *E = SolveLinEquationWithOverflow(
StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
return ExitLimit(E, E, Predicates);
return ExitLimit(E, E, false, Predicates);
}
return getCouldNotCompute();
}
Expand Down Expand Up @@ -8695,14 +8714,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
}

const SCEV *MaxBECount;
bool MaxOrZero = false;
if (isa<SCEVConstant>(BECount))
MaxBECount = BECount;
else if (isa<SCEVConstant>(BECountIfBackedgeTaken))
else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
// If we know exactly how many times the backedge will be taken if it's
// taken at least once, then the backedge count will either be that or
// zero.
MaxBECount = BECountIfBackedgeTaken;
else {
MaxOrZero = true;
} else {
// Calculate the maximum backedge count based on the range of values
// permitted by Start, End, and Stride.
APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin()
Expand Down Expand Up @@ -8739,7 +8760,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;

return ExitLimit(BECount, MaxBECount, Predicates);
return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
}

ScalarEvolution::ExitLimit
Expand Down Expand Up @@ -8816,7 +8837,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;

return ExitLimit(BECount, MaxBECount, Predicates);
return ExitLimit(BECount, MaxBECount, false, Predicates);
}

const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
Expand Down Expand Up @@ -9598,6 +9619,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,

if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
if (SE->isBackedgeTakenCountMaxOrZero(L))
OS << ", actual taken count either this or zero.";
} else {
OS << "Unpredictable max backedge-taken count. ";
}
Expand Down
28 changes: 18 additions & 10 deletions lib/Transforms/Scalar/LoopUnrollPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,14 +1000,22 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
if (Convergent)
UP.AllowRemainder = false;

// Try to find the trip count upper bound if it is allowed and we cannot find
// exact trip count.
if (UP.UpperBound) {
if (!TripCount) {
MaxTripCount = SE->getSmallConstantMaxTripCount(L);
// Only unroll with small upper bound.
if (MaxTripCount > UnrollMaxUpperBound)
MaxTripCount = 0;
// Try to find the trip count upper bound if we cannot find the exact trip
// count.
bool MaxOrZero = false;
if (!TripCount) {
MaxTripCount = SE->getSmallConstantMaxTripCount(L);
MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
// We can unroll by the upper bound amount if it's generally allowed or if
// we know that the loop is executed either the upper bound or zero times.
// (MaxOrZero unrolling keeps only the first loop test, so the number of
// loop tests remains the same compared to the non-unrolled version, whereas
// the generic upper bound unrolling keeps all but the last loop test so the
// number of loop tests goes up which may end up being worse on targets with
// constriained branch predictor resources so is controlled by an option.)
// In addition we only unroll small upper bounds.
if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) {
MaxTripCount = 0;
}
}

Expand All @@ -1025,8 +1033,8 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,

// Unroll the loop.
if (!UnrollLoop(L, UP.Count, TripCount, UP.Force, UP.Runtime,
UP.AllowExpensiveTripCount, UseUpperBound, TripMultiple, LI,
SE, &DT, &AC, &ORE, PreserveLCSSA))
UP.AllowExpensiveTripCount, UseUpperBound, MaxOrZero,
TripMultiple, LI, SE, &DT, &AC, &ORE, PreserveLCSSA))
return false;

// If loop has an unroll count pragma or unrolled by explicitly set count
Expand Down
13 changes: 7 additions & 6 deletions lib/Transforms/Utils/LoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks,
///
/// PreserveCondBr indicates whether the conditional branch of the LatchBlock
/// needs to be preserved. It is needed when we use trip count upper bound to
/// fully unroll the loop.
/// fully unroll the loop. If PreserveOnlyFirst is also set then only the first
/// conditional branch needs to be preserved.
///
/// Similarly, TripMultiple divides the number of times that the LatchBlock may
/// execute without exiting the loop.
Expand All @@ -207,10 +208,10 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks,
/// DominatorTree if they are non-null.
bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
bool AllowRuntime, bool AllowExpensiveTripCount,
bool PreserveCondBr, unsigned TripMultiple, LoopInfo *LI,
ScalarEvolution *SE, DominatorTree *DT,
AssumptionCache *AC, OptimizationRemarkEmitter *ORE,
bool PreserveLCSSA) {
bool PreserveCondBr, bool PreserveOnlyFirst,
unsigned TripMultiple, LoopInfo *LI, ScalarEvolution *SE,
DominatorTree *DT, AssumptionCache *AC,
OptimizationRemarkEmitter *ORE, bool PreserveLCSSA) {
BasicBlock *Preheader = L->getLoopPreheader();
if (!Preheader) {
DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n");
Expand Down Expand Up @@ -550,7 +551,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
assert(NeedConditional &&
"NeedCondition cannot be modified by both complete "
"unrolling and runtime unrolling");
NeedConditional = (PreserveCondBr && j);
NeedConditional = (PreserveCondBr && j && !(PreserveOnlyFirst && i != 0));
} else if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) {
// If we know the trip count or a multiple of it, we can safely use an
// unconditional branch for some iterations.
Expand Down
8 changes: 4 additions & 4 deletions test/Analysis/ScalarEvolution/trip-count13.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ loop:

; CHECK-LABEL: Determining loop execution counts for: @u_0
; CHECK-NEXT: Loop %loop: backedge-taken count is (-100 + (-1 * %rhs) + ((100 + %rhs) umax %rhs))
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.

leave:
ret void
Expand All @@ -34,7 +34,7 @@ loop:

; CHECK-LABEL: Determining loop execution counts for: @u_1
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-1 * %start) + ((-100 + %start) umax %start))
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.

leave:
ret void
Expand All @@ -54,7 +54,7 @@ loop:

; CHECK-LABEL: Determining loop execution counts for: @s_0
; CHECK-NEXT: Loop %loop: backedge-taken count is (-100 + (-1 * %rhs) + ((100 + %rhs) smax %rhs))
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.

leave:
ret void
Expand All @@ -74,7 +74,7 @@ loop:

; CHECK-LABEL: Determining loop execution counts for: @s_1
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-1 * %start) + ((-100 + %start) smax %start))
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.

leave:
ret void
Expand Down
Loading

0 comments on commit 9e0c61c

Please sign in to comment.