Skip to content

Commit

Permalink
[SDAG] Simplify SDNodeFlags with bitwise logic (llvm#114061)
Browse files Browse the repository at this point in the history
This patch allows using enumeration values directly and simplifies the
implementation with bitwise logic. It addresses the comment in
llvm#113808 (comment).
  • Loading branch information
dtcxzyw authored Oct 31, 2024
1 parent 36b7915 commit cf9d1c1
Show file tree
Hide file tree
Showing 16 changed files with 90 additions and 172 deletions.
16 changes: 4 additions & 12 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,7 @@ struct BinaryOpc_match {
if (!Flags.has_value())
return true;

SDNodeFlags TmpFlags = *Flags;
TmpFlags.intersectWith(N->getFlags());
return TmpFlags == *Flags;
return (*Flags & N->getFlags()) == *Flags;
}

return false;
Expand Down Expand Up @@ -668,9 +666,7 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
const RHS &R) {
SDNodeFlags Flags;
Flags.setDisjoint(true);
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint);
}

template <typename LHS, typename RHS>
Expand Down Expand Up @@ -813,9 +809,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
if (!Flags.has_value())
return true;

SDNodeFlags TmpFlags = *Flags;
TmpFlags.intersectWith(N->getFlags());
return TmpFlags == *Flags;
return (*Flags & N->getFlags()) == *Flags;
}

return false;
Expand Down Expand Up @@ -848,9 +842,7 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {

template <typename Opnd>
inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
SDNodeFlags Flags;
Flags.setNonNeg(true);
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg);
}

