Skip to content

Commit

Permalink
Compute timestamp diff for binning in projection
Browse files Browse the repository at this point in the history
  • Loading branch information
ankoh committed Oct 4, 2024
1 parent eedd029 commit 3e75efb
Showing 1 changed file with 37 additions and 20 deletions.
57 changes: 37 additions & 20 deletions packages/sqlynx-compute/src/datafusion_binning_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow::array::{ArrayRef, AsArray, Int32Array, RecordBatch, TimestampMillisecondArray};
use arrow::datatypes::{Field, SchemaBuilder, TimestampMillisecondType};
use arrow::datatypes::{Field, Int64Type, SchemaBuilder, TimestampMillisecondType};
use arrow::datatypes::DataType;
use arrow::datatypes::TimeUnit;
use arrow::util::pretty::pretty_format_batches;
Expand All @@ -14,7 +14,7 @@ use datafusion_physical_expr::expressions::{binary, CastExpr};
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::expressions::lit;
use datafusion_physical_plan::aggregates::{AggregateMode, AggregateExec, PhysicalGroupBy};
use datafusion_physical_plan::collect;
use datafusion_physical_plan::{collect, ExecutionPlan};
use datafusion_physical_plan::memory::MemoryExec;
use datafusion_physical_plan::projection::ProjectionExec;
use indoc::indoc;
Expand Down Expand Up @@ -51,14 +51,14 @@ async fn test_bin_timestamps() -> anyhow::Result<()> {
let udaf_min = Arc::new(AggregateUDF::new_from_impl(Min::new()));
let udaf_max = Arc::new(AggregateUDF::new_from_impl(Max::new()));

let when_min = AggregateExprBuilder::new(udaf_min, vec![col_when.clone()])
.schema(data.schema())
.alias("min")
.build()?;
let when_max = AggregateExprBuilder::new(udaf_max, vec![col_when.clone()])
.schema(data.schema())
.alias("max")
.build()?;
let when_min = AggregateExprBuilder::new(udaf_min, vec![col_when.clone()])
.schema(data.schema())
.alias("min")
.build()?;

let data_scan = Arc::new(
MemoryExec::try_new(&[vec![data.clone()]], data.schema(), None)?,
Expand All @@ -71,26 +71,43 @@ async fn test_bin_timestamps() -> anyhow::Result<()> {
data_scan.clone(),
data.schema()
)?);
let minmax_diff = binary(
col("max", &groupby_exec.schema())?,
Operator::Minus,
col("min", &groupby_exec.schema())?,
&groupby_exec.schema())?;
let minmax_ms = Arc::new(CastExpr::new(minmax_diff.clone(), DataType::Int64, None));

let projection_exec = Arc::new(ProjectionExec::try_new(
vec![
(col("min", &groupby_exec.schema())?, "min".to_string()),
(col("max", &groupby_exec.schema())?, "max".to_string()),
(minmax_diff, "diff".to_string()),
(minmax_ms, "diff_ms".to_string())
],
groupby_exec
)?);

let task_ctx = Arc::new(TaskContext::default());
let minmax_result = collect(groupby_exec, task_ctx.clone()).await?;
let minmax_result = collect(projection_exec, task_ctx.clone()).await?;

assert_eq!(minmax_result.len(), 1);
assert_eq!(minmax_result[0].num_columns(), 2);
assert_eq!(minmax_result[0].num_rows(), 1);
assert_eq!(format!("{}", pretty_format_batches(&minmax_result)?), indoc! {"
+---------------------+---------------------+
| min | max |
+---------------------+---------------------+
| 2024-04-01T12:00:00 | 2024-04-03T12:00:00 |
+---------------------+---------------------+
+---------------------+---------------------+-----------+-----------+
| min | max | diff | diff_ms |
+---------------------+---------------------+-----------+-----------+
| 2024-04-01T12:00:00 | 2024-04-03T12:00:00 | PT172800S | 172800000 |
+---------------------+---------------------+-----------+-----------+
"}.trim());

let min_col = minmax_result[0].column(0).as_primitive::<TimestampMillisecondType>();
let max_col = minmax_result[0].column(1).as_primitive::<TimestampMillisecondType>();
let diff_ms_col = minmax_result[0].column(3).as_primitive::<Int64Type>();
assert_eq!(min_col.value_as_datetime(0).unwrap().to_string(), "2024-04-01 12:00:00");
assert_eq!(max_col.value_as_datetime(0).unwrap().to_string(), "2024-04-03 12:00:00");
assert_eq!(diff_ms_col.value(0), 172800000);

let range_millis = max_col.value(0) - min_col.value(0);
let range_millis = diff_ms_col.value(0);
let bin_millis = range_millis / 100;

// Maybe we get around date_part here for now by just casting to milliseconds.
Expand All @@ -99,14 +116,14 @@ async fn test_bin_timestamps() -> anyhow::Result<()> {
// let diff_extract = Arc::new(ScalarFunctionExpr::new("min_ms", date_part_udf.clone(), vec![lit("day"), time_diff.clone()], DataType::Float64));
// let time_diff_cast = Arc::new(CastExpr::new(time_diff.clone(), DataType::Int64, None));

let when_diff_min = binary(
let when_diff = binary(
col("when", data.schema_ref())?,
Operator::Minus,
lit(ScalarValue::TimestampMillisecond(Some(min_col.value(0)), None)),
data.schema_ref())?;
let when_diff_min_cast = Arc::new(CastExpr::new(when_diff_min, DataType::Int64, None));
let when_bin = binary(
when_diff_min_cast,
let when_diff_ms = Arc::new(CastExpr::new(when_diff, DataType::Int64, None));
let when_diff_bin = binary(
when_diff_ms,
Operator::Divide,
lit(ScalarValue::Int64(Some(bin_millis))),
data.schema_ref())?;
Expand All @@ -117,7 +134,7 @@ async fn test_bin_timestamps() -> anyhow::Result<()> {
let projection_exec = Arc::new(ProjectionExec::try_new(
vec![
(col("id", data.schema_ref())?, "id".to_string()),
(when_bin, "bin".to_string())
(when_diff_bin, "bin".to_string())
],
data_scan
)?);
Expand Down

0 comments on commit 3e75efb

Please sign in to comment.