Skip to content

Commit

Permalink
AVX-512: Added FMA intrinsics with rounding mode
Browse files Browse the repository at this point in the history
By Asaf Badouh and Elena Demikhovsky

Added special nodes for rounding: FMADD_RND, FMSUB_RND..
It will prevent merge between nodes with rounding and other standard nodes.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@227303 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
Elena Demikhovsky committed Jan 28, 2015
1 parent aef3618 commit b9d3801
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 137 deletions.
148 changes: 36 additions & 112 deletions lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17039,54 +17039,6 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask,
return DAG.getNode(X86ISD::SELECT, dl, VT, IMask, Op, PreservedSrc);
}

static unsigned getOpcodeForFMAIntrinsic(unsigned IntNo) {
switch (IntNo) {
default: llvm_unreachable("Impossible intrinsic"); // Can't reach here.
case Intrinsic::x86_fma_vfmadd_ps:
case Intrinsic::x86_fma_vfmadd_pd:
case Intrinsic::x86_fma_vfmadd_ps_256:
case Intrinsic::x86_fma_vfmadd_pd_256:
case Intrinsic::x86_fma_mask_vfmadd_ps_512:
case Intrinsic::x86_fma_mask_vfmadd_pd_512:
return X86ISD::FMADD;
case Intrinsic::x86_fma_vfmsub_ps:
case Intrinsic::x86_fma_vfmsub_pd:
case Intrinsic::x86_fma_vfmsub_ps_256:
case Intrinsic::x86_fma_vfmsub_pd_256:
case Intrinsic::x86_fma_mask_vfmsub_ps_512:
case Intrinsic::x86_fma_mask_vfmsub_pd_512:
return X86ISD::FMSUB;
case Intrinsic::x86_fma_vfnmadd_ps:
case Intrinsic::x86_fma_vfnmadd_pd:
case Intrinsic::x86_fma_vfnmadd_ps_256:
case Intrinsic::x86_fma_vfnmadd_pd_256:
case Intrinsic::x86_fma_mask_vfnmadd_ps_512:
case Intrinsic::x86_fma_mask_vfnmadd_pd_512:
return X86ISD::FNMADD;
case Intrinsic::x86_fma_vfnmsub_ps:
case Intrinsic::x86_fma_vfnmsub_pd:
case Intrinsic::x86_fma_vfnmsub_ps_256:
case Intrinsic::x86_fma_vfnmsub_pd_256:
case Intrinsic::x86_fma_mask_vfnmsub_ps_512:
case Intrinsic::x86_fma_mask_vfnmsub_pd_512:
return X86ISD::FNMSUB;
case Intrinsic::x86_fma_vfmaddsub_ps:
case Intrinsic::x86_fma_vfmaddsub_pd:
case Intrinsic::x86_fma_vfmaddsub_ps_256:
case Intrinsic::x86_fma_vfmaddsub_pd_256:
case Intrinsic::x86_fma_mask_vfmaddsub_ps_512:
case Intrinsic::x86_fma_mask_vfmaddsub_pd_512:
return X86ISD::FMADDSUB;
case Intrinsic::x86_fma_vfmsubadd_ps:
case Intrinsic::x86_fma_vfmsubadd_pd:
case Intrinsic::x86_fma_vfmsubadd_ps_256:
case Intrinsic::x86_fma_vfmsubadd_pd_256:
case Intrinsic::x86_fma_mask_vfmsubadd_ps_512:
case Intrinsic::x86_fma_mask_vfmsubadd_pd_512:
return X86ISD::FMSUBADD;
}
}

