Skip to content

Commit

Permalink
Use a more granular lock on exp operations.
Browse files Browse the repository at this point in the history
For now in `experiment.new`, `init_executor` and `collect_exectuor` we
lock the whole scm repo. We can make a more granular lock in these
operations.
  • Loading branch information
karajan1001 committed Nov 28, 2022
1 parent 5fa33a0 commit 6bfce5a
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 124 deletions.
10 changes: 1 addition & 9 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ExpRefInfo,
)
from .stash import ExpStashEntry
from .utils import exp_refs_by_rev, scm_locked, unlocked_repo
from .utils import exp_refs_by_rev, unlocked_repo

logger = logging.getLogger(__name__)

Expand All @@ -50,18 +50,12 @@ class Experiments:
)

def __init__(self, repo):
from dvc.lock import make_lock
from dvc.scm import NoSCMError

if repo.config["core"].get("no_scm", False):
raise NoSCMError

self.repo = repo
self.scm_lock = make_lock(
os.path.join(self.repo.tmp_dir, "exp_scm_lock"),
tmp_dir=self.repo.tmp_dir,
hardlink_lock=repo.config["core"].get("hardlink_lock", False),
)

@property
def scm(self):
Expand Down Expand Up @@ -253,7 +247,6 @@ def _validate_new_ref(self, exp_ref: ExpRefInfo):
if self.scm.get_ref(str(exp_ref)):
raise ExperimentExistsError(exp_ref.name)

