Skip to content

Commit

Permalink
[Profile] backward propagate profile info in JumpThreading
Browse files Browse the repository at this point in the history
Differential Revsion: http://reviews.llvm.org/D36864


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@311208 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
david-xl committed Aug 18, 2017
1 parent 066b24c commit 6d92310
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 1 deletion.
113 changes: 112 additions & 1 deletion lib/Transforms/Scalar/JumpThreading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,113 @@ JumpThreadingPass::JumpThreadingPass(int T) {
BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T);
}

/// runOnFunction - Top level algorithm.
// Update branch probability information according to conditional
// branch probablity. This is usually made possible for cloned branches
// in inline instances by the context specific profile in the caller.
// For instance,
//
// [Block PredBB]
// [Branch PredBr]
// if (t) {
// Block A;
// } else {
// Block B;
// }
//
// [Block BB]
// cond = PN([true, %A], [..., %B]); // PHI node
// [Branch CondBr]
// if (cond) {
// ... // P(cond == true) = 1%
// }
//
// Here we know that when block A is taken, c must be true, which means
// P(cond == true | A) = 1
//
// Given that P(cond == true) = P(cond == true | A) * P(A) +
// P(cond == true | B) * P(B)
// we get
// P(cond == true ) = P(A) + P(cond == true | B) * P(B)
//
// which gives us:
// P(A) <= P(c == true), i.e.
// P(t == true) <= P(cond == true)
//
// In other words, if we know P(cond == true), we know that P(t == true)
// can not be greater than 1%.
static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator());
if (!CondBr)
return;

BranchProbability BP;
uint64_t TrueWeight, FalseWeight;
if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight))
return;

// Returns the outgoing edge of the dominating predecessor block
// that leads to the PhiNode's incoming block:
auto GetPredOutEdge =
[](BasicBlock *IncomingBB,
BasicBlock *PhiBB) -> std::pair<BasicBlock *, BasicBlock *> {
auto *PredBB = IncomingBB;
while (auto *SinglePredBB = PredBB->getSinglePredecessor())
PredBB = SinglePredBB;

BranchInst *PredBr = dyn_cast<BranchInst>(IncomingBB->getTerminator());
if (PredBr && PredBr->isConditional())
return {IncomingBB, PhiBB};

return {nullptr, nullptr};
};

for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
Value *PhiOpnd = PN->getIncomingValue(i);
ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd);

if (!CI || !CI->getType()->isIntegerTy(1))
continue;

BP = (CI->isOne() ? BranchProbability::getBranchProbability(
TrueWeight, TrueWeight + FalseWeight)
: BranchProbability::getBranchProbability(
FalseWeight, TrueWeight + FalseWeight));

auto PredOutEdge = GetPredOutEdge(PN->getIncomingBlock(i), BB);
if (!PredOutEdge.first)
return;

BasicBlock *PredBB = PredOutEdge.first;
BranchInst *PredBr = cast<BranchInst>(PredBB->getTerminator());

uint64_t PredTrueWeight, PredFalseWeight;
// FIXME: We currently only set the profile data when it is missing.
// With PGO, this can be used to refine even existing profile data with
// context information. This needs to be done after more performance
// testing.
if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight))
continue;

// We can not infer anything useful when BP >= 50%, because BP is the
// upper bound probability value.
if (BP >= BranchProbability(50, 100))
continue;

SmallVector<uint32_t, 2> Weights;
if (PredBr->getSuccessor(0) == PredOutEdge.second) {
Weights.push_back(BP.getNumerator());
Weights.push_back(BP.getCompl().getNumerator());
} else {
Weights.push_back(BP.getCompl().getNumerator());
Weights.push_back(BP.getNumerator());
}
PredBr->setMetadata(LLVMContext::MD_prof,
MDBuilder(PredBr->getParent()->getContext())
.createBranchWeights(Weights));
}
}

/// runOnFunction - Toplevel algorithm.
///
bool JumpThreading::runOnFunction(Function &F) {
if (skipFunction(F))
Expand Down Expand Up @@ -991,6 +1097,11 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {
if (SimplifyPartiallyRedundantLoad(LI))
return true;

// Before threading, try to propagate profile data backwards:
if (PHINode *PN = dyn_cast<PHINode>(CondInst))
if (PN->getParent() == BB && isa<BranchInst>(BB->getTerminator()))
updatePredecessorProfileMetadata(PN, BB);

// Handle a variety of cases where we are branching on something derived from
// a PHI node in the current block. If we can prove that any predecessors
// compute a predictable value based on a PHI node, thread those predecessors.
Expand Down
37 changes: 37 additions & 0 deletions test/Transforms/JumpThreading/threading_prof1.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; RUN: opt -jump-threading -S < %s | FileCheck %s
; RUN: opt -passes=jump-threading -S < %s | FileCheck %s

define void @test() {
bb:
%tmp = call i32 @a()
%tmp1 = icmp eq i32 %tmp, 1
br i1 %tmp1, label %bb5, label %bb2
; CHECK: br i1 %tmp1,{{.*}} !prof ![[PROF1:[0-9]+]]

bb2: ; preds = %bb
%tmp3 = call i32 @b()
%tmp4 = icmp ne i32 %tmp3, 1
br label %bb5
; CHECK: br i1 %tmp4, {{.*}} !prof ![[PROF2:[0-9]+]]

bb5: ; preds = %bb2, %bb
%tmp6 = phi i1 [ false, %bb ], [ %tmp4, %bb2 ]
br i1 %tmp6, label %bb8, label %bb7, !prof !0

bb7: ; preds = %bb5
call void @bar()
br label %bb8

bb8: ; preds = %bb7, %bb5
ret void
}

declare void @bar()

declare i32 @a()

declare i32 @b()

!0 = !{!"branch_weights", i32 2146410443, i32 1073205}
;CHECK: ![[PROF1]] = !{!"branch_weights", i32 1073205, i32 2146410443}
;CHECK: ![[PROF2]] = !{!"branch_weights", i32 2146410443, i32 1073205}
42 changes: 42 additions & 0 deletions test/Transforms/JumpThreading/threading_prof2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: opt -jump-threading -S < %s | FileCheck %s
; RUN: opt -passes=jump-threading -S < %s | FileCheck %s
define void @test() {
bb:
%tmp = call i32 @a()
%tmp1 = icmp eq i32 %tmp, 1
br i1 %tmp1, label %bb5, label %bb2
; CHECK: br i1 %tmp1,{{.*}} !prof ![[PROF1:[0-9]+]]

bb2:
%tmp3 = call i32 @b()
%tmp4 = icmp ne i32 %tmp3, 1
br label %bb5
; CHECK: br i1 %tmp4, {{.*}} !prof ![[PROF2:[0-9]+]]

bb5:
%tmp6 = phi i1 [ false, %bb ], [ %tmp4, %bb2 ]
br i1 %tmp6, label %bb8, label %bb7, !prof !0

bb7:
call void @bar()
br label %bb9

bb8:
call void @foo()
br label %bb9

bb9:
ret void
}

declare void @bar()

declare void @foo()

declare i32 @a()

declare i32 @b()

!0 = !{!"branch_weights", i32 2146410443, i32 1073205}
;CHECK: ![[PROF1]] = !{!"branch_weights", i32 1073205, i32 2146410443}
;CHECK: ![[PROF2]] = !{!"branch_weights", i32 2146410443, i32 1073205}

0 comments on commit 6d92310

Please sign in to comment.