static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget *Subtarget,
SelectionDAG &DAG) {
SDLoc dl(Op);
Expand Down Expand Up @@ -17123,9 +17075,43 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget *Subtarget
Mask, Src0, Subtarget, DAG);
}
case INTR_TYPE_2OP_MASK: {
return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Op.getOperand(1),
SDValue Mask = Op.getOperand(4);
SDValue PassThru = Op.getOperand(3);
unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
if (IntrWithRoundingModeOpcode != 0) {
unsigned Round = cast<ConstantSDNode>(Op.getOperand(5))->getZExtValue();
if (Round != X86::STATIC_ROUNDING::CUR_DIRECTION) {
return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode,
dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2),
Op.getOperand(3), Op.getOperand(5)),
Mask, PassThru, Subtarget, DAG);
}
}
return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT,
Op.getOperand(1),
Op.getOperand(2)),
Op.getOperand(4), Op.getOperand(3), Subtarget, DAG);
Mask, PassThru, Subtarget, DAG);
}
case FMA_OP_MASK: {
SDValue Src1 = Op.getOperand(1);
SDValue Src2 = Op.getOperand(2);
SDValue Src3 = Op.getOperand(3);
SDValue Mask = Op.getOperand(4);
unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
if (IntrWithRoundingModeOpcode != 0) {
SDValue Rnd = Op.getOperand(5);
if (cast<ConstantSDNode>(Rnd)->getZExtValue() !=
X86::STATIC_ROUNDING::CUR_DIRECTION)
return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode,
dl, Op.getValueType(),
Src1, Src2, Src3, Rnd),
Mask, Src1, Subtarget, DAG);
}
return getVectorMaskingNode(DAG.getNode(IntrData->Opc0,
dl, Op.getValueType(),
Src1, Src2, Src3),
Mask, Src1, Subtarget, DAG);
}
case CMP_MASK:
case CMP_MASK_CC: {
Expand Down Expand Up @@ -17215,16 +17201,6 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget *Subtarget
return DAG.getNode(IntrData->Opc0, dl, VT, VMask, Op.getOperand(1),
Op.getOperand(2));
}
case FMA_OP_MASK:
{
return getVectorMaskingNode(DAG.getNode(IntrData->Opc0,
dl, Op.getValueType(),
Op.getOperand(1),
Op.getOperand(2),
Op.getOperand(3)),
Op.getOperand(4), Op.getOperand(1),
Subtarget, DAG);
}
default:
break;
}
Expand Down Expand Up @@ -17395,58 +17371,6 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget *Subtarget
SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
return DAG.getNode(Opcode, dl, VTs, NewOps);
}

case Intrinsic::x86_fma_mask_vfmadd_ps_512:
case Intrinsic::x86_fma_mask_vfmadd_pd_512:
case Intrinsic::x86_fma_mask_vfmsub_ps_512:
case Intrinsic::x86_fma_mask_vfmsub_pd_512:
case Intrinsic::x86_fma_mask_vfnmadd_ps_512:
case Intrinsic::x86_fma_mask_vfnmadd_pd_512:
case Intrinsic::x86_fma_mask_vfnmsub_ps_512:
case Intrinsic::x86_fma_mask_vfnmsub_pd_512:
case Intrinsic::x86_fma_mask_vfmaddsub_ps_512:
case Intrinsic::x86_fma_mask_vfmaddsub_pd_512:
case Intrinsic::x86_fma_mask_vfmsubadd_ps_512:
case Intrinsic::x86_fma_mask_vfmsubadd_pd_512: {
auto *SAE = cast<ConstantSDNode>(Op.getOperand(5));
if (SAE->getZExtValue() == X86::STATIC_ROUNDING::CUR_DIRECTION)
return getVectorMaskingNode(DAG.getNode(getOpcodeForFMAIntrinsic(IntNo),
dl, Op.getValueType(),
Op.getOperand(1),
Op.getOperand(2),
Op.getOperand(3)),
Op.getOperand(4), Op.getOperand(1),
Subtarget, DAG);
else
return SDValue();
}

case Intrinsic::x86_fma_vfmadd_ps:
case Intrinsic::x86_fma_vfmadd_pd:
case Intrinsic::x86_fma_vfmsub_ps:
case Intrinsic::x86_fma_vfmsub_pd:
case Intrinsic::x86_fma_vfnmadd_ps:
case Intrinsic::x86_fma_vfnmadd_pd:
case Intrinsic::x86_fma_vfnmsub_ps:
case Intrinsic::x86_fma_vfnmsub_pd:
case Intrinsic::x86_fma_vfmaddsub_ps:
case Intrinsic::x86_fma_vfmaddsub_pd:
case Intrinsic::x86_fma_vfmsubadd_ps:
case Intrinsic::x86_fma_vfmsubadd_pd:
case Intrinsic::x86_fma_vfmadd_ps_256:
case Intrinsic::x86_fma_vfmadd_pd_256:
case Intrinsic::x86_fma_vfmsub_ps_256:
case Intrinsic::x86_fma_vfmsub_pd_256:
case Intrinsic::x86_fma_vfnmadd_ps_256:
case Intrinsic::x86_fma_vfnmadd_pd_256:
case Intrinsic::x86_fma_vfnmsub_ps_256:
case Intrinsic::x86_fma_vfnmsub_pd_256:
case Intrinsic::x86_fma_vfmaddsub_ps_256:
case Intrinsic::x86_fma_vfmaddsub_pd_256:
case Intrinsic::x86_fma_vfmsubadd_ps_256:
case Intrinsic::x86_fma_vfmsubadd_pd_256:
return DAG.getNode(getOpcodeForFMAIntrinsic(IntNo), dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
}
}

