Skip to content

Commit

Permalink
Merge pull request apache#4844: Add ExecutableStagePayload to simplif…
Browse files Browse the repository at this point in the history
…y runner stage reconstruction

[BEAM-3565]
  • Loading branch information
tgroh authored Mar 19, 2018
2 parents f22b1d2 + 21472cc commit 58e3b06
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 70 deletions.
17 changes: 17 additions & 0 deletions model/pipeline/src/main/proto/beam_runner_api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,23 @@ message PCollection {
DisplayData display_data = 5;
}

// The payload for an executable stage. This will eventually be passed to an SDK in the form of a
// ProcessBundleDescriptor.
message ExecutableStagePayload {

Environment environment = 1;

// Input PCollection id.
string input = 2;

// PTransform ids contained within this executable stage.
repeated string transforms = 3;

// Output PCollection ids.
repeated string outputs = 4;

}

// The payload for the primitive ParDo transform.
message ParDoPayload {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.apache.beam.runners.core.construction;

import static com.google.common.base.Preconditions.checkArgument;

import java.io.IOException;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.sdk.runners.AppliedPTransform;

/**
* Utilities for converting {@link ExecutableStage}s to and from {@link RunnerApi} protocol buffers.
*/
public class ExecutableStageTranslation {

/** Extracts an {@link ExecutableStagePayload} from the given transform. */
public static ExecutableStagePayload getExecutableStagePayload(
AppliedPTransform<?, ?, ?> appliedTransform) throws IOException {
RunnerApi.PTransform transform =
PTransformTranslation.toProto(appliedTransform, SdkComponents.create());
checkArgument(ExecutableStage.URN.equals(transform.getSpec().getUrn()));
return ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@

package org.apache.beam.runners.core.construction.graph;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.getOnlyElement;

import java.util.Collection;
import java.util.Optional;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
import org.apache.beam.runners.core.construction.Environments;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;

Expand Down Expand Up @@ -84,64 +81,72 @@ public interface ExecutableStage {
* follows:
*
* <ul>
* <li>The {@link PTransform#getSubtransformsList()} contains the result of {@link
* #getTransforms()}.
* <li>The {@link PTransform#getSubtransformsList()} contains no subtransforms. This ensures
* that executable stages are treated as primitive transforms.
* <li>The only {@link PCollection} in the {@link PTransform#getInputsMap()} is the result of
* {@link #getInputPCollection()}.
* <li>The output {@link PCollection PCollections} in the values of {@link
* PTransform#getOutputsMap()} are the {@link PCollectionNode PCollections} returned by
* {@link #getOutputPCollections()}.
* <li>The {@link FunctionSpec} contains an {@link ExecutableStagePayload} which has its input
* and output PCollections set to the same values as the outer PTransform itself. It further
* contains the environment set of transforms for this stage.
* </ul>
*
* <p>The executable stage can be reconstructed from the resulting {@link ExecutableStagePayload}
* and components alone via {@link #fromPayload(ExecutableStagePayload, Components)}.
*/
default PTransform toPTransform() {
ExecutableStagePayload.Builder payload = ExecutableStagePayload.newBuilder();

payload.setEnvironment(getEnvironment());

PCollectionNode input = getInputPCollection();
payload.setInput(input.getId());

for (PTransformNode transform : getTransforms()) {
payload.addTransforms(transform.getId());
}

for (PCollectionNode output : getOutputPCollections()) {
payload.addOutputs(output.getId());
}

PTransform.Builder pt = PTransform.newBuilder();
pt.setSpec(FunctionSpec.newBuilder()
.setUrn(ExecutableStage.URN)
.setPayload(payload.build().toByteString())
.build());
pt.putInputs("input", getInputPCollection().getId());
int i = 0;
for (PCollectionNode materializedPCollection : getOutputPCollections()) {
pt.putOutputs(String.format("materialized_%s", i), materializedPCollection.getId());
i++;
}
for (PTransformNode fusedTransform : getTransforms()) {
pt.addSubtransforms(fusedTransform.getId());
int outputIndex = 0;
for (PCollectionNode pcNode : getOutputPCollections()) {
// Do something
pt.putOutputs(String.format("materialized_%d", outputIndex), pcNode.getId());
outputIndex++;
}
pt.setSpec(FunctionSpec.newBuilder().setUrn(ExecutableStage.URN));
return pt.build();
}

// TODO: Should this live under ExecutableStageTranslation?
/**
* Return an {@link ExecutableStage} constructed from the provided {@link PTransform}
* Return an {@link ExecutableStage} constructed from the provided {@link FunctionSpec}
* representation.
*
* <p>See {@link #toPTransform()} for information about the required format of the {@link
* PTransform}. The environment will be determined by an arbitrary {@link PTransform} contained
* within the {@link PTransform#getSubtransformsList()}.
* <p>See {@link #toPTransform()} for how the payload is constructed. Note that the payload
* contains some information redundant with the {@link PTransform} due to runner implementations
* not having the full transform context at translation time, but rather access to an
* {@link org.apache.beam.sdk.runners.AppliedPTransform}.
*/
static ExecutableStage fromPTransform(PTransform ptransform, Components components) {
checkArgument(ptransform.getSpec().getUrn().equals(URN));
// It may be better to put this in an explicit Payload if other metadata becomes required
Optional<Environment> environment =
Environments.getEnvironment(ptransform.getSubtransforms(0), components);
checkArgument(
environment.isPresent(),
"%s with no %s",
ExecutableStage.class.getSimpleName(),
Environment.class.getSimpleName());
String inputId = getOnlyElement(ptransform.getInputsMap().values());
PCollectionNode inputNode =
PipelineNode.pCollection(inputId, components.getPcollectionsOrThrow(inputId));
Collection<PCollectionNode> outputNodes =
ptransform
.getOutputsMap()
.values()
.stream()
.map(id -> PipelineNode.pCollection(id, components.getPcollectionsOrThrow(id)))
.collect(Collectors.toSet());
Collection<PTransformNode> transformNodes =
ptransform
.getSubtransformsList()
.stream()
.map(id -> PipelineNode.pTransform(id, components.getTransformsOrThrow(id)))
.collect(Collectors.toSet());
return ImmutableExecutableStage.of(environment.get(), inputNode, transformNodes, outputNodes);
static ExecutableStage fromPayload(ExecutableStagePayload payload, Components components) {
Environment environment = payload.getEnvironment();
PCollectionNode input = PipelineNode.pCollection(payload.getInput(),
components.getPcollectionsOrThrow(payload.getInput()));
List<PTransformNode> transforms = payload.getTransformsList().stream()
.map(id -> PipelineNode.pTransform(id, components.getTransformsOrThrow(id)))
.collect(Collectors.toList());
List<PCollectionNode> outputs = payload.getOutputsList().stream()
.map(id -> PipelineNode.pCollection(id, components.getPcollectionsOrThrow(id)))
.collect(Collectors.toList());
return ImmutableExecutableStage.of(environment, input, transforms, outputs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Collections;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
Expand All @@ -46,7 +47,7 @@
@RunWith(JUnit4.class)
public class ExecutableStageTest {
@Test
public void testRoundTripToFromTransform() {
public void testRoundTripToFromTransform() throws Exception {
Environment env = Environment.newBuilder().setUrl("foo").build();
PTransform pt =
PTransform.newBuilder()
Expand Down Expand Up @@ -84,13 +85,15 @@ public void testRoundTripToFromTransform() {
assertThat(stagePTransform.getOutputsCount(), equalTo(1));
assertThat(stagePTransform.getInputsMap(), hasValue("input.out"));
assertThat(stagePTransform.getInputsCount(), equalTo(1));
assertThat(stagePTransform.getSubtransformsList(), contains("pt"));

assertThat(ExecutableStage.fromPTransform(stagePTransform, components), equalTo(stage));
ExecutableStagePayload payload = ExecutableStagePayload.parseFrom(
stagePTransform.getSpec().getPayload());
assertThat(payload.getTransformsList(), contains("pt"));
assertThat(ExecutableStage.fromPayload(payload, components), equalTo(stage));
}

@Test
public void testRoundTripToFromTransformFused() {
public void testRoundTripToFromTransformFused() throws Exception {
PTransform parDoTransform =
PTransform.newBuilder()
.putInputs("input", "impulse.out")
Expand Down Expand Up @@ -148,9 +151,11 @@ public void testRoundTripToFromTransformFused() {
assertThat(ptransform.getSpec().getUrn(), equalTo(ExecutableStage.URN));
assertThat(ptransform.getInputsMap().values(), containsInAnyOrder("impulse.out"));
assertThat(ptransform.getOutputsMap().values(), emptyIterable());
assertThat(ptransform.getSubtransformsList(), contains("parDo", "window"));

ExecutableStage desered = ExecutableStage.fromPTransform(ptransform, components);
ExecutableStagePayload payload = ExecutableStagePayload.parseFrom(
ptransform.getSpec().getPayload());
assertThat(payload.getTransformsList(), contains("parDo", "window"));
ExecutableStage desered = ExecutableStage.fromPayload(payload, components);
assertThat(desered, equalTo(subgraph));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import static com.google.common.collect.Iterables.getOnlyElement;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
Expand All @@ -30,6 +29,7 @@
import com.google.common.collect.ImmutableSet;
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
Expand All @@ -44,6 +44,8 @@
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
import org.hamcrest.Description;
import org.hamcrest.TypeSafeMatcher;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -239,8 +241,7 @@ public void fusesCompatibleEnvironments() {
PipelineNode.pTransform("window", windowTransform)));
// Nothing consumes the outputs of ParDo or Window, so they don't have to be materialized
assertThat(subgraph.getOutputPCollections(), emptyIterable());
assertThat(
subgraph.toPTransform().getSubtransformsList(), containsInAnyOrder("parDo", "window"));
assertThat(subgraph, hasSubtransforms("parDo", "window"));
}

@Test
Expand Down Expand Up @@ -299,8 +300,7 @@ public void materializesWithStatefulConsumer() {
contains(
PipelineNode.pCollection(
"parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build())));
assertThat(
subgraph.toPTransform().getSubtransformsList(), containsInAnyOrder("parDo"));
assertThat(subgraph, hasSubtransforms("parDo"));
}

@Test
Expand Down Expand Up @@ -359,8 +359,7 @@ public void materializesWithConsumerWithTimer() {
contains(
PipelineNode.pCollection(
"parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build())));
assertThat(
subgraph.toPTransform().getSubtransformsList(), containsInAnyOrder("parDo"));
assertThat(subgraph, hasSubtransforms("parDo"));
}

@Test
Expand Down Expand Up @@ -440,9 +439,7 @@ public void fusesFlatten() {
GreedyStageFuser.forGrpcPortRead(
p, impulseOutputNode, p.getPerElementConsumers(impulseOutputNode));
assertThat(subgraph.getOutputPCollections(), emptyIterable());
assertThat(
subgraph.toPTransform().getSubtransformsList(),
containsInAnyOrder("read", "parDo", "flatten", "window"));
assertThat(subgraph, hasSubtransforms("read", "parDo", "flatten", "window"));
}

@Test
Expand Down Expand Up @@ -524,9 +521,7 @@ public void fusesFlattenWithDifferentEnvironmentInputs() {
GreedyStageFuser.forGrpcPortRead(
p, impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("read", readTransform)));
assertThat(subgraph.getOutputPCollections(), emptyIterable());
assertThat(
subgraph.toPTransform().getSubtransformsList(),
containsInAnyOrder("read", "flatten", "window"));
assertThat(subgraph, hasSubtransforms("read", "flatten", "window"));

// Flatten shows up in both of these subgraphs, but elements only go through a path to the
// flatten once.
Expand All @@ -540,9 +535,7 @@ public void fusesFlattenWithDifferentEnvironmentInputs() {
contains(
PipelineNode.pCollection(
"flatten.out", components.getPcollectionsOrThrow("flatten.out"))));
assertThat(
readFromOtherEnv.toPTransform().getSubtransformsList(),
containsInAnyOrder("envRead", "flatten"));
assertThat(readFromOtherEnv, hasSubtransforms("envRead", "flatten"));
}

@Test
Expand Down Expand Up @@ -892,7 +885,7 @@ public void materializesWithSideInputConsumer() {
GreedyStageFuser.forGrpcPortRead(
p, impulseOutputNode, ImmutableSet.of(readNode));
assertThat(subgraph.getOutputPCollections(), contains(readOutput));
assertThat(subgraph.toPTransform().getSubtransformsList(), contains(readNode.getId()));
assertThat(subgraph, hasSubtransforms(readNode.getId()));
}

@Test
Expand Down Expand Up @@ -943,6 +936,28 @@ public void materializesWithGroupByKeyConsumer() {
GreedyStageFuser.forGrpcPortRead(
p, impulseOutputNode, ImmutableSet.of(readNode));
assertThat(subgraph.getOutputPCollections(), contains(readOutput));
assertThat(subgraph.toPTransform().getSubtransformsList(), contains(readNode.getId()));
assertThat(subgraph, hasSubtransforms(readNode.getId()));
}

private static TypeSafeMatcher<ExecutableStage> hasSubtransforms(String id, String... ids) {
Set<String> expectedTransforms = ImmutableSet.<String>builder().add(id).add(ids).build();
return new TypeSafeMatcher<ExecutableStage>() {
@Override
protected boolean matchesSafely(ExecutableStage executableStage) {
// NOTE: Transform names must be unique, so it's fine to throw here if this does not hold.
Set<String> stageTransforms = executableStage.getTransforms().stream()
.map(PTransformNode::getId)
.collect(Collectors.toSet());
return stageTransforms.containsAll(expectedTransforms)
&& expectedTransforms.containsAll(stageTransforms);
}

@Override
public void describeTo(Description description) {
description.appendText(
"ExecutableStage with subtransform ids: " + expectedTransforms);
}
};
}

}

0 comments on commit 58e3b06

Please sign in to comment.