Skip to content

Commit

Permalink
[X86][AVX] Generalized matching for target shuffle combines
Browse files Browse the repository at this point in the history
This patch is a first step towards a more extendible method of matching combined target shuffle masks.

Initially this just pulls out the existing basic mask matches and adds support for some 256/512 bit equivalents. Future patterns will require a number of features to be added but I wanted to keep this patch simple.

I hope we can avoid duplication between shuffle lowering and combining and share more complex pattern match functions in future commits.

Differential Revision: http://reviews.llvm.org/D19198

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@270230 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
RKSimon committed May 20, 2016
1 parent a7f9ea7 commit 11c52a1
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 133 deletions.
245 changes: 146 additions & 99 deletions lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24054,6 +24054,136 @@ static SDValue combineShuffle256(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

// Attempt to match a combined shuffle mask against supported unary shuffle
// instructions.
// TODO: Investigate sharing more of this with shuffle lowering.
// TODO: Investigate using isShuffleEquivalent() instead of Mask.equals().
static bool matchUnaryVectorShuffle(MVT SrcVT, ArrayRef<int> Mask,
const X86Subtarget &Subtarget,
unsigned &Shuffle, MVT &ShuffleVT) {
bool FloatDomain = SrcVT.isFloatingPoint();

// Match a 128-bit integer vector against a VZEXT_MOVL (MOVQ) instruction.
if (!FloatDomain && SrcVT.is128BitVector() && Mask.size() == 2 &&
Mask[0] == 0 && Mask[1] < 0) {
Shuffle = X86ISD::VZEXT_MOVL;
ShuffleVT = MVT::v2i64;
return true;
}

if (!FloatDomain)
return false;

// Check if we have SSE3 which will let us use MOVDDUP etc. The
// instructions are no slower than UNPCKLPD but has the option to
// fold the input operand into even an unaligned memory load.
if (SrcVT.is128BitVector() && Subtarget.hasSSE3()) {
if (Mask.equals({0, 0})) {
Shuffle = X86ISD::MOVDDUP;
ShuffleVT = MVT::v2f64;
return true;
}
if (Mask.equals({0, 0, 2, 2})) {
Shuffle = X86ISD::MOVSLDUP;
ShuffleVT = MVT::v4f32;
return true;
}
if (Mask.equals({1, 1, 3, 3})) {
Shuffle = X86ISD::MOVSHDUP;
ShuffleVT = MVT::v4f32;
return true;
}
}

if (SrcVT.is256BitVector()) {
assert(Subtarget.hasAVX() && "AVX required for 256-bit vector shuffles");
if (Mask.equals({0, 0, 2, 2})) {
Shuffle = X86ISD::MOVDDUP;
ShuffleVT = MVT::v4f64;
return true;
}
if (Mask.equals({0, 0, 2, 2, 4, 4, 6, 6})) {
Shuffle = X86ISD::MOVSLDUP;
ShuffleVT = MVT::v8f32;
return true;
}
if (Mask.equals({1, 1, 3, 3, 5, 5, 7, 7})) {
Shuffle = X86ISD::MOVSHDUP;
ShuffleVT = MVT::v8f32;
return true;
}
}

if (SrcVT.is512BitVector()) {
assert(Subtarget.hasAVX512() &&
"AVX512 required for 512-bit vector shuffles");
if (Mask.equals({0, 0, 2, 2, 4, 4, 6, 6})) {
Shuffle = X86ISD::MOVDDUP;
ShuffleVT = MVT::v8f64;
return true;
}
if (Mask.equals({0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14})) {
Shuffle = X86ISD::MOVSLDUP;
ShuffleVT = MVT::v16f32;
return true;
}
if (Mask.equals({1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15})) {
Shuffle = X86ISD::MOVSHDUP;
ShuffleVT = MVT::v16f32;
return true;
}
}

return false;
}

// Attempt to match a combined unary shuffle mask against supported binary
// shuffle instructions.
// TODO: Investigate sharing more of this with shuffle lowering.
// TODO: Investigate using isShuffleEquivalent() instead of Mask.equals().
static bool matchBinaryVectorShuffle(MVT SrcVT, ArrayRef<int> Mask,
unsigned &Shuffle, MVT &ShuffleVT) {
bool FloatDomain = SrcVT.isFloatingPoint();

if (SrcVT.is128BitVector()) {
if (Mask.equals({0, 0}) && FloatDomain) {
Shuffle = X86ISD::MOVLHPS;
ShuffleVT = MVT::v4f32;
return true;
}
if (Mask.equals({1, 1}) && FloatDomain) {
Shuffle = X86ISD::MOVHLPS;
ShuffleVT = MVT::v4f32;
return true;
}
if (Mask.equals({0, 0, 1, 1}) && FloatDomain) {
Shuffle = X86ISD::UNPCKL;
ShuffleVT = MVT::v4f32;
return true;
}
if (Mask.equals({2, 2, 3, 3}) && FloatDomain) {
Shuffle = X86ISD::UNPCKH;
ShuffleVT = MVT::v4f32;
return true;
}
if (Mask.equals({0, 0, 1, 1, 2, 2, 3, 3}) ||
Mask.equals({0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7})) {
Shuffle = X86ISD::UNPCKL;
ShuffleVT = Mask.size() == 8 ? MVT::v8i16 : MVT::v16i8;
return true;
}
if (Mask.equals({4, 4, 5, 5, 6, 6, 7, 7}) ||
Mask.equals(
{8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15})) {
Shuffle = X86ISD::UNPCKH;
ShuffleVT = Mask.size() == 8 ? MVT::v8i16 : MVT::v16i8;
return true;
}
}

return false;
}

/// \brief Combine an arbitrary chain of shuffles into a single instruction if
/// possible.
///
Expand Down Expand Up @@ -24095,117 +24225,34 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root,
if (MaskEltSizeInBits > 64)
return false;

// Use the float domain if the operand type is a floating point type.
bool FloatDomain = VT.isFloatingPoint();

// For floating point shuffles, we don't have free copies in the shuffle
// instructions or the ability to load as part of the instruction, so
// canonicalize their shuffles to UNPCK or MOV variants.
//
// Note that even with AVX we prefer the PSHUFD form of shuffle for integer
// vectors because it can have a load folded into it that UNPCK cannot. This
// doesn't preclude something switching to the shorter encoding post-RA.
//
// FIXME: Should teach these routines about AVX vector widths.
if (FloatDomain && VT.is128BitVector()) {
if (Mask.equals({0, 0}) || Mask.equals({1, 1})) {
bool Lo = Mask.equals({0, 0});
unsigned Shuffle;
MVT ShuffleVT;
// Check if we have SSE3 which will let us use MOVDDUP. That instruction
// is no slower than UNPCKLPD but has the option to fold the input operand
// into even an unaligned memory load.
if (Lo && Subtarget.hasSSE3()) {
Shuffle = X86ISD::MOVDDUP;
ShuffleVT = MVT::v2f64;
} else {
// We have MOVLHPS and MOVHLPS throughout SSE and they encode smaller
// than the UNPCK variants.
Shuffle = Lo ? X86ISD::MOVLHPS : X86ISD::MOVHLPS;
ShuffleVT = MVT::v4f32;
}
if (Depth == 1 && Root.getOpcode() == Shuffle)
return false; // Nothing to do!
Res = DAG.getBitcast(ShuffleVT, Input);
DCI.AddToWorklist(Res.getNode());
if (Shuffle == X86ISD::MOVDDUP)
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res);
else
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, Res);
DCI.AddToWorklist(Res.getNode());
DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res),
/*AddTo*/ true);
return true;
}
if (Subtarget.hasSSE3() &&
(Mask.equals({0, 0, 2, 2}) || Mask.equals({1, 1, 3, 3}))) {
bool Lo = Mask.equals({0, 0, 2, 2});
unsigned Shuffle = Lo ? X86ISD::MOVSLDUP : X86ISD::MOVSHDUP;
MVT ShuffleVT = MVT::v4f32;
if (Depth == 1 && Root.getOpcode() == Shuffle)
return false; // Nothing to do!
Res = DAG.getBitcast(ShuffleVT, Input);
DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res);
DCI.AddToWorklist(Res.getNode());
DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res),
/*AddTo*/ true);
return true;
}
if (Mask.equals({0, 0, 1, 1}) || Mask.equals({2, 2, 3, 3})) {
bool Lo = Mask.equals({0, 0, 1, 1});
unsigned Shuffle = Lo ? X86ISD::UNPCKL : X86ISD::UNPCKH;
MVT ShuffleVT = MVT::v4f32;
if (Depth == 1 && Root.getOpcode() == Shuffle)
return false; // Nothing to do!
Res = DAG.getBitcast(ShuffleVT, Input);
DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, Res);
DCI.AddToWorklist(Res.getNode());
DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res),
/*AddTo*/ true);
return true;
}
// Don't combine if we are a AVX512/EVEX target and the mask element size
// is different from the root element size - this would prevent writemasks
// from being reused.
// TODO - check for writemasks usage instead of always preventing combining.
// TODO - attempt to narrow Mask back to writemask size.
if (RootVT.getScalarSizeInBits() != MaskEltSizeInBits &&
(RootSizeInBits == 512 ||
(Subtarget.hasVLX() && RootSizeInBits >= 128))) {
return false;
}

