Skip to content

Commit

Permalink
New Transform for GROUPING SETS.
Browse files Browse the repository at this point in the history
  • Loading branch information
RinChanNOWWW committed Mar 15, 2023
1 parent f154f15 commit fea23fc
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 6 deletions.
13 changes: 10 additions & 3 deletions src/query/expression/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ impl Value<AnyType> {
pub fn try_downcast<T: ValueType>(&self) -> Option<Value<T>> {
Some(self.as_ref().try_downcast::<T>()?.to_owned())
}

pub fn wrap_nullable(&self) -> Self {
match self {
Value::Column(c) => Value::Column(c.wrap_nullable()),
scalar => scalar.clone(),
}
}
}

impl<'a> ValueRef<'a, AnyType> {
Expand Down Expand Up @@ -1479,14 +1486,14 @@ impl Column {
}
}

pub fn wrap_nullable(self) -> Self {
pub fn wrap_nullable(&self) -> Self {
match self {
col @ Column::Nullable(_) => col,
col @ Column::Nullable(_) => col.clone(),
col => {
let mut validity = MutableBitmap::with_capacity(col.len());
validity.extend_constant(col.len(), true);
Column::Nullable(Box::new(NullableColumn {
column: col,
column: col.clone(),
validity: validity.into(),
}))
}
Expand Down
60 changes: 58 additions & 2 deletions src/query/service/src/pipelines/pipeline_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use common_sql::IndexType;
use common_storage::DataOperator;

use super::processors::ProfileWrapper;
use super::processors::TransformExpandGroupingSets;
use crate::api::DefaultExchangeInjector;
use crate::api::ExchangeInjector;
use crate::pipelines::processors::transforms::build_partition_bucket;
Expand Down Expand Up @@ -424,8 +425,63 @@ impl PipelineBuilder {
})
}

fn build_aggregate_expand(&mut self, _aggregate: &AggregateExpand) -> Result<()> {
todo!()
fn build_aggregate_expand(&mut self, expand: &AggregateExpand) -> Result<()> {
self.build_pipeline(&expand.input)?;
let input_schema = expand.input.output_schema()?;
let group_bys = expand
.group_bys
.iter()
.filter_map(|i| {
// Do not collect virtual column "_grouping_id".
if *i != expand.grouping_id_index {
match input_schema.index_of(&i.to_string()) {
Ok(index) => {
let ty = input_schema.field(index).data_type().clone();
Some(Ok((index, ty)))
}
Err(e) => Some(Err(e)),
}
} else {
None
}
})
.collect::<Result<Vec<_>>>()?;
let grouping_sets = expand
.grouping_sets
.iter()
.map(|sets| {
sets.iter()
.map(|i| {
let i = input_schema.index_of(&i.to_string())?;
let offset = group_bys.iter().position(|(j, _)| *j == i).unwrap();
Ok(offset)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()?;
let mut grouping_ids = Vec::with_capacity(grouping_sets.len());
for set in grouping_sets {
let mut id = 0;
for i in set {
id |= 1 << i;
}
// For element in `group_bys`,
// if it is in current grouping set: set 0, else: set 1. (1 represents it will be NULL in grouping)
// Example: GROUP BY GROUPING SETS ((a, b), (a), (b), ())
// group_bys: [a, b]
// grouping_sets: [[0, 1], [0], [1], []]
// grouping_ids: 00, 01, 10, 11
grouping_ids.push(!id);
}

self.main_pipeline.add_transform(|input, output| {
Ok(TransformExpandGroupingSets::create(
input,
output,
group_bys.clone(),
grouping_ids.clone(),
))
})
}

fn build_aggregate_partial(&mut self, aggregate: &AggregatePartial) -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions src/query/service/src/pipelines/processors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub use transforms::TransformBlockCompact;
pub use transforms::TransformCastSchema;
pub use transforms::TransformCompact;
pub use transforms::TransformCreateSets;
pub use transforms::TransformExpandGroupingSets;
pub use transforms::TransformHashJoinProbe;
pub use transforms::TransformLimit;
pub use transforms::TransformResortAddOn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod aggregate_exchange_injector;
mod aggregate_meta;
mod aggregator_params;
mod serde;
mod transform_aggregate_expand;
mod transform_aggregate_final;
mod transform_aggregate_partial;
mod transform_group_by_final;
Expand All @@ -29,6 +30,7 @@ pub use aggregate_cell::HashTableCell;
pub use aggregate_cell::PartitionedHashTableDropper;
pub use aggregate_exchange_injector::AggregateInjector;
pub use aggregator_params::AggregatorParams;
pub use transform_aggregate_expand::TransformExpandGroupingSets;
pub use transform_aggregate_final::TransformFinalAggregate;
pub use transform_aggregate_partial::TransformPartialAggregate;
pub use transform_group_by_final::TransformFinalGroupBy;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2023 Datafuse Labs.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use common_exception::Result;
use common_expression::types::DataType;
use common_expression::types::NumberDataType;
use common_expression::types::NumberScalar;
use common_expression::BlockEntry;
use common_expression::DataBlock;
use common_expression::Scalar;
use common_expression::Value;
use common_pipeline_core::processors::port::InputPort;
use common_pipeline_core::processors::port::OutputPort;
use common_pipeline_core::processors::processor::ProcessorPtr;
use common_pipeline_transforms::processors::transforms::Transform;
use common_pipeline_transforms::processors::transforms::Transformer;

pub struct TransformExpandGroupingSets {
group_bys: Vec<(usize, DataType)>,
grouping_ids: Vec<usize>,
}

impl TransformExpandGroupingSets {
pub fn create(
input: Arc<InputPort>,
output: Arc<OutputPort>,
group_bys: Vec<(usize, DataType)>,
grouping_ids: Vec<usize>,
) -> ProcessorPtr {
ProcessorPtr::create(Transformer::create(
input,
output,
TransformExpandGroupingSets {
grouping_ids,
group_bys,
},
))
}
}

impl Transform for TransformExpandGroupingSets {
const NAME: &'static str = "TransformExpandGroupingSets";

fn transform(&mut self, data: DataBlock) -> Result<DataBlock> {
let num_rows = data.num_rows();
let num_group_bys = self.group_bys.len();
let mut output_blocks = Vec::with_capacity(self.grouping_ids.len());

for &id in &self.grouping_ids {
// Repeat data for each grouping set.
let grouping_column = BlockEntry {
data_type: DataType::Number(NumberDataType::UInt32),
value: Value::Scalar(Scalar::Number(NumberScalar::UInt32(id as u32))),
};
let mut columns = data
.columns()
.iter()
.cloned()
.chain(vec![grouping_column])
.collect::<Vec<_>>();
let bits = !id;
for i in 0..num_group_bys {
let entry = unsafe {
let offset = self.group_bys.get_unchecked(i).0;
columns.get_unchecked_mut(offset)
};
if bits & (1 << i) == 0 {
// This column should be set to NULLs.
*entry = BlockEntry {
data_type: entry.data_type.wrap_nullable(),
value: Value::Scalar(Scalar::Null),
}
} else {
*entry = BlockEntry {
data_type: entry.data_type.wrap_nullable(),
value: entry.value.wrap_nullable(),
}
}
}
output_blocks.push(DataBlock::new(columns, num_rows));
}

DataBlock::concat(&output_blocks)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub use aggregator::TransformAggregateDeserializer;
pub use aggregator::TransformAggregateSerializer;
pub use aggregator::TransformAggregateSpillReader;
pub use aggregator::TransformAggregateSpillWriter;
pub use aggregator::TransformExpandGroupingSets;
pub use aggregator::TransformFinalAggregate;
pub use aggregator::TransformGroupByDeserializer;
pub use aggregator::TransformGroupBySerializer;
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/src/executor/physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ pub struct AggregateExpand {
pub plan_id: u32,

pub input: Box<PhysicalPlan>,
pub group_bys: Vec<usize>,
pub grouping_id_index: IndexType,
pub grouping_sets: Vec<Vec<usize>>,
/// Only used for explain
Expand Down
2 changes: 2 additions & 0 deletions src/query/sql/src/executor/physical_plan_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ impl PhysicalPlanBuilder {
let expand = AggregateExpand {
plan_id: self.next_plan_id(),
input,
group_bys: group_items.clone(),
grouping_id_index: agg.grouping_id_index,
grouping_sets: agg.grouping_sets.clone(),
stat_info: Some(stat_info.clone()),
Expand Down Expand Up @@ -528,6 +529,7 @@ impl PhysicalPlanBuilder {
let expand = AggregateExpand {
plan_id: self.next_plan_id(),
input: Box::new(input),
group_bys: group_items.clone(),
grouping_id_index: agg.grouping_id_index,
grouping_sets: agg.grouping_sets.clone(),
stat_info: Some(stat_info.clone()),
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/src/executor/physical_plan_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ pub trait PhysicalPlanReplacer {
Ok(PhysicalPlan::AggregateExpand(AggregateExpand {
plan_id: plan.plan_id,
input: Box::new(input),
group_bys: plan.group_bys.clone(),
grouping_id_index: plan.grouping_id_index,
grouping_sets: plan.grouping_sets.clone(),
stat_info: plan.stat_info.clone(),
Expand Down
4 changes: 3 additions & 1 deletion src/query/sql/src/planner/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::types::DataType;
use common_expression::types::NumberDataType;
use itertools::Itertools;

use super::prune_by_children;
use crate::binder::scalar::ScalarBinder;
Expand Down Expand Up @@ -371,7 +372,8 @@ impl Binder {
set.sort();
set
})
.collect();
.collect::<Vec<_>>();
let grouping_sets = grouping_sets.into_iter().unique().collect();
bind_context.aggregate_info.grouping_sets = grouping_sets;
// Add a virtual column `_grouping_id` to group items.
let grouping_id_column = self.create_column_binding(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
statement ok
drop table if exists t;

statement ok
create table t (a string, b string, c int);

statement ok
insert into t values ('a','A',1),('a','A',2),('a','B',1),('a','B',3),('b','A',1),('b','A',4),('b','B',1),('b','B',5);

query TTI
select a, b, sum(c) as sc from t group by grouping sets ((a,b),(),(b),(a)) order by sc;
----
a A 3
a B 4
b A 5
b B 6
a NULL 7
NULL A 8
NULL B 10
b NULL 11
NULL NULL 18

0 comments on commit fea23fc

Please sign in to comment.