Skip to content

Commit

Permalink
exp list: cleanup and move logic inside repo api
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Nov 26, 2022
1 parent bb80eab commit d758755
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 75 deletions.
12 changes: 3 additions & 9 deletions dvc/commands/experiments/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.ui import ui

logger = logging.getLogger(__name__)

Expand All @@ -16,20 +17,13 @@ def run(self):
num=self.args.num,
git_remote=self.args.git_remote,
)
tags = self.repo.scm.describe(exps)
remained = {baseline for baseline, tag in tags.items() if tag is None}
base = "refs/heads"
ref_heads = self.repo.scm.describe(remained, base=base)

for baseline in exps:
name = baseline[:7]
if tags[baseline] or ref_heads[baseline]:
name = tags[baseline] or ref_heads[baseline][len(base) + 1 :]
if not name_only:
print(f"{name}:")
ui.write(f"{baseline}:")
for exp_name in exps[baseline]:
indent = "" if name_only else "\t"
print(f"{indent}{exp_name}")
ui.write(f"{indent}{exp_name}")

return 0

Expand Down
29 changes: 16 additions & 13 deletions dvc/repo/experiments/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dvc.scm import iter_revs
from dvc.types import Optional

from .utils import exp_refs, exp_refs_by_baseline
from .utils import exp_refs_by_baseline

logger = logging.getLogger(__name__)

Expand All @@ -20,19 +20,22 @@ def ls(
num: int = 1,
git_remote: Optional[str] = None,
):
results = defaultdict(list)
if all_commits:
gen = exp_refs(repo.scm, git_remote)
for info in gen:
results[info.baseline_sha].append(info.name)
return results
rev_set = None
if not all_commits:
revs = iter_revs(repo.scm, [rev or "HEAD"], num)
rev_set = set(revs.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)

rev = rev or "HEAD"
tags = repo.scm.describe(ref_info_dict.keys())
remained = {baseline for baseline, tag in tags.items() if tag is None}
base = "refs/heads"
ref_heads = repo.scm.describe(remained, base=base)

