Skip to content

Commit

Permalink
[flang] Adjust transformational folding to match runtime (#90132)
Browse files Browse the repository at this point in the history
The transformational intrinsic functions MATMUL, DOT_PRODUCT, and NORM2
all involve summing up intermediate products into accumulators. In the
constant folding library, this is done with extended precision Kahan
summation for REAL and COMPLEX arguments, but in the runtime
implementations it is not, and this leads to discrepancies between
folded results and dynamic results.

Disable the use of Kahan summation in folding to resolve these
discrepancies, but don't discard the code, in case we want to add Kahan
summation in the runtime for some or all of these intrinsic functions.
  • Loading branch information
klausler authored May 1, 2024
1 parent a1c1279 commit 3502d34
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 43 deletions.
6 changes: 6 additions & 0 deletions flang/lib/Evaluate/fold-implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@

namespace Fortran::evaluate {

// Don't use Kahan extended precision summation any more when folding
// transformational intrinsic functions other than SUM, since it is
// not used in the runtime implementations of those functions and we
// want results to match.
static constexpr bool useKahanSummation{false};

// Utilities
template <typename T> class Folder {
public:
Expand Down
27 changes: 17 additions & 10 deletions flang/lib/Evaluate/fold-matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,25 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
Element bElt{mb->At(bAt)};
if constexpr (T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex) {
// Kahan summation
auto product{aElt.Multiply(bElt, rounding)};
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
auto next{correction.Add(product.value, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
if constexpr (useKahanSummation) {
auto next{correction.Add(product.value, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(product.value)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
} else if constexpr (T::category == TypeCategory::Integer) {
// Don't use Kahan summation in numeric MATMUL folding;
// the runtime doesn't use it, and results should match.
auto product{aElt.MultiplySigned(bElt)};
overflow |= product.SignedMultiplicationOverflowed();
auto added{sum.AddSigned(product.lower)};
Expand Down
33 changes: 18 additions & 15 deletions flang/lib/Evaluate/fold-real.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ template <int KIND> class Norm2Accumulator {
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
void operator()(
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
// Kahan summation of scaled elements:
// Summation of scaled elements:
// Naively,
// NORM2(A(:)) = SQRT(SUM(A(:)**2))
// For any T > 0, we have mathematically
Expand All @@ -76,24 +76,27 @@ template <int KIND> class Norm2Accumulator {
auto item{array_.At(at)};
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
auto next{square.Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
if constexpr (useKahanSummation) {
auto next{square.Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
} else {
auto sum{element.Add(square, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
element = sum.value;
}
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &result) {
// result+correction == SUM((data(:)/maxAbs)**2)
// result = maxAbs * SQRT(result+correction)
auto corrected{result.Add(correction_, rounding_)};
overflow_ |= corrected.flags.test(RealFlag::Overflow);
correction_ = Scalar<T>{};
auto root{corrected.value.SQRT().value};
// incoming result = SUM((data(:)/maxAbs)**2)
// outgoing result = maxAbs * SQRT(result)
auto root{result.SQRT().value};
auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
maxAbs_.IncrementSubscripts(maxAbsAt_);
overflow_ |= product.flags.test(RealFlag::Overflow);
Expand Down
48 changes: 30 additions & 18 deletions flang/lib/Evaluate/fold-reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{Fold(
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
Element correction{}; // Use Kahan summation for greater precision.
[[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
if constexpr (useKahanSummation) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
}
} else if constexpr (T::category == TypeCategory::Logical) {
Expr<T> conjunctions{Fold(context,
Expand All @@ -80,17 +86,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
Element correction{}; // Use Kahan summation for greater precision.
[[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
if constexpr (useKahanSummation) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
}
}
if (overflow) {
Expand Down

0 comments on commit 3502d34

Please sign in to comment.