Skip to content

Commit

Permalink
Add lowering for AVX2 shift instructions.
Browse files Browse the repository at this point in the history
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144380 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
topperc committed Nov 11, 2011
1 parent 1c47de8 commit 46154eb
Show file tree
Hide file tree
Showing 4 changed files with 377 additions and 190 deletions.
185 changes: 131 additions & 54 deletions lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,21 +1050,9 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM)
setOperationAction(ISD::MUL, MVT::v4i64, Custom);
setOperationAction(ISD::MUL, MVT::v8i32, Legal);
setOperationAction(ISD::MUL, MVT::v16i16, Legal);
// Don't lower v32i8 because there is no 128-bit byte mul

setOperationAction(ISD::VSELECT, MVT::v32i8, Legal);

setOperationAction(ISD::SHL, MVT::v4i32, Legal);
setOperationAction(ISD::SHL, MVT::v2i64, Legal);
setOperationAction(ISD::SRL, MVT::v4i32, Legal);
setOperationAction(ISD::SRL, MVT::v2i64, Legal);
setOperationAction(ISD::SRA, MVT::v4i32, Legal);

setOperationAction(ISD::SHL, MVT::v8i32, Legal);
setOperationAction(ISD::SHL, MVT::v4i64, Legal);
setOperationAction(ISD::SRL, MVT::v8i32, Legal);
setOperationAction(ISD::SRL, MVT::v4i64, Legal);
setOperationAction(ISD::SRA, MVT::v8i32, Legal);
// Don't lower v32i8 because there is no 128-bit byte mul
} else {
setOperationAction(ISD::ADD, MVT::v4i64, Custom);
setOperationAction(ISD::ADD, MVT::v8i32, Custom);
Expand Down Expand Up @@ -10130,47 +10118,6 @@ SDValue X86TargetLowering::LowerShift(SDValue Op, SelectionDAG &DAG) const {
if (!Subtarget->hasXMMInt())
return SDValue();

// Decompose 256-bit shifts into smaller 128-bit shifts.
if (VT.getSizeInBits() == 256) {
int NumElems = VT.getVectorNumElements();
MVT EltVT = VT.getVectorElementType().getSimpleVT();
EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);

// Extract the two vectors
SDValue V1 = Extract128BitVector(R, DAG.getConstant(0, MVT::i32), DAG, dl);
SDValue V2 = Extract128BitVector(R, DAG.getConstant(NumElems/2, MVT::i32),
DAG, dl);

// Recreate the shift amount vectors
SDValue Amt1, Amt2;
if (Amt.getOpcode() == ISD::BUILD_VECTOR) {
// Constant shift amount
SmallVector<SDValue, 4> Amt1Csts;
SmallVector<SDValue, 4> Amt2Csts;
for (int i = 0; i < NumElems/2; ++i)
Amt1Csts.push_back(Amt->getOperand(i));
for (int i = NumElems/2; i < NumElems; ++i)
Amt2Csts.push_back(Amt->getOperand(i));

Amt1 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
&Amt1Csts[0], NumElems/2);
Amt2 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
&Amt2Csts[0], NumElems/2);
} else {
// Variable shift amount
Amt1 = Extract128BitVector(Amt, DAG.getConstant(0, MVT::i32), DAG, dl);
Amt2 = Extract128BitVector(Amt, DAG.getConstant(NumElems/2, MVT::i32),
DAG, dl);
}

// Issue new vector shifts for the smaller types
V1 = DAG.getNode(Op.getOpcode(), dl, NewVT, V1, Amt1);
V2 = DAG.getNode(Op.getOpcode(), dl, NewVT, V2, Amt2);

// Concatenate the result back
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, V1, V2);
}

// Optimize shl/srl/sra with constant shift amount.
if (isSplatVector(Amt.getNode())) {
SDValue SclrAmt = Amt->getOperand(0);
Expand Down Expand Up @@ -10259,9 +10206,97 @@ SDValue X86TargetLowering::LowerShift(SDValue Op, SelectionDAG &DAG) const {
Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
return Res;
}

if (Subtarget->hasAVX2()) {
if (VT == MVT::v4i64 && Op.getOpcode() == ISD::SHL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_pslli_q, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));

if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SHL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_pslli_d, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));