revs = iter_revs(repo.scm, [rev], num)
rev_set = set(revs.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for rev, ref_info_list in ref_info_dict.items():
results[rev] = [ref_info.name for ref_info in ref_info_list]
results = defaultdict(list)
for baseline in ref_info_dict:
name = baseline[:7]
if tags[baseline] or ref_heads[baseline]:
name = tags[baseline] or ref_heads[baseline][len(base) + 1 :]
results[name] = [info.name for info in ref_info_dict[baseline]]

return results
17 changes: 5 additions & 12 deletions dvc/repo/experiments/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,23 @@ class ExpRefInfo:

namespace = EXPS_NAMESPACE

def __init__(
self, baseline_sha: Optional[str] = None, name: Optional[str] = None
):
def __init__(self, baseline_sha: str, name: Optional[str] = None):
self.baseline_sha = baseline_sha
self.name: str = name if name else ""

def __str__(self):
return "/".join(self.parts)

def __repr__(self):
baseline = f"'{self.baseline_sha}'" if self.baseline_sha else "None"
baseline = f"'{self.baseline_sha}'"
name = f"'{self.name}'" if self.name else "None"
return f"ExpRefInfo(baseline_sha={baseline}, name={name})"

@property
def parts(self):
return (
(self.namespace,)
+ (
(self.baseline_sha[:2], self.baseline_sha[2:])
if self.baseline_sha
else ()
)
+ ((self.baseline_sha[:2], self.baseline_sha[2:]))
+ ((self.name,) if self.name else ())
)

Expand All @@ -54,14 +48,13 @@ def from_ref(cls, ref: str):
try:
parts = ref.split("/")
if (
len(parts) < 2
or len(parts) == 3
len(parts) < 4
or len(parts) > 5
or "/".join(parts[:2]) != EXPS_NAMESPACE
):
raise InvalidExpRefError(ref)
except ValueError:
raise InvalidExpRefError(ref)
baseline_sha = parts[2] + parts[3] if len(parts) >= 4 else None
baseline_sha = parts[2] + parts[3]
name = parts[4] if len(parts) == 5 else None
return cls(baseline_sha, name)
4 changes: 2 additions & 2 deletions dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def exp_refs_by_rev(scm: "Git", rev: str) -> Generator[ExpRefInfo, None, None]:


def exp_refs_by_baseline(
scm: "Git", revs: Set[str], url: Optional[str] = None
scm: "Git", revs: Optional[Set[str]] = None, url: Optional[str] = None
) -> Mapping[str, List[ExpRefInfo]]:
"""Iterate over all experiment refs with the specified baseline."""
all_exp_refs = exp_refs(scm, url)
result = defaultdict(list)
for ref in all_exp_refs:
if ref.baseline_sha in revs:
if revs is None or ref.baseline_sha in revs:
result[ref.baseline_sha].append(ref)
return result

Expand Down
45 changes: 10 additions & 35 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def test_packed_args_exists(tmp_dir, scm, dvc, exp_stage, caplog):
assert "Temporary DVC file" in caplog.text


def _prepare_experiments(tmp_dir, scm, dvc, exp_stage):
def test_list(tmp_dir, scm, dvc, exp_stage):
baseline_a = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp_a = first(results)
Expand All @@ -428,54 +428,29 @@ def _prepare_experiments(tmp_dir, scm, dvc, exp_stage):
ref_info_b = first(exp_refs_by_rev(scm, exp_b))

tmp_dir.scm_gen("new", "new", commit="new")
baseline_c = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"])
exp_c = first(results)
ref_info_c = first(exp_refs_by_rev(scm, exp_c))

return baseline_a, baseline_c, ref_info_a, ref_info_b, ref_info_c


def test_list(tmp_dir, scm, dvc, exp_stage):
(
baseline_a,
baseline_c,
ref_info_a,
ref_info_b,
ref_info_c,
) = _prepare_experiments(tmp_dir, scm, dvc, exp_stage)

assert dvc.experiments.ls() == {baseline_c: [ref_info_c.name]}
assert dvc.experiments.ls() == {"master": [ref_info_c.name]}

exp_list = dvc.experiments.ls(rev=ref_info_a.baseline_sha)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {ref_info_a.name, ref_info_b.name}
baseline_a[:7]: {ref_info_a.name, ref_info_b.name}
}

exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {ref_info_a.name, ref_info_b.name},
baseline_c: {ref_info_c.name},
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"master": {ref_info_c.name},
}


def test_list_cli(tmp_dir, scm, dvc, capsys, exp_stage):
from dvc.cli import main

baseline_a, _, ref_info_a, ref_info_b, ref_info_c = _prepare_experiments(
tmp_dir, scm, dvc, exp_stage
)

# Make sure that we prioritize the current branch name
scm.checkout("branch", True)

capsys.readouterr()
assert main(["exp", "list", "-A"]) == 0
cap = capsys.readouterr()
assert set(cap.out.split()) == set(
["branch:", baseline_a[:7] + ":"]
+ [ref_info_a.name, ref_info_b.name, ref_info_c.name]
)
exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"branch": {ref_info_c.name},
}


def test_subdir(tmp_dir, scm, dvc, workspace):
Expand Down
7 changes: 3 additions & 4 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
ref_info_b = first(exp_refs_by_rev(scm, exp_b))

tmp_dir.scm_gen("new", "new", commit="new")
baseline_c = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"])
exp_c = first(results)
ref_info_c = first(exp_refs_by_rev(scm, exp_c))
Expand All @@ -165,13 +164,13 @@ def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
git_downstream.tmp_dir.scm.fetch_refspecs(remote, ["master:master"])
exp_list = downstream_exp.ls(rev=baseline_a, git_remote=remote)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {ref_info_a.name, ref_info_b.name}
baseline_a[:7]: {ref_info_a.name, ref_info_b.name}
}

exp_list = downstream_exp.ls(all_commits=True, git_remote=remote)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {ref_info_a.name, ref_info_b.name},
baseline_c: {ref_info_c.name},
baseline_a[:7]: {ref_info_a.name, ref_info_b.name},
"master": {ref_info_c.name},
}


Expand Down

0 comments on commit d758755

Please sign in to comment.