Skip to content

Commit

Permalink
[X86] Teach how to combine horizontal binop even in the presence of u…
Browse files Browse the repository at this point in the history
…ndefs.

Before this change, the backend was unable to fold a build_vector dag
node with UNDEF operands into a single horizontal add/sub.

This patch teaches how to combine a build_vector with UNDEF operands into a
horizontal add/sub when possible. The algorithm conservatively avoids to combine
a build_vector with only a single non-UNDEF operand.

Added test haddsub-undef.ll to verify that we correctly fold horizontal binop
even in the presence of UNDEFs.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211265 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
adibiagio committed Jun 19, 2014
1 parent 8317509 commit cfdf805
Show file tree
Hide file tree
Showing 2 changed files with 440 additions and 40 deletions.
155 changes: 115 additions & 40 deletions lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6077,21 +6077,35 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const {
/// This function only analyzes elements of \p N whose indices are
/// in range [BaseIdx, LastIdx).
static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
SelectionDAG &DAG,
unsigned BaseIdx, unsigned LastIdx,
SDValue &V0, SDValue &V1) {
EVT VT = N->getValueType(0);

assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!");
assert(N->getValueType(0).isVector() &&
N->getValueType(0).getVectorNumElements() >= LastIdx &&
assert(VT.isVector() && VT.getVectorNumElements() >= LastIdx &&
"Invalid Vector in input!");

bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD);
bool CanFold = true;
unsigned ExpectedVExtractIdx = BaseIdx;
unsigned NumElts = LastIdx - BaseIdx;
V0 = DAG.getUNDEF(VT);
V1 = DAG.getUNDEF(VT);

// Check if N implements a horizontal binop.
for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) {
SDValue Op = N->getOperand(i + BaseIdx);

// Skip UNDEFs.
if (Op->getOpcode() == ISD::UNDEF) {
// Update the expected vector extract index.
if (i * 2 == NumElts)
ExpectedVExtractIdx = BaseIdx;
ExpectedVExtractIdx += 2;
continue;
}

CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();

if (!CanFold)
Expand All @@ -6112,12 +6126,15 @@ static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,

unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();

if (i == 0)
V0 = Op0.getOperand(0);
else if (i * 2 == NumElts) {
V1 = Op0.getOperand(0);
ExpectedVExtractIdx = BaseIdx;

if (i * 2 < NumElts) {
if (V0.getOpcode() == ISD::UNDEF)
V0 = Op0.getOperand(0);
} else {
if (V1.getOpcode() == ISD::UNDEF)
V1 = Op0.getOperand(0);
if (i * 2 == NumElts)
ExpectedVExtractIdx = BaseIdx;
}

SDValue Expected = (i * 2 < NumElts) ? V0 : V1;
Expand Down Expand Up @@ -6163,9 +6180,14 @@ static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
/// Example:
/// HADD V0_LO, V1_LO
/// HADD V0_HI, V1_HI
///
/// If \p isUndefLO is set, then the algorithm propagates UNDEF to the lower
/// 128-bits of the result. If \p isUndefHI is set, then UNDEF is propagated to
/// the upper 128-bits of the result.
static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1,
SDLoc DL, SelectionDAG &DAG,
unsigned X86Opcode, bool Mode) {
unsigned X86Opcode, bool Mode,
bool isUndefLO, bool isUndefHI) {
EVT VT = V0.getValueType();
assert(VT.is256BitVector() && VT == V1.getValueType() &&
"Invalid nodes in input!");
Expand All @@ -6177,13 +6199,24 @@ static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1,
SDValue V1_HI = Extract128BitVector(V1, NumElts/2, DAG, DL);
EVT NewVT = V0_LO.getValueType();

SDValue LO, HI;
SDValue LO = DAG.getUNDEF(NewVT);
SDValue HI = DAG.getUNDEF(NewVT);

if (Mode) {
LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V0_HI);
HI = DAG.getNode(X86Opcode, DL, NewVT, V1_LO, V1_HI);
// Don't emit a horizontal binop if the result is expected to be UNDEF.
if (!isUndefLO && V0->getOpcode() != ISD::UNDEF)
LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V0_HI);
if (!isUndefHI && V1->getOpcode() != ISD::UNDEF)
HI = DAG.getNode(X86Opcode, DL, NewVT, V1_LO, V1_HI);
} else {
LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V1_LO);
HI = DAG.getNode(X86Opcode, DL, NewVT, V1_HI, V1_HI);
// Don't emit a horizontal binop if the result is expected to be UNDEF.
if (!isUndefLO && (V0_LO->getOpcode() != ISD::UNDEF ||
V1_LO->getOpcode() != ISD::UNDEF))
LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V1_LO);

if (!isUndefHI && (V0_HI->getOpcode() != ISD::UNDEF ||
V1_HI->getOpcode() != ISD::UNDEF))
HI = DAG.getNode(X86Opcode, DL, NewVT, V0_HI, V1_HI);
}

