Skip to content

Commit

Permalink
Teach simplifycfg to recompute branch weights when merging some branc…
Browse files Browse the repository at this point in the history
…hes, and

to discard weights when appropriate. Still more to do (and a new TODO), but
it's a start!


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@147286 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
nlewycky committed Dec 27, 2011
1 parent da32cc6 commit 06cc66f
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
67 changes: 67 additions & 0 deletions lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "llvm/GlobalVariable.h"
#include "llvm/Instructions.h"
#include "llvm/IntrinsicInst.h"
#include "llvm/LLVMContext.h"
#include "llvm/Metadata.h"
#include "llvm/Type.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/ValueTracking.h"
Expand Down Expand Up @@ -1462,6 +1464,26 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI,
return true;
}

/// ExtractBranchMetadata - Given a conditional BranchInstruction, retrieve the
/// probabilities of the branch taking each edge. Fills in the two APInt
/// parameters and return true, or returns false if no or invalid metadata was
/// found.
static bool ExtractBranchMetadata(BranchInst *BI,
APInt &ProbTrue, APInt &ProbFalse) {
assert(BI->isConditional() &&
"Looking for probabilities on unconditional branch?");
MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof);
if (!ProfileData || ProfileData->getNumOperands() != 3) return 0;
ConstantInt *CITrue = dyn_cast<ConstantInt>(ProfileData->getOperand(1));
ConstantInt *CIFalse = dyn_cast<ConstantInt>(ProfileData->getOperand(2));
if (!CITrue || !CIFalse) return 0;
ProbTrue = CITrue->getValue();
ProbFalse = CIFalse->getValue();
assert(ProbTrue.getBitWidth() == 32 && ProbFalse.getBitWidth() == 32 &&
"Branch probability metadata must be 32-bit integers");
return true;
}

/// FoldBranchToCommonDest - If this basic block is simple enough, and if a
/// predecessor branches to us and one of our successors, fold the block into
/// the predecessor and use logical operations to pick the right destination.
Expand Down Expand Up @@ -1636,6 +1658,51 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) {
PBI->setSuccessor(1, FalseDest);
}

// TODO: If BB is reachable from all paths through PredBlock, then we
// could replace PBI's branch probabilities with BI's.

// Merge probability data into PredBlock's branch.
APInt A, B, C, D;
if (ExtractBranchMetadata(PBI, C, D) && ExtractBranchMetadata(BI, A, B)) {
// bbA: br bbB (a% probability), bbC (b% prob.)
// bbB: br bbD (c% probability), bbC (d% prob.)
// --> bbA: br bbD ((a*c)% prob.), bbC ((b+a*d)% prob.)
//
// Probabilities aren't stored as ratios directly. Converting to
// probability-numerator form, we get:
// (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.

bool Overflow1 = false, Overflow2 = false, Overflow3 = false;
bool Overflow4 = false, Overflow5 = false, Overflow6 = false;
APInt ProbTrue = A.umul_ov(C, Overflow1);

APInt Tmp1 = A.umul_ov(D, Overflow2);
APInt Tmp2 = B.umul_ov(C, Overflow3);
APInt Tmp3 = B.umul_ov(D, Overflow4);
APInt Tmp4 = Tmp1.uadd_ov(Tmp2, Overflow5);
APInt ProbFalse = Tmp4.uadd_ov(Tmp3, Overflow6);

APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
ProbTrue = ProbTrue.udiv(GCD);
ProbFalse = ProbFalse.udiv(GCD);

if (Overflow1 || Overflow2 || Overflow3 || Overflow4 || Overflow5 ||
Overflow6) {
DEBUG(dbgs() << "Overflow recomputing branch weight on: " << *PBI
<< "when merging with: " << *BI);
PBI->setMetadata(LLVMContext::MD_prof, NULL);
} else {
LLVMContext &Context = BI->getContext();
Value *Ops[3];
Ops[0] = BI->getMetadata(LLVMContext::MD_prof)->getOperand(0);
Ops[1] = ConstantInt::get(Context, ProbTrue);
Ops[2] = ConstantInt::get(Context, ProbFalse);
PBI->setMetadata(LLVMContext::MD_prof, MDNode::get(Context, Ops));
}
} else {
PBI->setMetadata(LLVMContext::MD_prof, NULL);
}

// Copy any debug value intrinsics into the end of PredBlock.
for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
if (isa<DbgInfoIntrinsic>(*I))
Expand Down
66 changes: 64 additions & 2 deletions test/Transforms/SimplifyCFG/preserve-branchweights.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,45 @@ entry:
br i1 %a, label %Y, label %X, !prof !0
; CHECK: br i1 %or.cond, label %Z, label %Y, !prof !0

X:
%c = or i1 %b, false
br i1 %c, label %Z, label %Y, !prof !1

Y:
call void @helper(i32 0)
ret void

Z:
call void @helper(i32 1)
ret void
}

define void @test2(i1 %a, i1 %b) {
; CHECK: @test2
entry:
br i1 %a, label %X, label %Y, !prof !1
; CHECK: br i1 %or.cond, label %Z, label %Y, !prof !1
; CHECK-NOT: !prof

X:
%c = or i1 %b, false
br i1 %c, label %Z, label %Y, !prof !2

Y:
call void @helper(i32 0)
ret void

Z:
call void @helper(i32 1)
ret void
}

define void @test3(i1 %a, i1 %b) {
; CHECK: @test3
; CHECK-NOT: !prof
entry:
br i1 %a, label %X, label %Y, !prof !1

X:
%c = or i1 %b, false
br i1 %c, label %Z, label %Y
Expand All @@ -21,6 +60,29 @@ Z:
ret void
}

!0 = metadata !{metadata !"branch_weights", i32 1, i32 2}
define void @test4(i1 %a, i1 %b) {
; CHECK: @test4
; CHECK-NOT: !prof
entry:
br i1 %a, label %X, label %Y

X:
%c = or i1 %b, false
br i1 %c, label %Z, label %Y, !prof !1

Y:
call void @helper(i32 0)
ret void

Z:
call void @helper(i32 1)
ret void
}

!0 = metadata !{metadata !"branch_weights", i32 3, i32 5}
!1 = metadata !{metadata !"branch_weights", i32 1, i32 1}
!2 = metadata !{metadata !"branch_weights", i32 1, i32 2}

; CHECK: !0 = metadata !{metadata !"branch_weights", i32 2, i32 1}
; CHECK: !0 = metadata !{metadata !"branch_weights", i32 5, i32 11}
; CHECK: !1 = metadata !{metadata !"branch_weights", i32 1, i32 5}
; CHECK-NOT: !2

0 comments on commit 06cc66f

Please sign in to comment.