Skip to content

Commit

Permalink
[BEAM-11355] Fix topological ordering of pipeline_from_stages after s…
Browse files Browse the repository at this point in the history
…ort_stages (apache#13432)
  • Loading branch information
Yifan Mai authored Dec 3, 2020
1 parent bcea54f commit 8ab1955
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,14 @@ def add_parent(child, parent):
pipeline_proto.components.transforms[parent])
copy_output_pcollections(components.transforms[parent])
del components.transforms[parent].subtransforms[:]
add_parent(parent, parents.get(parent))
# Ensure that child is the last item in the parent's subtransforms.
# If the stages were previously sorted into topological order using
# sort_stages, this ensures that the parent transforms are also
# added in topological order.
if child in components.transforms[parent].subtransforms:
components.transforms[parent].subtransforms.remove(child)
components.transforms[parent].subtransforms.append(child)
add_parent(parent, parents.get(parent))

def copy_subtransforms(transform):
for subtransform_id in transform.subtransforms:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,37 @@ def test_optimize_multiple_combine_globally(self):
beam.Pipeline.from_runner_api(
optimized_pipeline_proto, runner, pipeline_options.PipelineOptions())

def test_pipeline_from_sorted_stages_is_toplogically_ordered(self):
pipeline = beam.Pipeline()
side = pipeline | 'side' >> Create([3, 4])

class CreateAndMultiplyBySide(beam.PTransform):
def expand(self, pcoll):
return (
pcoll | 'main' >> Create([1, 2]) | 'compute' >> beam.FlatMap(
lambda x, s: [x * y for y in s], beam.pvalue.AsIter(side)))

_ = pipeline | 'create-and-multiply-by-side' >> CreateAndMultiplyBySide()
pipeline_proto = pipeline.to_runner_api()
optimized_pipeline_proto = translations.optimize_pipeline(
pipeline_proto, [
(lambda stages, _: reversed(list(stages))),
translations.sort_stages,
],
known_runner_urns=frozenset(),
partial=True)

def assert_is_topologically_sorted(transform_id, visited_pcolls):
transform = optimized_pipeline_proto.components.transforms[transform_id]
self.assertTrue(set(transform.inputs.values()).issubset(visited_pcolls))
visited_pcolls.update(transform.outputs.values())
for subtransform in transform.subtransforms:
assert_is_topologically_sorted(subtransform, visited_pcolls)

self.assertEqual(len(optimized_pipeline_proto.root_transform_ids), 1)
assert_is_topologically_sorted(
optimized_pipeline_proto.root_transform_ids[0], set())


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 8ab1955

Please sign in to comment.