Skip to content

Commit

Permalink
[LV] Support efficient vectorization of an induction with redundant c…
Browse files Browse the repository at this point in the history
…asts

D30041 extended SCEVPredicateRewriter to improve handling of Phi nodes whose
update chain involves casts; PSCEV can now build an AddRecurrence for some
forms of such phi nodes, under the proper runtime overflow test. This means
that we can identify such phi nodes as an induction, and the loop-vectorizer
can now vectorize such inductions, however inefficiently. The vectorizer
doesn't know that it can ignore the casts, and so it vectorizes them.

This patch records the casts in the InductionDescriptor, so that they could
be marked to be ignored for cost calculation (we use VecValuesToIgnore for
that) and ignored for vectorization/widening/scalarization (i.e. treated as
TriviallyDead).

In addition to marking all these casts to be ignored, we also need to make
sure that each cast is mapped to the right vector value in the vector loop body
(be it a widened, vectorized, or scalarized induction). So whenever an
induction phi is mapped to a vector value (during vectorization/widening/
scalarization), we also map the respective cast instruction (if exists) to that
vector value. (If the phi-update sequence of an induction involves more than one
cast, then the above mapping to vector value is relevant only for the last cast
of the sequence as we allow only the "last cast" to be used outside the
induction update chain itself).

This is the last step in addressing PR30654.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@320672 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
dnuzman committed Dec 14, 2017
1 parent 5b7b2d5 commit 3ce8b66
Showing 6 changed files with 492 additions and 25 deletions.
5 changes: 5 additions & 0 deletions include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
@@ -1884,6 +1884,11 @@ class PredicatedScalarEvolution {
/// The printed text is indented by \p Depth.
void print(raw_ostream &OS, unsigned Depth) const;

/// Check if \p AR1 and \p AR2 are equal, while taking into account
/// Equal predicates in Preds.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1,
const SCEVAddRecExpr *AR2) const;

private:
/// Increments the version number of the predicate. This needs to be called
/// every time the SCEV predicate changes.
24 changes: 19 additions & 5 deletions include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
@@ -306,10 +306,13 @@ class InductionDescriptor {
/// induction, the induction descriptor \p D will contain the data describing
/// this induction. If by some other means the caller has a better SCEV
/// expression for \p Phi than the one returned by the ScalarEvolution
/// analysis, it can be passed through \p Expr.
static bool isInductionPHI(PHINode *Phi, const Loop* L, ScalarEvolution *SE,
InductionDescriptor &D,
const SCEV *Expr = nullptr);
/// analysis, it can be passed through \p Expr. If the def-use chain
/// associated with the phi includes casts (that we know we can ignore
/// under proper runtime checks), they are passed through \p CastsToIgnore.
static bool
isInductionPHI(PHINode *Phi, const Loop* L, ScalarEvolution *SE,
InductionDescriptor &D, const SCEV *Expr = nullptr,
SmallVectorImpl<Instruction *> *CastsToIgnore = nullptr);

/// Returns true if \p Phi is a floating point induction in the loop \p L.
/// If \p Phi is an induction, the induction descriptor \p D will contain
@@ -348,10 +351,18 @@ class InductionDescriptor {
Instruction::BinaryOpsEnd;
}

/// Returns a reference to the type cast instructions in the induction
/// update chain, that are redundant when guarded with a runtime
/// SCEV overflow check.
const SmallVectorImpl<Instruction *> &getCastInsts() const {
return RedundantCasts;
}

private:
/// Private constructor - used by \c isInductionPHI.
InductionDescriptor(Value *Start, InductionKind K, const SCEV *Step,
BinaryOperator *InductionBinOp = nullptr);
BinaryOperator *InductionBinOp = nullptr,
SmallVectorImpl<Instruction *> *Casts = nullptr);

/// Start value.
TrackingVH<Value> StartValue;
@@ -361,6 +372,9 @@ class InductionDescriptor {
const SCEV *Step = nullptr;
// Instruction that advances induction variable.
BinaryOperator *InductionBinOp = nullptr;
// Instructions used for type-casts of the induction variable,
// that are redundant when guarded with a runtime SCEV overflow check.
SmallVector<Instruction *, 2> RedundantCasts;
};

