Skip to content

Commit

Permalink
[IR] Redesign the case iterator in SwitchInst to actually be an iterator
Browse files Browse the repository at this point in the history
and to expose a handle to represent the actual case rather than having
the iterator return a reference to itself.

All of this allows the iterator to be used with common STL facilities,
standard algorithms, etc.

Doing this exposed some missing facilities in the iterator facade that
I've fixed and required some work to the actual iterator to fully
support the necessary API.

Differential Revision: https://reviews.llvm.org/D31548

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@300032 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
chandlerc committed Apr 12, 2017
1 parent 00b7906 commit ddfada2
Show file tree
Hide file tree
Showing 28 changed files with 296 additions and 198 deletions.
6 changes: 6 additions & 0 deletions include/llvm/ADT/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,15 @@ class iterator_facade_base
return !static_cast<const DerivedT *>(this)->operator<(RHS);
}

PointerT operator->() { return &static_cast<DerivedT *>(this)->operator*(); }
PointerT operator->() const {
return &static_cast<const DerivedT *>(this)->operator*();
}
ReferenceProxy operator[](DifferenceTypeT n) {
static_assert(IsRandomAccess,
"Subscripting is only defined for random access iterators.");
return ReferenceProxy(static_cast<DerivedT *>(this)->operator+(n));
}
ReferenceProxy operator[](DifferenceTypeT n) const {
static_assert(IsRandomAccess,
"Subscripting is only defined for random access iterators.");
Expand Down
3 changes: 1 addition & 2 deletions include/llvm/Analysis/CFGPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ struct DOTGraphTraits<const Function*> : public DefaultDOTGraphTraits {

std::string Str;
raw_string_ostream OS(Str);
SwitchInst::ConstCaseIt Case =
SwitchInst::ConstCaseIt::fromSuccessorIndex(SI, SuccNo);
auto Case = *SwitchInst::ConstCaseIt::fromSuccessorIndex(SI, SuccNo);
OS << Case.getCaseValue()->getValue();
return OS.str();
}
Expand Down
191 changes: 118 additions & 73 deletions include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3096,49 +3096,39 @@ class SwitchInst : public TerminatorInst {
// -2
static const unsigned DefaultPseudoIndex = static_cast<unsigned>(~0L-1);

template <class SwitchInstT, class ConstantIntT, class BasicBlockT>
class CaseIteratorT
: public iterator_facade_base<
CaseIteratorT<SwitchInstT, ConstantIntT, BasicBlockT>,
std::random_access_iterator_tag,
CaseIteratorT<SwitchInstT, ConstantIntT, BasicBlockT>> {
template <typename CaseHandleT> class CaseIteratorT;

/// A handle to a particular switch case. It exposes a convenient interface
/// to both the case value and the successor block.
///
/// We define this as a template and instantiate it to form both a const and
/// non-const handle.
template <typename SwitchInstT, typename ConstantIntT, typename BasicBlockT>
class CaseHandleT {
// Directly befriend both const and non-const iterators.
friend class SwitchInst::CaseIteratorT<
CaseHandleT<SwitchInstT, ConstantIntT, BasicBlockT>>;

protected:
// Expose the switch type we're parameterized with to the iterator.
typedef SwitchInstT SwitchInstType;

SwitchInstT *SI;
ptrdiff_t Index;

public:
typedef CaseIteratorT<SwitchInstT, ConstantIntT, BasicBlockT> Self;

/// Default constructed iterator is in an invalid state until assigned to
/// a case for a particular switch.
CaseIteratorT() : SI(nullptr) {}

/// Initializes case iterator for given SwitchInst and for given
/// case number.
CaseIteratorT(SwitchInstT *SI, unsigned CaseNum) {
this->SI = SI;
Index = CaseNum;
}

/// Initializes case iterator for given SwitchInst and for given
/// TerminatorInst's successor index.
static Self fromSuccessorIndex(SwitchInstT *SI, unsigned SuccessorIndex) {
assert(SuccessorIndex < SI->getNumSuccessors() &&
"Successor index # out of range!");
return SuccessorIndex != 0 ?
Self(SI, SuccessorIndex - 1) :
Self(SI, DefaultPseudoIndex);
}
CaseHandleT() = default;
CaseHandleT(SwitchInstT *SI, ptrdiff_t Index) : SI(SI), Index(Index) {}

public:
/// Resolves case value for current case.
ConstantIntT *getCaseValue() {
ConstantIntT *getCaseValue() const {
assert((unsigned)Index < SI->getNumCases() &&
"Index out the number of cases.");
return reinterpret_cast<ConstantIntT *>(SI->getOperand(2 + Index * 2));
}

/// Resolves successor for current case.
BasicBlockT *getCaseSuccessor() {
BasicBlockT *getCaseSuccessor() const {
assert(((unsigned)Index < SI->getNumCases() ||
(unsigned)Index == DefaultPseudoIndex) &&
"Index out the number of cases.");
Expand All @@ -3156,43 +3146,20 @@ class SwitchInst : public TerminatorInst {
return (unsigned)Index != DefaultPseudoIndex ? Index + 1 : 0;
}

Self &operator+=(ptrdiff_t N) {
// Check index correctness after addition.
// Note: Index == getNumCases() means end().
assert(Index + N >= 0 && (unsigned)(Index + N) <= SI->getNumCases() &&
"Index out the number of cases.");
Index += N;
return *this;
}
Self &operator-=(ptrdiff_t N) {
// Check index correctness after subtraction.
// Note: Index == getNumCases() means end().
assert(Index - N >= 0 && (unsigned)(Index - N) <= SI->getNumCases() &&
"Index out the number of cases.");
Index -= N;
return *this;
}
bool operator==(const Self& RHS) const {
bool operator==(const CaseHandleT &RHS) const {
assert(SI == RHS.SI && "Incompatible operators.");
return Index == RHS.Index;
}
bool operator<(const Self& RHS) const {
assert(SI == RHS.SI && "Incompatible operators.");
return Index < RHS.Index;
}
Self &operator*() { return *this; }
const Self &operator*() const { return *this; }
};

typedef CaseIteratorT<const SwitchInst, const ConstantInt, const BasicBlock>
ConstCaseIt;
typedef CaseHandleT<const SwitchInst, const ConstantInt, const BasicBlock>
ConstCaseHandle;

class CaseIt : public CaseIteratorT<SwitchInst, ConstantInt, BasicBlock> {
typedef CaseIteratorT<SwitchInst, ConstantInt, BasicBlock> ParentTy;
class CaseHandle : public CaseHandleT<SwitchInst, ConstantInt, BasicBlock> {
friend class SwitchInst::CaseIteratorT<CaseHandle>;

public:
CaseIt(const ParentTy &Src) : ParentTy(Src) {}
CaseIt(SwitchInst *SI, unsigned CaseNum) : ParentTy(SI, CaseNum) {}
CaseHandle(SwitchInst *SI, ptrdiff_t Index) : CaseHandleT(SI, Index) {}

/// Sets the new value for current case.
void setValue(ConstantInt *V) {
Expand All @@ -3207,6 +3174,74 @@ class SwitchInst : public TerminatorInst {
}
};

template <typename CaseHandleT>
class CaseIteratorT
: public iterator_facade_base<CaseIteratorT<CaseHandleT>,
std::random_access_iterator_tag,
CaseHandleT> {
typedef typename CaseHandleT::SwitchInstType SwitchInstT;

CaseHandleT Case;

public:
/// Default constructed iterator is in an invalid state until assigned to
/// a case for a particular switch.
CaseIteratorT() = default;

/// Initializes case iterator for given SwitchInst and for given
/// case number.
CaseIteratorT(SwitchInstT *SI, unsigned CaseNum) : Case(SI, CaseNum) {}

/// Initializes case iterator for given SwitchInst and for given
/// TerminatorInst's successor index.
static CaseIteratorT fromSuccessorIndex(SwitchInstT *SI,
unsigned SuccessorIndex) {
assert(SuccessorIndex < SI->getNumSuccessors() &&
"Successor index # out of range!");
return SuccessorIndex != 0 ? CaseIteratorT(SI, SuccessorIndex - 1)
: CaseIteratorT(SI, DefaultPseudoIndex);
}

/// Support converting to the const variant. This will be a no-op for const
/// variant.
operator CaseIteratorT<ConstCaseHandle>() const {
return CaseIteratorT<ConstCaseHandle>(Case.SI, Case.Index);
}

CaseIteratorT &operator+=(ptrdiff_t N) {
// Check index correctness after addition.
// Note: Index == getNumCases() means end().
assert(Case.Index + N >= 0 &&
(unsigned)(Case.Index + N) <= Case.SI->getNumCases() &&
"Case.Index out the number of cases.");
Case.Index += N;
return *this;
}
CaseIteratorT &operator-=(ptrdiff_t N) {
// Check index correctness after subtraction.
// Note: Case.Index == getNumCases() means end().
assert(Case.Index - N >= 0 &&
(unsigned)(Case.Index - N) <= Case.SI->getNumCases() &&
"Case.Index out the number of cases.");
Case.Index -= N;
return *this;
}
ptrdiff_t operator-(const CaseIteratorT &RHS) const {
assert(Case.SI == RHS.Case.SI && "Incompatible operators.");
return Case.Index - RHS.Case.Index;
}
bool operator==(const CaseIteratorT &RHS) const { return Case == RHS.Case; }
bool operator<(const CaseIteratorT &RHS) const {
assert(Case.SI == RHS.Case.SI && "Incompatible operators.");
return Case.Index < RHS.Case.Index;
}
CaseHandleT &operator*() { return Case; }
const CaseHandleT &operator*() const { return Case; }
};

typedef CaseIteratorT<CaseHandle> CaseIt;
typedef CaseIteratorT<ConstCaseHandle> ConstCaseIt;

static SwitchInst *Create(Value *Value, BasicBlock *Default,
unsigned NumCases,
Instruction *InsertBefore = nullptr) {
Expand Down Expand Up @@ -3290,30 +3325,40 @@ class SwitchInst : public TerminatorInst {
/// default case iterator to indicate that it is handled by the default
/// handler.
CaseIt findCaseValue(const ConstantInt *C) {
for (CaseIt i = case_begin(), e = case_end(); i != e; ++i)
if (i.getCaseValue() == C)
return i;
CaseIt I = llvm::find_if(
cases(), [C](CaseHandle &Case) { return Case.getCaseValue() == C; });
if (I != case_end())
return I;

return case_default();
}
ConstCaseIt findCaseValue(const ConstantInt *C) const {
for (ConstCaseIt i = case_begin(), e = case_end(); i != e; ++i)
if (i.getCaseValue() == C)
return i;
ConstCaseIt I = llvm::find_if(cases(), [C](ConstCaseHandle &Case) {
return Case.getCaseValue() == C;
});
if (I != case_end())
return I;

return case_default();
}

/// Finds the unique case value for a given successor. Returns null if the
/// successor is not found, not unique, or is the default case.
ConstantInt *findCaseDest(BasicBlock *BB) {
if (BB == getDefaultDest()) return nullptr;
if (BB == getDefaultDest())
return nullptr;

ConstantInt *CI = nullptr;
for (CaseIt i = case_begin(), e = case_end(); i != e; ++i) {
if (i.getCaseSuccessor() == BB) {
if (CI) return nullptr; // Multiple cases lead to BB.
else CI = i.getCaseValue();
}
for (auto Case : cases()) {
if (Case.getCaseSuccessor() != BB)
continue;

if (CI)
return nullptr; // Multiple cases lead to BB.

CI = Case.getCaseValue();
}

return CI;
}

Expand All @@ -3330,7 +3375,7 @@ class SwitchInst : public TerminatorInst {
/// This action invalidates iterators for all cases following the one removed,
/// including the case_end() iterator. It returns an iterator for the next
/// case.
CaseIt removeCase(CaseIt i);
CaseIt removeCase(CaseIt I);

unsigned getNumSuccessors() const { return getNumOperands()/2; }
BasicBlock *getSuccessor(unsigned idx) const {
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/InlineCost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,8 @@ bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) {
// does not (yet) fire.
SmallPtrSet<BasicBlock *, 8> SuccessorBlocks;
SuccessorBlocks.insert(SI.getDefaultDest());
for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I)
SuccessorBlocks.insert(I.getCaseSuccessor());
for (auto Case : SI.cases())
SuccessorBlocks.insert(Case.getCaseSuccessor());
// Add cost corresponding to the number of distinct destinations. The first
// we model as free because of fallthrough.
Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost;
Expand Down Expand Up @@ -1379,7 +1379,7 @@ bool CallAnalyzer::analyzeCall(CallSite CS) {
Value *Cond = SI->getCondition();
if (ConstantInt *SimpleCond =
dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) {
BBWorklist.insert(SI->findCaseValue(SimpleCond).getCaseSuccessor());
BBWorklist.insert(SI->findCaseValue(SimpleCond)->getCaseSuccessor());
continue;
}
}
Expand Down
8 changes: 4 additions & 4 deletions lib/Analysis/LazyValueInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1430,14 +1430,14 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
unsigned BitWidth = Val->getType()->getIntegerBitWidth();
ConstantRange EdgesVals(BitWidth, DefaultCase/*isFullSet*/);

for (SwitchInst::CaseIt i : SI->cases()) {
ConstantRange EdgeVal(i.getCaseValue()->getValue());
for (auto Case : SI->cases()) {
ConstantRange EdgeVal(Case.getCaseValue()->getValue());
if (DefaultCase) {
// It is possible that the default destination is the destination of
// some cases. There is no need to perform difference for those cases.
if (i.getCaseSuccessor() != BBTo)
if (Case.getCaseSuccessor() != BBTo)
EdgesVals = EdgesVals.difference(EdgeVal);
} else if (i.getCaseSuccessor() == BBTo)
} else if (Case.getCaseSuccessor() == BBTo)
EdgesVals = EdgesVals.unionWith(EdgeVal);
}
Result = LVILatticeVal::getRange(std::move(EdgesVals));
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/SparsePropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ void SparseSolver::getFeasibleSuccessors(TerminatorInst &TI,
Succs.assign(TI.getNumSuccessors(), true);
return;
}
SwitchInst::CaseIt Case = SI.findCaseValue(cast<ConstantInt>(C));
SwitchInst::CaseHandle Case = *SI.findCaseValue(cast<ConstantInt>(C));
Succs[Case.getSuccessorIndex()] = true;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Bitcode/Writer/BitcodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2578,7 +2578,7 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
Vals.push_back(VE.getTypeID(SI.getCondition()->getType()));
pushValue(SI.getCondition(), InstID, Vals);
Vals.push_back(VE.getValueID(SI.getDefaultDest()));
for (SwitchInst::ConstCaseIt Case : SI.cases()) {
for (auto Case : SI.cases()) {
Vals.push_back(VE.getValueID(Case.getCaseValue()));
Vals.push_back(VE.getValueID(Case.getCaseSuccessor()));
}
Expand Down
2 changes: 1 addition & 1 deletion lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5457,7 +5457,7 @@ bool CodeGenPrepare::optimizeSwitchInst(SwitchInst *SI) {
auto *ExtInst = CastInst::Create(ExtType, Cond, NewType);
ExtInst->insertBefore(SI);
SI->setCondition(ExtInst);
for (SwitchInst::CaseIt Case : SI->cases()) {
for (auto Case : SI->cases()) {
APInt NarrowConst = Case.getCaseValue()->getValue();
APInt WideConst = (ExtType == Instruction::ZExt) ?
NarrowConst.zext(RegWidth) : NarrowConst.sext(RegWidth);
Expand Down
6 changes: 3 additions & 3 deletions lib/ExecutionEngine/Interpreter/Execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,10 @@ void Interpreter::visitSwitchInst(SwitchInst &I) {

// Check to see if any of the cases match...
BasicBlock *Dest = nullptr;
for (SwitchInst::CaseIt i = I.case_begin(), e = I.case_end(); i != e; ++i) {
GenericValue CaseVal = getOperandValue(i.getCaseValue(), SF);
for (auto Case : I.cases()) {
GenericValue CaseVal = getOperandValue(Case.getCaseValue(), SF);
if (executeICMP_EQ(CondVal, CaseVal, ElTy).IntVal != 0) {
Dest = cast<BasicBlock>(i.getCaseSuccessor());
Dest = cast<BasicBlock>(Case.getCaseSuccessor());
break;
}
}
Expand Down
7 changes: 3 additions & 4 deletions lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2897,12 +2897,11 @@ void AssemblyWriter::printInstruction(const Instruction &I) {
Out << ", ";
writeOperand(SI.getDefaultDest(), true);
Out << " [";
for (SwitchInst::ConstCaseIt i = SI.case_begin(), e = SI.case_end();
i != e; ++i) {
for (auto Case : SI.cases()) {
Out << "\n ";
writeOperand(i.getCaseValue(), true);
writeOperand(Case.getCaseValue(), true);
Out << ", ";
writeOperand(i.getCaseSuccessor(), true);
writeOperand(Case.getCaseSuccessor(), true);
}
Out << "\n ]";
} else if (isa<IndirectBrInst>(I)) {
Expand Down
Loading

0 comments on commit ddfada2

Please sign in to comment.