Skip to content

Commit 727b6ff

Browse files
DandandanDaniël Heresalambmustafasrepo
authored
Add fetch to SortPreservingMergeExec and SortPreservingMergeStream (apache#6811)
* Add fetch to sortpreservingmergeexec * Add fetch to sortpreservingmergeexec * fmt * Deserialize * Fmt * Fix test * Fix test * Fix test * Fix plan output * Doc * Update datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs Co-authored-by: Andrew Lamb <[email protected]> * Extract into method * Remove from sort enforcement * Update datafusion/core/src/physical_plan/sorts/merge.rs Co-authored-by: Mustafa Akur <[email protected]> * Update datafusion/proto/src/physical_plan/mod.rs Co-authored-by: Mustafa Akur <[email protected]> --------- Co-authored-by: Daniël Heres <[email protected]> Co-authored-by: Andrew Lamb <[email protected]> Co-authored-by: Mustafa Akur <[email protected]>
1 parent c1698b7 commit 727b6ff

File tree

19 files changed

+103
-26
lines changed

19 files changed

+103
-26
lines changed

datafusion/core/src/physical_optimizer/global_sort_selection.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl PhysicalOptimizerRule for GlobalSortSelection {
7070
Arc::new(SortPreservingMergeExec::new(
7171
sort_exec.expr().to_vec(),
7272
Arc::new(sort),
73-
));
73+
).with_fetch(sort_exec.fetch()));
7474
Some(global_sort)
7575
} else {
7676
None

datafusion/core/src/physical_plan/repartition/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ impl ExecutionPlan for RepartitionExec {
497497
sort_exprs,
498498
BaselineMetrics::new(&self.metrics, partition),
499499
context.session_config().batch_size(),
500+
None,
500501
)
501502
} else {
502503
Ok(Box::pin(RepartitionStream {

datafusion/core/src/physical_plan/sorts/merge.rs

+31-7
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ macro_rules! primitive_merge_helper {
3939
}
4040

4141
macro_rules! merge_helper {
42-
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
42+
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident) => {{
4343
let streams = FieldCursorStream::<$t>::new($sort, $streams);
4444
return Ok(Box::pin(SortPreservingMergeStream::new(
4545
Box::new(streams),
4646
$schema,
4747
$tracking_metrics,
4848
$batch_size,
49+
$fetch,
4950
)));
5051
}};
5152
}
@@ -57,17 +58,18 @@ pub(crate) fn streaming_merge(
5758
expressions: &[PhysicalSortExpr],
5859
metrics: BaselineMetrics,
5960
batch_size: usize,
61+
fetch: Option<usize>,
6062
) -> Result<SendableRecordBatchStream> {
6163
// Special case single column comparisons with optimized cursor implementations
6264
if expressions.len() == 1 {
6365
let sort = expressions[0].clone();
6466
let data_type = sort.expr.data_type(schema.as_ref())?;
6567
downcast_primitive! {
66-
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size),
67-
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size)
68-
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size)
69-
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size)
70-
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size)
68+
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch),
69+
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch)
70+
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch)
71+
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch)
72+
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch)
7173
_ => {}
7274
}
7375
}
@@ -78,6 +80,7 @@ pub(crate) fn streaming_merge(
7880
schema,
7981
metrics,
8082
batch_size,
83+
fetch,
8184
)))
8285
}
8386

@@ -140,6 +143,12 @@ struct SortPreservingMergeStream<C> {
140143

141144
/// Vector that holds cursors for each non-exhausted input partition
142145
cursors: Vec<Option<C>>,
146+
147+
/// Optional number of rows to fetch
148+
fetch: Option<usize>,
149+
150+
/// number of rows produced
151+
produced: usize,
143152
}
144153

145154
impl<C: Cursor> SortPreservingMergeStream<C> {
@@ -148,6 +157,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
148157
schema: SchemaRef,
149158
metrics: BaselineMetrics,
150159
batch_size: usize,
160+
fetch: Option<usize>,
151161
) -> Self {
152162
let stream_count = streams.partitions();
153163

@@ -160,6 +170,8 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
160170
loser_tree: vec![],
161171
loser_tree_adjusted: false,
162172
batch_size,
173+
fetch,
174+
produced: 0,
163175
}
164176
}
165177

@@ -227,15 +239,27 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
227239
if self.advance(stream_idx) {
228240
self.loser_tree_adjusted = false;
229241
self.in_progress.push_row(stream_idx);
230-
if self.in_progress.len() < self.batch_size {
242+
243+
// stop sorting if fetch has been reached
244+
if self.fetch_reached() {
245+
self.aborted = true;
246+
} else if self.in_progress.len() < self.batch_size {
231247
continue;
232248
}
233249
}
234250

251+
self.produced += self.in_progress.len();
252+
235253
return Poll::Ready(self.in_progress.build_record_batch().transpose());
236254
}
237255
}
238256