if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SHL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_pslli_w, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));

if (VT == MVT::v4i64 && Op.getOpcode() == ISD::SRL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrli_q, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));

if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SRL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrli_d, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));

if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SRL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrli_w, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));

if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SRA)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrai_d, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));

if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SRA)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrai_w, MVT::i32),
R, DAG.getConstant(ShiftAmt, MVT::i32));
}
}
}

// AVX2 variable shifts
if (Subtarget->hasAVX2()) {
if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psllv_d, MVT::i32),
R, Amt);
if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SHL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psllv_d_256, MVT::i32),
R, Amt);
if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SHL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psllv_q, MVT::i32),
R, Amt);
if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SHL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psllv_q_256, MVT::i32),
R, Amt);

if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrlv_d, MVT::i32),
R, Amt);
if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrlv_d_256, MVT::i32),
R, Amt);
if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SRL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrlv_q, MVT::i32),
R, Amt);
if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SRL)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrlv_q_256, MVT::i32),
R, Amt);

if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRA)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrav_d, MVT::i32),
R, Amt);
if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRA)
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getConstant(Intrinsic::x86_avx2_psrav_d_256, MVT::i32),
R, Amt);
}

// Lower SHL with variable shift amount.
if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) {
Op = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
Expand Down Expand Up @@ -10328,6 +10363,48 @@ SDValue X86TargetLowering::LowerShift(SDValue Op, SelectionDAG &DAG) const {
R, DAG.getNode(ISD::ADD, dl, VT, R, R));
return R;
}

// Decompose 256-bit shifts into smaller 128-bit shifts.
if (VT.getSizeInBits() == 256) {
int NumElems = VT.getVectorNumElements();
MVT EltVT = VT.getVectorElementType().getSimpleVT();
EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);

// Extract the two vectors
SDValue V1 = Extract128BitVector(R, DAG.getConstant(0, MVT::i32), DAG, dl);
SDValue V2 = Extract128BitVector(R, DAG.getConstant(NumElems/2, MVT::i32),
DAG, dl);

// Recreate the shift amount vectors
SDValue Amt1, Amt2;
if (Amt.getOpcode() == ISD::BUILD_VECTOR) {
// Constant shift amount
SmallVector<SDValue, 4> Amt1Csts;
SmallVector<SDValue, 4> Amt2Csts;
for (int i = 0; i < NumElems/2; ++i)
Amt1Csts.push_back(Amt->getOperand(i));
for (int i = NumElems/2; i < NumElems; ++i)
Amt2Csts.push_back(Amt->getOperand(i));

Amt1 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
&Amt1Csts[0], NumElems/2);
Amt2 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
&Amt2Csts[0], NumElems/2);
} else {
// Variable shift amount
Amt1 = Extract128BitVector(Amt, DAG.getConstant(0, MVT::i32), DAG, dl);
Amt2 = Extract128BitVector(Amt, DAG.getConstant(NumElems/2, MVT::i32),
DAG, dl);
}

// Issue new vector shifts for the smaller types
V1 = DAG.getNode(Op.getOpcode(), dl, NewVT, V1, Amt1);
V2 = DAG.getNode(Op.getOpcode(), dl, NewVT, V2, Amt2);

// Concatenate the result back
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, V1, V2);
}

return SDValue();
}