@scm_locked
def new(
self,
queue: BaseStashQueue,
Expand All @@ -279,7 +272,6 @@ def new(
except ExperimentExistsError as err:
if not (kwargs.get("force", False) or kwargs.get("reset", False)):
raise err

return queue.put(*args, **kwargs)

def _resume_checkpoint(
Expand Down
27 changes: 10 additions & 17 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Callable,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Tuple,
Expand Down Expand Up @@ -40,9 +41,6 @@
EXEC_CHECKPOINT,
EXEC_HEAD,
EXEC_MERGE,
EXEC_NAMESPACE,
EXPS_NAMESPACE,
EXPS_STASH,
ExpRefInfo,
)

Expand Down Expand Up @@ -168,6 +166,7 @@ def __init__(
@abstractmethod
def init_git(
self,
repo: "Repo",
scm: "Git",
stash_rev: str,
entry: "ExpStashEntry",
Expand Down Expand Up @@ -315,6 +314,7 @@ def unpack_repro_args(path):
def fetch_exps(
self,
dest_scm: "Git",
refs: List[str],
force: bool = False,
on_diverged: Callable[[str, bool], None] = None,
**kwargs,
Expand All @@ -323,26 +323,19 @@ def fetch_exps(
Args:
dest_scm: Destination Git instance.
refs: reference names to be fetched from the remotes.
force: If True, diverged refs will be overwritten
on_diverged: Callback in the form on_diverged(ref, is_checkpoint)
to be called when an experiment ref has diverged.
Extra kwargs will be passed into the remote git client.
"""
from ..utils import iter_remote_refs

refs = []
has_checkpoint = False
for ref in iter_remote_refs(
dest_scm,
self.git_url,
base=EXPS_NAMESPACE,
**kwargs,
):
if ref == EXEC_CHECKPOINT:
has_checkpoint = True
elif not ref.startswith(EXEC_NAMESPACE) and ref != EXPS_STASH:
refs.append(ref)

if EXEC_CHECKPOINT in refs:
refs.remove(EXEC_CHECKPOINT)
has_checkpoint = True
else:
has_checkpoint = False

def on_diverged_ref(orig_ref: str, new_rev: str):
if force:
Expand Down
85 changes: 46 additions & 39 deletions dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

from funcy import cached_property
from funcy import cached_property, retry
from scmrepo.exceptions import SCMError as _SCMError

from dvc.lock import LockError
from dvc.scm import SCM, GitMergeError
from dvc.utils.fs import makedirs, remove

Expand All @@ -19,7 +20,7 @@
EXEC_MERGE,
EXEC_NAMESPACE,
)
from ..utils import EXEC_TMP_DIR
from ..utils import EXEC_TMP_DIR, get_exp_rwlock
from .base import BaseExecutor, TaskStatus

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,8 +69,10 @@ class TempDirExecutor(BaseLocalExecutor):
QUIET = True
DEFAULT_LOCATION = "tempdir"

@retry(180, errors=LockError, timeout=1)
def init_git(
self,
repo: "Repo",
scm: "Git",
stash_rev: str,
entry: "ExpStashEntry",
Expand All @@ -86,28 +89,29 @@ def init_git(
if infofile:
self.info.dump_json(infofile)

with self.set_exec_refs(scm, stash_rev, entry):
refspec = f"{EXEC_NAMESPACE}/"
push_refspec(scm, self.git_url, refspec, refspec)
with get_exp_rwlock(repo, writes=[EXEC_NAMESPACE]):
with self.set_exec_refs(scm, stash_rev, entry):
refspec = f"{EXEC_NAMESPACE}/"
push_refspec(scm, self.git_url, refspec, refspec)

if branch:
push_refspec(scm, self.git_url, branch, branch)
self.scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
elif self.scm.get_ref(EXEC_BRANCH):
self.scm.remove_ref(EXEC_BRANCH)
if branch:
push_refspec(scm, self.git_url, branch, branch)
self.scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
elif self.scm.get_ref(EXEC_BRANCH):
self.scm.remove_ref(EXEC_BRANCH)

if self.scm.get_ref(EXEC_CHECKPOINT):
self.scm.remove_ref(EXEC_CHECKPOINT)
# checkout EXEC_HEAD and apply EXEC_MERGE on top of it without
# committing
head = EXEC_BRANCH if branch else EXEC_HEAD
self.scm.checkout(head, detach=True)
merge_rev = self.scm.get_ref(EXEC_MERGE)
if self.scm.get_ref(EXEC_CHECKPOINT):
self.scm.remove_ref(EXEC_CHECKPOINT)
# checkout EXEC_HEAD and apply EXEC_MERGE on top of it without
# committing
head = EXEC_BRANCH if branch else EXEC_HEAD
self.scm.checkout(head, detach=True)
merge_rev = self.scm.get_ref(EXEC_MERGE)

try:
self.scm.merge(merge_rev, squash=True, commit=False)
except _SCMError as exc:
raise GitMergeError(str(exc), scm=self.scm)
try:
self.scm.merge(merge_rev, squash=True, commit=False)
except _SCMError as exc:
raise GitMergeError(str(exc), scm=self.scm)

def _config(self, cache_dir):
local_config = os.path.join(
Expand Down Expand Up @@ -168,8 +172,10 @@ def from_stash_entry(
logger.debug("Init workspace executor in '%s'", root_dir)
return executor

@retry(180, errors=LockError, timeout=1)
def init_git(
self,
repo: "Repo",
scm: "Git",
stash_rev: str,
entry: "ExpStashEntry",
Expand All @@ -180,25 +186,26 @@ def init_git(
if infofile:
self.info.dump_json(infofile)

scm.set_ref(EXEC_HEAD, entry.head_rev)
scm.set_ref(EXEC_MERGE, stash_rev)
scm.set_ref(EXEC_BASELINE, entry.baseline_rev)
self._detach_stack.enter_context(
self.scm.detach_head(
self.scm.get_ref(EXEC_HEAD),
force=True,
client="dvc",
with get_exp_rwlock(repo, writes=[EXEC_NAMESPACE]):
scm.set_ref(EXEC_HEAD, entry.head_rev)
scm.set_ref(EXEC_MERGE, stash_rev)
scm.set_ref(EXEC_BASELINE, entry.baseline_rev)
self._detach_stack.enter_context(
self.scm.detach_head(
self.scm.get_ref(EXEC_HEAD),
force=True,
client="dvc",
)
)
)
merge_rev = self.scm.get_ref(EXEC_MERGE)
try:
self.scm.merge(merge_rev, squash=True, commit=False)
except _SCMError as exc:
raise GitMergeError(str(exc), scm=self.scm)
if branch:
self.scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
elif scm.get_ref(EXEC_BRANCH):
self.scm.remove_ref(EXEC_BRANCH)
merge_rev = self.scm.get_ref(EXEC_MERGE)
try:
self.scm.merge(merge_rev, squash=True, commit=False)
except _SCMError as exc:
raise GitMergeError(str(exc), scm=self.scm)
if branch:
self.scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
elif scm.get_ref(EXEC_BRANCH):
self.scm.remove_ref(EXEC_BRANCH)

def init_cache(self, repo: "Repo", rev: str, run_cache: bool = True):
pass
Expand Down
1 change: 1 addition & 0 deletions dvc/repo/experiments/executor/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _git_client_args(fs):

def init_git(
self,
repo: "Repo",
scm: "Git",
stash_rev: str,
entry: "ExpStashEntry",
Expand Down
59 changes: 37 additions & 22 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from ..executor.local import WorkspaceExecutor
from ..refs import ExpRefInfo
from ..stash import ExpStash, ExpStashEntry
from ..utils import EXEC_PID_DIR, EXEC_TMP_DIR, exp_refs_by_rev, scm_locked
from ..utils import EXEC_PID_DIR, EXEC_TMP_DIR, exp_refs_by_rev, get_exp_rwlock
from .utils import get_remote_executor_refs

if TYPE_CHECKING:
from scmrepo.git import Git
Expand Down Expand Up @@ -550,28 +551,39 @@ def _update_params(self, params: Dict[str, List[str]]):

@staticmethod
@retry(180, errors=LockError, timeout=1)
@scm_locked
def get_stash_entry(
exp: "Experiments",
queue_entry: QueueEntry,
) -> "ExpStashEntry":
stash = ExpStash(exp.scm, queue_entry.stash_ref)
stash_rev = queue_entry.stash_rev
with get_exp_rwlock(exp.repo, writes=[queue_entry.stash_ref]):
stash_entry = stash.stash_revs.get(
stash_rev,
ExpStashEntry(None, stash_rev, stash_rev, None, None),
)
if stash_entry.stash_index is not None:
stash.drop(stash_entry.stash_index)
return stash_entry

@classmethod
def init_executor(
cls,
exp: "Experiments",
queue_entry: QueueEntry,
executor_cls: Type[BaseExecutor] = WorkspaceExecutor,
**kwargs,
) -> BaseExecutor:
scm = exp.scm
stash = ExpStash(scm, queue_entry.stash_ref)
stash_rev = queue_entry.stash_rev
stash_entry = stash.stash_revs.get(
stash_rev,
ExpStashEntry(None, stash_rev, stash_rev, None, None),
)
if stash_entry.stash_index is not None:
stash.drop(stash_entry.stash_index)
stash_entry = cls.get_stash_entry(exp, queue_entry)

executor = executor_cls.from_stash_entry(
exp.repo, stash_entry, **kwargs
)

stash_rev = queue_entry.stash_rev
infofile = exp.celery_queue.get_infofile_path(stash_rev)
executor.init_git(
exp.repo,
exp.repo.scm,
stash_rev,
stash_entry,
Expand All @@ -592,7 +604,6 @@ def get_infofile_path(self, name: str) -> str:

@staticmethod
@retry(180, errors=LockError, timeout=1)
@scm_locked
def collect_git(
exp: "Experiments",
executor: BaseExecutor,
Expand All @@ -606,16 +617,20 @@ def on_diverged(ref: str, checkpoint: bool):
raise CheckpointExistsError(ref_info.name)
raise ExperimentExistsError(ref_info.name)

for ref in executor.fetch_exps(
exp.scm,
force=exec_result.force,
on_diverged=on_diverged,
):
exp_rev = exp.scm.get_ref(ref)
if exp_rev:
assert exec_result.exp_hash
logger.debug("Collected experiment '%s'.", exp_rev[:7])
results[exp_rev] = exec_result.exp_hash
refs = get_remote_executor_refs(exp.scm, executor.git_url)

with get_exp_rwlock(exp.repo, writes=refs):
for ref in executor.fetch_exps(
exp.scm,
refs,
force=exec_result.force,
on_diverged=on_diverged,
):
exp_rev = exp.scm.get_ref(ref)
if exp_rev:
assert exec_result.exp_hash
logger.debug("Collected experiment '%s'.", exp_rev[:7])
results[exp_rev] = exec_result.exp_hash

return results

Expand Down
6 changes: 4 additions & 2 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from ..exceptions import UnresolvedQueueExpNamesError
from ..executor.base import ExecutorInfo, ExecutorResult
from ..utils import EXEC_TMP_DIR
from ..refs import CELERY_STASH
from ..utils import EXEC_TMP_DIR, get_exp_rwlock
from .base import BaseStashQueue, QueueDoneResult, QueueEntry, QueueGetResult
from .exceptions import CannotKillTasksError
from .tasks import run_exp
Expand Down Expand Up @@ -165,7 +166,8 @@ def start_workers(self, count: int) -> int:

def put(self, *args, **kwargs) -> QueueEntry:
"""Stash an experiment and add it to the queue."""
entry = self._stash_exp(*args, **kwargs)
with get_exp_rwlock(self.repo, writes=["workspace", CELERY_STASH]):
entry = self._stash_exp(*args, **kwargs)
self.celery.signature(run_exp.s(entry.asdict())).delay()
return entry

Expand Down
8 changes: 6 additions & 2 deletions dvc/repo/experiments/queue/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict

from celery import shared_task
from celery.utils.log import get_task_logger
Expand All @@ -7,11 +7,15 @@
from ..executor.local import TempDirExecutor
from .base import BaseStashQueue, QueueEntry

if TYPE_CHECKING:
from ..executor.base import BaseExecutor


logger = get_task_logger(__name__)


@shared_task
def setup_exp(entry_dict: Dict[str, Any]) -> TempDirExecutor:
def setup_exp(entry_dict: Dict[str, Any]) -> "BaseExecutor":
"""Setup an experiment.
Arguments:
Expand Down
Loading

0 comments on commit 6bfce5a

Please sign in to comment.