Skip to content

Commit

Permalink
[CallSiteSplitting] Refactor creating callsites.
Browse files Browse the repository at this point in the history
Summary:
This change makes the call site creation more general if any of the
arguments is predicated on a condition in the call site's predecessors.

If we find a callsite, that potentially can be split, we collect the set
of conditions for the call site's predecessors (currently only 2
predecessors are allowed). To do that, we traverse each predecessor's
predecessors as long as it only has single predecessors and record the
condition, if it is relevant to the call site. For each condition, we
also check if the condition is taken or not. In case it is not taken,
we record the inverse predicate.

We use the recorded conditions to create the new call sites and split
the basic block.

This has 2 benefits: (1) it is slightly easier to see what is going on
(IMO) and (2) we can easily extend it to handle more complex control
flow.

Reviewers: davidxl, junbuml

Reviewed By: junbuml

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D40728

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@320547 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
fhahn committed Dec 13, 2017
1 parent fa621d2 commit bb04a0e
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 115 deletions.
183 changes: 68 additions & 115 deletions lib/Transforms/Scalar/CallSiteSplitting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@ using namespace PatternMatch;

STATISTIC(NumCallSiteSplit, "Number of call-site split");

static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI,
static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI,
Value *Op) {
if (!NewCallI)
NewCallI = CallI->clone();
CallSite CS(NewCallI);
unsigned ArgNo = 0;
for (auto &I : CS.args()) {
Expand All @@ -85,10 +83,8 @@ static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI,
}
}

static void setConstantInArgument(Instruction *CallI, Instruction *&NewCallI,
static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI,
Value *Op, Constant *ConstValue) {
if (!NewCallI)
NewCallI = CallI->clone();
CallSite CS(NewCallI);
unsigned ArgNo = 0;
for (auto &I : CS.args()) {
Expand All @@ -114,99 +110,69 @@ static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) {
return false;
}

static SmallVector<BranchInst *, 2>
findOrCondRelevantToCallArgument(CallSite CS) {
SmallVector<BranchInst *, 2> BranchInsts;
for (auto PredBB : predecessors(CS.getInstruction()->getParent())) {
auto *PBI = dyn_cast<BranchInst>(PredBB->getTerminator());
if (!PBI || !PBI->isConditional())
continue;
/// If From has a conditional jump to To, add the condition to Conditions,
/// if it is relevant to any argument at CS.
static void
recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To,
SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
auto *BI = dyn_cast<BranchInst>(From->getTerminator());
if (!BI || !BI->isConditional())
return;

CmpInst::Predicate Pred;
Value *Cond = BI->getCondition();
if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
return;

ICmpInst *Cmp = cast<ICmpInst>(Cond);
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
if (isCondRelevantToAnyCallArgument(Cmp, CS))
Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To
? Pred
: Cmp->getInversePredicate()});
}

CmpInst::Predicate Pred;
Value *Cond = PBI->getCondition();
if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
continue;
ICmpInst *Cmp = cast<ICmpInst>(Cond);
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
if (isCondRelevantToAnyCallArgument(Cmp, CS))
BranchInsts.push_back(PBI);
/// Record ICmp conditions relevant to any argument in CS following Pred's
/// single successors. If there are conflicting conditions along a path, like
/// x == 1 and x == 0, the first condition will be used.
static void
recordConditions(const CallSite &CS, BasicBlock *Pred,
SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions);
BasicBlock *From = Pred;
BasicBlock *To = Pred;
SmallPtrSet<BasicBlock *, 4> Visited = {From};
while (!Visited.count(From->getSinglePredecessor()) &&
(From = From->getSinglePredecessor())) {
recordCondition(CS, From, To, Conditions);
To = From;
}
return BranchInsts;
}

