Skip to content

Commit

Permalink
ARROW-9196: [C++][Compute] All casts accept scalar and sliced inputs
Browse files Browse the repository at this point in the history
This is primarily a refactoring of the tests for "cast" to ensure that every case is also verified with scalar and sliced inputs through `CheckScalarUnary`. Caveat: ExtensionScalar is not implemented, so those checks aren't enabled. It should be straightforward for them to be enabled in a follow up, however.

NB: also resolves ARROW-9198

Closes apache#9490 from bkietz/9196-Make-temporal-casts-work-

Authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
bkietz authored and pitrou committed Feb 17, 2021
1 parent 848c803 commit 2c707d4
Show file tree
Hide file tree
Showing 11 changed files with 1,219 additions and 1,608 deletions.
4 changes: 2 additions & 2 deletions cpp/src/arrow/array/array_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class ARROW_EXPORT Array {
///
/// Note that for `null_count == 0` or for null type, this will be null.
/// This buffer does not account for any slice offset
std::shared_ptr<Buffer> null_bitmap() const { return data_->buffers[0]; }
const std::shared_ptr<Buffer>& null_bitmap() const { return data_->buffers[0]; }

/// Raw pointer to the null bitmap.
///
Expand Down Expand Up @@ -160,7 +160,7 @@ class ARROW_EXPORT Array {
/// Input-checking variant of Array::Slice
Result<std::shared_ptr<Array>> SliceSafe(int64_t offset) const;

std::shared_ptr<ArrayData> data() const { return data_; }
const std::shared_ptr<ArrayData>& data() const { return data_; }

int num_fields() const { return static_cast<int>(data_->child_data.size()); }

Expand Down
7 changes: 3 additions & 4 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,14 +663,13 @@ struct ScalarUnaryNotNullStateful {
static void Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0,
Datum* out) {
ArrayData* out_arr = out->mutable_array();
auto out_data = out_arr->GetMutableValues<uint8_t>(1);
auto out_data = out_arr->GetMutableValues<Decimal128>(1);
VisitArrayValuesInline<Arg0Type>(
arg0,
[&](Arg0Value v) {
functor.op.template Call<OutValue, Arg0Value>(ctx, v).ToBytes(out_data);
out_data += 16;
*out_data++ = functor.op.template Call<OutValue, Arg0Value>(ctx, v);
},
[&]() { out_data += 16; });
[&]() { ++out_data; });
}
};

Expand Down
64 changes: 49 additions & 15 deletions cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,71 @@
#include <vector>

#include "arrow/array/builder_nested.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/compute/kernels/scalar_cast_internal.h"
#include "arrow/util/bitmap_ops.h"

