Skip to content

Commit

Permalink
ARROW-4865: [Rust] Support list casts
Browse files Browse the repository at this point in the history
This is a follow up from the initial cast kernel PR, and adds support for:

* List<Primitive> to List<Primitive>
* Primitive to List<Primitive>

The only remaining expansion to the cast kernel will be temporal casts, then I think we'll be able to cover all cast use-cases.

Author: Neville Dipale <[email protected]>

Closes apache#3896 from nevi-me/ARROW-4865 and squashes the following commits:

36530ba <Neville Dipale> restrict sliced-array limitation to primitive->list casts
fd1c49c <Neville Dipale> disable casting sliced arrays
184d0d7 <Neville Dipale> address review comments
a68e472 <Neville Dipale> fix comment
3d27160 <Neville Dipale> ARROW-4865:  Support list casts
  • Loading branch information
nevi-me authored and andygrove committed Mar 15, 2019
1 parent afffe3a commit 90d665e
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 9 deletions.
2 changes: 1 addition & 1 deletion rust/arrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ pub type ArrayRef = Arc<Array>;

/// Constructs an array using the input `data`. Returns a reference-counted `Array`
/// instance.
fn make_array(data: ArrayDataRef) -> ArrayRef {
pub(crate) fn make_array(data: ArrayDataRef) -> ArrayRef {
// TODO: here data_type() needs to clone the type - maybe add a type tag enum to
// avoid the cloning.
match data.data_type().clone() {
Expand Down
244 changes: 236 additions & 8 deletions rust/arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
// specific language governing permissions and limitations
// under the License.

//! Defines cast kernels for `ArrayRef`, allowing casting arrays between supported datatypes.
//! Defines cast kernels for `ArrayRef`, allowing casting arrays between supported
//! datatypes.
//!
//! Example:
//!
Expand All @@ -37,6 +38,8 @@
use std::sync::Arc;

use crate::array::*;
use crate::array_data::ArrayData;
use crate::buffer::Buffer;
use crate::builder::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
Expand All @@ -48,10 +51,12 @@ use crate::error::{ArrowError, Result};
/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings
/// in integer casts return null
/// * Numeric to boolean: 0 returns `false`, any other value returns `true`
/// * List to List: the underlying data type is cast
/// * Primitive to List: a list array with 1 value per slot is created
///
/// Unsupported Casts
/// * To or from `StructArray`
/// * To or from `ListArray`
/// * List to primitive
/// * Utf8 to boolean
pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
use DataType::*;
Expand All @@ -68,15 +73,60 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
(_, Struct(_)) => Err(ArrowError::ComputeError(
"Cannot cast to struct from other types".to_string(),
)),
(List(_), List(_)) => Err(ArrowError::ComputeError(
"Casting between lists not yet supported".to_string(),
)),
(List(_), List(ref to)) => {
let data = array.data_ref();
let underlying_array = make_array(data.child_data()[0].clone());
let cast_array = cast(&underlying_array, &to)?;
let array_data = ArrayData::new(
*to.clone(),
array.len(),
Some(cast_array.null_count()),
cast_array
.data()
.null_bitmap()
.clone()
.map(|bitmap| bitmap.bits),
array.offset(),
// reuse offset buffer
data.buffers().to_vec(),
vec![cast_array.data()],
);
let list = ListArray::from(Arc::new(array_data));
Ok(Arc::new(list) as ArrayRef)
}
(List(_), _) => Err(ArrowError::ComputeError(
"Cannot cast list to non-list data types".to_string(),
)),
(_, List(_)) => Err(ArrowError::ComputeError(
"Cannot cast primitive types to lists".to_string(),
)),
(_, List(ref to)) => {
// see ARROW-4886 for this limitation
if array.offset() != 0 {
return Err(ArrowError::ComputeError(
"Cast kernel does not yet support sliced (non-zero offset) arrays"
.to_string(),
));
}
// cast primitive to list's primitive
let cast_array = cast(array, &to)?;
// create offsets, where if array.len() = 2, we have [0,1,2]
let offsets: Vec<i32> = (0..array.len() as i32 + 1).collect();
let value_offsets = Buffer::from(offsets[..].to_byte_slice());
let list_data = ArrayData::new(
*to.clone(),
array.len(),
Some(cast_array.null_count()),
cast_array
.data()
.null_bitmap()
.clone()
.map(|bitmap| bitmap.bits),
0,
vec![value_offsets],
vec![cast_array.data()],
);
let list_array = Arc::new(ListArray::from(Arc::new(list_data))) as ArrayRef;

Ok(list_array)
}
(_, Boolean) => match from_type {
UInt8 => cast_numeric_to_bool::<UInt8Type>(array),
UInt16 => cast_numeric_to_bool::<UInt16Type>(array),
Expand Down Expand Up @@ -458,6 +508,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::buffer::Buffer;

#[test]
fn test_cast_i32_to_f64() {
Expand Down Expand Up @@ -486,6 +537,23 @@ mod tests {
assert_eq!(false, c.is_valid(4));
}

#[test]
fn test_cast_i32_to_u8_sliced() {
let a = Int32Array::from(vec![-5, 6, -7, 8, 100000000]);
let array = Arc::new(a) as ArrayRef;
assert_eq!(0, array.offset());
let array = array.slice(2, 3);
assert_eq!(2, array.offset());
let b = cast(&array, &DataType::UInt8).unwrap();
assert_eq!(3, b.len());
assert_eq!(0, b.offset());
let c = b.as_any().downcast_ref::<UInt8Array>().unwrap();
assert_eq!(false, c.is_valid(0));
assert_eq!(8, c.value(1));
// overflows return None
assert_eq!(false, c.is_valid(2));
}

#[test]
fn test_cast_i32_to_i32() {
let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
Expand All @@ -499,6 +567,88 @@ mod tests {
assert_eq!(9, c.value(4));
}

#[test]
fn test_cast_i32_to_list_i32() {
let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::List(Box::new(DataType::Int32))).unwrap();
assert_eq!(5, b.len());
let arr = b.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(0, arr.value_offset(0));
assert_eq!(1, arr.value_offset(1));
assert_eq!(2, arr.value_offset(2));
assert_eq!(3, arr.value_offset(3));
assert_eq!(4, arr.value_offset(4));
assert_eq!(1, arr.value_length(0));
assert_eq!(1, arr.value_length(1));
assert_eq!(1, arr.value_length(2));
assert_eq!(1, arr.value_length(3));
assert_eq!(1, arr.value_length(4));
let values = arr.values();
let c = values.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(5, c.value(0));
assert_eq!(6, c.value(1));
assert_eq!(7, c.value(2));
assert_eq!(8, c.value(3));
assert_eq!(9, c.value(4));
}

#[test]
fn test_cast_i32_to_list_i32_nullable() {
let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::List(Box::new(DataType::Int32))).unwrap();
assert_eq!(5, b.len());
assert_eq!(1, b.null_count());
let arr = b.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(0, arr.value_offset(0));
assert_eq!(1, arr.value_offset(1));
assert_eq!(2, arr.value_offset(2));
assert_eq!(3, arr.value_offset(3));
assert_eq!(4, arr.value_offset(4));
assert_eq!(1, arr.value_length(0));
assert_eq!(1, arr.value_length(1));
assert_eq!(1, arr.value_length(2));
assert_eq!(1, arr.value_length(3));
assert_eq!(1, arr.value_length(4));
let values = arr.values();
let c = values.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(1, c.null_count());
assert_eq!(5, c.value(0));
assert_eq!(false, c.is_valid(1));
assert_eq!(7, c.value(2));
assert_eq!(8, c.value(3));
assert_eq!(9, c.value(4));
}

#[test]
#[should_panic(
expected = "Cast kernel does not yet support sliced (non-zero offset) arrays"
)]
fn test_cast_i32_to_list_i32_nullable_sliced() {
let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), None]);
let array = Arc::new(a) as ArrayRef;
let array = array.slice(2, 3);
let b = cast(&array, &DataType::List(Box::new(DataType::Int32))).unwrap();
assert_eq!(3, b.len());
assert_eq!(1, b.null_count());
let arr = b.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(0, arr.value_offset(0));
assert_eq!(1, arr.value_offset(1));
assert_eq!(2, arr.value_offset(2));
assert_eq!(1, arr.value_length(0));
assert_eq!(1, arr.value_length(1));
assert_eq!(1, arr.value_length(2));
let values = arr.values();
let c = values.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(1, c.null_count());
assert_eq!(7, c.value(0));
assert_eq!(8, c.value(1));
// if one removes the non-zero-offset limitation, this assertion passes when it
// shouldn't
assert_eq!(0, c.value(2));
}

