Skip to content

Commit

Permalink
[ReachingDefAnalysis] Turn MBBReachingDefsInfo into a proper class (N…
Browse files Browse the repository at this point in the history
…FC) (llvm#110432)

I'm trying to speed up the reaching def analysis by changing the
underlying data structure.  Turning MBBReachingDefsInfo into a proper
class decouples the data structure and its users.  This patch does not
change the existing three-dimensional vector structure.

---------

Co-authored-by: Nikita Popov <[email protected]>
  • Loading branch information
kazutakahirata and nikic authored Sep 30, 2024
1 parent 1efd122 commit 64f2bff
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 25 deletions.
50 changes: 44 additions & 6 deletions llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,50 @@ struct PointerLikeTypeTraits<ReachingDef> {
}
};

// The storage for all reaching definitions.
class MBBReachingDefsInfo {
public:
void init(unsigned NumBlockIDs) { AllReachingDefs.resize(NumBlockIDs); }

unsigned numBlockIDs() const { return AllReachingDefs.size(); }

void startBasicBlock(unsigned MBBNumber, unsigned NumRegUnits) {
AllReachingDefs[MBBNumber].resize(NumRegUnits);
}

void append(unsigned MBBNumber, unsigned Unit, int Def) {
AllReachingDefs[MBBNumber][Unit].push_back(Def);
}

void prepend(unsigned MBBNumber, unsigned Unit, int Def) {
auto &Defs = AllReachingDefs[MBBNumber][Unit];
Defs.insert(Defs.begin(), Def);
}

void replaceFront(unsigned MBBNumber, unsigned Unit, int Def) {
assert(!AllReachingDefs[MBBNumber][Unit].empty());
*AllReachingDefs[MBBNumber][Unit].begin() = Def;
}

void clear() { AllReachingDefs.clear(); }

ArrayRef<ReachingDef> defs(unsigned MBBNumber, unsigned Unit) const {
if (AllReachingDefs[MBBNumber].empty())
// Block IDs are not necessarily dense.
return ArrayRef<ReachingDef>();
return AllReachingDefs[MBBNumber][Unit];
}

private:
/// All reaching defs of a given RegUnit for a given MBB.
using MBBRegUnitDefs = TinyPtrVector<ReachingDef>;
/// All reaching defs of all reg units for a given MBB
using MBBDefsInfo = std::vector<MBBRegUnitDefs>;

/// All reaching defs of all reg units for all MBBs
SmallVector<MBBDefsInfo, 4> AllReachingDefs;
};

