Skip to content

Commit

Permalink
Refactor batch load job path, and add support for data-dependent tables.
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenlax authored and jkff committed Apr 19, 2017
1 parent 58ed5c7 commit 8581caf
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@

import com.google.api.services.bigquery.model.TableReference;
import com.google.api.services.bigquery.model.TableRow;
import com.google.api.services.bigquery.model.TableSchema;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.TableRefToJson;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition;
import org.apache.beam.sdk.options.BigQueryOptions;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
Expand All @@ -47,20 +49,37 @@
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.TypeDescriptor;


/**
* PTransform that uses BigQuery batch-load jobs to write a PCollection to BigQuery.
*/
class BatchLoadBigQuery<T> extends PTransform<PCollection<T>, WriteResult> {
class BatchLoads<T> extends
PTransform<PCollection<KV<TableDestination, TableRow>>, WriteResult> {
BigQueryIO.Write<T> write;

BatchLoadBigQuery(BigQueryIO.Write<T> write) {
private static class ConstantSchemaFunction implements
SerializableFunction<TableDestination, TableSchema> {
private final @Nullable
String jsonSchema;

ConstantSchemaFunction(TableSchema schema) {
this.jsonSchema = BigQueryHelpers.toJsonString(schema);
}

@Override
@Nullable
public TableSchema apply(TableDestination table) {
return BigQueryHelpers.fromJsonString(jsonSchema, TableSchema.class);
}
}

BatchLoads(BigQueryIO.Write<T> write) {
this.write = write;
}

@Override
public WriteResult expand(PCollection<T> input) {
public WriteResult expand(PCollection<KV<TableDestination, TableRow>> input) {
Pipeline p = input.getPipeline();
BigQueryOptions options = p.getOptions().as(BigQueryOptions.class);
ValueProvider<TableReference> table = write.getTableWithDefaultProject(options);
Expand All @@ -80,7 +99,8 @@ public WriteResult expand(PCollection<T> input) {
e);
}

// Create a singleton job ID token at execution time.
// Create a singleton job ID token at execution time. This will be used as the base for all
// load jobs issued from this instance of the transfomr.
PCollection<String> singleton = p.apply("Create", Create.of(tempFilePrefix));
PCollectionView<String> jobIdTokenView = p
.apply("TriggerIdCreation", Create.of("ignored"))
Expand All @@ -93,67 +113,71 @@ public String apply(String input) {
}))
.apply(View.<String>asSingleton());

PCollection<T> typedInputInGlobalWindow =
PCollection<KV<TableDestination, TableRow>> inputInGlobalWindow =
input.apply(
Window.<T>into(new GlobalWindows())
Window.<KV<TableDestination, TableRow>>into(new GlobalWindows())
.triggering(DefaultTrigger.of())
.discardingFiredPanes());
// Avoid applying the formatFunction if it is the identity formatter.
PCollection<TableRow> inputInGlobalWindow;
if (write.getFormatFunction() == BigQueryIO.IDENTITY_FORMATTER) {
inputInGlobalWindow = (PCollection<TableRow>) typedInputInGlobalWindow;
} else {
inputInGlobalWindow =
typedInputInGlobalWindow.apply(
MapElements.into(new TypeDescriptor<TableRow>() {}).via(write.getFormatFunction()));
}

// PCollection of filename, file byte size.
PCollection<KV<String, Long>> results = inputInGlobalWindow
.apply("WriteBundles",
ParDo.of(new WriteBundles(tempFilePrefix)));
// PCollection of filename, file byte size, and table destination.
PCollection<WriteBundlesToFiles.Result> results = inputInGlobalWindow
.apply("WriteBundlesToFiles",
ParDo.of(new WriteBundlesToFiles(tempFilePrefix)));

TupleTag<KV<Long, List<String>>> multiPartitionsTag =
new TupleTag<KV<Long, List<String>>>("multiPartitionsTag") {};
TupleTag<KV<Long, List<String>>> singlePartitionTag =
new TupleTag<KV<Long, List<String>>>("singlePartitionTag") {};
TupleTag<KV<KV<TableDestination, Integer>, List<String>>> multiPartitionsTag =
new TupleTag<KV<KV<TableDestination, Integer>, List<String>>>("multiPartitionsTag") {};
TupleTag<KV<KV<TableDestination, Integer>, List<String>>> singlePartitionTag =
new TupleTag<KV<KV<TableDestination, Integer>, List<String>>>("singlePartitionTag") {};

// Turn the list of files and record counts in a PCollectionView that can be used as a
// side input.
PCollectionView<Iterable<KV<String, Long>>> resultsView = results
.apply("ResultsView", View.<KV<String, Long>>asIterable());
PCollectionView<Iterable<WriteBundlesToFiles.Result>> resultsView = results
.apply("ResultsView", View.<WriteBundlesToFiles.Result>asIterable());
// This transform will look at the set of files written for each table, and if any table has
// too many files or bytes, will partition that table's files into multiple partitions for
// loading.
PCollectionTuple partitions = singleton.apply(ParDo
.of(new WritePartition(
write.getTable(),
write.getTableDescription(),
resultsView,
multiPartitionsTag,
singlePartitionTag))
.withSideInputs(resultsView)
.withOutputTags(multiPartitionsTag, TupleTagList.of(singlePartitionTag)));

// If WriteBundles produced more than MAX_NUM_FILES files or MAX_SIZE_BYTES bytes, then
// Since BigQueryIO.java does not yet have support for per-table schemas, inject a constant
// schema function here. If no schema is specified, this function will return null.
SerializableFunction<TableDestination, TableSchema> schemaFunction =
new ConstantSchemaFunction(write.getSchema());

// If WriteBundlesToFiles produced more than MAX_NUM_FILES files or MAX_SIZE_BYTES bytes, then
// the import needs to be split into multiple partitions, and those partitions will be
// specified in multiPartitionsTag.
PCollection<String> tempTables = partitions.get(multiPartitionsTag)
.apply("MultiPartitionsGroupByKey", GroupByKey.<Long, List<String>>create())
PCollection<KV<TableDestination, String>> tempTables = partitions.get(multiPartitionsTag)
// What's this GroupByKey for? Is this so we have a deterministic temp tables? If so, maybe
// Reshuffle is better here.
.apply("MultiPartitionsGroupByKey",
GroupByKey.<KV<TableDestination, Integer>, List<String>>create())
.apply("MultiPartitionsWriteTables", ParDo.of(new WriteTables(
false,
write.getBigQueryServices(),
jobIdTokenView,
tempFilePrefix,
NestedValueProvider.of(table, new TableRefToJson()),
write.getJsonSchema(),
WriteDisposition.WRITE_EMPTY,
CreateDisposition.CREATE_IF_NEEDED,
write.getTableDescription()))
schemaFunction))
.withSideInputs(jobIdTokenView));

