@@ -273,9 +273,7 @@ impl CommonSubexprEliminate {
273
273
274
274
let mut proj_exprs = vec ! [ ] ;
275
275
for expr in & new_group_expr {
276
- let out_col: Column =
277
- expr. to_field ( & new_input_schema) ?. qualified_column ( ) ;
278
- proj_exprs. push ( Expr :: Column ( out_col) ) ;
276
+ extract_expressions ( expr, & new_input_schema, & mut proj_exprs) ?
279
277
}
280
278
for ( expr_rewritten, expr_orig) in rewritten. into_iter ( ) . zip ( new_aggr_expr) {
281
279
if expr_rewritten == expr_orig {
@@ -488,6 +486,22 @@ fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalP
488
486
)
489
487
}
490
488
489
+ fn extract_expressions (
490
+ expr : & Expr ,
491
+ schema : & DFSchema ,
492
+ result : & mut Vec < Expr > ,
493
+ ) -> Result < ( ) > {
494
+ if let Expr :: GroupingSet ( groupings) = expr {
495
+ for e in groupings. distinct_expr ( ) {
496
+ result. push ( Expr :: Column ( e. to_field ( schema) ?. qualified_column ( ) ) )
497
+ }
498
+ } else {
499
+ result. push ( Expr :: Column ( expr. to_field ( schema) ?. qualified_column ( ) ) ) ;
500
+ }
501
+
502
+ Ok ( ( ) )
503
+ }
504
+
491
505
/// Which type of [expressions](Expr) should be considered for rewriting?
492
506
#[ derive( Debug , Clone , Copy ) ]
493
507
enum ExprMask {
@@ -773,8 +787,8 @@ mod test {
773
787
avg, col, lit, logical_plan:: builder:: LogicalPlanBuilder , sum,
774
788
} ;
775
789
use datafusion_expr:: {
776
- AccumulatorFactoryFunction , AggregateUDF , ReturnTypeFunction , Signature ,
777
- StateTypeFunction , Volatility ,
790
+ grouping_set , AccumulatorFactoryFunction , AggregateUDF , ReturnTypeFunction ,
791
+ Signature , StateTypeFunction , Volatility ,
778
792
} ;
779
793
780
794
use crate :: optimizer:: OptimizerContext ;
@@ -1251,4 +1265,52 @@ mod test {
1251
1265
1252
1266
Ok ( ( ) )
1253
1267
}
1268
+
1269
+ #[ test]
1270
+ fn test_extract_expressions_from_grouping_set ( ) -> Result < ( ) > {
1271
+ let mut result = Vec :: with_capacity ( 3 ) ;
1272
+ let grouping = grouping_set ( vec ! [ vec![ col( "a" ) , col( "b" ) ] , vec![ col( "c" ) ] ] ) ;
1273
+ let schema = DFSchema :: new_with_metadata (
1274
+ vec ! [
1275
+ DFField :: new_unqualified( "a" , DataType :: Int32 , false ) ,
1276
+ DFField :: new_unqualified( "b" , DataType :: Int32 , false ) ,
1277
+ DFField :: new_unqualified( "c" , DataType :: Int32 , false ) ,
1278
+ ] ,
1279
+ HashMap :: default ( ) ,
1280
+ ) ?;
1281
+ extract_expressions ( & grouping, & schema, & mut result) ?;
1282
+
1283
+ assert ! ( result. len( ) == 3 ) ;
1284
+ Ok ( ( ) )
1285
+ }
1286
+
1287
+ #[ test]
1288
+ fn test_extract_expressions_from_grouping_set_with_identical_expr ( ) -> Result < ( ) > {
1289
+ let mut result = Vec :: with_capacity ( 2 ) ;
1290
+ let grouping = grouping_set ( vec ! [ vec![ col( "a" ) , col( "b" ) ] , vec![ col( "a" ) ] ] ) ;
1291
+ let schema = DFSchema :: new_with_metadata (
1292
+ vec ! [
1293
+ DFField :: new_unqualified( "a" , DataType :: Int32 , false ) ,
1294
+ DFField :: new_unqualified( "b" , DataType :: Int32 , false ) ,
1295
+ ] ,
1296
+ HashMap :: default ( ) ,
1297
+ ) ?;
1298
+ extract_expressions ( & grouping, & schema, & mut result) ?;
1299
+
1300
+ assert ! ( result. len( ) == 2 ) ;
1301
+ Ok ( ( ) )
1302
+ }
1303
+
1304
+ #[ test]
1305
+ fn test_extract_expressions_from_col ( ) -> Result < ( ) > {
1306
+ let mut result = Vec :: with_capacity ( 1 ) ;
1307
+ let schema = DFSchema :: new_with_metadata (
1308
+ vec ! [ DFField :: new_unqualified( "a" , DataType :: Int32 , false ) ] ,
1309
+ HashMap :: default ( ) ,
1310
+ ) ?;
1311
+ extract_expressions ( & col ( "a" ) , & schema, & mut result) ?;
1312
+
1313
+ assert ! ( result. len( ) == 1 ) ;
1314
+ Ok ( ( ) )
1315
+ }
1254
1316
}
0 commit comments