Skip to content

Commit

Permalink
fix(rust, python): no cse in groupby until fixed (pola-rs#10216)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 1, 2023
1 parent 1c0ad40 commit f4f3f98
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 38 deletions.
77 changes: 39 additions & 38 deletions crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,44 +526,45 @@ impl<'a> RewritingVisitor for CommonSubExprOptimizer<'a> {
node.replace(lp);
}
}
ALogicalPlan::Aggregate {
input,
keys,
aggs,
options,
maintain_order,
apply,
schema,
} => {
if let Some(aggs) =
self.find_cse(aggs, &mut expr_arena, &mut id_array_offsets, true)?
{
let keys = keys.clone();
let options = options.clone();
let schema = schema.clone();
let apply = apply.clone();
let maintain_order = *maintain_order;
let input = *input;

let input = node.with_arena_mut(|lp_arena| {
let lp = ALogicalPlanBuilder::new(input, &mut expr_arena, lp_arena)
.with_columns(aggs.cse_exprs().to_vec())
.build();
lp_arena.add(lp)
});

let lp = ALogicalPlan::Aggregate {
input,
keys,
aggs: aggs.default_exprs().to_vec(),
options,
schema,
maintain_order,
apply,
};
node.replace(lp);
}
}
// TODO! activate once fixed
// ALogicalPlan::Aggregate {
// input,
// keys,
// aggs,
// options,
// maintain_order,
// apply,
// schema,
// } => {
// if let Some(aggs) =
// self.find_cse(aggs, &mut expr_arena, &mut id_array_offsets, true)?
// {
// let keys = keys.clone();
// let options = options.clone();
// let schema = schema.clone();
// let apply = apply.clone();
// let maintain_order = *maintain_order;
// let input = *input;
//
// let input = node.with_arena_mut(|lp_arena| {
// let lp = ALogicalPlanBuilder::new(input, &mut expr_arena, lp_arena)
// .with_columns(aggs.cse_exprs().to_vec())
// .build();
// lp_arena.add(lp)
// });
//
// let lp = ALogicalPlan::Aggregate {
// input,
// keys,
// aggs: aggs.default_exprs().to_vec(),
// options,
// schema,
// maintain_order,
// apply,
// };
// node.replace(lp);
// }
// }
_ => {}
};
std::mem::swap(self.expr_arena, &mut expr_arena);
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def test_cse_expr_selection_streaming(monkeypatch: Any, capfd: Any) -> None:
assert "df -> hstack[cse] -> ordered_sink" in err


@pytest.mark.skip(reason="activate once fixed")
def test_cse_expr_groupby() -> None:
q = pl.LazyFrame(
{
Expand Down Expand Up @@ -269,3 +270,24 @@ def test_windows_cse_excluded() -> None:
"c_diff": [None, 2, -2, 1, 1, 1, -4],
"c_diff_by_a": [None, 2, -2, None, 1, 1, None],
}


def test_cse_groupby_10215() -> None:
assert (
pl.DataFrame(
{
"a": [1, 2, 3],
"b": [1, 1, 1],
}
)
.lazy()
.groupby(
"a",
)
.agg(
(pl.col("a").sum() * pl.col("a").sum()).alias("x"),
(pl.col("b").sum() * pl.col("b").sum()).alias("y"),
)
.collect()
.sort("a")
).to_dict(False) == {"a": [1, 2, 3], "x": [1, 4, 9], "y": [1, 1, 1]}

0 comments on commit f4f3f98

Please sign in to comment.