Skip to content

Commit

Permalink
ARROW-15614: [C++] Add sqrt binary scalar kernel
Browse files Browse the repository at this point in the history
This is a proposal for resolving [ARROW-15614](https://issues.apache.org/jira/browse/ARROW-15614). Tests will be added when the resolution approach is agreed

Closes apache#12412 from rtpsw/ARROW-15614

Lead-authored-by: Yaron Gvili <[email protected]>
Co-authored-by: rtpsw <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
rtpsw authored and lidavidm committed Feb 20, 2022
1 parent f9f2c08 commit 773da64
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ SCALAR_ARITHMETIC_UNARY(Ln, "ln", "ln_checked")
SCALAR_ARITHMETIC_UNARY(Log10, "log10", "log10_checked")
SCALAR_ARITHMETIC_UNARY(Log1p, "log1p", "log1p_checked")
SCALAR_ARITHMETIC_UNARY(Log2, "log2", "log2_checked")
SCALAR_ARITHMETIC_UNARY(Sqrt, "sqrt", "sqrt_checked")
SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked")
SCALAR_ARITHMETIC_UNARY(Sin, "sin", "sin_checked")
SCALAR_ARITHMETIC_UNARY(Tan, "tan", "tan_checked")
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,18 @@ Result<Datum> Logb(const Datum& arg, const Datum& base,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Get the square-root of a value.
///
/// If argument is null the result will be null.
///
/// \param[in] arg The values to compute the square-root for.
/// \param[in] options arithmetic options (overflow handling), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise square-root
ARROW_EXPORT
Result<Datum> Sqrt(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Round to the nearest integer less than or equal in magnitude to the
/// argument.
///
Expand Down
44 changes: 44 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,29 @@ struct PowerChecked {
}
};

struct SquareRoot {
template <typename T, typename Arg>
static enable_if_floating_value<Arg, T> Call(KernelContext*, Arg arg, Status*) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg < 0.0) {
return std::numeric_limits<T>::quiet_NaN();
}
return std::sqrt(arg);
}
};

struct SquareRootChecked {
template <typename T, typename Arg>
static enable_if_floating_value<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg < 0.0) {
*st = Status::Invalid("square root of negative number");
return arg;
}
return std::sqrt(arg);
}
};

struct Sign {
template <typename T, typename Arg>
static constexpr enable_if_floating_value<Arg, T> Call(KernelContext*, Arg arg,
Expand Down Expand Up @@ -2360,6 +2383,18 @@ const FunctionDoc pow_checked_doc{
"or integer overflow is encountered."),
{"base", "exponent"}};

const FunctionDoc sqrt_doc{
"Takes the square root of arguments element-wise",
("A negative argument returns a NaN. For a variant that returns an\n"
"error, use function \"sqrt_checked\"."),
{"x"}};

const FunctionDoc sqrt_checked_doc{
"Takes the square root of arguments element-wise",
("A negative argument returns an error. For a variant that returns a\n"
"NaN, use function \"sqrt\"."),
{"x"}};

const FunctionDoc sign_doc{
"Get the signedness of the arguments element-wise",
("Output is any of (-1,1) for nonzero inputs and 0 for zero input.\n"
Expand Down Expand Up @@ -2809,6 +2844,15 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
"power_checked", &pow_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(power_checked)));

// ----------------------------------------------------------------------
auto sqrt = MakeUnaryArithmeticFunctionFloatingPoint<SquareRoot>("sqrt", &sqrt_doc);
DCHECK_OK(registry->AddFunction(std::move(sqrt)));

// ----------------------------------------------------------------------
auto sqrt_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<SquareRootChecked>(
"sqrt_checked", &sqrt_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(sqrt_checked)));

// ----------------------------------------------------------------------
auto sign =
MakeUnaryArithmeticFunctionWithFixedIntOutType<Sign, Int8Type>("sign", &sign_doc);
Expand Down
57 changes: 57 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ class TestBaseUnaryArithmetic : public ::testing::Test {
}
}

// (CScalar, CScalar)
void AssertUnaryOpRaises(UnaryFunction func, CType argument,
const std::string& expected_msg) {
auto arg = MakeScalar(argument);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_msg),
func(arg, options_, nullptr));
}