template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
Expand Down
8 changes: 2 additions & 6 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1064,17 +1064,13 @@ class SelectionDAG {
/// addressing some offset of an object. i.e. if a load is split into multiple
/// components, create an add nuw from the base pointer to the offset.
SDValue getObjectPtrOffset(const SDLoc &SL, SDValue Ptr, TypeSize Offset) {
SDNodeFlags Flags;
Flags.setNoUnsignedWrap(true);
return getMemBasePlusOffset(Ptr, Offset, SL, Flags);
return getMemBasePlusOffset(Ptr, Offset, SL, SDNodeFlags::NoUnsignedWrap);
}

SDValue getObjectPtrOffset(const SDLoc &SL, SDValue Ptr, SDValue Offset) {
// The object itself can't wrap around the address space, so it shouldn't be
// possible for the adds of the offsets to the split parts to overflow.
SDNodeFlags Flags;
Flags.setNoUnsignedWrap(true);
return getMemBasePlusOffset(Ptr, Offset, SL, Flags);
return getMemBasePlusOffset(Ptr, Offset, SL, SDNodeFlags::NoUnsignedWrap);
}

/// Return a new CALLSEQ_START node, that starts new call frame, in which
Expand Down
20 changes: 15 additions & 5 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ struct SDNodeFlags {
None = 0,
NoUnsignedWrap = 1 << 0,
NoSignedWrap = 1 << 1,
NoWrap = NoUnsignedWrap | NoSignedWrap,
Exact = 1 << 2,
Disjoint = 1 << 3,
NonNeg = 1 << 4,
Expand Down Expand Up @@ -419,7 +420,7 @@ struct SDNodeFlags {
};

/// Default constructor turns off all optimization flags.
SDNodeFlags() : Flags(0) {}
SDNodeFlags(unsigned Flags = SDNodeFlags::None) : Flags(Flags) {}

/// Propagate the fast-math-flags from an IR FPMathOperator.
void copyFMF(const FPMathOperator &FPMO) {
Expand Down Expand Up @@ -467,15 +468,23 @@ struct SDNodeFlags {
bool operator==(const SDNodeFlags &Other) const {
return Flags == Other.Flags;
}

/// Clear any flags in this flag set that aren't also set in Flags. All
/// flags will be cleared if Flags are undefined.
void intersectWith(const SDNodeFlags Flags) { this->Flags &= Flags.Flags; }
void operator&=(const SDNodeFlags &OtherFlags) { Flags &= OtherFlags.Flags; }
void operator|=(const SDNodeFlags &OtherFlags) { Flags |= OtherFlags.Flags; }
};

LLVM_DECLARE_ENUM_AS_BITMASK(decltype(SDNodeFlags::None),
SDNodeFlags::Unpredictable);

inline SDNodeFlags operator|(SDNodeFlags LHS, SDNodeFlags RHS) {
LHS |= RHS;
return LHS;
}

inline SDNodeFlags operator&(SDNodeFlags LHS, SDNodeFlags RHS) {
LHS &= RHS;
return LHS;
}

/// Represents one node in the SelectionDAG.
///
class SDNode : public FoldingSetNode, public ilist_node<SDNode> {
Expand Down Expand Up @@ -1013,6 +1022,7 @@ END_TWO_BYTE_PACK()

SDNodeFlags getFlags() const { return Flags; }
void setFlags(SDNodeFlags NewFlags) { Flags = NewFlags; }
void dropFlags(unsigned Mask) { Flags &= ~Mask; }

/// Clear any flags in this node that aren't also set in Flags.
/// If Flags is not in a defined state then this has no effect.
Expand Down
40 changes: 15 additions & 25 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
SDNodeFlags NewFlags;
if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
Flags.hasNoUnsignedWrap())
NewFlags.setNoUnsignedWrap(true);
NewFlags |= SDNodeFlags::NoUnsignedWrap;

if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
// Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
Expand Down Expand Up @@ -2892,11 +2892,11 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
if (N->getFlags().hasNoUnsignedWrap() &&
N0->getFlags().hasNoUnsignedWrap() &&
N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
Flags.setNoUnsignedWrap(true);
Flags |= SDNodeFlags::NoUnsignedWrap;
if (N->getFlags().hasNoSignedWrap() &&
N0->getFlags().hasNoSignedWrap() &&
N0.getOperand(0)->getFlags().hasNoSignedWrap())
Flags.setNoSignedWrap(true);
Flags |= SDNodeFlags::NoSignedWrap;
}
SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
DAG.getConstant(CM, DL, VT), Flags);
Expand All @@ -2920,12 +2920,12 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
N0->getFlags().hasNoUnsignedWrap() &&
OMul->getFlags().hasNoUnsignedWrap() &&
OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
Flags.setNoUnsignedWrap(true);
Flags |= SDNodeFlags::NoUnsignedWrap;
if (N->getFlags().hasNoSignedWrap() &&
N0->getFlags().hasNoSignedWrap() &&
OMul->getFlags().hasNoSignedWrap() &&
OMul.getOperand(0)->getFlags().hasNoSignedWrap())
Flags.setNoSignedWrap(true);
Flags |= SDNodeFlags::NoSignedWrap;
}
SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
DAG.getConstant(CM, DL, VT), Flags);
Expand Down Expand Up @@ -2987,11 +2987,8 @@ SDValue DAGCombiner::visitADD(SDNode *N) {

// fold (a+b) -> (a|b) iff a and b share no bits.
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
DAG.haveNoCommonBitsSet(N0, N1)) {
SDNodeFlags Flags;
Flags.setDisjoint(true);
return DAG.getNode(ISD::OR, DL, VT, N0, N1, Flags);
}
DAG.haveNoCommonBitsSet(N0, N1))
return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);

// Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
Expand Down Expand Up @@ -9556,11 +9553,8 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {

// fold (a^b) -> (a|b) iff a and b share no bits.
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
DAG.haveNoCommonBitsSet(N0, N1)) {
SDNodeFlags Flags;
Flags.setDisjoint(true);
return DAG.getNode(ISD::OR, DL, VT, N0, N1, Flags);
}
DAG.haveNoCommonBitsSet(N0, N1))
return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);

