Skip to content

Commit a76b09e

Browse files
fsdvhandygrove
andauthored
Properly project grouping set expressions (apache#6777)
* update version to 26.0.0 * update Cargo.lock * changelog * prettier * update changelog * VTX-1613: update ignore rule * VTX-1613: revert * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: handle grouping sets * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: debug * VTX-1613: cleanup & fix for grouping set * VTX-1613: cleanup * VTX-1613: cleanup * VTX-1613: cleanup * VTX-1613: cleanup * VTX-1613: cleanup * VTX-1613: cleanup * VTX-1613: cleanup * VTX-1613: fix import --------- Co-authored-by: Andy Grove <[email protected]>
1 parent b9ecfc5 commit a76b09e

File tree

2 files changed

+71
-9
lines changed

2 files changed

+71
-9
lines changed

datafusion-cli/Cargo.lock

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/optimizer/src/common_subexpr_eliminate.rs

+67-5
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,7 @@ impl CommonSubexprEliminate {
273273

274274
let mut proj_exprs = vec![];
275275
for expr in &new_group_expr {
276-
let out_col: Column =
277-
expr.to_field(&new_input_schema)?.qualified_column();
278-
proj_exprs.push(Expr::Column(out_col));
276+
extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
279277
}
280278
for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) {
281279
if expr_rewritten == expr_orig {
@@ -488,6 +486,22 @@ fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalP
488486
)
489487
}
490488

489+
fn extract_expressions(
490+
expr: &Expr,
491+
schema: &DFSchema,
492+
result: &mut Vec<Expr>,
493+
) -> Result<()> {
494+
if let Expr::GroupingSet(groupings) = expr {
495+
for e in groupings.distinct_expr() {
496+
result.push(Expr::Column(e.to_field(schema)?.qualified_column()))
497+
}
498+
} else {
499+
result.push(Expr::Column(expr.to_field(schema)?.qualified_column()));
500+
}
501+
502+
Ok(())
503+
}
504+
491505
/// Which type of [expressions](Expr) should be considered for rewriting?
492506
#[derive(Debug, Clone, Copy)]
493507
enum ExprMask {
@@ -773,8 +787,8 @@ mod test {
773787
avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
774788
};
775789
use datafusion_expr::{
776-
AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
777-
StateTypeFunction, Volatility,
790+
grouping_set, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction,
791+
Signature, StateTypeFunction, Volatility,
778792
};
779793

780794
use crate::optimizer::OptimizerContext;
@@ -1251,4 +1265,52 @@ mod test {
12511265

12521266
Ok(())
12531267
}
1268+
1269+
#[test]
1270+
fn test_extract_expressions_from_grouping_set() -> Result<()> {
1271+
let mut result = Vec::with_capacity(3);
1272+
let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1273+
let schema = DFSchema::new_with_metadata(
1274+
vec![
1275+
DFField::new_unqualified("a", DataType::Int32, false),
1276+
DFField::new_unqualified("b", DataType::Int32, false),
1277+
DFField::new_unqualified("c", DataType::Int32, false),
1278+
],
1279+
HashMap::default(),
1280+
)?;
1281+
extract_expressions(&grouping, &schema, &mut result)?;
1282+
1283+
assert!(result.len() == 3);
1284+
Ok(())
1285+
}
1286+
1287+
#[test]
1288+
fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1289+
let mut result = Vec::with_capacity(2);
1290+
let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1291+
let schema = DFSchema::new_with_metadata(
1292+
vec![
1293+
DFField::new_unqualified("a", DataType::Int32, false),
1294+
DFField::new_unqualified("b", DataType::Int32, false),
1295+
],
1296+
HashMap::default(),
1297+
)?;
1298+
extract_expressions(&grouping, &schema, &mut result)?;
1299+
1300+
assert!(result.len() == 2);
1301+
Ok(())
1302+
}
1303+
1304+
#[test]
1305+
fn test_extract_expressions_from_col() -> Result<()> {
1306+
let mut result = Vec::with_capacity(1);
1307+
let schema = DFSchema::new_with_metadata(
1308+
vec![DFField::new_unqualified("a", DataType::Int32, false)],
1309+
HashMap::default(),
1310+
)?;
1311+
extract_expressions(&col("a"), &schema, &mut result)?;
1312+
1313+
assert!(result.len() == 1);
1314+
Ok(())
1315+
}
12541316
}

0 commit comments

Comments
 (0)