BasicBlock *InsertPreheaderForLoop(Loop *L, DominatorTree *DT, LoopInfo *LI,
36 changes: 30 additions & 6 deletions lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
@@ -4732,6 +4732,30 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
return Rewrite;
}

// FIXME: This utility is currently required because the Rewriter currently
// does not rewrite this expression:
// {0, +, (sext ix (trunc iy to ix) to iy)}
// into {0, +, %step},
// even when the following Equal predicate exists:
// "%step == (sext ix (trunc iy to ix) to iy)".
bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
if (AR1 == AR2)
return true;

auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) &&
!Preds.implies(SE.getEqualPredicate(Expr2, Expr1)))
return false;
return true;
};

if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
!areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
return false;
return true;
}

/// A helper function for createAddRecFromPHI to handle simple cases.
///
/// This function tries to find an AddRec expression for the simplest (yet most
@@ -4874,33 +4898,33 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
// indices form a positive value.
if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
Flags = setFlags(Flags, SCEV::FlagNW);

const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
Flags = setFlags(Flags, SCEV::FlagNUW);
}

// We cannot transfer nuw and nsw flags from subtraction
// operations -- sub nuw X, Y is not the same as add nuw X, -Y
// for instance.
}

const SCEV *StartVal = getSCEV(StartValueV);
const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);

// Okay, for the entire analysis of this edge we assumed the PHI
// to be symbolic. We now need to go back and purge all of the
// entries for the scalars that use the symbolic expression.
forgetSymbolicName(PN, SymbolicName);
ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;

// We can add Flags to the post-inc expression only if we
// know that it is *undefined behavior* for BEValueV to
// overflow.
if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
(void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);

return PHISCEV;
}
}
141 changes: 133 additions & 8 deletions lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
@@ -678,7 +678,8 @@ Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder,
}

InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K,
const SCEV *Step, BinaryOperator *BOp)
const SCEV *Step, BinaryOperator *BOp,
SmallVectorImpl<Instruction *> *Casts)
: StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) {
assert(IK != IK_NoInduction && "Not an induction");

@@ -705,6 +706,12 @@ InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K,
(InductionBinOp->getOpcode() == Instruction::FAdd ||
InductionBinOp->getOpcode() == Instruction::FSub))) &&
"Binary opcode should be specified for FP induction");

if (Casts) {
for (auto &Inst : *Casts) {
RedundantCasts.push_back(Inst);
}
}
}

int InductionDescriptor::getConsecutiveDirection() const {
@@ -808,7 +815,7 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop,
StartValue = Phi->getIncomingValue(1);
} else {
assert(TheLoop->contains(Phi->getIncomingBlock(1)) &&
"Unexpected Phi node in the loop");
"Unexpected Phi node in the loop");
BEValue = Phi->getIncomingValue(1);
StartValue = Phi->getIncomingValue(0);
}
@@ -841,6 +848,110 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop,
return true;
}

