Skip to content

Commit

Permalink
[Enhancement]aggregation pipeline shared limit (StarRocks#42449)
Browse files Browse the repository at this point in the history
Signed-off-by: zombee0 <[email protected]>
  • Loading branch information
zombee0 authored Apr 7, 2024
1 parent be769fb commit f07cf05
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 13 deletions.
27 changes: 27 additions & 0 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ AggregatorParamsPtr convert_to_aggregator_params(const TPlanNode& tnode) {
params->grouping_exprs = tnode.agg_node.grouping_exprs;
params->aggregate_functions = tnode.agg_node.aggregate_functions;
params->intermediate_aggr_exprs = tnode.agg_node.intermediate_aggr_exprs;
params->enable_pipeline_share_limit =
tnode.agg_node.__isset.enable_pipeline_share_limit ? tnode.agg_node.enable_pipeline_share_limit : false;
break;
}
default:
Expand Down Expand Up @@ -1315,6 +1317,31 @@ void Aggregator::build_hash_map(size_t chunk_size, bool agg_group_by_with_limit)
});
}

void Aggregator::build_hash_map(size_t chunk_size, std::atomic<int64_t>& shared_limit_countdown,
bool agg_group_by_with_limit) {
if (agg_group_by_with_limit && _params->enable_pipeline_share_limit) {
_build_hash_map_with_shared_limit(chunk_size, shared_limit_countdown);
return;
}
build_hash_map(chunk_size, agg_group_by_with_limit);
}

void Aggregator::_build_hash_map_with_shared_limit(size_t chunk_size, std::atomic<int64_t>& shared_limit_countdown) {
auto start_size = _hash_map_variant.size();
if (_hash_map_variant.size() >= _limit || shared_limit_countdown.load(std::memory_order_relaxed) <= 0) {
build_hash_map_with_selection(chunk_size);
return;
} else {
_streaming_selection.assign(chunk_size, 0);
}
_hash_map_variant.visit([&](auto& hash_map_with_key) {
using MapType = std::remove_reference_t<decltype(*hash_map_with_key)>;
hash_map_with_key->build_hash_map(chunk_size, _group_by_columns, _mem_pool.get(), AllocateState<MapType>(this),
&_tmp_agg_states);
});
shared_limit_countdown.fetch_sub(_hash_map_variant.size() - start_size, std::memory_order_relaxed);
}