Expand Down
7 changes: 7 additions & 0 deletions lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ namespace llvm {
FNMSUB,
FMADDSUB,
FMSUBADD,
// FMA with rounding mode
FMADD_RND,
FNMADD_RND,
FMSUB_RND,
FNMSUB_RND,
FMADDSUB_RND,
FMSUBADD_RND,

// Compress and expand
COMPRESS,
Expand Down
41 changes: 32 additions & 9 deletions lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -3582,6 +3582,24 @@ multiclass avx512_fma3p_rm<bits<8> opc, string OpcodeStr, X86VectorVTInfo _,
}
} // Constraints = "$src1 = $dst"

let Constraints = "$src1 = $dst" in {
// Omitting the parameter OpNode (= null_frag) disables ISel pattern matching.
multiclass avx512_fma3_round_rrb<bits<8> opc, string OpcodeStr, X86VectorVTInfo _,
SDPatternOperator OpNode> {
defm rb: AVX512_maskable_3src<opc, MRMSrcReg, _, (outs _.RC:$dst),
(ins _.RC:$src2, _.RC:$src3, AVX512RC:$rc),
OpcodeStr, "$rc, $src3, $src2", "$src2, $src3, $rc",
(_.VT ( OpNode _.RC:$src1, _.RC:$src2, _.RC:$src3, (i32 imm:$rc)))>,
AVX512FMA3Base, EVEX_B, EVEX_RC;
}
} // Constraints = "$src1 = $dst"

multiclass avx512_fma3_round_forms<bits<8> opc213, string OpcodeStr,
X86VectorVTInfo VTI, SDPatternOperator OpNode> {
defm v213r : avx512_fma3_round_rrb<opc213, !strconcat(OpcodeStr, "213", VTI.Suffix),
VTI, OpNode>, EVEX_CD8<VTI.EltSize, CD8VF>;
}

