Skip to content

Commit

Permalink
[BEAM-14251] add output_coder_override to ExpansionRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
ihji committed Apr 13, 2022
1 parent dffa7c1 commit e4cfeec
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ message ExpansionRequest {
// A namespace (prefix) to use for the id of any newly created
// components.
string namespace = 3;

// (Optional) Map from a local output tag to a coder id.
// If it is set, asks the expansion service to use the given
// coders for the output PCollections. Note that the request
// may not be fulfilled.
map<string, string> output_coder_override = 4;
}

message ExpansionResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ SingleOutputExpandableTransform<InputT, OutputT> of(
Endpoints.ApiServiceDescriptor apiDesc =
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build();
return new SingleOutputExpandableTransform<>(
urn, payload, apiDesc, DEFAULT, getFreshNamespaceIndex());
urn, payload, apiDesc, DEFAULT, getFreshNamespaceIndex(), null);
}

@VisibleForTesting
Expand All @@ -103,7 +103,7 @@ static <InputT extends PInput, OutputT> SingleOutputExpandableTransform<InputT,
Endpoints.ApiServiceDescriptor apiDesc =
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build();
return new SingleOutputExpandableTransform<>(
urn, payload, apiDesc, clientFactory, getFreshNamespaceIndex());
urn, payload, apiDesc, clientFactory, getFreshNamespaceIndex(), null);
}

/** Expandable transform for output type of PCollection. */
Expand All @@ -114,8 +114,9 @@ public static class SingleOutputExpandableTransform<InputT extends PInput, Outpu
byte[] payload,
Endpoints.ApiServiceDescriptor endpoint,
ExpansionServiceClientFactory clientFactory,
Integer namespaceIndex) {
super(urn, payload, endpoint, clientFactory, namespaceIndex);
Integer namespaceIndex,
Map<String, Coder> outputCoders) {
super(urn, payload, endpoint, clientFactory, namespaceIndex, outputCoders);
}

@Override
Expand All @@ -126,12 +127,33 @@ PCollection<OutputT> toOutputCollection(Map<TupleTag<?>, PCollection> output) {

public MultiOutputExpandableTransform<InputT> withMultiOutputs() {
return new MultiOutputExpandableTransform<>(
getUrn(), getPayload(), getEndpoint(), getClientFactory(), getNamespaceIndex());
getUrn(),
getPayload(),
getEndpoint(),
getClientFactory(),
getNamespaceIndex(),
getOutputCoders());
}

public <T> SingleOutputExpandableTransform<InputT, T> withOutputType() {
public SingleOutputExpandableTransform<InputT, OutputT> withOutputCoder(Coder outputCoder) {
return new SingleOutputExpandableTransform<>(
getUrn(), getPayload(), getEndpoint(), getClientFactory(), getNamespaceIndex());
getUrn(),
getPayload(),
getEndpoint(),
getClientFactory(),
getNamespaceIndex(),
ImmutableMap.of("0", outputCoder));
}

public SingleOutputExpandableTransform<InputT, OutputT> withOutputCoder(
Map<String, Coder> outputCoders) {
return new SingleOutputExpandableTransform<>(
getUrn(),
getPayload(),
getEndpoint(),
getClientFactory(),
getNamespaceIndex(),
outputCoders);
}
}

Expand All @@ -143,8 +165,9 @@ public static class MultiOutputExpandableTransform<InputT extends PInput>
byte[] payload,
Endpoints.ApiServiceDescriptor endpoint,
ExpansionServiceClientFactory clientFactory,
Integer namespaceIndex) {
super(urn, payload, endpoint, clientFactory, namespaceIndex);
Integer namespaceIndex,
Map<String, Coder> outputCoders) {
super(urn, payload, endpoint, clientFactory, namespaceIndex, outputCoders);
}

@Override
Expand All @@ -167,6 +190,7 @@ public abstract static class ExpandableTransform<InputT extends PInput, OutputT
private final Endpoints.ApiServiceDescriptor endpoint;
private final ExpansionServiceClientFactory clientFactory;
private final Integer namespaceIndex;
private final Map<String, Coder> outputCoders;

private transient RunnerApi.@Nullable Components expandedComponents;
private transient RunnerApi.@Nullable PTransform expandedTransform;
Expand All @@ -179,12 +203,14 @@ public abstract static class ExpandableTransform<InputT extends PInput, OutputT
byte[] payload,
Endpoints.ApiServiceDescriptor endpoint,
ExpansionServiceClientFactory clientFactory,
Integer namespaceIndex) {
Integer namespaceIndex,
Map<String, Coder> outputCoders) {
this.urn = urn;
this.payload = payload;
this.endpoint = endpoint;
this.clientFactory = clientFactory;
this.namespaceIndex = namespaceIndex;
this.outputCoders = outputCoders;
}

@Override
Expand Down Expand Up @@ -225,9 +251,25 @@ public OutputT expand(InputT input) {
}
}