void Aggregator::build_hash_map_with_selection(size_t chunk_size) {
_hash_map_variant.visit([&](auto& hash_map_with_key) {
using MapType = std::remove_reference_t<decltype(*hash_map_with_key)>;
Expand Down
11 changes: 10 additions & 1 deletion be/src/exec/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ struct AggregatorParams {
bool needs_finalize;
bool has_outer_join_child;
int64_t limit;
bool enable_pipeline_share_limit;
TStreamingPreaggregationMode::type streaming_preaggregation_mode;
TupleId intermediate_tuple_id;
TupleId output_tuple_id;
Expand Down Expand Up @@ -500,6 +501,7 @@ class Aggregator : public pipeline::ContextWithDependency {

public:
void build_hash_map(size_t chunk_size, bool agg_group_by_with_limit = false);
void build_hash_map(size_t chunk_size, std::atomic<int64_t>& shared_limit_countdown, bool agg_group_by_with_limit);
void build_hash_map_with_selection(size_t chunk_size);
void build_hash_map_with_selection_and_allocation(size_t chunk_size, bool agg_group_by_with_limit = false);
[[nodiscard]] Status convert_hash_map_to_chunk(int32_t chunk_size, ChunkPtr* chunk,
Expand All @@ -512,6 +514,8 @@ class Aggregator : public pipeline::ContextWithDependency {
protected:
bool _reached_limit() { return _limit != -1 && _num_rows_returned >= _limit; }

void _build_hash_map_with_shared_limit(size_t chunk_size, std::atomic<int64_t>& shared_limit_countdown);

bool _use_intermediate_as_input() {
if (is_pending_reset_state()) {
DCHECK(_aggr_mode == AM_BLOCKING_PRE_CACHE || _aggr_mode == AM_STREAMING_PRE_CACHE);
Expand Down Expand Up @@ -615,7 +619,9 @@ class AggregatorFactoryBase {
public:
using Ptr = std::shared_ptr<T>;
AggregatorFactoryBase(const TPlanNode& tnode)
: _tnode(tnode), _aggregator_param(convert_to_aggregator_params(_tnode)) {}
: _tnode(tnode), _aggregator_param(convert_to_aggregator_params(_tnode)) {
_shared_limit_countdown.store(_aggregator_param->limit);
}

Ptr get_or_create(size_t id) {
auto it = _aggregators.find(id);
Expand All @@ -635,11 +641,14 @@ class AggregatorFactoryBase {
const TPlanNode& t_node() { return _tnode; }
const AggrMode aggr_mode() { return _aggr_mode; }

std::atomic<int64_t>& get_shared_limit_countdown() { return _shared_limit_countdown; }

private:
const TPlanNode& _tnode;
AggregatorParamsPtr _aggregator_param;
std::unordered_map<size_t, Ptr> _aggregators;
AggrMode _aggr_mode = AggrMode::AM_DEFAULT;
std::atomic<int64_t> _shared_limit_countdown;
};

using AggregatorFactory = AggregatorFactoryBase<Aggregator>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Status AggregateBlockingSinkOperator::push_chunk(RuntimeState* state, const Chun
SCOPED_TIMER(_aggregator->agg_compute_timer());
// try to build hash table if has group by keys
if (!_aggregator->is_none_group_by_exprs()) {
TRY_CATCH_BAD_ALLOC(_aggregator->build_hash_map(chunk_size, _agg_group_by_with_limit));
TRY_CATCH_BAD_ALLOC(_aggregator->build_hash_map(chunk_size, _shared_limit_countdown, _agg_group_by_with_limit));
TRY_CATCH_BAD_ALLOC(_aggregator->try_convert_to_two_level_map());
}

Expand Down Expand Up @@ -130,7 +130,8 @@ Status AggregateBlockingSinkOperatorFactory::prepare(RuntimeState* state) {
OperatorPtr AggregateBlockingSinkOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) {
// init operator
auto aggregator = _aggregator_factory->get_or_create(driver_sequence);
auto op = std::make_shared<AggregateBlockingSinkOperator>(aggregator, this, _id, _plan_node_id, driver_sequence);
auto op = std::make_shared<AggregateBlockingSinkOperator>(aggregator, this, _id, _plan_node_id, driver_sequence,
_aggregator_factory->get_shared_limit_countdown());
return op;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ namespace starrocks::pipeline {
class AggregateBlockingSinkOperator : public Operator {
public:
AggregateBlockingSinkOperator(AggregatorPtr aggregator, OperatorFactory* factory, int32_t id, int32_t plan_node_id,
int32_t driver_sequence, const char* name = "aggregate_blocking_sink")
: Operator(factory, id, name, plan_node_id, false, driver_sequence), _aggregator(std::move(aggregator)) {
int32_t driver_sequence, std::atomic<int64_t>& shared_limit_countdown,
const char* name = "aggregate_blocking_sink")
: Operator(factory, id, name, plan_node_id, false, driver_sequence),
_aggregator(std::move(aggregator)),
_shared_limit_countdown(shared_limit_countdown) {
_aggregator->set_aggr_phase(AggrPhase2);
_aggregator->ref();
}
Expand Down Expand Up @@ -58,6 +61,7 @@ class AggregateBlockingSinkOperator : public Operator {
std::atomic_bool _is_finished = false;
// whether enable aggregate group by limit optimize
bool _agg_group_by_with_limit = false;
std::atomic<int64_t>& _shared_limit_countdown;
};

class AggregateBlockingSinkOperatorFactory final : public OperatorFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,19 @@ Status AggregateDistinctBlockingSinkOperator::push_chunk(RuntimeState* state, co
{
SCOPED_TIMER(_aggregator->agg_compute_timer());
bool limit_with_no_agg = _aggregator->limit() != -1;
auto size = _aggregator->hash_set_variant().size();
if (limit_with_no_agg) {
auto size = _aggregator->hash_set_variant().size();
if (size >= _aggregator->limit()) {
if (size >= _aggregator->limit() || (_aggregator->params()->enable_pipeline_share_limit &&
_shared_limit_countdown.load(std::memory_order_relaxed) <= 0)) {
(void)set_finishing(state);
return Status::OK();
}
}
RETURN_IF_ERROR(_aggregator->evaluate_groupby_exprs(chunk.get()));
TRY_CATCH_BAD_ALLOC(_aggregator->build_hash_set(chunk->num_rows()));
if (limit_with_no_agg && _aggregator->params()->enable_pipeline_share_limit) {
_shared_limit_countdown.fetch_sub(_aggregator->hash_set_variant().size() - size, std::memory_order_relaxed);
}
TRY_CATCH_BAD_ALLOC(_aggregator->try_convert_to_two_level_set());

_aggregator->update_num_input_rows(chunk->num_rows());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ class AggregateDistinctBlockingSinkOperator : public Operator {
public:
AggregateDistinctBlockingSinkOperator(AggregatorPtr aggregator, OperatorFactory* factory, int32_t id,
int32_t plan_node_id, int32_t driver_sequence,
std::atomic<int64_t>& shared_limit_countdown,
const char* name = "aggregate_distinct_blocking_sink")
: Operator(factory, id, name, plan_node_id, false, driver_sequence), _aggregator(std::move(aggregator)) {
: Operator(factory, id, name, plan_node_id, false, driver_sequence),
_aggregator(std::move(aggregator)),
_shared_limit_countdown(shared_limit_countdown) {
_aggregator->set_aggr_phase(AggrPhase2);
_aggregator->ref();
}
Expand Down Expand Up @@ -56,6 +59,7 @@ class AggregateDistinctBlockingSinkOperator : public Operator {
private:
// Whether prev operator has no output
bool _is_finished = false;
std::atomic<int64_t>& _shared_limit_countdown;
};

class AggregateDistinctBlockingSinkOperatorFactory final : public OperatorFactory {
Expand All @@ -75,7 +79,8 @@ class AggregateDistinctBlockingSinkOperatorFactory final : public OperatorFactor
void close(RuntimeState* state) override { OperatorFactory::close(state); }
OperatorPtr create(int32_t degree_of_parallelism, int32_t driver_sequence) override {
return std::make_shared<AggregateDistinctBlockingSinkOperator>(
_aggregator_factory->get_or_create(driver_sequence), this, _id, _plan_node_id, driver_sequence);
_aggregator_factory->get_or_create(driver_sequence), this, _id, _plan_node_id, driver_sequence,
_aggregator_factory->get_shared_limit_countdown());
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ OperatorPtr SpillableAggregateBlockingSinkOperatorFactory::create(int32_t degree
int32_t driver_sequence) {
auto aggregator = _aggregator_factory->get_or_create(driver_sequence);

auto op = std::make_shared<SpillableAggregateBlockingSinkOperator>(aggregator, this, _id, _plan_node_id,
driver_sequence);
auto op = std::make_shared<SpillableAggregateBlockingSinkOperator>(
aggregator, this, _id, _plan_node_id, driver_sequence, _aggregator_factory->get_shared_limit_countdown());
// create spiller
auto spiller = _spill_factory->create(*_spill_options);
// create spill process channel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ OperatorPtr SpillableAggregateDistinctBlockingSinkOperatorFactory::create(int32_
int32_t driver_sequence) {
auto aggregator = _aggregator_factory->get_or_create(driver_sequence);

auto op = std::make_shared<SpillableAggregateDistinctBlockingSinkOperator>(aggregator, this, _id, _plan_node_id,
driver_sequence);
auto op = std::make_shared<SpillableAggregateDistinctBlockingSinkOperator>(
aggregator, this, _id, _plan_node_id, driver_sequence, _aggregator_factory->get_shared_limit_countdown());
// create spiller
auto spiller = _spill_factory->create(*_spill_options);
// create spill process channel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ protected void toThrift(TPlanNode msg) {
msg.agg_node.setAgg_func_set_version(FeConstants.AGG_FUNC_VERSION);
msg.agg_node.setInterpolate_passthrough(
useStreamingPreagg && ConnectContext.get().getSessionVariable().isInterpolatePassthrough());
msg.agg_node.setEnable_pipeline_share_limit(ConnectContext.get().getSessionVariable().getEnableAggregationPipelineShareLimit());
}

protected String getDisplayLabelDetail() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,8 @@ public static MaterializedViewRewriteMode parse(String str) {

public static final String CBO_PUSHDOWN_TOPN_LIMIT = "cbo_push_down_topn_limit";

public static final String ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT = "enable_aggregation_pipeline_share_limit";

public static final String ENABLE_EXPR_PRUNE_PARTITION = "enable_expr_prune_partition";

public static final String AUDIT_EXECUTE_STMT = "audit_execute_stmt";
Expand Down Expand Up @@ -1382,6 +1384,9 @@ public static MaterializedViewRewriteMode parse(String str) {
@VarAttr(name = CBO_PUSHDOWN_TOPN_LIMIT)
private long cboPushDownTopNLimit = 1000;

@VarAttr(name = ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT, flag = VariableMgr.INVISIBLE)
private boolean enableAggregationPipelineShareLimit = true;

@VarAttr(name = ENABLE_HYPERSCAN_VEC)
private boolean enableHyperscanVec = true;

Expand Down Expand Up @@ -1420,6 +1425,10 @@ public void setCboPushDownTopNLimit(long cboPushDownTopNLimit) {
this.cboPushDownTopNLimit = cboPushDownTopNLimit;
}

public boolean getEnableAggregationPipelineShareLimit() {
return enableAggregationPipelineShareLimit;
}

public String getThriftPlanProtocol() {
return thriftPlanProtocol;
}
Expand Down
3 changes: 3 additions & 0 deletions gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,9 @@ struct TAggregationNode {
27: optional bool use_sort_agg

28: optional bool use_per_bucket_optimize

// enable runtime limit, pipelines share one limit
29: optional bool enable_pipeline_share_limit = false
}

struct TRepeatNode {
Expand Down

0 comments on commit f07cf05

Please sign in to comment.