Skip to content

Commit

Permalink
Teach the DAG combiner to turn chains of FADDs (x+x+x+x+...) into FMU…
Browse files Browse the repository at this point in the history
…Ls by constants. This is only enabled in unsafe FP math mode, since it does not preserve rounding effects for all such constants.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@162956 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
resistor committed Aug 30, 2012
1 parent fafa283 commit 43da6c7
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 1 deletion.
123 changes: 122 additions & 1 deletion lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5681,6 +5681,127 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
N0.getOperand(1), N1));

// In unsafe math mode, we can fold chains of FADD's of the same value
// into multiplications. This transform is not safe in general because
// we are reducing the number of rounding steps.
if (DAG.getTarget().Options.UnsafeFPMath &&
TLI.isOperationLegalOrCustom(ISD::FMUL, VT) &&
!N0CFP && !N1CFP) {
if (N0.getOpcode() == ISD::FMUL) {
ConstantFPSDNode *CFP00 = dyn_cast<ConstantFPSDNode>(N0.getOperand(0));
ConstantFPSDNode *CFP01 = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));

// (fadd (fmul c, x), x) -> (fmul c+1, x)
if (CFP00 && !CFP01 && N0.getOperand(1) == N1) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP00, 0),
DAG.getConstantFP(1.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N1, NewCFP);
}

// (fadd (fmul x, c), x) -> (fmul c+1, x)
if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP01, 0),
DAG.getConstantFP(1.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N1, NewCFP);
}

// (fadd (fadd x, x), x) -> (fmul 3.0, x)
if (!CFP00 && !CFP01 && N0.getOperand(0) == N0.getOperand(1) &&
N0.getOperand(0) == N1) {
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N1, DAG.getConstantFP(3.0, VT));
}

// (fadd (fmul c, x), (fadd x, x)) -> (fmul c+2, x)
if (CFP00 && !CFP01 && N1.getOpcode() == ISD::FADD &&
N1.getOperand(0) == N1.getOperand(1) &&
N0.getOperand(1) == N1.getOperand(0)) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP00, 0),
DAG.getConstantFP(2.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0.getOperand(1), NewCFP);
}

// (fadd (fmul x, c), (fadd x, x)) -> (fmul c+2, x)
if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
N1.getOperand(0) == N1.getOperand(1) &&
N0.getOperand(0) == N1.getOperand(0)) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP01, 0),
DAG.getConstantFP(2.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0.getOperand(0), NewCFP);
}
}

if (N1.getOpcode() == ISD::FMUL) {
ConstantFPSDNode *CFP10 = dyn_cast<ConstantFPSDNode>(N1.getOperand(0));
ConstantFPSDNode *CFP11 = dyn_cast<ConstantFPSDNode>(N1.getOperand(1));

// (fadd x, (fmul c, x)) -> (fmul c+1, x)
if (CFP10 && !CFP11 && N1.getOperand(1) == N0) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP10, 0),
DAG.getConstantFP(1.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0, NewCFP);
}

// (fadd x, (fmul x, c)) -> (fmul c+1, x)
if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP11, 0),
DAG.getConstantFP(1.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0, NewCFP);
}

// (fadd x, (fadd x, x)) -> (fmul 3.0, x)
if (!CFP10 && !CFP11 && N1.getOperand(0) == N1.getOperand(1) &&
N1.getOperand(0) == N0) {
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0, DAG.getConstantFP(3.0, VT));
}

// (fadd (fadd x, x), (fmul c, x)) -> (fmul c+2, x)
if (CFP10 && !CFP11 && N1.getOpcode() == ISD::FADD &&
N1.getOperand(0) == N1.getOperand(1) &&
N0.getOperand(1) == N1.getOperand(0)) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP10, 0),
DAG.getConstantFP(2.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0.getOperand(1), NewCFP);
}

// (fadd (fadd x, x), (fmul x, c)) -> (fmul c+2, x)
if (CFP11 && !CFP10 && N1.getOpcode() == ISD::FADD &&
N1.getOperand(0) == N1.getOperand(1) &&
N0.getOperand(0) == N1.getOperand(0)) {
SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
SDValue(CFP11, 0),
DAG.getConstantFP(2.0, VT));
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0.getOperand(0), NewCFP);
}
}

// (fadd (fadd x, x), (fadd x, x)) -> (fmul 4.0, x)
if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
N0.getOperand(0) == N0.getOperand(1) &&
N1.getOperand(0) == N1.getOperand(1) &&
N0.getOperand(0) == N1.getOperand(0)) {
return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
N0.getOperand(0),
DAG.getConstantFP(4.0, VT));
}
}

// FADD -> FMA combines:
if ((DAG.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast ||
DAG.getTarget().Options.UnsafeFPMath) &&
Expand All @@ -5692,7 +5813,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
return DAG.getNode(ISD::FMA, N->getDebugLoc(), VT,
N0.getOperand(0), N0.getOperand(1), N1);
}

// fold (fadd x, (fmul y, z)) -> (fma x, y, z)
// Note: Commutes FADD operands.
if (N1.getOpcode() == ISD::FMUL && N1->hasOneUse()) {
Expand Down
37 changes: 37 additions & 0 deletions test/CodeGen/X86/fp-fast.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; RUN: llc -march=x86-64 -mtriple=x86_64-apple-darwin -enable-unsafe-fp-math < %s | FileCheck %s

; CHECK: test1
define float @test1(float %a) {
; CHECK-NOT: vaddss
; CHECK: vmulss
; CHECK-NOT: vaddss
; CHECK: ret
%t1 = fadd float %a, %a
%r = fadd float %t1, %t1
ret float %r
}

; CHECK: test2
define float @test2(float %a) {
; CHECK-NOT: vaddss
; CHECK: vmulss
; CHECK-NOT: vaddss
; CHECK: ret
%t1 = fmul float 4.0, %a
%t2 = fadd float %a, %a
%r = fadd float %t1, %t2
ret float %r
}

; CHECK: test3
define float @test3(float %a) {
; CHECK-NOT: vaddss
; CHECK: vxorps
; CHECK-NOT: vaddss
; CHECK: ret
%t1 = fmul float 2.0, %a
%t2 = fadd float %a, %a
%r = fsub float %t1, %t2
ret float %r
}

0 comments on commit 43da6c7

Please sign in to comment.