ExpansionApi.ExpansionRequest.Builder requestBuilder =
ExpansionApi.ExpansionRequest.newBuilder();
if (!outputCoders.isEmpty()) {
requestBuilder.putAllOutputCoderOverride(
outputCoders.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
v -> {
try {
return components.registerCoder(v.getValue());
} catch (IOException e) {
throw new RuntimeException(e);
}
})));
}
RunnerApi.Components originalComponents = components.toComponents();
ExpansionApi.ExpansionRequest request =
ExpansionApi.ExpansionRequest.newBuilder()
requestBuilder
.setComponents(originalComponents)
.setTransform(ptransformBuilder.build())
.setNamespace(getNamespace())
Expand Down Expand Up @@ -434,5 +476,9 @@ ExpansionServiceClientFactory getClientFactory() {
Integer getNamespaceIndex() {
return namespaceIndex;
}

Map<String, Coder> getOutputCoders() {
return outputCoders;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def with_pipeline(component, pcoll_id=None):
}
transform = with_pipeline(
ptransform.PTransform.from_runner_api(request.transform, context))
if len(request.output_coder_override) == 1:
output_coder = {
k: context.element_type_from_coder_id(v)
for k,
v in request.output_coder_override.items()
}
transform = transform.with_output_types(list(output_coder.values())[0])
inputs = transform._pvaluish_from_dict({
tag:
with_pipeline(context.pcollections.get_by_id(pcoll_id), pcoll_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,29 @@ def from_runner_api_parameter(unused_ptransform, payload, unused_context):
return PayloadTransform(payload.decode('ascii'))


@ptransform.PTransform.register_urn('map_to_union_types', None)
class MapToUnionTypesTransform(ptransform.PTransform):
class CustomDoFn(beam.DoFn):
def process(self, element):
if element == 1:
return ['1']
elif element == 2:
return [2]
else:
return [3.0]

def expand(self, pcoll):
return pcoll | beam.ParDo(self.CustomDoFn())

def to_runner_api_parameter(self, unused_context):
return b'map_to_union_types', None

@staticmethod
def from_runner_api_parameter(
unused_ptransform, unused_payload, unused_context):
return MapToUnionTypesTransform()


@ptransform.PTransform.register_urn('fib', bytes)
class FibTransform(ptransform.PTransform):
def __init__(self, level):
Expand Down
19 changes: 18 additions & 1 deletion sdks/python/apache_beam/transforms/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import artifact_service
from apache_beam.transforms import ptransform
from apache_beam.typehints import WithTypeHints
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import row_type
from apache_beam.typehints.schemas import named_fields_to_schema
Expand Down Expand Up @@ -433,6 +434,9 @@ def __init__(self, urn, payload, expansion_service=None):
self._inputs = {} # type: Dict[str, pvalue.PCollection]
self._outputs = {} # type: Dict[str, pvalue.PCollection]

def with_output_types(self, *args, **kwargs):
return WithTypeHints.with_output_types(self, *args, **kwargs)

def replace_named_inputs(self, named_inputs):
self._inputs = named_inputs

Expand Down Expand Up @@ -498,11 +502,24 @@ def expand(self, pvalueish):
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.IMPULSE.urn),
outputs={'out': transform_proto.inputs[tag]}))
output_coder = None
if self._type_hints.output_types:
if self._type_hints.output_types[0]:
output_coder = dict((str(k), context.coder_id_from_element_type(v))
for k,
v in enumerate(self._type_hints.output_types[0]))
elif self._type_hints.output_types[1]:
output_coder = {
k: context.coder_id_from_element_type(v)
for k,
v in self._type_hints.output_types[1].items()
}
components = context.to_runner_api()
request = beam_expansion_api_pb2.ExpansionRequest(
components=components,
namespace=self._external_namespace, # type: ignore # mypy thinks self._namespace is threading.local
transform=transform_proto)
transform=transform_proto,
output_coder_override=output_coder)

with self._service() as service:
response = service.Expand(request)
Expand Down
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/transforms/external_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,36 @@ def test_payload(self):
'payload', b's', expansion_service.ExpansionServiceServicer()))
assert_that(res, equal_to(['as', 'bbs']))

def test_output_coder(self):
external_transform = beam.ExternalTransform(
'map_to_union_types',
None,
expansion_service.ExpansionServiceServicer()).with_output_types(int)
with beam.Pipeline() as p:
res = (p | beam.Create([2, 2], reshuffle=False) | external_transform)
assert_that(res, equal_to([2, 2]))
context = pipeline_context.PipelineContext(
external_transform._expanded_components)
self.assertEqual(len(external_transform._expanded_transform.outputs), 1)
for _, pcol_id in external_transform._expanded_transform.outputs.items():
pcol = context.pcollections.get_by_id(pcol_id)
self.assertEqual(pcol.element_type, int)

def test_no_output_coder(self):
external_transform = beam.ExternalTransform(
'map_to_union_types',
None,
expansion_service.ExpansionServiceServicer())
with beam.Pipeline() as p:
res = (p | beam.Create([2, 2], reshuffle=False) | external_transform)
assert_that(res, equal_to([2, 2]))
context = pipeline_context.PipelineContext(
external_transform._expanded_components)
self.assertEqual(len(external_transform._expanded_transform.outputs), 1)
for _, pcol_id in external_transform._expanded_transform.outputs.items():
pcol = context.pcollections.get_by_id(pcol_id)
self.assertEqual(pcol.element_type, typehints.Any)

def test_nested(self):
with beam.Pipeline() as p:
assert_that(p | FibTransform(6), equal_to([8]))
Expand Down

0 comments on commit e4cfeec

Please sign in to comment.