/// This function is called when we suspect that the update-chain of a phi node
/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts,
/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime
/// predicate P under which the SCEV expression for the phi can be the
/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the
/// cast instructions that are involved in the update-chain of this induction.
/// A caller that adds the required runtime predicate can be free to drop these
/// cast instructions, and compute the phi using \p AR (instead of some scev
/// expression with casts).
///
/// For example, without a predicate the scev expression can take the following
/// form:
/// (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy)
///
/// It corresponds to the following IR sequence:
/// %for.body:
/// %x = phi i64 [ 0, %ph ], [ %add, %for.body ]
/// %casted_phi = "ExtTrunc i64 %x"
/// %add = add i64 %casted_phi, %step
///
/// where %x is given in \p PN,
/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate,
/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of
/// several forms, for example, such as:
/// ExtTrunc1: %casted_phi = and %x, 2^n-1
/// or:
/// ExtTrunc2: %t = shl %x, m
/// %casted_phi = ashr %t, m
///
/// If we are able to find such sequence, we return the instructions
/// we found, namely %casted_phi and the instructions on its use-def chain up
/// to the phi (not including the phi).
bool getCastsForInductionPHI(
PredicatedScalarEvolution &PSE, const SCEVUnknown *PhiScev,
const SCEVAddRecExpr *AR, SmallVectorImpl<Instruction *> &CastInsts) {

assert(CastInsts.empty() && "CastInsts is expected to be empty.");
auto *PN = cast<PHINode>(PhiScev->getValue());
assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression");
const Loop *L = AR->getLoop();

// Find any cast instructions that participate in the def-use chain of
// PhiScev in the loop.
// FORNOW/TODO: We currently expect the def-use chain to include only
// two-operand instructions, where one of the operands is an invariant.
// createAddRecFromPHIWithCasts() currently does not support anything more
// involved than that, so we keep the search simple. This can be
// extended/generalized as needed.

auto getDef = [&](const Value *Val) -> Value * {
const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val);
if (!BinOp)
return nullptr;
Value *Op0 = BinOp->getOperand(0);
Value *Op1 = BinOp->getOperand(1);
Value *Def = nullptr;
if (L->isLoopInvariant(Op0))
Def = Op1;
else if (L->isLoopInvariant(Op1))
Def = Op0;
return Def;
};

// Look for the instruction that defines the induction via the
// loop backedge.
BasicBlock *Latch = L->getLoopLatch();
if (!Latch)
return false;
Value *Val = PN->getIncomingValueForBlock(Latch);
if (!Val)
return false;

// Follow the def-use chain until the induction phi is reached.
// If on the way we encounter a Value that has the same SCEV Expr as the
// phi node, we can consider the instructions we visit from that point
// as part of the cast-sequence that can be ignored.
bool InCastSequence = false;
auto *Inst = dyn_cast<Instruction>(Val);
while (Val != PN) {
// If we encountered a phi node other than PN, or if we left the loop,
// we bail out.
if (!Inst || !L->contains(Inst)) {
return false;
}
auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val));
if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR))
InCastSequence = true;
if (InCastSequence) {
// Only the last instruction in the cast sequence is expected to have
// uses outside the induction def-use chain.
if (!CastInsts.empty())
if (!Inst->hasOneUse())
return false;
CastInsts.push_back(Inst);
}
Val = getDef(Val);
if (!Val)
return false;
Inst = dyn_cast<Instruction>(Val);
}

return InCastSequence;
}

bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
PredicatedScalarEvolution &PSE,
InductionDescriptor &D,
@@ -870,13 +981,26 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
return false;
}

// Record any Cast instructions that participate in the induction update
const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev);
// If we started from an UnknownSCEV, and managed to build an addRecurrence
// only after enabling Assume with PSCEV, this means we may have encountered
// cast instructions that required adding a runtime check in order to
// guarantee the correctness of the AddRecurence respresentation of the
// induction.
if (PhiScev != AR && SymbolicPhi) {
SmallVector<Instruction *, 2> Casts;
if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts))
return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts);
}

return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR);
}

bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
ScalarEvolution *SE,
InductionDescriptor &D,
const SCEV *Expr) {
bool InductionDescriptor::isInductionPHI(
PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE,
InductionDescriptor &D, const SCEV *Expr,
SmallVectorImpl<Instruction *> *CastsToIgnore) {
Type *PhiTy = Phi->getType();
// We only handle integer and pointer inductions variables.
if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy())
@@ -895,7 +1019,7 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
// FIXME: We should treat this as a uniform. Unfortunately, we
// don't currently know how to handled uniform PHIs.
DEBUG(dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n");
return false;
return false;
}

Value *StartValue =
@@ -908,7 +1032,8 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
return false;

if (PhiTy->isIntegerTy()) {
D = InductionDescriptor(StartValue, IK_IntInduction, Step);
D = InductionDescriptor(StartValue, IK_IntInduction, Step, /*BOp=*/ nullptr,
CastsToIgnore);
return true;
}

Loading

0 comments on commit 3ce8b66

Please sign in to comment.