Skip to content

Commit

Permalink
ARROW-11797: [C++][Dataset] Provide batch stream Scanner methods
Browse files Browse the repository at this point in the history
Closes apache#9589 from bkietz/11797-Provide-Scanner-methods-t

Lead-authored-by: Benjamin Kietzman <[email protected]>
Co-authored-by: David Li <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
bkietz and lidavidm committed Apr 15, 2021
1 parent 02cdeab commit d575858
Show file tree
Hide file tree
Showing 17 changed files with 675 additions and 170 deletions.
16 changes: 16 additions & 0 deletions cpp/src/arrow/dataset/file_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,25 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
//
// NB: neither of these will have any impact whatsoever on the common case of writing
// an in-memory table to disk.

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#elif defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#endif

// TODO: (ARROW-11782/ARROW-12288) Remove calls to Scan()
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, scanner->Scan());
ARROW_ASSIGN_OR_RAISE(ScanTaskVector scan_tasks, scan_task_it.ToVector());

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#elif defined(_MSC_VER)
#pragma warning(pop)
#endif

WriteState state(write_options);
auto res = internal::RunSynchronously<arrow::detail::Empty>(
[&](internal::Executor* cpu_executor) -> Future<> {
Expand Down
16 changes: 5 additions & 11 deletions cpp/src/arrow/dataset/file_csv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,11 @@ N/A,bar
ASSERT_OK(builder.Project({"str"}));
ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());

ASSERT_OK_AND_ASSIGN(auto scan_task_it, scanner->Scan());
for (auto maybe_scan_task : scan_task_it) {
ASSERT_OK_AND_ASSIGN(auto scan_task, maybe_scan_task);
ASSERT_OK_AND_ASSIGN(auto batch_it, scan_task->Execute());
for (auto maybe_batch : batch_it) {
ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
// Run through the scan checking for errors to ensure that "f64" is read with the
// specified type and does not revert to the inferred type (if it reverts to
// inferring float64 then evaluation of the comparison expression should break)
}
}
ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
// Run through the scan checking for errors to ensure that "f64" is read with the
// specified type and does not revert to the inferred type (if it reverts to
// inferring float64 then evaluation of the comparison expression should break)
ASSERT_OK(batch_it.Visit([](TaggedRecordBatch) { return Status::OK(); }));
}