// We always canonicalize the 8 x i16 and 16 x i8 shuffles into their UNPCK
// variants as none of these have single-instruction variants that are
// superior to the UNPCK formulation.
if (!FloatDomain && VT.is128BitVector() &&
(Mask.equals({0, 0, 1, 1, 2, 2, 3, 3}) ||
Mask.equals({4, 4, 5, 5, 6, 6, 7, 7}) ||
Mask.equals({0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7}) ||
Mask.equals(
{8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15}))) {
bool Lo = Mask[0] == 0;
unsigned Shuffle = Lo ? X86ISD::UNPCKL : X86ISD::UNPCKH;
// Attempt to match the mask against known shuffle patterns.
MVT ShuffleVT;
unsigned Shuffle;

if (matchUnaryVectorShuffle(VT, Mask, Subtarget, Shuffle, ShuffleVT)) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return false; // Nothing to do!
MVT ShuffleVT;
switch (NumMaskElts) {
case 8:
ShuffleVT = MVT::v8i16;
break;
case 16:
ShuffleVT = MVT::v16i8;
break;
default:
llvm_unreachable("Impossible mask size!");
};
Res = DAG.getBitcast(ShuffleVT, Input);
DCI.AddToWorklist(Res.getNode());
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, Res);
Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res);
DCI.AddToWorklist(Res.getNode());
DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res),
/*AddTo*/ true);
return true;
}

