Skip to content

Commit

Permalink
ARROW-11457: [Rust] Make string comparisson kernels generic over Utf8…
Browse files Browse the repository at this point in the history
… and LargeUtf8

This PR makes the existing comparisson kernels that operate on `StringArray` generic over both `StringArray` and `LargeStringArray`.

Closes apache#9362 from ritchie46/generic_str_comparissons

Authored-by: Ritchie Vink <[email protected]>
Signed-off-by: Andrew Lamb <[email protected]>
ritchie46 authored and alamb committed Feb 1, 2021
1 parent 4f74ae4 commit b0b622b
Showing 1 changed file with 80 additions and 16 deletions.
96 changes: 80 additions & 16 deletions rust/arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
@@ -109,7 +109,10 @@ where
compare_op_scalar!(left, right, op)
}

pub fn like_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn like_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
let mut map = HashMap::new();
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
@@ -158,7 +161,10 @@ fn is_like_pattern(c: char) -> bool {
c == '%' || c == '_'
}

pub fn like_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn like_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
let null_bit_buffer = left.data().null_buffer().cloned();
let bytes = bit_util::ceil(left.len(), 8);
let mut bool_buf = MutableBuffer::from_len_zeroed(bytes);
@@ -217,7 +223,10 @@ pub fn like_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray>
Ok(BooleanArray::from(Arc::new(data)))
}

pub fn nlike_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn nlike_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
let mut map = HashMap::new();
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
@@ -262,7 +271,10 @@ pub fn nlike_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArra
Ok(BooleanArray::from(Arc::new(data)))
}

pub fn nlike_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn nlike_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
let null_bit_buffer = left.data().null_buffer().cloned();
let mut result = BooleanBufferBuilder::new(left.len());

@@ -308,51 +320,87 @@ pub fn nlike_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray
Ok(BooleanArray::from(Arc::new(data)))
}

pub fn eq_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn eq_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
compare_op!(left, right, |a, b| a == b)
}

pub fn eq_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a == b)
}

pub fn neq_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn neq_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
compare_op!(left, right, |a, b| a != b)
}

pub fn neq_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn neq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a != b)
}

pub fn lt_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn lt_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
compare_op!(left, right, |a, b| a < b)
}

pub fn lt_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn lt_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a < b)
}

pub fn lt_eq_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn lt_eq_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
compare_op!(left, right, |a, b| a <= b)
}

pub fn lt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn lt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a <= b)
}

pub fn gt_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn gt_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
compare_op!(left, right, |a, b| a > b)
}

pub fn gt_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn gt_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a > b)
}

pub fn gt_eq_utf8(left: &StringArray, right: &StringArray) -> Result<BooleanArray> {
pub fn gt_eq_utf8<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
) -> Result<BooleanArray> {
compare_op!(left, right, |a, b| a >= b)
}

pub fn gt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result<BooleanArray> {
pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a >= b)
}

@@ -1227,6 +1275,22 @@ mod tests {
$right
);
}

let left = LargeStringArray::from($left);
let res = $op(&left, $right).unwrap();
let expected = $expected;
assert_eq!(expected.len(), res.len());
for i in 0..res.len() {
let v = res.value(i);
assert_eq!(
v,
expected[i],
"unexpected result when comparing {} at position {} to {} ",
left.value(i),
i,
$right
);
}
}
};
}

0 comments on commit b0b622b

Please sign in to comment.