// look for 'add-like' folds:
// XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
Expand Down Expand Up @@ -10210,7 +10204,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
SDNodeFlags Flags;
// Preserve the disjoint flag for Or.
if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
Flags.setDisjoint(true);
Flags |= SDNodeFlags::Disjoint;
return DAG.getNode(N0.getOpcode(), DL, VT, Shl0, Shl1, Flags);
}
}
Expand Down Expand Up @@ -13922,11 +13916,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// fold (sext x) -> (zext x) if the sign bit is known zero.
if (!TLI.isSExtCheaperThanZExt(N0.getValueType(), VT) &&
(!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
DAG.SignBitIsZero(N0)) {
SDNodeFlags Flags;
Flags.setNonNeg(true);
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, Flags);
}
DAG.SignBitIsZero(N0))
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, SDNodeFlags::NonNeg);

if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
return NewVSel;
Expand Down Expand Up @@ -14807,10 +14798,9 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
uint64_t PtrOff = PtrAdjustmentInBits / 8;
SDLoc DL(LN0);
// The original load itself didn't wrap, so an offset within it doesn't.
SDNodeFlags Flags;
Flags.setNoUnsignedWrap(true);
SDValue NewPtr = DAG.getMemBasePlusOffset(
LN0->getBasePtr(), TypeSize::getFixed(PtrOff), DL, Flags);
SDValue NewPtr =
DAG.getMemBasePlusOffset(LN0->getBasePtr(), TypeSize::getFixed(PtrOff),
DL, SDNodeFlags::NoUnsignedWrap);
AddToWorklist(NewPtr.getNode());

SDValue Load;
Expand Down
7 changes: 2 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1697,12 +1697,9 @@ SDValue SelectionDAGLegalize::ExpandFCOPYSIGN(SDNode *Node) const {
SignBit = DAG.getNode(ISD::TRUNCATE, DL, MagVT, SignBit);
}

SDNodeFlags Flags;
Flags.setDisjoint(true);

// Store the part with the modified sign and convert back to float.
SDValue CopiedSign =
DAG.getNode(ISD::OR, DL, MagVT, ClearedSign, SignBit, Flags);
SDValue CopiedSign = DAG.getNode(ISD::OR, DL, MagVT, ClearedSign, SignBit,
SDNodeFlags::Disjoint);

return modifySignAsInt(MagAsInt, DL, CopiedSign);
}
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4674,9 +4674,9 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
DAG.getNode(ISD::SHL, dl, ShAmtVT, SrlTmp,
DAG.getConstant(Log2_32(ShiftUnitInBits), dl, ShAmtVT));

Flags.setExact(true);
SDValue ByteOffset = DAG.getNode(ISD::SRL, dl, ShAmtVT, BitOffset,
DAG.getConstant(3, dl, ShAmtVT), Flags);
SDValue ByteOffset =
DAG.getNode(ISD::SRL, dl, ShAmtVT, BitOffset,
DAG.getConstant(3, dl, ShAmtVT), SDNodeFlags::Exact);
// And clamp it, because OOB load is an immediate UB,
// while shift overflow would have *just* been poison.
ByteOffset = DAG.getNode(ISD::AND, dl, ShAmtVT, ByteOffset,
Expand Down
12 changes: 3 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1700,11 +1700,8 @@ SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) {
SDValue ClearedSign =
DAG.getNode(ISD::VP_AND, DL, IntVT, Mag, ClearSignMask, Mask, EVL);

SDNodeFlags Flags;
Flags.setDisjoint(true);

SDValue CopiedSign = DAG.getNode(ISD::VP_OR, DL, IntVT, ClearedSign, SignBit,
Mask, EVL, Flags);
Mask, EVL, SDNodeFlags::Disjoint);

return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
}
Expand Down Expand Up @@ -1886,11 +1883,8 @@ SDValue VectorLegalizer::ExpandFCOPYSIGN(SDNode *Node) {
APInt::getSignedMaxValue(IntVT.getScalarSizeInBits()), DL, IntVT);
SDValue ClearedSign = DAG.getNode(ISD::AND, DL, IntVT, Mag, ClearSignMask);

SDNodeFlags Flags;
Flags.setDisjoint(true);

SDValue CopiedSign =
DAG.getNode(ISD::OR, DL, IntVT, ClearedSign, SignBit, Flags);
SDValue CopiedSign = DAG.getNode(ISD::OR, DL, IntVT, ClearedSign, SignBit,
SDNodeFlags::Disjoint);

