Skip to content

Commit

Permalink
Retry: [BPI] Use a safer constructor to calculate branch probabilities
Browse files Browse the repository at this point in the history
BPI may trigger signed overflow UB while computing branch probabilities for
cold calls or to unreachables. For example, with our current choice of weights,
we'll crash if there are >= 2^12 branches to an unreachable.

Use a safer BranchProbability constructor which is better at handling fractions
with large denominators.

Changes since the initial commit:
  - Use explicit casts to ensure that multiplication operands are 64-bit
    ints.

rdar://problem/29368161

Differential Revision: https://reviews.llvm.org/D27862

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@290022 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
vedantk committed Dec 17, 2016
1 parent 8739616 commit 318da23
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 12 deletions.
24 changes: 12 additions & 12 deletions lib/Analysis/BranchProbabilityInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
return true;
}

BranchProbability UnreachableProb(UR_TAKEN_WEIGHT,
(UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) *
UnreachableEdges.size());
BranchProbability ReachableProb(UR_NONTAKEN_WEIGHT,
(UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) *
ReachableEdges.size());
auto UnreachableProb = BranchProbability::getBranchProbability(
UR_TAKEN_WEIGHT, (UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) *
uint64_t(UnreachableEdges.size()));
auto ReachableProb = BranchProbability::getBranchProbability(
UR_NONTAKEN_WEIGHT,
(UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) * uint64_t(ReachableEdges.size()));

for (unsigned SuccIdx : UnreachableEdges)
setEdgeProbability(BB, SuccIdx, UnreachableProb);
Expand Down Expand Up @@ -300,12 +300,12 @@ bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
return true;
}

BranchProbability ColdProb(CC_TAKEN_WEIGHT,
(CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) *
ColdEdges.size());
BranchProbability NormalProb(CC_NONTAKEN_WEIGHT,
(CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) *
NormalEdges.size());
auto ColdProb = BranchProbability::getBranchProbability(
CC_TAKEN_WEIGHT,
(CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
auto NormalProb = BranchProbability::getBranchProbability(
CC_NONTAKEN_WEIGHT,
(CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));

for (unsigned SuccIdx : ColdEdges)
setEdgeProbability(BB, SuccIdx, ColdProb);
Expand Down
88 changes: 88 additions & 0 deletions unittests/Analysis/BranchProbabilityInfoTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//===- BranchProbabilityInfoTest.cpp - BranchProbabilityInfo unit tests ---===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/DataTypes.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "gtest/gtest.h"

namespace llvm {
namespace {

struct BranchProbabilityInfoTest : public testing::Test {
std::unique_ptr<BranchProbabilityInfo> BPI;
std::unique_ptr<DominatorTree> DT;
std::unique_ptr<LoopInfo> LI;
LLVMContext C;

BranchProbabilityInfo &buildBPI(Function &F) {
DT.reset(new DominatorTree(F));
LI.reset(new LoopInfo(*DT));
BPI.reset(new BranchProbabilityInfo(F, *LI));
return *BPI;
}

std::unique_ptr<Module> makeLLVMModule() {
const char *ModuleString = "define void @f() { exit: ret void }\n";
SMDiagnostic Err;
return parseAssemblyString(ModuleString, Err, C);
}
};

TEST_F(BranchProbabilityInfoTest, StressUnreachableHeuristic) {
auto M = makeLLVMModule();
Function *F = M->getFunction("f");

// define void @f() {
// entry:
// switch i32 undef, label %exit, [
// i32 0, label %preexit
// ... ;;< Add lots of cases to stress the heuristic.
// ]
// preexit:
// unreachable
// exit:
// ret void
// }

auto *ExitBB = &F->back();
auto *EntryBB = BasicBlock::Create(C, "entry", F, /*insertBefore=*/ExitBB);

auto *PreExitBB =
BasicBlock::Create(C, "preexit", F, /*insertBefore=*/ExitBB);
new UnreachableInst(C, PreExitBB);

unsigned NumCases = 4096;
auto *I32 = IntegerType::get(C, 32);
auto *Undef = UndefValue::get(I32);
auto *Switch = SwitchInst::Create(Undef, ExitBB, NumCases, EntryBB);
for (unsigned I = 0; I < NumCases; ++I)
Switch->addCase(ConstantInt::get(I32, I), PreExitBB);

BranchProbabilityInfo &BPI = buildBPI(*F);

// FIXME: This doesn't seem optimal. Since all of the cases handled by the
// switch have the *same* destination block ("preexit"), shouldn't it be the
// hot one? I'd expect the results to be reversed here...
EXPECT_FALSE(BPI.isEdgeHot(EntryBB, PreExitBB));
EXPECT_TRUE(BPI.isEdgeHot(EntryBB, ExitBB));
}

} // end anonymous namespace
} // end namespace llvm
1 change: 1 addition & 0 deletions unittests/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS
add_llvm_unittest(AnalysisTests
AliasAnalysisTest.cpp
BlockFrequencyInfoTest.cpp
BranchProbabilityInfoTest.cpp
CallGraphTest.cpp
CFGTest.cpp
CGSCCPassManagerTest.cpp
Expand Down

0 comments on commit 318da23

Please sign in to comment.