// Match a 128-bit integer vector against a VZEXT_MOVL (MOVQ) instruction.
if (!FloatDomain && VT.is128BitVector() &&
Mask.size() == 2 && Mask[0] == 0 && Mask[1] < 0) {
unsigned Shuffle = X86ISD::VZEXT_MOVL;
MVT ShuffleVT = MVT::v2i64;
if (matchBinaryVectorShuffle(VT, Mask, Shuffle, ShuffleVT)) {
if (Depth == 1 && Root.getOpcode() == Shuffle)
return false; // Nothing to do!
Res = DAG.getBitcast(ShuffleVT, Input);
Expand Down
11 changes: 5 additions & 6 deletions test/CodeGen/X86/vector-shuffle-combining-avx.ll
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,15 @@ define <8 x float> @combine_vpermilvar_8f32_identity(<8 x float> %a0) {
define <8 x float> @combine_vpermilvar_8f32_movddup(<8 x float> %a0) {
; ALL-LABEL: combine_vpermilvar_8f32_movddup:
; ALL: # BB#0:
; ALL-NEXT: vpermilps {{.*#+}} ymm0 = ymm0[0,1,0,1,4,5,4,5]
; ALL-NEXT: vmovddup {{.*#+}} ymm0 = ymm0[0,0,2,2]
; ALL-NEXT: retq
%1 = tail call <8 x float> @llvm.x86.avx.vpermilvar.ps.256(<8 x float> %a0, <8 x i32> <i32 0, i32 1, i32 0, i32 1, i32 4, i32 5, i32 4, i32 5>)
ret <8 x float> %1
}
define <8 x float> @combine_vpermilvar_8f32_movddup_load(<8 x float> *%a0) {
; ALL-LABEL: combine_vpermilvar_8f32_movddup_load:
; ALL: # BB#0:
; ALL-NEXT: vmovaps (%rdi), %ymm0
; ALL-NEXT: vpermilps {{.*#+}} ymm0 = ymm0[0,1,0,1,4,5,4,5]
; ALL-NEXT: vmovddup {{.*#+}} ymm0 = mem[0,0,2,2]
; ALL-NEXT: retq
%1 = load <8 x float>, <8 x float> *%a0
%2 = tail call <8 x float> @llvm.x86.avx.vpermilvar.ps.256(<8 x float> %1, <8 x i32> <i32 0, i32 1, i32 0, i32 1, i32 4, i32 5, i32 4, i32 5>)
Expand All @@ -113,7 +112,7 @@ define <8 x float> @combine_vpermilvar_8f32_movddup_load(<8 x float> *%a0) {
define <8 x float> @combine_vpermilvar_8f32_movshdup(<8 x float> %a0) {
; ALL-LABEL: combine_vpermilvar_8f32_movshdup:
; ALL: # BB#0:
; ALL-NEXT: vpermilps {{.*#+}} ymm0 = ymm0[1,1,3,3,5,5,7,7]
; ALL-NEXT: vmovshdup {{.*#+}} ymm0 = ymm0[1,1,3,3,5,5,7,7]
; ALL-NEXT: retq
%1 = tail call <8 x float> @llvm.x86.avx.vpermilvar.ps.256(<8 x float> %a0, <8 x i32> <i32 1, i32 1, i32 3, i32 3, i32 5, i32 5, i32 7, i32 7>)
ret <8 x float> %1
Expand All @@ -122,7 +121,7 @@ define <8 x float> @combine_vpermilvar_8f32_movshdup(<8 x float> %a0) {
define <8 x float> @combine_vpermilvar_8f32_movsldup(<8 x float> %a0) {
; ALL-LABEL: combine_vpermilvar_8f32_movsldup:
; ALL: # BB#0:
; ALL-NEXT: vpermilps {{.*#+}} ymm0 = ymm0[0,0,2,2,4,4,6,6]
; ALL-NEXT: vmovsldup {{.*#+}} ymm0 = ymm0[0,0,2,2,4,4,6,6]
; ALL-NEXT: retq
%1 = tail call <8 x float> @llvm.x86.avx.vpermilvar.ps.256(<8 x float> %a0, <8 x i32> <i32 0, i32 0, i32 2, i32 2, i32 4, i32 4, i32 6, i32 6>)
ret <8 x float> %1
Expand Down Expand Up @@ -159,7 +158,7 @@ define <4 x double> @combine_vpermilvar_4f64_identity(<4 x double> %a0) {
define <4 x double> @combine_vpermilvar_4f64_movddup(<4 x double> %a0) {
; ALL-LABEL: combine_vpermilvar_4f64_movddup:
; ALL: # BB#0:
; ALL-NEXT: vpermilpd {{.*#+}} ymm0 = ymm0[0,0,2,2]
; ALL-NEXT: vmovddup {{.*#+}} ymm0 = ymm0[0,0,2,2]
; ALL-NEXT: retq
%1 = tail call <4 x double> @llvm.x86.avx.vpermilvar.pd.256(<4 x double> %a0, <4 x i64> <i64 0, i64 0, i64 4, i64 4>)
ret <4 x double> %1
Expand Down
Loading

0 comments on commit 11c52a1

Please sign in to comment.