Expand Down
99 changes: 36 additions & 63 deletions lib/Target/X86/X86InstrSSE.td
Original file line number Diff line number Diff line change
Expand Up @@ -7655,7 +7655,6 @@ defm VPMASKMOVQ : avx2_pmovmask<"vpmaskmovq",
// Variable Bit Shifts
//
multiclass avx2_var_shift<bits<8> opc, string OpcodeStr,
PatFrag pf128, PatFrag pf256,
Intrinsic Int128, Intrinsic Int256> {
def rr : AVX28I<opc, MRMSrcReg, (outs VR128:$dst),
(ins VR128:$src1, VR128:$src2),
Expand All @@ -7664,7 +7663,8 @@ multiclass avx2_var_shift<bits<8> opc, string OpcodeStr,
def rm : AVX28I<opc, MRMSrcMem, (outs VR128:$dst),
(ins VR128:$src1, i128mem:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR128:$dst, (Int128 VR128:$src1, (pf128 addr:$src2)))]>,
[(set VR128:$dst,
(Int128 VR128:$src1, (bitconvert (memopv2i64 addr:$src2))))]>,
VEX_4V;
def Yrr : AVX28I<opc, MRMSrcReg, (outs VR256:$dst),
(ins VR256:$src1, VR256:$src2),
Expand All @@ -7673,70 +7673,43 @@ multiclass avx2_var_shift<bits<8> opc, string OpcodeStr,
def Yrm : AVX28I<opc, MRMSrcMem, (outs VR256:$dst),
(ins VR256:$src1, i256mem:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR256:$dst, (Int256 VR256:$src1, (pf256 addr:$src2)))]>,
[(set VR256:$dst,
(Int256 VR256:$src1, (bitconvert (memopv4i64 addr:$src2))))]>,
VEX_4V;
}

defm VPSLLVD : avx2_var_shift<0x47, "vpsllvd", memopv4i32, memopv8i32,
int_x86_avx2_psllv_d, int_x86_avx2_psllv_d_256>;
defm VPSLLVQ : avx2_var_shift<0x47, "vpsllvq", memopv2i64, memopv4i64,
int_x86_avx2_psllv_q, int_x86_avx2_psllv_q_256>,
VEX_W;
defm VPSRLVD : avx2_var_shift<0x45, "vpsrlvd", memopv4i32, memopv8i32,
int_x86_avx2_psrlv_d, int_x86_avx2_psrlv_d_256>;
defm VPSRLVQ : avx2_var_shift<0x45, "vpsrlvq", memopv2i64, memopv4i64,
int_x86_avx2_psrlv_q, int_x86_avx2_psrlv_q_256>,
VEX_W;
defm VPSRAVD : avx2_var_shift<0x46, "vpsravd", memopv4i32, memopv8i32,
int_x86_avx2_psrav_d, int_x86_avx2_psrav_d_256>;