/// This class provides the reaching def analysis.
class ReachingDefAnalysis : public MachineFunctionPass {
private:
Expand Down Expand Up @@ -93,12 +137,6 @@ class ReachingDefAnalysis : public MachineFunctionPass {
/// their basic blocks.
DenseMap<MachineInstr *, int> InstIds;

/// All reaching defs of a given RegUnit for a given MBB.
using MBBRegUnitDefs = TinyPtrVector<ReachingDef>;
/// All reaching defs of all reg units for a given MBB
using MBBDefsInfo = std::vector<MBBRegUnitDefs>;
/// All reaching defs of all reg units for a all MBBs
using MBBReachingDefsInfo = SmallVector<MBBDefsInfo, 4>;
MBBReachingDefsInfo MBBReachingDefs;

/// Default values are 'nothing happened a long time ago'.
Expand Down
40 changes: 21 additions & 19 deletions llvm/lib/CodeGen/ReachingDefAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ static bool isValidRegDefOf(const MachineOperand &MO, MCRegister PhysReg,

void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
unsigned MBBNumber = MBB->getNumber();
assert(MBBNumber < MBBReachingDefs.size() &&
assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");
MBBReachingDefs[MBBNumber].resize(NumRegUnits);
MBBReachingDefs.startBasicBlock(MBBNumber, NumRegUnits);

// Reset instruction counter in each basic block.
CurInstr = 0;
Expand All @@ -71,7 +71,7 @@ void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
// before the call.
if (LiveRegs[Unit] != -1) {
LiveRegs[Unit] = -1;
MBBReachingDefs[MBBNumber][Unit].push_back(-1);
MBBReachingDefs.append(MBBNumber, Unit, -1);
}
}
}
Expand All @@ -97,7 +97,7 @@ void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
// Insert the most recent reaching definition we found.
for (unsigned Unit = 0; Unit != NumRegUnits; ++Unit)
if (LiveRegs[Unit] != ReachingDefDefaultVal)
MBBReachingDefs[MBBNumber][Unit].push_back(LiveRegs[Unit]);
MBBReachingDefs.append(MBBNumber, Unit, LiveRegs[Unit]);
}

void ReachingDefAnalysis::leaveBasicBlock(MachineBasicBlock *MBB) {
Expand All @@ -122,7 +122,7 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
assert(!MI->isDebugInstr() && "Won't process debug instructions");

unsigned MBBNumber = MI->getParent()->getNumber();
assert(MBBNumber < MBBReachingDefs.size() &&
assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");

for (auto &MO : MI->operands()) {
Expand All @@ -136,7 +136,7 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
// How many instructions since this reg unit was last written?
if (LiveRegs[Unit] != CurInstr) {
LiveRegs[Unit] = CurInstr;
MBBReachingDefs[MBBNumber][Unit].push_back(CurInstr);
MBBReachingDefs.append(MBBNumber, Unit, CurInstr);
}
}
}
Expand All @@ -146,7 +146,7 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {

void ReachingDefAnalysis::reprocessBasicBlock(MachineBasicBlock *MBB) {
unsigned MBBNumber = MBB->getNumber();
assert(MBBNumber < MBBReachingDefs.size() &&
assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");

// Count number of non-debug instructions for end of block adjustment.
Expand All @@ -169,16 +169,16 @@ void ReachingDefAnalysis::reprocessBasicBlock(MachineBasicBlock *MBB) {
if (Def == ReachingDefDefaultVal)
continue;

auto Start = MBBReachingDefs[MBBNumber][Unit].begin();
if (Start != MBBReachingDefs[MBBNumber][Unit].end() && *Start < 0) {
if (*Start >= Def)
auto Defs = MBBReachingDefs.defs(MBBNumber, Unit);
if (!Defs.empty() && Defs.front() < 0) {
if (Defs.front() >= Def)
continue;

// Update existing reaching def from predecessor to a more recent one.
*Start = Def;
MBBReachingDefs.replaceFront(MBBNumber, Unit, Def);
} else {
// Insert new reaching def from predecessor.
MBBReachingDefs[MBBNumber][Unit].insert(Start, Def);
MBBReachingDefs.prepend(MBBNumber, Unit, Def);
}

// Update reaching def at end of BB. Keep in mind that these are
Expand Down Expand Up @@ -234,7 +234,7 @@ void ReachingDefAnalysis::reset() {

void ReachingDefAnalysis::init() {
NumRegUnits = TRI->getNumRegUnits();
MBBReachingDefs.resize(MF->getNumBlockIDs());
MBBReachingDefs.init(MF->getNumBlockIDs());
// Initialize the MBBOutRegsInfos
MBBOutRegsInfos.resize(MF->getNumBlockIDs());
LoopTraversal Traversal;
Expand All @@ -247,10 +247,11 @@ void ReachingDefAnalysis::traverse() {
processBasicBlock(TraversedMBB);
#ifndef NDEBUG
// Make sure reaching defs are sorted and unique.
for (MBBDefsInfo &MBBDefs : MBBReachingDefs) {
for (MBBRegUnitDefs &RegUnitDefs : MBBDefs) {
for (unsigned MBBNumber = 0, NumBlockIDs = MF->getNumBlockIDs();
MBBNumber != NumBlockIDs; ++MBBNumber) {
for (unsigned Unit = 0; Unit != NumRegUnits; ++Unit) {
int LastDef = ReachingDefDefaultVal;
for (int Def : RegUnitDefs) {
for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) {
assert(Def > LastDef && "Defs must be sorted and unique");
LastDef = Def;
}
Expand All @@ -265,11 +266,11 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI,
int InstId = InstIds.lookup(MI);
int DefRes = ReachingDefDefaultVal;
unsigned MBBNumber = MI->getParent()->getNumber();
assert(MBBNumber < MBBReachingDefs.size() &&
assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");
int LatestDef = ReachingDefDefaultVal;
for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
for (int Def : MBBReachingDefs[MBBNumber][Unit]) {
for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) {
if (Def >= InstId)
break;
DefRes = Def;
Expand Down Expand Up @@ -299,7 +300,8 @@ bool ReachingDefAnalysis::hasSameReachingDef(MachineInstr *A, MachineInstr *B,

MachineInstr *ReachingDefAnalysis::getInstFromId(MachineBasicBlock *MBB,
int InstId) const {
assert(static_cast<size_t>(MBB->getNumber()) < MBBReachingDefs.size() &&
assert(static_cast<size_t>(MBB->getNumber()) <
MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");
assert(InstId < static_cast<int>(MBB->size()) &&
"Unexpected instruction id.");
Expand Down

0 comments on commit 64f2bff

Please sign in to comment.