257+
fn fetch_reached(&mut self) -> bool {
258+
self.fetch
259+
.map(|fetch| self.produced + self.in_progress.len() >= fetch)
260+
.unwrap_or(false)
261+
}
262+
239263
fn advance(&mut self, stream_idx: usize) -> bool {
240264
let slot = &mut self.cursors[stream_idx];
241265
match slot.as_mut() {

datafusion/core/src/physical_plan/sorts/sort.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ impl ExternalSorter {
189189
&self.expr,
190190
self.metrics.baseline.clone(),
191191
self.batch_size,
192+
self.fetch,
192193
)
193194
} else if !self.in_mem_batches.is_empty() {
194195
let result = self.in_mem_sort_stream(self.metrics.baseline.clone());
@@ -285,14 +286,13 @@ impl ExternalSorter {
285286
})
286287
.collect::<Result<_>>()?;
287288

288-
// TODO: Pushdown fetch to streaming merge (#6000)
289-
290289
streaming_merge(
291290
streams,
292291
self.schema.clone(),
293292
&self.expr,
294293
metrics,
295294
self.batch_size,
295+
self.fetch,
296296
)
297297
}
298298

datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs

+25-5
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ pub struct SortPreservingMergeExec {
7171
expr: Vec<PhysicalSortExpr>,
7272
/// Execution metrics
7373
metrics: ExecutionPlanMetricsSet,
74+
/// Optional number of rows to fetch. Stops producing rows after this fetch
75+
fetch: Option<usize>,
7476
}
7577

7678
impl SortPreservingMergeExec {
@@ -80,8 +82,14 @@ impl SortPreservingMergeExec {
8082
input,
8183
expr,
8284
metrics: ExecutionPlanMetricsSet::new(),
85+
fetch: None,
8386
}
8487
}
88+
/// Sets the number of rows to fetch
89+
pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
90+
self.fetch = fetch;
91+
self
92+
}
8593

8694
/// Input schema
8795
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
@@ -92,6 +100,11 @@ impl SortPreservingMergeExec {
92100
pub fn expr(&self) -> &[PhysicalSortExpr] {
93101
&self.expr
94102
}
103+
104+
/// Fetch
105+
pub fn fetch(&self) -> Option<usize> {
106+
self.fetch
107+
}
95108
}
96109

97110
impl ExecutionPlan for SortPreservingMergeExec {
@@ -137,10 +150,10 @@ impl ExecutionPlan for SortPreservingMergeExec {
137150
self: Arc<Self>,
138151
children: Vec<Arc<dyn ExecutionPlan>>,
139152
) -> Result<Arc<dyn ExecutionPlan>> {
140-
Ok(Arc::new(SortPreservingMergeExec::new(
141-
self.expr.clone(),
142-
children[0].clone(),
143-
)))
153+
Ok(Arc::new(
154+
SortPreservingMergeExec::new(self.expr.clone(), children[0].clone())
155+
.with_fetch(self.fetch),
156+
))
144157
}
145158

146159
fn execute(
@@ -192,6 +205,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
192205
&self.expr,
193206
BaselineMetrics::new(&self.metrics, partition),
194207
context.session_config().batch_size(),
208+
self.fetch,
195209
)?;
196210

197211
debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
@@ -209,7 +223,12 @@ impl ExecutionPlan for SortPreservingMergeExec {
209223
match t {
210224
DisplayFormatType::Default | DisplayFormatType::Verbose => {
211225
let expr: Vec<String> = self.expr.iter().map(|e| e.to_string()).collect();
212-
write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
226+
write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?;
227+
if let Some(fetch) = self.fetch {
228+
write!(f, ", fetch={fetch}")?;
229+
};
230+
231+
Ok(())
213232
}
214233
}
215234
}
@@ -814,6 +833,7 @@ mod tests {
814833
sort.as_slice(),
815834
BaselineMetrics::new(&metrics, 0),
816835
task_ctx.session_config().batch_size(),
836+
None,
817837
)
818838
.unwrap();
819839

datafusion/core/tests/sql/explain_analyze.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ async fn test_physical_plan_display_indent() {
599599
let physical_plan = dataframe.create_physical_plan().await.unwrap();
600600
let expected = vec![
601601
"GlobalLimitExec: skip=0, fetch=10",
602-
" SortPreservingMergeExec: [the_min@2 DESC]",
602+
" SortPreservingMergeExec: [the_min@2 DESC], fetch=10",
603603
" SortExec: fetch=10, expr=[the_min@2 DESC]",
604604
" ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]",
605605
" AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",

datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ Limit: skip=0, fetch=10
7171
------------TableScan: nation projection=[n_nationkey, n_name]
7272
physical_plan
7373
GlobalLimitExec: skip=0, fetch=10
74-
--SortPreservingMergeExec: [revenue@2 DESC]
74+
--SortPreservingMergeExec: [revenue@2 DESC], fetch=10
7575
----SortExec: fetch=10, expr=[revenue@2 DESC]
7676
------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, c_phone@3 as c_phone, c_comment@6 as c_comment]
7777
--------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as n_name, c_address@5 as c_address, c_comment@6 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]

datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Limit: skip=0, fetch=10
7575
----------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
7676
physical_plan
7777
GlobalLimitExec: skip=0, fetch=10
78-
--SortPreservingMergeExec: [value@1 DESC]
78+
--SortPreservingMergeExec: [value@1 DESC], fetch=10
7979
----SortExec: fetch=10, expr=[value@1 DESC]
8080
------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value]
8181
--------NestedLoopJoinExec: join_type=Inner, filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@1

datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Limit: skip=0, fetch=10
5656
------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")]
5757
physical_plan
5858
GlobalLimitExec: skip=0, fetch=10
59-
--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC]
59+
--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10
6060
----SortExec: fetch=10, expr=[custdist@1 DESC,c_count@0 DESC]
6161
------ProjectionExec: expr=[c_count@0 as c_count, COUNT(UInt8(1))@1 as custdist]
6262
--------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(UInt8(1))]

datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Limit: skip=0, fetch=10
6767
------------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")]
6868
physical_plan
6969
GlobalLimitExec: skip=0, fetch=10
70-
--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
70+
--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10
7171
----SortExec: fetch=10, expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
7272
------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1 as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt]
7373
--------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)]

datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ Limit: skip=0, fetch=10
101101
----------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]
102102
physical_plan
103103
GlobalLimitExec: skip=0, fetch=10
104-
--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
104+
--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10
105105
----SortExec: fetch=10, expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
106106
------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment]
107107
--------CoalesceBatchesExec: target_batch_size=8192

datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Limit: skip=0, fetch=10
6060
----------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("9204")]
6161
physical_plan
6262
GlobalLimitExec: skip=0, fetch=10
63-
--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
63+
--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST], fetch=10
6464
----SortExec: fetch=10, expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
6565
------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority]
6666
--------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]

datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Limit: skip=0, fetch=10
7777
--------------TableScan: nation projection=[n_nationkey, n_name]
7878
physical_plan
7979
GlobalLimitExec: skip=0, fetch=10
80-
--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC]
80+
--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10
8181
----SortExec: fetch=10, expr=[nation@0 ASC NULLS LAST,o_year@1 DESC]
8282
------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, SUM(profit.amount)@2 as sum_profit]
8383
--------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)]

datafusion/core/tests/sqllogictests/test_files/union.slt

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ Limit: skip=0, fetch=5
308308
--------TableScan: aggregate_test_100 projection=[c1, c3]
309309
physical_plan
310310
GlobalLimitExec: skip=0, fetch=5
311-
--SortPreservingMergeExec: [c9@1 DESC]
311+
--SortPreservingMergeExec: [c9@1 DESC], fetch=5
312312
----UnionExec
313313
------SortExec: expr=[c9@1 DESC]
314314
--------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9]

datafusion/core/tests/sqllogictests/test_files/window.slt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,7 @@ Limit: skip=0, fetch=5
17921792
------------TableScan: aggregate_test_100 projection=[c2, c3, c9]
17931793
physical_plan
17941794
GlobalLimitExec: skip=0, fetch=5
1795-
--SortPreservingMergeExec: [c3@0 ASC NULLS LAST]
1795+
--SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5
17961796
----ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2]
17971797
------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]
17981798
--------SortExec: expr=[c3@0 ASC NULLS LAST,c9@1 DESC]

datafusion/proto/proto/datafusion.proto

+2
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,8 @@ message SortExecNode {
13661366
message SortPreservingMergeExecNode {
13671367
PhysicalPlanNode input = 1;
13681368
repeated PhysicalExprNode expr = 2;
1369+
// Maximum number of highest/lowest rows to fetch; negative means no limit
1370+
int64 fetch = 3;
13691371
}
13701372

13711373
message CoalesceBatchesExecNode {

datafusion/proto/src/generated/pbjson.rs

+19
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)