void AssertUnaryOpRaises(UnaryFunction func, const std::string& argument,
const std::string& expected_msg) {
auto arg = ArrayFromJSON(type_singleton(), argument);
Expand Down Expand Up @@ -1542,6 +1550,26 @@ TEST_F(TestUnaryArithmeticDecimal, Log) {
}
}

TEST_F(TestUnaryArithmeticDecimal, SquareRoot) {
std::vector<std::string> funcs = {"sqrt", "sqrt_checked"};
for (const auto& func : funcs) {
for (const auto& ty : PositiveScaleTypes()) {
CheckDecimalToFloat(func, {DecimalArrayFromJSON(ty, R"([])")});
CheckDecimalToFloat(
func, {DecimalArrayFromJSON(ty, R"(["4.00", "16.00", "36.00", null])")});
CheckRaises("sqrt_checked", {DecimalArrayFromJSON(ty, R"(["-2.00"])")},
"square root of negative number");
}
for (const auto& ty : NegativeScaleTypes()) {
CheckDecimalToFloat(func, {DecimalArrayFromJSON(ty, R"([])")});
CheckDecimalToFloat(func,
{DecimalArrayFromJSON(ty, R"(["400", "1600", "3600", null])")});
CheckRaises("sqrt_checked", {DecimalArrayFromJSON(ty, R"(["-400"])")},
"square root of negative number");
}
}
}

TEST_F(TestUnaryArithmeticDecimal, Negate) {
auto max128 = Decimal128::GetMaxValue(38);
auto max256 = Decimal256::GetMaxValue(76);
Expand Down Expand Up @@ -3284,6 +3312,35 @@ TYPED_TEST(TestUnaryArithmeticSigned, Log) {
this->AssertUnaryOpRaises(Log1p, "[-2]", "logarithm of negative number");
}

TYPED_TEST(TestUnaryArithmeticIntegral, Sqrt) {
// Integer arguments promoted to double, sanity check here
for (auto check_overflow : {false, true}) {
this->SetOverflowCheck(check_overflow);
this->AssertUnaryOp(Sqrt, "[1, null]", ArrayFromJSON(float64(), "[1, null]"));
this->AssertUnaryOp(Sqrt, "[4, null]", ArrayFromJSON(float64(), "[2, null]"));
this->AssertUnaryOp(Sqrt, "[null, 9]", ArrayFromJSON(float64(), "[null, 3]"));
}
}

TYPED_TEST(TestUnaryArithmeticFloating, Sqrt) {
using CType = typename TestFixture::CType;
this->SetNansEqual(true);
auto min_val = std::numeric_limits<CType>::min();
auto max_val = std::numeric_limits<CType>::max();
for (auto check_overflow : {false, true}) {
this->SetOverflowCheck(check_overflow);
this->AssertUnaryOp(Sqrt, "[1, 2, null, NaN, Inf]",
"[1, 1.414213562, null, NaN, Inf]");
this->AssertUnaryOp(Sqrt, min_val, static_cast<CType>(std::sqrt(min_val)));
#ifndef __MINGW32__
// this is problematic and produces a slight difference on MINGW
this->AssertUnaryOp(Sqrt, max_val, static_cast<CType>(std::sqrt(max_val)));
#endif
}
this->AssertUnaryOpRaises(Sqrt, "[-1]", "square root of negative number");
this->AssertUnaryOpRaises(Sqrt, "[-Inf]", "square root of negative number");
}

TYPED_TEST(TestUnaryArithmeticSigned, Sign) {
using CType = typename TestFixture::CType;
auto min = std::numeric_limits<CType>::min();
Expand Down
4 changes: 4 additions & 0 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ Mixed time resolution temporal inputs will be cast to finest input resolution.
+------------------+--------+------------------+----------------------+-------+
| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) |
+------------------+--------+------------------+----------------------+-------+
| sqrt | Unary | Numeric | Numeric | |
+------------------+--------+------------------+----------------------+-------+
| sqrt_checked | Unary | Numeric | Numeric | |
+------------------+--------+------------------+----------------------+-------+
| subtract | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
+------------------+--------+------------------+----------------------+-------+
| subtract_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) |
Expand Down
2 changes: 2 additions & 0 deletions docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ throws an ``ArrowInvalid`` exception when overflow is detected.
power
power_checked
sign
sqrt
sqrt_checked
subtract
subtract_checked

Expand Down

0 comments on commit 773da64

Please sign in to comment.