From c07592131a616afc51f093ca4d4fe8615b63062e Mon Sep 17 00:00:00 2001 From: Gao Date: Mon, 25 Oct 2021 19:16:37 +0800 Subject: [PATCH] exp run: add experiment name check. (#6848) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add experiment name check. 1. Add experiment name check (https://git-scm.com/docs/git-check-ref-format) 2. Add duplicate exp name check. 3. Add some unit test for it. * Ban slash / in dvc exp names * Use dulwich backend for ref name checking * Some bug fix * Update dvc/repo/experiments/__init__.py Co-authored-by: Peter Rowlands (변기호) * Update dvc/repo/experiments/__init__.py Co-authored-by: Peter Rowlands (변기호) * Some review changes * Make some funtion more reusable. Co-authored-by: Peter Rowlands (변기호) --- dvc/repo/experiments/__init__.py | 21 +++++++++++++++++++++ dvc/repo/experiments/utils.py | 9 +++++++++ dvc/scm/git/__init__.py | 1 + dvc/scm/git/backend/dulwich/__init__.py | 5 +++++ tests/func/experiments/test_experiments.py | 15 +++++++++++++++ tests/unit/repo/experiments/test_utils.py | 16 +++++++++++++++- 6 files changed, 66 insertions(+), 1 deletion(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 30b9a5d55e..43d1f1db1c 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -473,6 +473,17 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False): "\tdvc exp branch \n" ) + def _validate_new_ref(self, exp_ref: ExpRefInfo): + from .utils import check_ref_format + + if not exp_ref.name: + return + + check_ref_format(self.scm, exp_ref) + + if self.scm.get_ref(str(exp_ref)): + raise ExperimentExistsError(exp_ref.name) + @scm_locked def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): """Create a new experiment. @@ -485,6 +496,16 @@ def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): *args, resume_rev=checkpoint_resume, **kwargs ) + name = kwargs.get("name", None) + baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev() + exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name) + + try: + self._validate_new_ref(exp_ref) + except ExperimentExistsError as err: + if not (kwargs.get("force", False) or kwargs.get("reset", False)): + raise err + return self._stash_exp(*args, **kwargs) def _resume_checkpoint( diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index c2a8f9cb05..1c68a2b0a9 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -157,3 +157,12 @@ def resolve_exp_ref( msg.extend([f"\t{info}" for info in exp_ref_list]) raise InvalidArgumentError("\n".join(msg)) return exp_ref_list[0] + + +def check_ref_format(scm: "Git", ref: ExpRefInfo): + # "/" forbidden, only in dvc exp as we didn't support it for now. + if not scm.check_ref_format(str(ref)) or "/" in ref.name: + raise InvalidArgumentError( + f"Invalid exp name {ref.name}, the exp name must follow rules in " + "https://git-scm.com/docs/git-check-ref-format" + ) diff --git a/dvc/scm/git/__init__.py b/dvc/scm/git/__init__.py index 5d60ec232f..ba8e76c4ce 100644 --- a/dvc/scm/git/__init__.py +++ b/dvc/scm/git/__init__.py @@ -346,6 +346,7 @@ def get_fs(self, rev: str): status = partialmethod(_backend_func, "status") merge = partialmethod(_backend_func, "merge") validate_git_remote = partialmethod(_backend_func, "validate_git_remote") + check_ref_format = partialmethod(_backend_func, "check_ref_format") def resolve_rev(self, rev: str) -> str: from dvc.repo.experiments.utils import exp_refs_by_name diff --git a/dvc/scm/git/backend/dulwich/__init__.py b/dvc/scm/git/backend/dulwich/__init__.py index 84bfd2f4f7..1f706a8e2d 100644 --- a/dvc/scm/git/backend/dulwich/__init__.py +++ b/dvc/scm/git/backend/dulwich/__init__.py @@ -681,3 +681,8 @@ def validate_git_remote(self, url: str, **kwargs): os.path.join("", path) ): raise InvalidRemoteSCMRepo(url) + + def check_ref_format(self, refname: str): + from dulwich.refs import check_ref_format + + return check_ref_format(refname.encode()) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index cf94871f5b..ee44e83753 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -53,6 +53,7 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): tmp_dir=not workspace, ) + new_mock = mocker.spy(dvc.experiments, "_stash_exp") with pytest.raises(ExperimentExistsError): dvc.experiments.run( exp_stage.addressing, @@ -60,6 +61,7 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): params=["foo=3"], tmp_dir=not workspace, ) + new_mock.assert_not_called() results = dvc.experiments.run( exp_stage.addressing, @@ -685,3 +687,16 @@ def test_exp_run_recursive(tmp_dir, scm, dvc, run_copy_metrics): ) assert dvc.experiments.run(".", recursive=True) assert (tmp_dir / "metric.json").parse() == {"foo": 1} + + +def test_experiment_name_invalid(tmp_dir, scm, dvc, exp_stage, mocker): + from dvc.exceptions import InvalidArgumentError + + new_mock = mocker.spy(dvc.experiments, "_stash_exp") + with pytest.raises(InvalidArgumentError): + dvc.experiments.run( + exp_stage.addressing, + name="fo^o", + params=["foo=3"], + ) + new_mock.assert_not_called() diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index b0fd3da808..0d5cb8ddf0 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -1,7 +1,8 @@ import pytest +from dvc.exceptions import InvalidArgumentError from dvc.repo.experiments.base import EXPS_NAMESPACE, ExpRefInfo -from dvc.repo.experiments.utils import resolve_exp_ref +from dvc.repo.experiments.utils import check_ref_format, resolve_exp_ref def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"): @@ -25,3 +26,16 @@ def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): remote_ref_info = resolve_exp_ref(scm, "foo" if name_only else ref, remote) assert isinstance(remote_ref_info, ExpRefInfo) assert str(remote_ref_info) == ref + + +@pytest.mark.parametrize( + "name,result", [("name", True), ("group/name", False), ("na me", False)] +) +def test_run_check_ref_format(scm, name, result): + + ref = ExpRefInfo("abc123", name) + if result: + check_ref_format(scm, ref) + else: + with pytest.raises(InvalidArgumentError): + check_ref_format(scm, ref)