Skip to content

Commit

Permalink
apacheGH-36203: [C++] Support casting in both ways for is_in and inde…
Browse files Browse the repository at this point in the history
…x_in (apache#36204)

### Rationale for this change

This is a follow up of apache#36058 (review). Currently it only try to cast the value set to input type, not the other way around. This causes some valid input types to be rejected.

### What changes are included in this PR?

The kernels will first try to case value_set to input type during preparation. If it doesn't work, it would try to cast input to value_set type before the lookup happens.

### Are these changes tested?

Yes. 

### Are there any user-facing changes?

Some previously rejected input types will now be valid.

* Closes: apache#36203

Lead-authored-by: Jin Shang <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
js8544 and pitrou authored Jun 28, 2023
1 parent efd5686 commit 6f3bd25
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 34 deletions.
90 changes: 69 additions & 21 deletions cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ namespace arrow {
using internal::checked_cast;
using internal::HashTraits;

namespace compute {
namespace internal {
namespace compute::internal {
namespace {

// This base class enables non-templated access to the value set type
struct SetLookupStateBase : public KernelState {
std::shared_ptr<DataType> value_set_type;
};

template <typename Type>
struct SetLookupState : public KernelState {
struct SetLookupState : public SetLookupStateBase {
explicit SetLookupState(MemoryPool* pool) : memory_pool(pool) {}

Status Init(const SetLookupOptions& options) {
Expand Down Expand Up @@ -65,6 +69,7 @@ struct SetLookupState : public KernelState {
if (!options.skip_nulls && lookup_table->GetNull() >= 0) {
null_index = memo_index_to_value_index[lookup_table->GetNull()];
}
value_set_type = options.value_set.type();
return Status::OK();
}

Expand Down Expand Up @@ -115,11 +120,12 @@ struct SetLookupState : public KernelState {
};

template <>
struct SetLookupState<NullType> : public KernelState {
struct SetLookupState<NullType> : public SetLookupStateBase {
explicit SetLookupState(MemoryPool*) {}

Status Init(const SetLookupOptions& options) {
value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls;
value_set_type = null();
return Status::OK();
}

Expand Down Expand Up @@ -215,16 +221,31 @@ struct InitStateVisitor {
return Status::Invalid("Array type didn't match type of values set: ", *arg_type,
" vs ", *options.value_set.type());
}

if (!options.value_set.is_arraylike()) {
return Status::Invalid("Set lookup value set must be Array or ChunkedArray");
} else if (!options.value_set.type()->Equals(*arg_type)) {
ARROW_ASSIGN_OR_RAISE(
options.value_set,
auto cast_result =
Cast(options.value_set, CastOptions::Safe(arg_type.GetSharedPtr()),
ctx->exec_context()));
ctx->exec_context());
if (cast_result.ok()) {
options.value_set = *cast_result;
} else if (CanCast(*arg_type.type, *options.value_set.type())) {
// Avoid casting from non binary types to string like above
// Otherwise, will try to cast input array to value set type during kernel exec
if ((options.value_set.type()->id() == Type::STRING ||
options.value_set.type()->id() == Type::LARGE_STRING) &&
!is_base_binary_like(arg_type.id())) {
return Status::Invalid("Array type didn't match type of values set: ",
*arg_type, " vs ", *options.value_set.type());
}
} else {
return Status::Invalid("Array type doesn't match type of values set: ", *arg_type,
" vs ", *options.value_set.type());
}
}

RETURN_NOT_OK(VisitTypeInline(*arg_type, this));
RETURN_NOT_OK(VisitTypeInline(*options.value_set.type(), this));
return std::move(result);
}
};
Expand Down Expand Up @@ -263,15 +284,12 @@ struct IndexInVisitor {
}

template <typename Type>
Status ProcessIndexIn() {
Status ProcessIndexIn(const SetLookupState<Type>& state, const ArraySpan& input) {
using T = typename GetViewType<Type>::T;

const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());

FirstTimeBitmapWriter bitmap_writer(out_bitmap, out->offset, out->length);
int32_t* out_data = out->GetValues<int32_t>(1);
VisitArraySpanInline<Type>(
data,
input,
[&](T v) {
int32_t index = state.lookup_table->Get(v);
if (index != -1) {
Expand Down Expand Up @@ -303,6 +321,19 @@ struct IndexInVisitor {
return Status::OK();
}

template <typename Type>
Status ProcessIndexIn() {
const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
if (!data.type->Equals(state.value_set_type)) {
auto materialized_input = data.ToArrayData();
ARROW_ASSIGN_OR_RAISE(auto casted_input,
Cast(*materialized_input, state.value_set_type,
CastOptions::Safe(), ctx->exec_context()));
return ProcessIndexIn(state, *casted_input.array());
}
return ProcessIndexIn(state, data);
}

template <typename Type>
enable_if_boolean<Type, Status> Visit(const Type&) {
return ProcessIndexIn<BooleanType>();
Expand Down Expand Up @@ -331,7 +362,10 @@ struct IndexInVisitor {
return ProcessIndexIn<MonthDayNanoIntervalType>();
}

Status Execute() { return VisitTypeInline(*data.type, this); }
Status Execute() {
const auto& state = checked_cast<const SetLookupStateBase&>(*ctx->state());
return VisitTypeInline(*state.value_set_type, this);
}
};

Status ExecIndexIn(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
Expand Down Expand Up @@ -359,13 +393,11 @@ struct IsInVisitor {
}

template <typename Type>
Status ProcessIsIn() {
Status ProcessIsIn(const SetLookupState<Type>& state, const ArraySpan& input) {
using T = typename GetViewType<Type>::T;
const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());

FirstTimeBitmapWriter writer(out->buffers[1].data, out->offset, out->length);
VisitArraySpanInline<Type>(
this->data,
input,
[&](T v) {
if (state.lookup_table->Get(v) != -1) {
writer.Set();
Expand All @@ -386,6 +418,20 @@ struct IsInVisitor {
return Status::OK();
}

template <typename Type>
Status ProcessIsIn() {
const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());

if (!data.type->Equals(state.value_set_type)) {
auto materialized_data = data.ToArrayData();
ARROW_ASSIGN_OR_RAISE(auto casted_data,
Cast(*materialized_data, state.value_set_type,
CastOptions::Safe(), ctx->exec_context()));
return ProcessIsIn(state, *casted_data.array());
}
return ProcessIsIn(state, data);
}

template <typename Type>
enable_if_boolean<Type, Status> Visit(const Type&) {
return ProcessIsIn<BooleanType>();
Expand Down Expand Up @@ -413,7 +459,10 @@ struct IsInVisitor {
return ProcessIsIn<MonthDayNanoIntervalType>();
}

Status Execute() { return VisitTypeInline(*data.type, this); }
Status Execute() {
const auto& state = checked_cast<const SetLookupStateBase&>(*ctx->state());
return VisitTypeInline(*state.value_set_type, this);
}
};

Status ExecIsIn(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
Expand Down Expand Up @@ -566,6 +615,5 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) {
}
}

} // namespace internal
} // namespace compute
} // namespace compute::internal
} // namespace arrow
87 changes: 74 additions & 13 deletions cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,40 @@ TEST_F(TestIsInKernel, ImplicitlyCastValueSet) {
"true, false, true, false]"));
AssertArraysEqual(*expected, *out.make_array());

// fails; value_set cannot be cast to int8
opts = SetLookupOptions{ArrayFromJSON(float32(), "[2.5, 3.1, 5.0]")};
ASSERT_RAISES(Invalid, CallFunction("is_in", {input}, &opts));
// value_set cannot be casted to int8, but int8 is castable to float
CheckIsIn(input, ArrayFromJSON(float32(), "[1.0, 2.5, 3.1, 5.0]"),
"[false, true, false, false, false, true, false, false, false]");

// Allow implicit casts between binary types...
CheckIsIn(ArrayFromJSON(binary(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
CheckIsIn(ArrayFromJSON(binary(), R"(["aaa", "bb", "ccc", null, "bbb"])"),
ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb"])"),
"[true, true, false, false, true]");
"[true, false, false, false, true]");
CheckIsIn(ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
ArrayFromJSON(binary(), R"(["aa", "bbb"])"),
"[false, true, false, false, true]");
CheckIsIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
ArrayFromJSON(large_utf8(), R"(["aaa", "bbb"])"),
"[true, true, false, false, true]");
CheckIsIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
ArrayFromJSON(utf8(), R"(["aaa", "bbb"])"),
"[true, true, false, false, true]");

// But explicitly deny implicit casts from non-binary to utf8 to
// avoid surprises
ASSERT_RAISES(Invalid,
IsIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
SetLookupOptions(ArrayFromJSON(
utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));

ASSERT_RAISES(Invalid,
IsIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
ASSERT_RAISES(Invalid,
IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
SetLookupOptions(ArrayFromJSON(
large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
}

template <typename Type>
Expand Down Expand Up @@ -253,11 +268,12 @@ TEST_F(TestIsInKernel, TimeDuration) {
"[true, false, false, true, true]", /*skip_nulls=*/true);
}

// Different units, invalid cast
ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 2]")));
// Different units, cast value_set to values will fail, then cast values to value_set
CheckIsIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
ArrayFromJSON(duration(TimeUnit::MILLI), "[1, 2, 2000]"),
"[false, false, true]");

// Different units, valid cast
// Different units, cast value_set to values
CheckIsIn(ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 1, 2000]"),
ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 2]"), "[true, false, true]");
}
Expand Down Expand Up @@ -779,11 +795,12 @@ TEST_F(TestIndexInKernel, TimeDuration) {
CheckIndexIn(duration(TimeUnit::SECOND), "[null, null, null, null]", "[null]",
"[0, 0, 0, 0]");

// Different units, invalid cast
ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 2]")));
// Different units, cast value_set to values will fail, then cast values to value_set
CheckIndexIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
ArrayFromJSON(duration(TimeUnit::MILLI), "[1, 2, 2000]"),
"[null, null, 2]");

// Different units, valid cast
// Different units, cast value_set to values
CheckIndexIn(ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 1, 2000]"),
ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 2]"), "[0, null, 1]");
}
Expand Down Expand Up @@ -822,6 +839,50 @@ TEST_F(TestIndexInKernel, Boolean) {
CheckIndexIn(boolean(), "[null, null, null, null]", "[null]", "[0, 0, 0, 0]");
}

TEST_F(TestIndexInKernel, ImplicitlyCastValueSet) {
auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]");

SetLookupOptions opts{ArrayFromJSON(int32(), "[2, 3, 5, 7]")};
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("index_in", {input}, &opts));

auto expected = ArrayFromJSON(int32(), ("[null, null, 0, 1, null,"
"2, null, 3, null]"));
AssertArraysEqual(*expected, *out.make_array());

// Although value_set cannot be cast to int8, but int8 is castable to float
CheckIndexIn(input, ArrayFromJSON(float32(), "[1.0, 2.5, 3.1, 5.0]"),
"[null, 0, null, null, null, 3, null, null, null]");

// Allow implicit casts between binary types...
CheckIndexIn(ArrayFromJSON(binary(), R"(["aaa", "bb", "ccc", null, "bbb"])"),
ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb"])"),
"[0, null, null, null, 1]");
CheckIndexIn(
ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
ArrayFromJSON(binary(), R"(["aa", "bbb"])"), "[null, 1, null, null, 1]");
CheckIndexIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
ArrayFromJSON(large_utf8(), R"(["aaa", "bbb"])"), "[0, 1, null, null, 1]");
CheckIndexIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
ArrayFromJSON(utf8(), R"(["aaa", "bbb"])"), "[0, 1, null, null, 1]");
// But explicitly deny implicit casts from non-binary to utf8 to
// avoid surprises
ASSERT_RAISES(Invalid,
IndexIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
SetLookupOptions(ArrayFromJSON(
utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));

ASSERT_RAISES(
Invalid,
IndexIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
ASSERT_RAISES(Invalid,
IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
SetLookupOptions(ArrayFromJSON(
large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
}

template <typename Type>
class TestIndexInKernelBinary : public TestIndexInKernel {};

Expand Down

0 comments on commit 6f3bd25

Please sign in to comment.