static bool tryCreateCallSitesOnOrPredicatedArgument(
CallSite CS, Instruction *&NewCSTakenFromHeader,
Instruction *&NewCSTakenFromNextCond, BasicBlock *HeaderBB) {
auto BranchInsts = findOrCondRelevantToCallArgument(CS);
assert(BranchInsts.size() <= 2 &&
"Unexpected number of blocks in the OR predicated condition");
Instruction *Instr = CS.getInstruction();
BasicBlock *CallSiteBB = Instr->getParent();
TerminatorInst *HeaderTI = HeaderBB->getTerminator();
bool IsCSInTakenPath = CallSiteBB == HeaderTI->getSuccessor(0);

for (auto *PBI : BranchInsts) {
assert(isa<ICmpInst>(PBI->getCondition()) &&
"Unexpected condition in a conditional branch.");
ICmpInst *Cmp = cast<ICmpInst>(PBI->getCondition());
Value *Arg = Cmp->getOperand(0);
assert(isa<Constant>(Cmp->getOperand(1)) &&
"Expected op1 to be a constant.");
Constant *ConstVal = cast<Constant>(Cmp->getOperand(1));
CmpInst::Predicate Pred = Cmp->getPredicate();

if (PBI->getParent() == HeaderBB) {
Instruction *&CallTakenFromHeader =
IsCSInTakenPath ? NewCSTakenFromHeader : NewCSTakenFromNextCond;
Instruction *&CallUntakenFromHeader =
IsCSInTakenPath ? NewCSTakenFromNextCond : NewCSTakenFromHeader;

assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
"Unexpected predicate in an OR condition");

// Set the constant value for agruments in the call predicated based on
// the OR condition.
Instruction *&CallToSetConst = Pred == ICmpInst::ICMP_EQ
? CallTakenFromHeader
: CallUntakenFromHeader;
setConstantInArgument(Instr, CallToSetConst, Arg, ConstVal);

// Add the NonNull attribute if compared with the null pointer.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
Instruction *&CallToSetAttr = Pred == ICmpInst::ICMP_EQ
? CallUntakenFromHeader
: CallTakenFromHeader;
addNonNullAttribute(Instr, CallToSetAttr, Arg);
}
continue;
}

if (Pred == ICmpInst::ICMP_EQ) {
if (PBI->getSuccessor(0) == Instr->getParent()) {
// Set the constant value for the call taken from the second block in
// the OR condition.
setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
} else {
// Add the NonNull attribute if compared with the null pointer for the
// call taken from the second block in the OR condition.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
}
} else {
if (PBI->getSuccessor(0) == Instr->getParent()) {
// Add the NonNull attribute if compared with the null pointer for the
// call taken from the second block in the OR condition.
if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
} else if (Pred == ICmpInst::ICMP_NE) {
// Set the constant value for the call in the untaken path from the
// header block.
setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
} else
llvm_unreachable("Unexpected condition");
static Instruction *
addConditions(CallSite &CS,
SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) {
if (Conditions.empty())
return nullptr;

Instruction *NewCI = CS.getInstruction()->clone();
for (auto &Cond : Conditions) {
Value *Arg = Cond.first->getOperand(0);
Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1));
if (Cond.second == ICmpInst::ICMP_EQ)
setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal);
else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
assert(Cond.second == ICmpInst::ICMP_NE);
addNonNullAttribute(CS.getInstruction(), NewCI, Arg);
}
}
return NewCSTakenFromHeader || NewCSTakenFromNextCond;
return NewCI;
}

static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
return Preds;
}

static bool canSplitCallSite(CallSite CS) {
Expand Down Expand Up @@ -358,12 +324,6 @@ static bool isPredicatedOnPHI(CallSite CS) {
return false;
}

static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
return Preds;
}

static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) {
if (!isPredicatedOnPHI(CS))
return false;
Expand All @@ -383,26 +343,19 @@ static bool isOrHeader(BasicBlock *HeaderBB, BasicBlock *OrBB) {

static bool tryToSplitOnOrPredicatedArgument(CallSite CS) {
auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
BasicBlock *HeaderBB = nullptr;
BasicBlock *OrBB = nullptr;
if (isOrHeader(Preds[0], Preds[1])) {
HeaderBB = Preds[0];
OrBB = Preds[1];
} else if (isOrHeader(Preds[1], Preds[0])) {
HeaderBB = Preds[1];
OrBB = Preds[0];
} else
if (!isOrHeader(Preds[0], Preds[1]) && !isOrHeader(Preds[1], Preds[0]))
return false;

Instruction *CallInst1 = nullptr;
Instruction *CallInst2 = nullptr;
if (!tryCreateCallSitesOnOrPredicatedArgument(CS, CallInst1, CallInst2,
HeaderBB)) {
assert(!CallInst1 && !CallInst2 && "Unexpected new call-sites cloned.");
SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2;
recordConditions(CS, Preds[0], C1);
recordConditions(CS, Preds[1], C2);

Instruction *CallInst1 = addConditions(CS, C1);
Instruction *CallInst2 = addConditions(CS, C2);
if (!CallInst1 && !CallInst2)
return false;
}

splitCallSite(CS, HeaderBB, OrBB, CallInst1, CallInst2);
splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1);
return true;
}

Expand Down
Loading

0 comments on commit bb04a0e

Please sign in to comment.