Skip to content

Commit

Permalink
[X86][SSE] Use MOVMSK for all_of/any_of reduction patterns
Browse files Browse the repository at this point in the history
This is a first attempt at using the MOVMSK instructions to replace all_of/any_of reduction patterns (i.e. an and/or + shuffle chain).

So far this only matches patterns where we are reducing an all/none bits source vector (i.e. a comparison result) but we should be able to expand on this in conjunction with improvements to 'bool vector' handling both in the x86 backend as well as the vectorizers etc.

Differential Revision: https://reviews.llvm.org/D28810

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@293880 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
RKSimon committed Feb 2, 2017
1 parent 69ad39a commit 178f518
Show file tree
Hide file tree
Showing 3 changed files with 365 additions and 542 deletions.
83 changes: 83 additions & 0 deletions lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28627,6 +28627,85 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1);
}

// Attempt to replace an all_of/any_of style horizontal reduction with a MOVMSK.
static SDValue combineHorizontalPredicateResult(SDNode *Extract,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// Bail without SSE2 or with AVX512VL (which uses predicate registers).
if (!Subtarget.hasSSE2() || Subtarget.hasVLX())
return SDValue();

EVT ExtractVT = Extract->getValueType(0);
unsigned BitWidth = ExtractVT.getSizeInBits();
if (ExtractVT != MVT::i64 && ExtractVT != MVT::i32 && ExtractVT != MVT::i16 &&
ExtractVT != MVT::i8)
return SDValue();

// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
for (ISD::NodeType Op : {ISD::OR, ISD::AND}) {
SDValue Match = matchBinOpReduction(Extract, Op);
if (!Match)
continue;

// EXTRACT_VECTOR_ELT can require implicit extension of the vector element
// which we can't support here for now.
if (Match.getScalarValueSizeInBits() != BitWidth)
continue;

// We require AVX2 for PMOVMSKB for v16i16/v32i8;
unsigned MatchSizeInBits = Match.getValueSizeInBits();
if (!(MatchSizeInBits == 128 ||
(MatchSizeInBits == 256 &&
((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2()))))
return SDValue();

// Don't bother performing this for 2-element vectors.
if (Match.getValueType().getVectorNumElements() <= 2)
return SDValue();

// Check that we are extracting a reduction of all sign bits.
if (DAG.ComputeNumSignBits(Match) != BitWidth)
return SDValue();

// For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
MVT MaskVT;
if (64 == BitWidth || 32 == BitWidth)
MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
MatchSizeInBits / BitWidth);
else
MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);

APInt CompareBits;
ISD::CondCode CondCode;
if (Op == ISD::OR) {
// any_of -> MOVMSK != 0
CompareBits = APInt::getNullValue(32);
CondCode = ISD::CondCode::SETNE;
} else {
// all_of -> MOVMSK == ((1 << NumElts) - 1)
CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements());
CondCode = ISD::CondCode::SETEQ;
}

// Perform the select as i32/i64 and then truncate to avoid partial register
// stalls.
unsigned ResWidth = std::max(BitWidth, 32u);
APInt ResOnes = APInt::getAllOnesValue(ResWidth);
APInt ResZero = APInt::getNullValue(ResWidth);
EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth);

SDLoc DL(Extract);
SDValue Res = DAG.getBitcast(MaskVT, Match);
Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res);
Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
DAG.getConstant(ResOnes, DL, ResVT),
DAG.getConstant(ResZero, DL, ResVT), CondCode);
return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
}

return SDValue();
}

static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// PSADBW is only supported on SSE2 and up.
Expand Down Expand Up @@ -28738,6 +28817,10 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
return SAD;

// Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
if (SDValue Cmp = combineHorizontalPredicateResult(N, DAG, Subtarget))
return Cmp;

// Only operate on vectors of 4 elements, where the alternative shuffling
// gets to be more expensive.
if (InputVector.getValueType() != MVT::v4i32)
Expand Down
Loading

0 comments on commit 178f518

Please sign in to comment.