multiclass avx512_fma3p_forms<bits<8> opc213, bits<8> opc231,
string OpcodeStr, X86VectorVTInfo VTI,
SDPatternOperator OpNode> {
Expand All @@ -3594,31 +3612,36 @@ multiclass avx512_fma3p_forms<bits<8> opc213, bits<8> opc231,

multiclass avx512_fma3p<bits<8> opc213, bits<8> opc231,
string OpcodeStr,
SDPatternOperator OpNode> {
SDPatternOperator OpNode,
SDPatternOperator OpNodeRnd> {
let ExeDomain = SSEPackedSingle in {
defm NAME##PSZ : avx512_fma3p_forms<opc213, opc231, OpcodeStr,
v16f32_info, OpNode>, EVEX_V512;
v16f32_info, OpNode>,
avx512_fma3_round_forms<opc213, OpcodeStr,
v16f32_info, OpNodeRnd>, EVEX_V512;
defm NAME##PSZ256 : avx512_fma3p_forms<opc213, opc231, OpcodeStr,
v8f32x_info, OpNode>, EVEX_V256;
defm NAME##PSZ128 : avx512_fma3p_forms<opc213, opc231, OpcodeStr,
v4f32x_info, OpNode>, EVEX_V128;
}
let ExeDomain = SSEPackedDouble in {
defm NAME##PDZ : avx512_fma3p_forms<opc213, opc231, OpcodeStr,
v8f64_info, OpNode>, EVEX_V512, VEX_W;
v8f64_info, OpNode>,
avx512_fma3_round_forms<opc213, OpcodeStr,
v8f64_info, OpNodeRnd>, EVEX_V512, VEX_W;
defm NAME##PDZ256 : avx512_fma3p_forms<opc213, opc231, OpcodeStr,
v4f64x_info, OpNode>, EVEX_V256, VEX_W;
defm NAME##PDZ128 : avx512_fma3p_forms<opc213, opc231, OpcodeStr,
v2f64x_info, OpNode>, EVEX_V128, VEX_W;
}
}

defm VFMADD : avx512_fma3p<0xA8, 0xB8, "vfmadd", X86Fmadd>;
defm VFMSUB : avx512_fma3p<0xAA, 0xBA, "vfmsub", X86Fmsub>;
defm VFMADDSUB : avx512_fma3p<0xA6, 0xB6, "vfmaddsub", X86Fmaddsub>;
defm VFMSUBADD : avx512_fma3p<0xA7, 0xB7, "vfmsubadd", X86Fmsubadd>;
defm VFNMADD : avx512_fma3p<0xAC, 0xBC, "vfnmadd", X86Fnmadd>;
defm VFNMSUB : avx512_fma3p<0xAE, 0xBE, "vfnmsub", X86Fnmsub>;
defm VFMADD : avx512_fma3p<0xA8, 0xB8, "vfmadd", X86Fmadd, X86FmaddRnd>;
defm VFMSUB : avx512_fma3p<0xAA, 0xBA, "vfmsub", X86Fmsub, X86FmsubRnd>;
defm VFMADDSUB : avx512_fma3p<0xA6, 0xB6, "vfmaddsub", X86Fmaddsub, X86FmaddsubRnd>;
defm VFMSUBADD : avx512_fma3p<0xA7, 0xB7, "vfmsubadd", X86Fmsubadd, X86FmsubaddRnd>;
defm VFNMADD : avx512_fma3p<0xAC, 0xBC, "vfnmadd", X86Fnmadd, X86FnmaddRnd>;
defm VFNMSUB : avx512_fma3p<0xAE, 0xBE, "vfnmsub", X86Fnmsub, X86FnmsubRnd>;

let Constraints = "$src1 = $dst" in {
multiclass avx512_fma3p_m132<bits<8> opc, string OpcodeStr, SDNode OpNode,
Expand Down
9 changes: 9 additions & 0 deletions lib/Target/X86/X86InstrFragmentsSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def SDTBlend : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0,1>,

def SDTFma : SDTypeProfile<1, 3, [SDTCisSameAs<0,1>,
SDTCisSameAs<1,2>, SDTCisSameAs<1,3>]>;
def SDTFmaRound : SDTypeProfile<1, 4, [SDTCisSameAs<0,1>,
SDTCisSameAs<1,2>, SDTCisSameAs<1,3>, SDTCisInt<4>]>;
def STDFp1SrcRm : SDTypeProfile<1, 2, [SDTCisSameAs<0,1>,
SDTCisVec<0>, SDTCisInt<2>]>;
def STDFp2SrcRm : SDTypeProfile<1, 3, [SDTCisSameAs<0,1>,
Expand Down Expand Up @@ -265,6 +267,13 @@ def X86Fnmsub : SDNode<"X86ISD::FNMSUB", SDTFma>;
def X86Fmaddsub : SDNode<"X86ISD::FMADDSUB", SDTFma>;
def X86Fmsubadd : SDNode<"X86ISD::FMSUBADD", SDTFma>;

def X86FmaddRnd : SDNode<"X86ISD::FMADD_RND", SDTFmaRound>;
def X86FnmaddRnd : SDNode<"X86ISD::FNMADD_RND", SDTFmaRound>;
def X86FmsubRnd : SDNode<"X86ISD::FMSUB_RND", SDTFmaRound>;
def X86FnmsubRnd : SDNode<"X86ISD::FNMSUB_RND", SDTFmaRound>;
def X86FmaddsubRnd : SDNode<"X86ISD::FMADDSUB_RND", SDTFmaRound>;
def X86FmsubaddRnd : SDNode<"X86ISD::FMSUBADD_RND", SDTFmaRound>;

def X86rsqrt28 : SDNode<"X86ISD::RSQRT28", STDFp1SrcRm>;
def X86rcp28 : SDNode<"X86ISD::RCP28", STDFp1SrcRm>;
def X86exp2 : SDNode<"X86ISD::EXP2", STDFp1SrcRm>;
Expand Down
Loading

0 comments on commit b9d3801

Please sign in to comment.