diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 7b0c8e97c63a1..e3d5820327ec2 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -996,10 +996,7 @@ mod tests { let my_add = ScalarFunction::new( "my_add", - vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ], + vec![DataType::Int32, DataType::Int32], DataType::Int32, myfunc, ); diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 97098d65b5e3a..ea40ac5efc613 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -21,7 +21,7 @@ use crate::error::ExecutionError; use crate::execution::physical_plan::udf::ScalarFunction; use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use std::sync::Arc; @@ -29,7 +29,7 @@ macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { ScalarFunction::new( $NAME, - vec![Field::new("n", DataType::Float64, true)], + vec![DataType::Float64], DataType::Float64, Arc::new(|args: &[ArrayRef]| { let n = &args[0].as_any().downcast_ref::(); @@ -86,7 +86,7 @@ mod tests { execution::context::ExecutionContext, logicalplan::{col, sqrt, LogicalPlanBuilder}, }; - use arrow::datatypes::Schema; + use arrow::datatypes::{Field, Schema}; #[test] fn cast_i8_input() -> Result<()> { diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index f79907d2f0a28..780cdba395dbf 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -26,7 +26,7 @@ use crate::error::Result; use crate::execution::context::ExecutionContextState; use crate::logicalplan::{LogicalPlan, ScalarValue}; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::{ compute::kernels::length::length, record_batch::{RecordBatch, RecordBatchReader}, @@ -138,7 +138,7 @@ pub trait Accumulator: Debug { pub fn scalar_functions() -> Vec { let mut udfs = vec![ScalarFunction::new( "length", - vec![Field::new("n", DataType::Utf8, true)], + vec![DataType::Utf8], DataType::UInt32, Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))), )]; diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index fb777f9552448..ca5908748a13d 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -20,7 +20,7 @@ use std::fmt; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use crate::error::Result; use crate::execution::physical_plan::PhysicalExpr; @@ -38,7 +38,7 @@ pub struct ScalarFunction { /// Function name pub name: String, /// Function argument meta-data - pub args: Vec, + pub arg_types: Vec, /// Return type pub return_type: DataType, /// UDF implementation @@ -61,7 +61,7 @@ impl Debug for ScalarFunction { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("ScalarFunction") .field("name", &self.name) - .field("args", &self.args) + .field("arg_types", &self.arg_types) .field("return_type", &self.return_type) .field("fun", &"") .finish() @@ -72,13 +72,13 @@ impl ScalarFunction { /// Create a new ScalarFunction pub fn new( name: &str, - args: Vec, + arg_types: Vec, return_type: DataType, fun: ScalarUdf, ) -> Self { Self { name: name.to_owned(), - args, + arg_types, return_type, fun, } diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 22d49b5ff65b4..65940baa34c38 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -69,9 +69,8 @@ where match self.scalar_functions.lookup(name) { Some(func_meta) => { for i in 0..expressions.len() { - let field = &func_meta.args[i]; let actual_type = expressions[i].get_type(schema)?; - let required_type = field.data_type(); + let required_type = &func_meta.arg_types[i]; if &actual_type != required_type { // attempt to coerce using numerical coercion // todo: also try string coercion. diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 8f366de0a4af9..f627d05cf4319 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -520,10 +520,8 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { let mut safe_args: Vec = vec![]; for i in 0..rex_args.len() { - safe_args.push( - rex_args[i] - .cast_to(fm.args[i].data_type(), schema)?, - ); + safe_args + .push(rex_args[i].cast_to(&fm.arg_types[i], schema)?); } Ok(Expr::ScalarFunction { @@ -908,7 +906,7 @@ mod tests { match name { "sqrt" => Some(Arc::new(ScalarFunction::new( "sqrt", - vec![Field::new("n", DataType::Float64, false)], + vec![DataType::Float64], DataType::Float64, Arc::new(|_| Err(ExecutionError::NotImplemented("".to_string()))), ))), diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 43ac27ad7fb8f..15120d754162e 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -220,7 +220,7 @@ fn create_ctx() -> Result { // register a custom UDF ctx.register_udf(ScalarFunction::new( "custom_sqrt", - vec![Field::new("n", DataType::Float64, true)], + vec![DataType::Float64], DataType::Float64, Arc::new(custom_sqrt), ));