Skip to content

Commit

Permalink
ARROW-16523: [C++] Part 1 of ExecPlan cleanup: Centralized Task Group (
Browse files Browse the repository at this point in the history
…apache#13143)

Authored-by: Sasha Krassovsky <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
  • Loading branch information
save-buffer authored Jul 14, 2022
1 parent 87d1889 commit cf03901
Show file tree
Hide file tree
Showing 20 changed files with 465 additions and 337 deletions.
80 changes: 33 additions & 47 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class ScalarAggregateNode : public ExecNode {
}

KernelContext kernel_ctx{exec_ctx};
states[i].resize(ThreadIndexer::Capacity());
states[i].resize(plan->max_concurrency());
RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx,
KernelInitArgs{kernels[i],
{
Expand Down Expand Up @@ -168,7 +168,7 @@ class ScalarAggregateNode : public ExecNode {
{"batch.length", batch.length}});
DCHECK_EQ(input, inputs_[0]);

auto thread_index = get_thread_index_();
auto thread_index = plan_->GetThreadIndex();

if (ErrorIfNotOk(DoConsume(std::move(batch), thread_index))) return;

Expand Down Expand Up @@ -196,8 +196,6 @@ class ScalarAggregateNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
finished_ = Future<>::Make();
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
// Scalar aggregates will only output a single batch
outputs_[0]->InputFinished(this, 1);
return Status::OK();
Expand All @@ -224,8 +222,6 @@ class ScalarAggregateNode : public ExecNode {
inputs_[0]->StopProducing(this);
}

Future<> finished() override { return finished_; }

protected:
std::string ToStringExtra(int indent = 0) const override {
std::stringstream ss;
Expand Down Expand Up @@ -266,7 +262,6 @@ class ScalarAggregateNode : public ExecNode {

std::vector<std::vector<std::unique_ptr<KernelState>>> states_;

ThreadIndexer get_thread_index_;
AtomicCounter input_counter_;
};

Expand All @@ -284,6 +279,19 @@ class GroupByNode : public ExecNode {
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}

Status Init() override {
output_task_group_id_ = plan_->RegisterTaskGroup(
[this](size_t, int64_t task_id) {
OutputNthBatch(task_id);
return Status::OK();
},
[this](size_t) {
finished_.MarkFinished();
return Status::OK();
});
return Status::OK();
}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
Expand Down Expand Up @@ -358,7 +366,7 @@ class GroupByNode : public ExecNode {
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});
size_t thread_index = get_thread_index_();
size_t thread_index = plan_->GetThreadIndex();
if (thread_index >= local_states_.size()) {
return Status::IndexError("thread index ", thread_index, " is out of range [0, ",
local_states_.size(), ")");
Expand Down Expand Up @@ -465,47 +473,32 @@ class GroupByNode : public ExecNode {
std::move(out_keys.values.begin(), out_keys.values.end(),
out_data.values.begin() + agg_kernels_.size());
state->grouper.reset();

if (output_counter_.SetTotal(
static_cast<int>(bit_util::CeilDiv(out_data.length, output_batch_size())))) {
// this will be hit if out_data.length == 0
finished_.MarkFinished();
}
return out_data;
}

void OutputNthBatch(int n) {
void OutputNthBatch(int64_t n) {
// bail if StopProducing was called
if (finished_.is_finished()) return;

int64_t batch_size = output_batch_size();
outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));

if (output_counter_.Increment()) {
finished_.MarkFinished();
}
}

Status OutputResult() {
RETURN_NOT_OK(Merge());
ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());

int num_output_batches = *output_counter_.total();
outputs_[0]->InputFinished(this, num_output_batches);

auto executor = ctx_->executor();
for (int i = 0; i < num_output_batches; ++i) {
if (executor) {
// bail if StopProducing was called
if (finished_.is_finished()) break;

auto plan = this->plan()->shared_from_this();
RETURN_NOT_OK(executor->Spawn([plan, this, i] { OutputNthBatch(i); }));
} else {
OutputNthBatch(i);
// To simplify merging, ensure that the first grouper is nonempty
for (size_t i = 0; i < local_states_.size(); i++) {
if (local_states_[i].grouper) {
std::swap(local_states_[i], local_states_[0]);
break;
}
}

RETURN_NOT_OK(Merge());
ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());

int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size());
outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_, num_output_batches));
return Status::OK();
}

Expand Down Expand Up @@ -555,10 +548,8 @@ class GroupByNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
finished_ = Future<>::Make();
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);

local_states_.resize(ThreadIndexer::Capacity());
local_states_.resize(plan_->max_concurrency());
return Status::OK();
}

Expand All @@ -576,17 +567,12 @@ class GroupByNode : public ExecNode {
EVENT(span_, "StopProducing");
DCHECK_EQ(output, outputs_[0]);

ARROW_UNUSED(input_counter_.Cancel());
if (output_counter_.Cancel()) {
finished_.MarkFinished();
}
if (input_counter_.Cancel()) finished_.MarkFinished();
inputs_[0]->StopProducing(this);
}

void StopProducing() override { StopProducing(outputs_[0]); }

Future<> finished() override { return finished_; }

protected:
std::string ToStringExtra(int indent = 0) const override {
std::stringstream ss;
Expand All @@ -608,7 +594,7 @@ class GroupByNode : public ExecNode {
};

ThreadLocalState* GetLocalState() {
size_t thread_index = get_thread_index_();
size_t thread_index = plan_->GetThreadIndex();
return &local_states_[thread_index];
}

Expand Down Expand Up @@ -650,14 +636,14 @@ class GroupByNode : public ExecNode {
}

ExecContext* ctx_;
int output_task_group_id_;

const std::vector<int> key_field_ids_;
const std::vector<int> agg_src_field_ids_;
const std::vector<Aggregate> aggs_;
const std::vector<const HashAggregateKernel*> agg_kernels_;

ThreadIndexer get_thread_index_;
AtomicCounter input_counter_, output_counter_;
AtomicCounter input_counter_;

std::vector<ThreadLocalState> local_states_;
ExecBatch out_data_;
Expand Down
Loading

0 comments on commit cf03901

Please sign in to comment.