Skip to content

Commit

Permalink
Spark: Parallelize task init when fetching locality info (apache#2800)
Browse files Browse the repository at this point in the history
  • Loading branch information
jshmchenxi authored Jul 12, 2021
1 parent 87aea34 commit 0bb89d0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
39 changes: 24 additions & 15 deletions spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -52,6 +53,8 @@
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.util.PropertyUtil;
import org.apache.iceberg.util.TableScanUtil;
import org.apache.iceberg.util.Tasks;
import org.apache.iceberg.util.ThreadPools;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.RuntimeConfig;
Expand Down Expand Up @@ -205,15 +208,18 @@ public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
// broadcast the table metadata as input partitions will be sent to executors
Broadcast<Table> tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table));

List<InputPartition<ColumnarBatch>> readTasks = Lists.newArrayList();
for (CombinedScanTask task : tasks()) {
readTasks.add(new ReadTask<>(
task, tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, new BatchReaderFactory(batchSize)));
}
LOG.info("Batching input partitions with {} tasks.", readTasks.size());
List<CombinedScanTask> scanTasks = tasks();
InputPartition<ColumnarBatch>[] readTasks = new InputPartition[scanTasks.size()];

Tasks.range(readTasks.length)
.stopOnFailure()
.executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null)
.run(index -> readTasks[index] = new ReadTask<>(
scanTasks.get(index), tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, new BatchReaderFactory(batchSize)));
LOG.info("Batching input partitions with {} tasks.", readTasks.length);

return readTasks;
return Arrays.asList(readTasks);
}

/**
Expand All @@ -226,14 +232,17 @@ public List<InputPartition<InternalRow>> planInputPartitions() {
// broadcast the table metadata as input partitions will be sent to executors
Broadcast<Table> tableBroadcast = sparkContext.broadcast(SerializableTable.copyOf(table));

List<InputPartition<InternalRow>> readTasks = Lists.newArrayList();
for (CombinedScanTask task : tasks()) {
readTasks.add(new ReadTask<>(
task, tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, InternalRowReaderFactory.INSTANCE));
}
List<CombinedScanTask> scanTasks = tasks();
InputPartition<InternalRow>[] readTasks = new InputPartition[scanTasks.size()];

Tasks.range(readTasks.length)
.stopOnFailure()
.executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null)
.run(index -> readTasks[index] = new ReadTask<>(
scanTasks.get(index), tableBroadcast, expectedSchemaString, caseSensitive,
localityPreferred, InternalRowReaderFactory.INSTANCE));

return readTasks;
return Arrays.asList(readTasks);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.util.PropertyUtil;
import org.apache.iceberg.util.TableScanUtil;
import org.apache.iceberg.util.Tasks;
import org.apache.iceberg.util.ThreadPools;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.RuntimeConfig;
Expand Down Expand Up @@ -132,11 +134,13 @@ public InputPartition[] planInputPartitions() {

List<CombinedScanTask> scanTasks = tasks();
InputPartition[] readTasks = new InputPartition[scanTasks.size()];
for (int i = 0; i < scanTasks.size(); i++) {
readTasks[i] = new ReadTask(
scanTasks.get(i), tableBroadcast, expectedSchemaString,
caseSensitive, localityPreferred);
}

Tasks.range(readTasks.length)
.stopOnFailure()
.executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null)
.run(index -> readTasks[index] = new ReadTask(
scanTasks.get(index), tableBroadcast, expectedSchemaString,
caseSensitive, localityPreferred));

return readTasks;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
import org.apache.iceberg.util.PropertyUtil;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.iceberg.util.TableScanUtil;
import org.apache.iceberg.util.Tasks;
import org.apache.iceberg.util.ThreadPools;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.connector.read.InputPartition;
Expand Down Expand Up @@ -136,11 +138,12 @@ public InputPartition[] planInputPartitions(Offset start, Offset end) {
TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, splitOpenFileCost));
InputPartition[] readTasks = new InputPartition[combinedScanTasks.size()];

for (int i = 0; i < combinedScanTasks.size(); i++) {
readTasks[i] = new ReadTask(
combinedScanTasks.get(i), tableBroadcast, expectedSchema,
caseSensitive, localityPreferred);
}
Tasks.range(readTasks.length)
.stopOnFailure()
.executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null)
.run(index -> readTasks[index] = new ReadTask(
combinedScanTasks.get(index), tableBroadcast, expectedSchema,
caseSensitive, localityPreferred));

return readTasks;
}
Expand Down

0 comments on commit 0bb89d0

Please sign in to comment.