INSTANTIATE_TEST_SUITE_P(TestUncompressedCsv, TestCsvFileFormat,
Expand Down
264 changes: 234 additions & 30 deletions cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
#include "arrow/dataset/scanner.h"

#include <algorithm>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <sstream>

#include "arrow/array/array_primitive.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/cast.h"
#include "arrow/dataset/dataset.h"
#include "arrow/dataset/dataset_internal.h"
#include "arrow/dataset/scanner_internal.h"
Expand Down Expand Up @@ -132,37 +137,124 @@ Result<EnumeratedRecordBatchIterator> Scanner::AddPositioningToInOrderScan(
EnumeratingIterator{std::make_shared<State>(std::move(scan), std::move(first))});
}

Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches() {
// TODO(ARROW-11797) Provide a better implementation that does readahead. Also, add
// unit testing
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, Scan());
struct BatchIter {
explicit BatchIter(ScanTaskIterator scan_task_it)
: scan_task_it(std::move(scan_task_it)) {}

Result<TaggedRecordBatch> Next() {
while (true) {
if (current_task == nullptr) {
ARROW_ASSIGN_OR_RAISE(current_task, scan_task_it.Next());
if (IsIterationEnd<std::shared_ptr<ScanTask>>(current_task)) {
return IterationEnd<TaggedRecordBatch>();
}
ARROW_ASSIGN_OR_RAISE(batch_it, current_task->Execute());
}
ARROW_ASSIGN_OR_RAISE(auto next, batch_it.Next());
if (IsIterationEnd<std::shared_ptr<RecordBatch>>(next)) {
current_task = nullptr;
} else {
return TaggedRecordBatch{next, current_task->fragment()};
}
struct ScanBatchesState : public std::enable_shared_from_this<ScanBatchesState> {
explicit ScanBatchesState(ScanTaskIterator scan_task_it,
std::shared_ptr<TaskGroup> task_group_)
: scan_tasks(std::move(scan_task_it)), task_group(std::move(task_group_)) {}

void ResizeBatches(size_t task_index) {
if (task_batches.size() <= task_index) {
task_batches.resize(task_index + 1);
task_drained.resize(task_index + 1);
}
}

void Push(TaggedRecordBatch batch, size_t task_index) {
{
std::lock_guard<std::mutex> lock(mutex);
ResizeBatches(task_index);
task_batches[task_index].push_back(std::move(batch));
}
ready.notify_one();
}

Status Finish(size_t task_index) {
{
std::lock_guard<std::mutex> lock(mutex);
ResizeBatches(task_index);
task_drained[task_index] = true;
}
ready.notify_one();
return Status::OK();
}

void PushScanTask() {
if (no_more_tasks) return;
std::unique_lock<std::mutex> lock(mutex);
auto maybe_task = scan_tasks.Next();
if (!maybe_task.ok()) {
no_more_tasks = true;
iteration_error = maybe_task.status();
return;
}
auto scan_task = maybe_task.ValueOrDie();
if (IsIterationEnd(scan_task)) {
no_more_tasks = true;
return;
}
auto state = shared_from_this();
auto id = next_scan_task_id++;
ResizeBatches(id);

lock.unlock();
task_group->Append([state, id, scan_task]() {
ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
for (auto maybe_batch : batch_it) {
ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
state->Push(TaggedRecordBatch{std::move(batch), scan_task->fragment()}, id);
}
return state->Finish(id);
});
}

Result<TaggedRecordBatch> Pop() {
std::unique_lock<std::mutex> lock(mutex);
ready.wait(lock, [this, &lock] {
while (pop_cursor < task_batches.size()) {
// queue for current scan task contains at least one batch, pop that
if (!task_batches[pop_cursor].empty()) return true;
// queue is empty but will be appended to eventually, wait for that
if (!task_drained[pop_cursor]) return false;

// Finished draining current scan task, enqueue a new one
++pop_cursor;
// Must unlock since serial task group will execute synchronously
lock.unlock();
PushScanTask();
lock.lock();
}
DCHECK(no_more_tasks);
// all scan tasks drained (or getting next task failed), terminate
return true;
});

if (pop_cursor == task_batches.size()) {
// Don't report an error until we yield up everything we can first
RETURN_NOT_OK(iteration_error);
return IterationEnd<TaggedRecordBatch>();
}

ScanTaskIterator scan_task_it;
RecordBatchIterator batch_it;
std::shared_ptr<ScanTask> current_task;
};
return TaggedRecordBatchIterator(BatchIter(std::move(scan_task_it)));
auto batch = std::move(task_batches[pop_cursor].front());
task_batches[pop_cursor].pop_front();
return batch;
}

/// Protecting mutating accesses to batches
std::mutex mutex;
std::condition_variable ready;
ScanTaskIterator scan_tasks;
std::shared_ptr<TaskGroup> task_group;
int next_scan_task_id = 0;
bool no_more_tasks = false;
Status iteration_error;
std::vector<std::deque<TaggedRecordBatch>> task_batches;
std::vector<bool> task_drained;
size_t pop_cursor = 0;
};

Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches() {
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
auto task_group = scan_options_->TaskGroup();
auto state = std::make_shared<ScanBatchesState>(std::move(scan_task_it), task_group);
for (int i = 0; i < scan_options_->fragment_readahead; i++) {
state->PushScanTask();
}
return MakeFunctionIterator([task_group, state]() -> Result<TaggedRecordBatch> {
ARROW_ASSIGN_OR_RAISE(auto batch, state->Pop());
if (!IsIterationEnd(batch)) return batch;
RETURN_NOT_OK(task_group->Finish());
return IterationEnd<TaggedRecordBatch>();
});
}

Result<FragmentIterator> SyncScanner::GetFragments() {
Expand All @@ -176,7 +268,30 @@ Result<FragmentIterator> SyncScanner::GetFragments() {
return GetFragmentsFromDatasets({dataset_}, scan_options_->filter);
}

Result<ScanTaskIterator> SyncScanner::Scan() {
Result<ScanTaskIterator> SyncScanner::Scan() { return ScanInternal(); }

Status SyncScanner::Scan(std::function<Status(TaggedRecordBatch)> visitor) {
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());

auto task_group = scan_options_->TaskGroup();

for (auto maybe_scan_task : scan_task_it) {
ARROW_ASSIGN_OR_RAISE(auto scan_task, maybe_scan_task);
task_group->Append([scan_task, visitor] {
ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
for (auto maybe_batch : batch_it) {
ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
RETURN_NOT_OK(
visitor(TaggedRecordBatch{std::move(batch), scan_task->fragment()}));
}
return Status::OK();
});
}

return task_group->Finish();
}

Result<ScanTaskIterator> SyncScanner::ScanInternal() {
// Transforms Iterator<Fragment> into a unified
// Iterator<ScanTask>. The first Iterator::Next invocation is going to do
// all the work of unwinding the chained iterators.
Expand Down Expand Up @@ -315,7 +430,7 @@ Result<std::shared_ptr<Table>> SyncScanner::ToTable() {

Future<std::shared_ptr<Table>> SyncScanner::ToTableInternal(
internal::Executor* cpu_executor) {
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, Scan());
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
auto task_group = scan_options_->TaskGroup();

/// Wraps the state in a shared_ptr to ensure that failing ScanTasks don't
Expand Down Expand Up @@ -343,5 +458,94 @@ Future<std::shared_ptr<Table>> SyncScanner::ToTableInternal(
FlattenRecordBatchVector(std::move(state->batches)));
}

Result<std::shared_ptr<Table>> Scanner::TakeRows(const Array& indices) {
if (indices.null_count() != 0) {
return Status::NotImplemented("null take indices");
}

compute::ExecContext ctx(scan_options_->pool);

const Array* original_indices;
// If we have to cast, this is the backing reference
std::shared_ptr<Array> original_indices_ptr;
if (indices.type_id() != Type::INT64) {
ARROW_ASSIGN_OR_RAISE(
original_indices_ptr,
compute::Cast(indices, int64(), compute::CastOptions::Safe(), &ctx));
original_indices = original_indices_ptr.get();
} else {
original_indices = &indices;
}

std::shared_ptr<Array> unsort_indices;
{
ARROW_ASSIGN_OR_RAISE(
auto sort_indices,
compute::SortIndices(*original_indices, compute::SortOrder::Ascending, &ctx));
ARROW_ASSIGN_OR_RAISE(original_indices_ptr,
compute::Take(*original_indices, *sort_indices,
compute::TakeOptions::Defaults(), &ctx));
original_indices = original_indices_ptr.get();
ARROW_ASSIGN_OR_RAISE(
unsort_indices,
compute::SortIndices(*sort_indices, compute::SortOrder::Ascending, &ctx));
}

RecordBatchVector out_batches;

auto raw_indices = static_cast<const Int64Array&>(*original_indices).raw_values();
int64_t offset = 0, row_begin = 0;

ARROW_ASSIGN_OR_RAISE(auto batch_it, ScanBatches());
while (true) {
ARROW_ASSIGN_OR_RAISE(auto batch, batch_it.Next());
if (IsIterationEnd(batch)) break;
if (offset == original_indices->length()) break;
DCHECK_LT(offset, original_indices->length());

int64_t length = 0;
while (offset + length < original_indices->length()) {
auto rel_index = raw_indices[offset + length] - row_begin;
if (rel_index >= batch.record_batch->num_rows()) break;
++length;
}
DCHECK_LE(offset + length, original_indices->length());
if (length == 0) {
row_begin += batch.record_batch->num_rows();
continue;
}

Datum rel_indices = original_indices->Slice(offset, length);
ARROW_ASSIGN_OR_RAISE(rel_indices,
compute::Subtract(rel_indices, Datum(row_begin),
compute::ArithmeticOptions(), &ctx));

ARROW_ASSIGN_OR_RAISE(Datum out_batch,
compute::Take(batch.record_batch, rel_indices,
compute::TakeOptions::Defaults(), &ctx));
out_batches.push_back(out_batch.record_batch());

offset += length;
row_begin += batch.record_batch->num_rows();
}

if (offset < original_indices->length()) {
std::stringstream error;
const int64_t max_values_shown = 3;
const int64_t num_remaining = original_indices->length() - offset;
for (int64_t i = 0; i < std::min<int64_t>(max_values_shown, num_remaining); i++) {
if (i > 0) error << ", ";
error << static_cast<const Int64Array*>(original_indices)->Value(offset + i);
}
if (num_remaining > max_values_shown) error << ", ...";
return Status::IndexError("Some indices were out of bounds: ", error.str());
}
ARROW_ASSIGN_OR_RAISE(Datum out, Table::FromRecordBatches(options()->projected_schema,
std::move(out_batches)));
ARROW_ASSIGN_OR_RAISE(
out, compute::Take(out, unsort_indices, compute::TakeOptions::Defaults(), &ctx));
return out.table();
}

} // namespace dataset
} // namespace arrow
15 changes: 14 additions & 1 deletion cpp/src/arrow/dataset/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,14 @@ class ARROW_DS_EXPORT Scanner {
/// in a concurrent fashion and outlive the iterator.
///
/// Note: Not supported by the async scanner
/// TODO(ARROW-11797) Deprecate Scan()
/// Planned for removal from the public API in ARROW-11782.
ARROW_DEPRECATED("Deprecated in 4.0.0 for removal in 5.0.0. Use ScanBatches().")
virtual Result<ScanTaskIterator> Scan();

/// \brief Apply a visitor to each RecordBatch as it is scanned. If multiple threads
/// are used (via use_threads), the visitor will be invoked from those threads and is
/// responsible for any synchronization.
virtual Status Scan(std::function<Status(TaggedRecordBatch)> visitor) = 0;
/// \brief Convert a Scanner into a Table.
///
/// Use this convenience utility with care. This will serially materialize the
Expand All @@ -279,6 +285,10 @@ class ARROW_DS_EXPORT Scanner {
/// To make up for the out-of-order iteration each batch is further tagged with
/// positional information.
virtual Result<EnumeratedRecordBatchIterator> ScanBatchesUnordered();
/// \brief A convenience to synchronously load the given rows by index.
///
/// Will only consume as many batches as needed from ScanBatches().
virtual Result<std::shared_ptr<Table>> TakeRows(const Array& indices);

/// \brief Get the options for this scan.
const std::shared_ptr<ScanOptions>& options() const { return scan_options_; }
Expand Down Expand Up @@ -306,12 +316,15 @@ class ARROW_DS_EXPORT SyncScanner : public Scanner {

Result<ScanTaskIterator> Scan() override;

Status Scan(std::function<Status(TaggedRecordBatch)> visitor) override;

Result<std::shared_ptr<Table>> ToTable() override;

protected:
/// \brief GetFragments returns an iterator over all Fragments in this scan.
Result<FragmentIterator> GetFragments();
Future<std::shared_ptr<Table>> ToTableInternal(internal::Executor* cpu_executor);
Result<ScanTaskIterator> ScanInternal();

std::shared_ptr<Dataset> dataset_;
// TODO(ARROW-8065) remove fragment_ after a Dataset is constuctible from fragments
Expand Down
Loading

0 comments on commit d575858

Please sign in to comment.