namespace arrow {

using internal::CopyBitmap;

namespace compute {
namespace internal {

template <typename Type>
void CastListExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
using offset_type = typename Type::offset_type;
using ScalarType = typename TypeTraits<Type>::ScalarType;

const CastOptions& options = CastState::Get(ctx);

const ArrayData& input = *batch[0].array();
ArrayData* result = out->mutable_array();
auto child_type = checked_cast<const Type&>(*out->type()).value_type();

if (input.offset != 0) {
ctx->SetStatus(Status::NotImplemented(
"Casting sliced lists (non-zero offset) not yet implemented"));
if (out->kind() == Datum::SCALAR) {
const auto& in_scalar = checked_cast<const ScalarType&>(*batch[0].scalar());
auto out_scalar = checked_cast<ScalarType*>(out->scalar().get());

DCHECK(!out_scalar->is_valid);
if (in_scalar.is_valid) {
KERNEL_ASSIGN_OR_RAISE(
out_scalar->value, ctx,
Cast(*in_scalar.value, child_type, options, ctx->exec_context()));

out_scalar->is_valid = true;
}
return;
}
// Copy buffers from parent
result->buffers = input.buffers;

auto child_type = checked_cast<const Type&>(*result->type).value_type();
const ArrayData& in_array = *batch[0].array();
ArrayData* out_array = out->mutable_array();

// Copy from parent
out_array->buffers = in_array.buffers;
Datum values = in_array.child_data[0];

if (in_array.offset != 0) {
KERNEL_ASSIGN_OR_RAISE(out_array->buffers[0], ctx,
CopyBitmap(ctx->memory_pool(), in_array.buffers[0]->data(),
in_array.offset, in_array.length));
KERNEL_ASSIGN_OR_RAISE(out_array->buffers[1], ctx,
ctx->Allocate(sizeof(offset_type) * (in_array.length + 1)));

auto offsets = in_array.GetValues<offset_type>(1);
auto shifted_offsets = out_array->GetMutableValues<offset_type>(1);

for (int64_t i = 0; i < in_array.length + 1; ++i) {
shifted_offsets[i] = offsets[i] - offsets[0];
}
values = in_array.child_data[0]->Slice(offsets[0], offsets[in_array.length]);
}

KERNEL_ASSIGN_OR_RAISE(Datum cast_values, ctx,
Cast(values, child_type, options, ctx->exec_context()));

Datum casted_child;
KERNEL_RETURN_IF_ERROR(
ctx, Cast(Datum(input.child_data[0]), child_type, options, ctx->exec_context())
.Value(&casted_child));
DCHECK_EQ(Datum::ARRAY, casted_child.kind());
result->child_data.push_back(casted_child.array());
DCHECK_EQ(Datum::ARRAY, cast_values.kind());
out_array->child_data.push_back(cast_values.array());
}

template <typename Type>
Expand Down
98 changes: 47 additions & 51 deletions cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/compute/kernels/scalar_cast_internal.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/util/bit_block_counter.h"
#include "arrow/util/int_util.h"
#include "arrow/util/value_parsing.h"
Expand Down Expand Up @@ -361,8 +362,7 @@ struct CastFunctor<O, Decimal128Type, enable_if_t<is_integer_type<O>::value>> {
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& options = checked_cast<const CastState*>(ctx->state())->options;

const ArrayData& input = *batch[0].array();
const auto& in_type_inst = checked_cast<const Decimal128Type&>(*input.type);
const auto& in_type_inst = checked_cast<const Decimal128Type&>(*batch[0].type());
const auto in_scale = in_type_inst.scale();

if (options.allow_decimal_truncate) {
Expand Down Expand Up @@ -395,34 +395,34 @@ struct CastFunctor<O, Decimal128Type, enable_if_t<is_integer_type<O>::value>> {
struct UnsafeUpscaleDecimal {
template <typename... Unused>
Decimal128 Call(KernelContext* ctx, Decimal128 val) const {
return val.IncreaseScaleBy(out_scale_ - in_scale_);
return val.IncreaseScaleBy(by_);
}

int32_t out_scale_, in_scale_;
int32_t by_;
};

struct UnsafeDownscaleDecimal {
template <typename... Unused>
Decimal128 Call(KernelContext* ctx, Decimal128 val) const {
return val.ReduceScaleBy(in_scale_ - out_scale_, false);
return val.ReduceScaleBy(by_, false);
}

int32_t out_scale_, in_scale_;
int32_t by_;
};

struct SafeRescaleDecimal {
template <typename... Unused>
Decimal128 Call(KernelContext* ctx, Decimal128 val) const {
auto result = val.Rescale(in_scale_, out_scale_);
if (ARROW_PREDICT_FALSE(!result.ok())) {
ctx->SetStatus(result.status());
return Decimal128(); // Zero
} else if (ARROW_PREDICT_FALSE(!(*result).FitsInPrecision(out_precision_))) {
ctx->SetStatus(Status::Invalid("Decimal value does not fit in precision"));
return Decimal128(); // Zero
} else {
return *std::move(result);
auto maybe_rescaled = val.Rescale(in_scale_, out_scale_);
if (ARROW_PREDICT_FALSE(!maybe_rescaled.ok())) {
ctx->SetStatus(maybe_rescaled.status());
return {}; // Zero
}

if (ARROW_PREDICT_TRUE(maybe_rescaled->FitsInPrecision(out_precision_))) {
return maybe_rescaled.MoveValueUnsafe();
}

ctx->SetStatus(Status::Invalid("Decimal value does not fit in precision"));
return {}; // Zero
}

int32_t out_scale_, out_precision_, in_scale_;
Expand All @@ -432,36 +432,33 @@ template <>
struct CastFunctor<Decimal128Type, Decimal128Type> {
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& options = checked_cast<const CastState*>(ctx->state())->options;
const ArrayData& input = *batch[0].array();
ArrayData* output = out->mutable_array();

const auto& in_type_inst = checked_cast<const Decimal128Type&>(*input.type);
const auto& out_type_inst = checked_cast<const Decimal128Type&>(*output->type);
const auto in_scale = in_type_inst.scale();
const auto out_scale = out_type_inst.scale();
const auto out_precision = out_type_inst.precision();
const auto& in_type = checked_cast<const Decimal128Type&>(*batch[0].type());
const auto& out_type = checked_cast<const Decimal128Type&>(*out->type());
const auto in_scale = in_type.scale();
const auto out_scale = out_type.scale();

if (options.allow_decimal_truncate) {
if (in_scale < out_scale) {
// Unsafe upscale
applicator::ScalarUnaryNotNullStateful<Decimal128Type, Decimal128Type,
UnsafeUpscaleDecimal>
kernel(UnsafeUpscaleDecimal{out_scale, in_scale});
kernel(UnsafeUpscaleDecimal{out_scale - in_scale});
return kernel.Exec(ctx, batch, out);
} else {
// Unsafe downscale
applicator::ScalarUnaryNotNullStateful<Decimal128Type, Decimal128Type,
UnsafeDownscaleDecimal>
kernel(UnsafeDownscaleDecimal{out_scale, in_scale});
kernel(UnsafeDownscaleDecimal{in_scale - out_scale});
return kernel.Exec(ctx, batch, out);
}
} else {
// Safe rescale
applicator::ScalarUnaryNotNullStateful<Decimal128Type, Decimal128Type,
SafeRescaleDecimal>
kernel(SafeRescaleDecimal{out_scale, out_precision, in_scale});
return kernel.Exec(ctx, batch, out);
}

// Safe rescale
applicator::ScalarUnaryNotNullStateful<Decimal128Type, Decimal128Type,
SafeRescaleDecimal>
kernel(SafeRescaleDecimal{out_scale, out_type.precision(), in_scale});
return kernel.Exec(ctx, batch, out);
}
};

Expand All @@ -471,15 +468,16 @@ struct CastFunctor<Decimal128Type, Decimal128Type> {
struct RealToDecimal {
template <typename OutValue, typename RealType>
Decimal128 Call(KernelContext* ctx, RealType val) const {
auto result = Decimal128::FromReal(val, out_precision_, out_scale_);
if (ARROW_PREDICT_FALSE(!result.ok())) {
if (!allow_truncate_) {
ctx->SetStatus(result.status());
}
return Decimal128(); // Zero
} else {
return *std::move(result);
auto maybe_decimal = Decimal128::FromReal(val, out_precision_, out_scale_);

if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) {
return maybe_decimal.MoveValueUnsafe();
}

if (!allow_truncate_) {
ctx->SetStatus(maybe_decimal.status());
}
return {}; // Zero
}

int32_t out_scale_, out_precision_;
Expand All @@ -490,10 +488,9 @@ template <typename I>
struct CastFunctor<Decimal128Type, I, enable_if_t<is_floating_type<I>::value>> {
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& options = checked_cast<const CastState*>(ctx->state())->options;
ArrayData* output = out->mutable_array();
const auto& out_type_inst = checked_cast<const Decimal128Type&>(*output->type);
const auto out_scale = out_type_inst.scale();
const auto out_precision = out_type_inst.precision();
const auto& out_type = checked_cast<const Decimal128Type&>(*out->type());
const auto out_scale = out_type.scale();
const auto out_precision = out_type.precision();

applicator::ScalarUnaryNotNullStateful<Decimal128Type, I, RealToDecimal> kernel(
RealToDecimal{out_scale, out_precision, options.allow_decimal_truncate});
Expand All @@ -516,9 +513,8 @@ struct DecimalToReal {
template <typename O>
struct CastFunctor<O, Decimal128Type, enable_if_t<is_floating_type<O>::value>> {
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& in_type_inst =
checked_cast<const Decimal128Type&>(*batch[0].array()->type);
const auto in_scale = in_type_inst.scale();
const auto& in_type = checked_cast<const Decimal128Type&>(*batch[0].type());
const auto in_scale = in_type.scale();

applicator::ScalarUnaryNotNullStateful<O, Decimal128Type, DecimalToReal> kernel(
DecimalToReal{in_scale});
Expand Down Expand Up @@ -564,7 +560,7 @@ std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
AddCommonNumberCasts<OutType>(out_ty, func.get());

// From decimal to integer
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType::Array(Type::DECIMAL)}, out_ty,
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
CastFunctor<OutType, Decimal128Type>::Exec));
return func;
}
Expand All @@ -588,7 +584,7 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
AddCommonNumberCasts<OutType>(out_ty, func.get());

// From decimal to floating point
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType::Array(Type::DECIMAL)}, out_ty,
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
CastFunctor<OutType, Decimal128Type>::Exec));
return func;
}
Expand All @@ -608,8 +604,8 @@ std::shared_ptr<CastFunction> GetCastToDecimal128() {
// Cast from other decimal
auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
// We resolve the output type of this kernel from the CastOptions
DCHECK_OK(func->AddKernel(Type::DECIMAL128, {InputType::Array(Type::DECIMAL128)},
sig_out_ty, exec));
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
return func;
}

Expand Down
Loading

0 comments on commit 2c707d4

Please sign in to comment.