diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py index 8f12779d67f8..b3d6e027b579 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py @@ -557,8 +557,14 @@ def add_parent(child, parent): pipeline_proto.components.transforms[parent]) copy_output_pcollections(components.transforms[parent]) del components.transforms[parent].subtransforms[:] - add_parent(parent, parents.get(parent)) + # Ensure that child is the last item in the parent's subtransforms. + # If the stages were previously sorted into topological order using + # sort_stages, this ensures that the parent transforms are also + # added in topological order. + if child in components.transforms[parent].subtransforms: + components.transforms[parent].subtransforms.remove(child) components.transforms[parent].subtransforms.append(child) + add_parent(parent, parents.get(parent)) def copy_subtransforms(transform): for subtransform_id in transform.subtransforms: diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py index 0ce11bda45c4..8eb79609367d 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py @@ -186,6 +186,37 @@ def test_optimize_multiple_combine_globally(self): beam.Pipeline.from_runner_api( optimized_pipeline_proto, runner, pipeline_options.PipelineOptions()) + def test_pipeline_from_sorted_stages_is_toplogically_ordered(self): + pipeline = beam.Pipeline() + side = pipeline | 'side' >> Create([3, 4]) + + class CreateAndMultiplyBySide(beam.PTransform): + def expand(self, pcoll): + return ( + pcoll | 'main' >> Create([1, 2]) | 'compute' >> beam.FlatMap( + lambda x, s: [x * y for y in s], beam.pvalue.AsIter(side))) + + _ = pipeline | 'create-and-multiply-by-side' >> CreateAndMultiplyBySide() + pipeline_proto = pipeline.to_runner_api() + optimized_pipeline_proto = translations.optimize_pipeline( + pipeline_proto, [ + (lambda stages, _: reversed(list(stages))), + translations.sort_stages, + ], + known_runner_urns=frozenset(), + partial=True) + + def assert_is_topologically_sorted(transform_id, visited_pcolls): + transform = optimized_pipeline_proto.components.transforms[transform_id] + self.assertTrue(set(transform.inputs.values()).issubset(visited_pcolls)) + visited_pcolls.update(transform.outputs.values()) + for subtransform in transform.subtransforms: + assert_is_topologically_sorted(subtransform, visited_pcolls) + + self.assertEqual(len(optimized_pipeline_proto.root_transform_ids), 1) + assert_is_topologically_sorted( + optimized_pipeline_proto.root_transform_ids[0], set()) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)