return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI);
Expand All @@ -6198,19 +6231,37 @@ static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
SDValue InVec0, InVec1;

// Try to match horizontal ADD/SUB.
unsigned NumUndefsLO = 0;
unsigned NumUndefsHI = 0;
unsigned Half = NumElts/2;

// Count the number of UNDEF operands in the build_vector in input.
for (unsigned i = 0, e = Half; i != e; ++i)
if (BV->getOperand(i)->getOpcode() == ISD::UNDEF)
NumUndefsLO++;

for (unsigned i = Half, e = NumElts; i != e; ++i)
if (BV->getOperand(i)->getOpcode() == ISD::UNDEF)
NumUndefsHI++;

// Early exit if this is either a build_vector of all UNDEFs or all the
// operands but one are UNDEF.
if (NumUndefsLO + NumUndefsHI + 1 >= NumElts)
return SDValue();

if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) {
// Try to match an SSE3 float HADD/HSUB.
if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);

if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
} else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) {
// Try to match an SSSE3 integer HADD/HSUB.
if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1);

if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1))
return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1);
}

Expand All @@ -6221,32 +6272,40 @@ static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
// Try to match an AVX horizontal add/sub of packed single/double
// precision floating point values from 256-bit vectors.
SDValue InVec2, InVec3;
if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts/2, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::FADD, NumElts/2, NumElts, InVec2, InVec3) &&
InVec0.getNode() == InVec2.getNode() &&
InVec1.getNode() == InVec3.getNode())
if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, Half, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::FADD, DAG, Half, NumElts, InVec2, InVec3) &&
((InVec0.getOpcode() == ISD::UNDEF ||
InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) &&
((InVec1.getOpcode() == ISD::UNDEF ||
InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3))
return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);

if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts/2, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::FSUB, NumElts/2, NumElts, InVec2, InVec3) &&
InVec0.getNode() == InVec2.getNode() &&
InVec1.getNode() == InVec3.getNode())
if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, Half, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::FSUB, DAG, Half, NumElts, InVec2, InVec3) &&
((InVec0.getOpcode() == ISD::UNDEF ||
InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) &&
((InVec1.getOpcode() == ISD::UNDEF ||
InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3))
return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
} else if (VT == MVT::v8i32 || VT == MVT::v16i16) {
// Try to match an AVX2 horizontal add/sub of signed integers.
SDValue InVec2, InVec3;
unsigned X86Opcode;
bool CanFold = true;

if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts/2, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::ADD, NumElts/2, NumElts, InVec2, InVec3) &&
InVec0.getNode() == InVec2.getNode() &&
InVec1.getNode() == InVec3.getNode())
if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::ADD, DAG, Half, NumElts, InVec2, InVec3) &&
((InVec0.getOpcode() == ISD::UNDEF ||
InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) &&
((InVec1.getOpcode() == ISD::UNDEF ||
InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3))
X86Opcode = X86ISD::HADD;
else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts/2, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::SUB, NumElts/2, NumElts, InVec2, InVec3) &&
InVec0.getNode() == InVec2.getNode() &&
InVec1.getNode() == InVec3.getNode())
else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, Half, InVec0, InVec1) &&
isHorizontalBinOp(BV, ISD::SUB, DAG, Half, NumElts, InVec2, InVec3) &&
((InVec0.getOpcode() == ISD::UNDEF ||
InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) &&
((InVec1.getOpcode() == ISD::UNDEF ||
InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3))
X86Opcode = X86ISD::HSUB;
else
CanFold = false;
Expand All @@ -6257,29 +6316,45 @@ static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
if (Subtarget->hasAVX2())
return DAG.getNode(X86Opcode, DL, VT, InVec0, InVec1);

// Do not try to expand this build_vector into a pair of horizontal
// add/sub if we can emit a pair of scalar add/sub.
if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half)
return SDValue();

// Convert this build_vector into a pair of horizontal binop followed by
// a concat vector.
return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, false);
bool isUndefLO = NumUndefsLO == Half;
bool isUndefHI = NumUndefsHI == Half;
return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, false,
isUndefLO, isUndefHI);
}
}

if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
VT == MVT::v16i16) && Subtarget->hasAVX()) {
unsigned X86Opcode;
if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::HADD;
else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::HSUB;
else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
else if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::FHADD;
else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
else if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1))
X86Opcode = X86ISD::FHSUB;
else
return SDValue();

// Don't try to expand this build_vector into a pair of horizontal add/sub
// if we can simply emit a pair of scalar add/sub.
if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half)
return SDValue();

// Convert this build_vector into two horizontal add/sub followed by
// a concat vector.
return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, true);
bool isUndefLO = NumUndefsLO == Half;
bool isUndefHI = NumUndefsHI == Half;
return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, true,
isUndefLO, isUndefHI);
}

return SDValue();
Expand Down
Loading

0 comments on commit cfdf805

Please sign in to comment.