From 34faad3e967184d68e135eacc8b25d0c03041af0 Mon Sep 17 00:00:00 2001
From: Ruslan Kuprieiev <ruslan@iterative.ai>
Date: Wed, 19 Sep 2018 09:24:38 +0300
Subject: [PATCH] repro: introduce -p|--pipeline option

Fixes #1107

Signed-off-by: Ruslan Kuprieiev <ruslan@iterative.ai>
---
 dvc/cli.py           |  7 +++++
 dvc/command/repro.py |  3 +-
 dvc/project.py       | 74 ++++++++++++++++++++++++++++++++------------
 tests/test_repro.py  | 12 +++++++
 4 files changed, 75 insertions(+), 21 deletions(-)

diff --git a/dvc/cli.py b/dvc/cli.py
index a00d76b22a..46d3c90f67 100644
--- a/dvc/cli.py
+++ b/dvc/cli.py
@@ -396,6 +396,13 @@ def parse_args(argv=None):
                         default=False,
                         help='Ask for confirmation before reproducing each '
                              'stage.')
+    repro_parser.add_argument(
+                        '-p',
+                        '--pipeline',
+                        action='store_true',
+                        default=False,
+                        help='Reproduce the whole pipeline that the '
+                             'specified stage file belongs to.')
     repro_parser.set_defaults(func=CmdRepro)
 
     # Remove
diff --git a/dvc/command/repro.py b/dvc/command/repro.py
index 284174b6db..7efa6c5ef9 100644
--- a/dvc/command/repro.py
+++ b/dvc/command/repro.py
@@ -20,7 +20,8 @@ def run(self):
                                        recursive=recursive,
                                        force=self.args.force,
                                        dry=self.args.dry,
-                                       interactive=self.args.interactive)
+                                       interactive=self.args.interactive,
+                                       pipeline=self.args.pipeline)
 
                 if len(stages) == 0:
                     self.project.logger.info(CmdDataStatus.UP_TO_DATE_MSG)
diff --git a/dvc/project.py b/dvc/project.py
index 8367061a10..7a0a8a0bfd 100644
--- a/dvc/project.py
+++ b/dvc/project.py
@@ -311,39 +311,73 @@ def reproduce(self,
                   recursive=True,
                   force=False,
                   dry=False,
-                  interactive=False):
-        import networkx as nx
-
-        stage = Stage.load(self, target)
-        G = self.graph()[1]
-        stages = nx.get_node_attributes(G, 'stage')
-        node = os.path.relpath(stage.path, self.root_dir)
+                  interactive=False,
+                  pipeline=False):
 
         if not interactive:
             config = self.config
             core = config._config[config.SECTION_CORE]
             interactive = core.get(config.SECTION_CORE_INTERACTIVE, False)
 
+        targets = []
+        if pipeline:
+            stage = Stage.load(self, target)
+            node = os.path.relpath(stage.path, self.root_dir)
+            pipelines = list(filter(lambda g: node in g.nodes(),
+                                    self.pipelines()))
+            assert len(pipelines) == 1
+            G = pipelines[0]
+            for node in G.nodes():
+                if G.in_degree(node) == 0:
+                    targets.append(os.path.join(self.root_dir, node))
+        else:
+            targets.append(target)
+
         self._files_to_git_add = []
+
+        ret = []
         with self.state:
-            if recursive:
-                ret = self._reproduce_stages(G,
-                                             stages,
-                                             node,
-                                             force,
-                                             dry,
-                                             interactive)
-            else:
-                ret = self._reproduce_stage(stages,
-                                            node,
-                                            force,
-                                            dry,
-                                            interactive)
+            for target in targets:
+                stages = self._reproduce(target,
+                                         recursive=recursive,
+                                         force=force,
+                                         dry=dry,
+                                         interactive=interactive)
+                ret.extend(stages)
 
         self._remind_to_git_add()
 
         return ret
 
+    def _reproduce(self,
+                   target,
+                   recursive=True,
+                   force=False,
+                   dry=False,
+                   interactive=False):
+        import networkx as nx
+
+        stage = Stage.load(self, target)
+        G = self.graph()[1]
+        stages = nx.get_node_attributes(G, 'stage')
+        node = os.path.relpath(stage.path, self.root_dir)
+
+        if recursive:
+            ret = self._reproduce_stages(G,
+                                         stages,
+                                         node,
+                                         force,
+                                         dry,
+                                         interactive)
+        else:
+            ret = self._reproduce_stage(stages,
+                                        node,
+                                        force,
+                                        dry,
+                                        interactive)
+
+        return ret
+
     def _reproduce_stages(self, G, stages, node, force, dry, interactive):
         import networkx as nx
 
diff --git a/tests/test_repro.py b/tests/test_repro.py
index 46450c1a99..75d2e026c3 100644
--- a/tests/test_repro.py
+++ b/tests/test_repro.py
@@ -205,6 +205,18 @@ def test(self):
         self.assertEqual(len(stages), 3)
 
 
+class TestReproPipeline(TestReproChangedDeepData):
+    def test(self):
+        stages = self.dvc.reproduce(self.file1_stage,
+                                    force=True,
+                                    pipeline=True)
+        self.assertEqual(len(stages), 3)
+
+    def test_cli(self):
+        ret = main(['repro', '--pipeline', '-f', self.file1_stage])
+        self.assertEqual(ret, 0)
+
+
 class TestReproLocked(TestReproChangedData):
     def test(self):
         file2 = 'file2'