Skip to content

Commit

Permalink
exp push/pull: add --all-commits flag and unify the collection of r…
Browse files Browse the repository at this point in the history
…evs (iterative#7154)

fix: iterative#7154
1. rename flags `-A/--all-commits` in the `exp pull/push`
2. add new flag "--rev" and "--num" in the `exp pull/push`
3. Unify the collection of revs in `exp push/pull`
4. add unit and func tests for `exp push/pull`
  • Loading branch information
karajan1001 committed Mar 9, 2022
1 parent d033345 commit 244637e
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 67 deletions.
33 changes: 27 additions & 6 deletions dvc/commands/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,45 @@

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsPull(CmdBase):
def raise_error_if_all_disabled(self):
if not any(
[self.args.experiment, self.args.all_commits, self.args.rev]
):
raise InvalidArgumentError(
"Either provide an `experiment` argument, or use the "
"`--rev` or `--all-commits` flag."
)

def run(self):
self.raise_error_if_all_disabled()

self.repo.experiments.pull(
pulled_exps = self.repo.experiments.pull(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
pull_cache=self.args.pull_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)

ui.write(
f"Pulled experiment '{self.args.experiment}'",
f"from Git remote '{self.args.git_remote}'.",
)
if pulled_exps:
ui.write(
f"Pulled experiment '{pulled_exps}'",
f"from Git remote '{self.args.git_remote}'.",
)
else:
ui.write("No experiments to pull.")
if not self.args.pull_cache:
ui.write(
"To pull cached outputs for this experiment"
Expand All @@ -36,6 +53,8 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags

EXPERIMENTS_PULL_HELP = "Pull an experiment from a Git remote."
experiments_pull_parser = experiments_subparsers.add_parser(
"pull",
Expand All @@ -44,6 +63,7 @@ def add_parser(experiments_subparsers, parent_parser):
help=EXPERIMENTS_PULL_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
add_rev_selection_flags(experiments_pull_parser, "Pull", False)
experiments_pull_parser.add_argument(
"-f",
"--force",
Expand Down Expand Up @@ -89,7 +109,8 @@ def add_parser(experiments_subparsers, parent_parser):
)
experiments_pull_parser.add_argument(
"experiment",
nargs="+",
nargs="*",
default=None,
help="Experiments to pull.",
metavar="<experiment>",
)
Expand Down
34 changes: 28 additions & 6 deletions dvc/commands/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,46 @@
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.commands import completion
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsPush(CmdBase):
def raise_error_if_all_disabled(self):
if not any(
[self.args.experiment, self.args.all_commits, self.args.rev]
):
raise InvalidArgumentError(
"Either provide an `experiment` argument, or use the "
"`--rev` or `--all-commits` flag."
)

def run(self):

self.repo.experiments.push(
self.raise_error_if_all_disabled()

pushed_exps = self.repo.experiments.push(
self.args.git_remote,
self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
force=self.args.force,
push_cache=self.args.push_cache,
dvc_remote=self.args.dvc_remote,
jobs=self.args.jobs,
run_cache=self.args.run_cache,
)

ui.write(
f"Pushed experiment '{self.args.experiment}'"
f"to Git remote '{self.args.git_remote}'."
)
if pushed_exps:
ui.write(
f"Pushed experiment '{pushed_exps}'"
f"to Git remote '{self.args.git_remote}'."
)
else:
ui.write("No experiments to pull.")
if not self.args.push_cache:
ui.write(
"To push cached outputs",
Expand All @@ -37,6 +55,8 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags

EXPERIMENTS_PUSH_HELP = "Push a local experiment to a Git remote."
experiments_push_parser = experiments_subparsers.add_parser(
"push",
Expand All @@ -45,6 +65,7 @@ def add_parser(experiments_subparsers, parent_parser):
help=EXPERIMENTS_PUSH_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
add_rev_selection_flags(experiments_push_parser, "Push", False)
experiments_push_parser.add_argument(
"-f",
"--force",
Expand Down Expand Up @@ -90,7 +111,8 @@ def add_parser(experiments_subparsers, parent_parser):
)
experiments_push_parser.add_argument(
"experiment",
nargs="+",
nargs="*",
default=None,
help="Experiments to push.",
metavar="<experiment>",
).complete = completion.EXPERIMENT
Expand Down
71 changes: 44 additions & 27 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from typing import Iterable, Union
from typing import Iterable, Optional, Set, Union

from dvc.exceptions import DvcException
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import TqdmGit
from dvc.scm import TqdmGit, iter_revs

from .base import ExpRefInfo
from .exceptions import UnresolvedExpNamesError
from .utils import exp_commits, resolve_name
from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name

logger = logging.getLogger(__name__)

Expand All @@ -18,33 +19,50 @@ def pull(
repo,
git_remote: str,
exp_names: Union[Iterable[str], str],
*args,
all_commits=False,
rev: Optional[str] = None,
num=1,
force: bool = False,
pull_cache: bool = False,
**kwargs,
):
if isinstance(exp_names, str):
exp_names = [exp_names]
exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)
unresolved_exp_names = [
exp_name
for exp_name, exp_ref in exp_ref_dict.items()
if exp_ref is None
]
if unresolved_exp_names:
raise UnresolvedExpNamesError(unresolved_exp_names)
) -> Iterable[str]:
exp_ref_set: Set["ExpRefInfo"] = set()
if all_commits:
exp_ref_set.update(exp_refs(repo.scm, git_remote))
else:
if exp_names:
if isinstance(exp_names, str):
exp_names = [exp_names]
exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)

unresolved_exp_names = []
for exp_name, exp_ref in exp_ref_dict.items():
if exp_ref is None:
unresolved_exp_names.append(exp_name)
else:
exp_ref_set.add(exp_ref)

if unresolved_exp_names:
raise UnresolvedExpNamesError(unresolved_exp_names)

if rev:
rev_dict = iter_revs(repo.scm, [rev], num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
exp_ref_set.update(ref_info_list)

exp_ref_set = exp_ref_dict.values()
_pull(repo, git_remote, exp_ref_set, force, pull_cache, **kwargs)
_pull(repo, git_remote, exp_ref_set, force)
if pull_cache:
_pull_cache(repo, exp_ref_set, **kwargs)
return [ref.name for ref in exp_ref_set]


def _pull(
repo,
git_remote: str,
exp_refs,
refs,
force: bool,
pull_cache: bool,
**kwargs,
):
def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
Expand All @@ -56,7 +74,7 @@ def on_diverged(refname: str, rev: str) -> bool:
"re-run with '--force'."
)

refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in exp_refs]
refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs]
logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'")

with TqdmGit(desc="Fetching git refs") as pbar:
Expand All @@ -68,20 +86,19 @@ def on_diverged(refname: str, rev: str) -> bool:
progress=pbar.update_git,
)

if pull_cache:
_pull_cache(repo, exp_refs, **kwargs)


def _pull_cache(
repo,
exp_ref,
refs: Union[ExpRefInfo, Iterable["ExpRefInfo"]],
dvc_remote=None,
jobs=None,
run_cache=False,
odb=None,
):
revs = list(exp_commits(repo.scm, exp_ref))
logger.debug(f"dvc fetch experiment '{exp_ref}'")
if isinstance(refs, ExpRefInfo):
refs = [refs]
revs = list(exp_commits(repo.scm, refs))
logger.debug(f"dvc fetch experiment '{refs}'")
repo.fetch(
jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, odb=odb
)
Loading

0 comments on commit 244637e

Please sign in to comment.