#[test]
fn test_cast_utf_to_i32() {
let a = BinaryArray::from(vec!["5", "6", "seven", "8", "9.1"]);
Expand Down Expand Up @@ -543,4 +693,82 @@ mod tests {
let array = Arc::new(a) as ArrayRef;
cast(&array, &DataType::Timestamp(TimeUnit::Microsecond)).unwrap();
}

#[test]
fn test_cast_list_i32_to_list_u16() {
// Construct a value array
let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).data();

let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice());

// Construct a list array from the above two
let list_data_type = DataType::List(Box::new(DataType::Int32));
let list_data = ArrayData::builder(list_data_type.clone())
.len(3)
.add_buffer(value_offsets.clone())
.add_child_data(value_data.clone())
.build();
let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef;

let cast_array =
cast(&list_array, &DataType::List(Box::new(DataType::UInt16))).unwrap();
// 3 negative values should get lost when casting to unsigned,
// 1 value should overflow
assert_eq!(4, cast_array.null_count());
// offsets should be the same
assert_eq!(
list_array.data().buffers().to_vec(),
cast_array.data().buffers().to_vec()
);
let array = cast_array
.as_ref()
.as_any()
.downcast_ref::<ListArray>()
.unwrap();
assert_eq!(DataType::UInt16, array.value_type());
assert_eq!(4, array.values().null_count());
assert_eq!(3, array.value_length(0));
assert_eq!(3, array.value_length(1));
assert_eq!(2, array.value_length(2));
let values = array.values();
let u16arr = values.as_any().downcast_ref::<UInt16Array>().unwrap();
assert_eq!(8, u16arr.len());
assert_eq!(4, u16arr.null_count());

assert_eq!(0, u16arr.value(0));
assert_eq!(0, u16arr.value(1));
assert_eq!(0, u16arr.value(2));
assert_eq!(false, u16arr.is_valid(3));
assert_eq!(false, u16arr.is_valid(4));
assert_eq!(false, u16arr.is_valid(5));
assert_eq!(2, u16arr.value(6));
assert_eq!(false, u16arr.is_valid(7));
}

#[test]
#[should_panic(
expected = "Casting from Int32 to Timestamp(Microsecond) not supported"
)]
fn test_cast_list_i32_to_list_timestamp() {
// Construct a value array
let value_data =
Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 8, 100000000]).data();

let value_offsets = Buffer::from(&[0, 3, 6, 9].to_byte_slice());

// Construct a list array from the above two
let list_data_type = DataType::List(Box::new(DataType::Int32));
let list_data = ArrayData::builder(list_data_type.clone())
.len(3)
.add_buffer(value_offsets.clone())
.add_child_data(value_data.clone())
.build();
let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef;

cast(
&list_array,
&DataType::List(Box::new(DataType::Timestamp(TimeUnit::Microsecond))),
)
.unwrap();
}
}

0 comments on commit 90d665e

Please sign in to comment.