let Predicates = [HasAVX2] in {

def : Pat<(v4i32 (shl (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
(VPSLLVDrr VR128:$src1, VR128:$src2)>;
def : Pat<(v2i64 (shl (v2i64 VR128:$src1), (v2i64 VR128:$src2))),
(VPSLLVQrr VR128:$src1, VR128:$src2)>;
def : Pat<(v4i32 (srl (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
(VPSRLVDrr VR128:$src1, VR128:$src2)>;
def : Pat<(v2i64 (srl (v2i64 VR128:$src1), (v2i64 VR128:$src2))),
(VPSRLVQrr VR128:$src1, VR128:$src2)>;
def : Pat<(v4i32 (sra (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
(VPSRAVDrr VR128:$src1, VR128:$src2)>;
def : Pat<(v8i32 (shl (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
(VPSLLVDYrr VR256:$src1, VR256:$src2)>;
def : Pat<(v4i64 (shl (v4i64 VR256:$src1), (v4i64 VR256:$src2))),
(VPSLLVQYrr VR256:$src1, VR256:$src2)>;
def : Pat<(v8i32 (srl (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
(VPSRLVDYrr VR256:$src1, VR256:$src2)>;
def : Pat<(v4i64 (srl (v4i64 VR256:$src1), (v4i64 VR256:$src2))),
(VPSRLVQYrr VR256:$src1, VR256:$src2)>;
def : Pat<(v8i32 (sra (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
(VPSRAVDYrr VR256:$src1, VR256:$src2)>;

def : Pat<(v4i32 (shl (v4i32 VR128:$src1),(loadv4i32 addr:$src2))),
(VPSLLVDrm VR128:$src1, addr:$src2)>;
def : Pat<(v4i32 (shl (v4i32 VR128:$src1),(loadv2i64 addr:$src2))),
(VPSLLVDrm VR128:$src1, addr:$src2)>;
def : Pat<(v2i64 (shl (v2i64 VR128:$src1),(loadv2i64 addr:$src2))),
(VPSLLVQrm VR128:$src1, addr:$src2)>;
def : Pat<(v4i32 (srl (v4i32 VR128:$src1),(loadv4i32 addr:$src2))),
(VPSRLVDrm VR128:$src1, addr:$src2)>;
def : Pat<(v2i64 (srl (v2i64 VR128:$src1),(loadv2i64 addr:$src2))),
(VPSRLVQrm VR128:$src1, addr:$src2)>;
def : Pat<(v4i32 (sra (v4i32 VR128:$src1),(loadv4i32 addr:$src2))),
(VPSRAVDrm VR128:$src1, addr:$src2)>;
def : Pat<(v8i32 (shl (v8i32 VR256:$src1),(loadv8i32 addr:$src2))),
(VPSLLVDYrm VR256:$src1, addr:$src2)>;
def : Pat<(v4i64 (shl (v4i64 VR256:$src1),(loadv4i64 addr:$src2))),
(VPSLLVQYrm VR256:$src1, addr:$src2)>;
def : Pat<(v8i32 (srl (v8i32 VR256:$src1),(loadv8i32 addr:$src2))),
(VPSRLVDYrm VR256:$src1, addr:$src2)>;
def : Pat<(v4i64 (srl (v4i64 VR256:$src1),(loadv4i64 addr:$src2))),
(VPSRLVQYrm VR256:$src1, addr:$src2)>;
def : Pat<(v8i32 (sra (v8i32 VR256:$src1),(loadv8i32 addr:$src2))),
(VPSRAVDYrm VR256:$src1, addr:$src2)>;
multiclass avx2_var_shift_i64<bits<8> opc, string OpcodeStr,
Intrinsic Int128, Intrinsic Int256> {
def rr : AVX28I<opc, MRMSrcReg, (outs VR128:$dst),
(ins VR128:$src1, VR128:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR128:$dst, (Int128 VR128:$src1, VR128:$src2))]>, VEX_4V;
def rm : AVX28I<opc, MRMSrcMem, (outs VR128:$dst),
(ins VR128:$src1, i128mem:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR128:$dst,
(Int128 VR128:$src1, (memopv2i64 addr:$src2)))]>,
VEX_4V;
def Yrr : AVX28I<opc, MRMSrcReg, (outs VR256:$dst),
(ins VR256:$src1, VR256:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR256:$dst, (Int256 VR256:$src1, VR256:$src2))]>, VEX_4V;
def Yrm : AVX28I<opc, MRMSrcMem, (outs VR256:$dst),
(ins VR256:$src1, i256mem:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR256:$dst,
(Int256 VR256:$src1, (memopv4i64 addr:$src2)))]>,
VEX_4V;
}


defm VPSLLVD : avx2_var_shift<0x47, "vpsllvd", int_x86_avx2_psllv_d,
int_x86_avx2_psllv_d_256>;
defm VPSLLVQ : avx2_var_shift_i64<0x47, "vpsllvq", int_x86_avx2_psllv_q,
int_x86_avx2_psllv_q_256>, VEX_W;
defm VPSRLVD : avx2_var_shift<0x45, "vpsrlvd", int_x86_avx2_psrlv_d,
int_x86_avx2_psrlv_d_256>;
defm VPSRLVQ : avx2_var_shift_i64<0x45, "vpsrlvq", int_x86_avx2_psrlv_q,
int_x86_avx2_psrlv_q_256>, VEX_W;
defm VPSRAVD : avx2_var_shift<0x46, "vpsravd", int_x86_avx2_psrav_d,
int_x86_avx2_psrav_d_256>;

Loading

0 comments on commit 46154eb

Please sign in to comment.