return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
}
Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,16 +1381,14 @@ void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinValue() / 8;

if (MemVT.isScalableVector()) {
SDNodeFlags Flags;
SDValue BytesIncrement = DAG.getVScale(
DL, Ptr.getValueType(),
APInt(Ptr.getValueSizeInBits().getFixedValue(), IncrementSize));
MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace());
Flags.setNoUnsignedWrap(true);
if (ScaledOffset)
*ScaledOffset += IncrementSize;
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement,
Flags);
SDNodeFlags::NoUnsignedWrap);
} else {
MPI = N->getPointerInfo().getWithOffset(IncrementSize);
// Increment the pointer to the other half.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12377,7 +12377,7 @@ bool SDNode::hasPredecessor(const SDNode *N) const {
}

void SDNode::intersectFlagsWith(const SDNodeFlags Flags) {
this->Flags.intersectWith(Flags);
this->Flags &= Flags;
}

SDValue
Expand Down
17 changes: 7 additions & 10 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4318,7 +4318,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
SDNodeFlags Flags;
if (NW.hasNoUnsignedWrap() ||
(int64_t(Offset) >= 0 && NW.hasNoUnsignedSignedWrap()))
Flags.setNoUnsignedWrap(true);
Flags |= SDNodeFlags::NoUnsignedWrap;

N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N,
DAG.getConstant(Offset, dl, N.getValueType()), Flags);
Expand Down Expand Up @@ -4484,10 +4484,9 @@ void SelectionDAGBuilder::visitAlloca(const AllocaInst &I) {
// Round the size of the allocation up to the stack alignment size
// by add SA-1 to the size. This doesn't overflow because we're computing
// an address inside an alloca.
SDNodeFlags Flags;
Flags.setNoUnsignedWrap(true);
AllocSize = DAG.getNode(ISD::ADD, dl, AllocSize.getValueType(), AllocSize,
DAG.getConstant(StackAlignMask, dl, IntPtr), Flags);
DAG.getConstant(StackAlignMask, dl, IntPtr),
SDNodeFlags::NoUnsignedWrap);

// Mask out the low bits for alignment purposes.
AllocSize = DAG.getNode(ISD::AND, dl, AllocSize.getValueType(), AllocSize,
Expand Down Expand Up @@ -11224,15 +11223,13 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {

// An aggregate return value cannot wrap around the address space, so
// offsets to its parts don't wrap either.
SDNodeFlags Flags;
Flags.setNoUnsignedWrap(true);

MachineFunction &MF = CLI.DAG.getMachineFunction();
Align HiddenSRetAlign = MF.getFrameInfo().getObjectAlign(DemoteStackIdx);
for (unsigned i = 0; i < NumValues; ++i) {
SDValue Add = CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
CLI.DAG.getConstant(Offsets[i], CLI.DL,
PtrVT), Flags);
SDValue Add =
CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
CLI.DAG.getConstant(Offsets[i], CLI.DL, PtrVT),
SDNodeFlags::NoUnsignedWrap);
SDValue L = CLI.DAG.getLoad(
RetTys[i], CLI.DL, CLI.Chain, Add,
MachinePointerInfo::getFixedStack(CLI.DAG.getMachineFunction(),
Expand Down
7 changes: 2 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4224,11 +4224,8 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,

// Set the NoFPExcept flag when no original matched node could
// raise an FP exception, but the new node potentially might.
if (!MayRaiseFPException && mayRaiseFPException(Res)) {
SDNodeFlags Flags = Res->getFlags();
Flags.setNoFPExcept(true);
Res->setFlags(Flags);
}
if (!MayRaiseFPException && mayRaiseFPException(Res))
Res->setFlags(Res->getFlags() | SDNodeFlags::NoFPExcept);

// If the node had chain/glue results, update our notion of the current
// chain and glue.
Expand Down
Loading

0 comments on commit cf9d1c1

Please sign in to comment.