PCollectionView<Iterable<String>> tempTablesView = tempTables
.apply("TempTablesView", View.<String>asIterable());
// This view maps each final table destination to the set of temporary partitioned tables
// the PCollection was loaded into.
PCollectionView<Map<TableDestination, Iterable<String>>> tempTablesView = tempTables
.apply("TempTablesView", View.<TableDestination, String>asMultimap());

singleton.apply(ParDo
.of(new WriteRename(
write.getBigQueryServices(),
jobIdTokenView,
NestedValueProvider.of(table, new TableRefToJson()),
write.getWriteDisposition(),
write.getCreateDisposition(),
tempTablesView,
Expand All @@ -162,17 +186,16 @@ public String apply(String input) {

// Write single partition to final table
partitions.get(singlePartitionTag)
.apply("SinglePartitionGroupByKey", GroupByKey.<Long, List<String>>create())
.apply("SinglePartitionGroupByKey",
GroupByKey.<KV<TableDestination, Integer>, List<String>>create())
.apply("SinglePartitionWriteTables", ParDo.of(new WriteTables(
true,
write.getBigQueryServices(),
jobIdTokenView,
tempFilePrefix,
NestedValueProvider.of(table, new TableRefToJson()),
write.getJsonSchema(),
write.getWriteDisposition(),
write.getCreateDisposition(),
write.getTableDescription()))
schemaFunction))
.withSideInputs(jobIdTokenView));

return WriteResult.in(input.getPipeline());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ public WriteResult expand(PCollection<T> input) {
if (input.isBounded() == IsBounded.UNBOUNDED) {
return rowsWithDestination.apply(new StreamingInserts(this));
} else {
return input.apply(new BatchLoadBigQuery<T>(this));

return rowsWithDestination.apply(new BatchLoads<T>(this));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import com.google.api.services.bigquery.model.TableReference;

import java.util.Objects;

/**
* Encapsulates a BigQuery table destination.
*/
Expand All @@ -42,12 +44,25 @@ public String getTableSpec() {
return tableSpec;
}


public TableReference getTableReference() {
return BigQueryHelpers.parseTableSpec(tableSpec);
}

public String getTableDescription() {
return tableDescription;
}

@Override
public boolean equals(Object o) {
if (!(o instanceof TableDestination)) {
return false;
}
TableDestination other = (TableDestination) o;
return tableSpec == other.tableSpec && tableDescription == other.tableDescription;
}

@Override
public int hashCode() {
return Objects.hash(tableSpec, tableDescription);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class TableRowWriter {
protected String mimeType = MimeTypes.TEXT;
private CountingOutputStream out;

public class Result {
String filename;
long byteSize;
public Result(String filename, long byteSize) {
this.filename = filename;
this.byteSize = byteSize;
}
}
TableRowWriter(String basename) {
this.tempFilePrefix = basename;
}
Expand Down Expand Up @@ -77,8 +85,8 @@ public void write(TableRow value) throws Exception {
out.write(NEWLINE);
}

public final KV<String, Long> close() throws IOException {
public final Result close() throws IOException {
channel.close();
return KV.of(fileName, out.getCount());
return new Result(fileName, out.getCount());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
package org.apache.beam.sdk.io.gcp.bigquery;

import com.google.api.services.bigquery.model.TableRow;

import java.util.Map;
import java.util.UUID;

import com.google.common.collect.Maps;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.KV;
Expand All @@ -31,25 +34,41 @@
* Writes each bundle of {@link TableRow} elements out to a separate file using
* {@link TableRowWriter}.
*/
class WriteBundles extends DoFn<TableRow, KV<String, Long>> {
private static final Logger LOG = LoggerFactory.getLogger(WriteBundles.class);
class WriteBundlesToFiles extends DoFn<KV<TableDestination, TableRow>, WriteBundlesToFiles.Result> {
private static final Logger LOG = LoggerFactory.getLogger(WriteBundlesToFiles.class);

private transient TableRowWriter writer = null;
// Map from tablespec to a writer for that table.
private transient Map<TableDestination, TableRowWriter> writers;
private final String tempFilePrefix;

WriteBundles(String tempFilePrefix) {
public static class Result {
public String filename;
public Long fileByteSize;
public TableDestination tableDestination;

public Result(String filename, Long fileByteSize, TableDestination tableDestination) {
this.filename = filename;
this.fileByteSize = fileByteSize;
this.tableDestination = tableDestination;
}
}
WriteBundlesToFiles(String tempFilePrefix) {
this.tempFilePrefix = tempFilePrefix;
this.writers = Maps.newHashMap();
}

@ProcessElement
public void processElement(ProcessContext c) throws Exception {
// ??? can we assume Java8?
TableRowWriter writer = writers.getOrDefault(c.element().getKey(), null);
if (writer == null) {
writer = new TableRowWriter(tempFilePrefix);
writer.open(UUID.randomUUID().toString());
writers.put(c.element().getKey(), writer);
LOG.debug("Done opening writer {}", writer);
}
try {
writer.write(c.element());
writer.write(c.element().getValue());
} catch (Exception e) {
// Discard write result and close the write.
try {
Expand All @@ -65,10 +84,11 @@ public void processElement(ProcessContext c) throws Exception {

@FinishBundle
public void finishBundle(Context c) throws Exception {
if (writer != null) {
c.output(writer.close());
writer = null;
for (Map.Entry<TableDestination, TableRowWriter> entry : writers.entrySet()) {
TableRowWriter.Result result = entry.getValue().close();
c.output(new Result(result.filename, result.byteSize, entry.getKey()));
}
writers.clear();
}

@Override
Expand Down
Loading

0 comments on commit 8581caf

Please sign in to comment.