Skip to content

Commit

Permalink
live integration: get rid of dvc from live (iterative#5466)
Browse files Browse the repository at this point in the history
* live integration: get rid of dvc from live

* live: remove from api, add test for html generation during run

* stage: run: convert monitors to class context managers

* run: make monitors run in single thread

* run: move monitor logic out of run
  • Loading branch information
pared authored Mar 2, 2021
1 parent e385e59 commit 4a8cb80
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 200 deletions.
4 changes: 2 additions & 2 deletions dvc/api/__init__.py → dvc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def make_checkpoint():
from time import sleep

from dvc.env import DVC_CHECKPOINT, DVC_ROOT
from dvc.stage.run import CHECKPOINT_SIGNAL_FILE
from dvc.stage.monitor import CheckpointTask

if os.getenv(DVC_CHECKPOINT) is None:
return

root_dir = os.getenv(DVC_ROOT, Repo.find_root())
signal_file = os.path.join(
root_dir, Repo.DVC_DIR, "tmp", CHECKPOINT_SIGNAL_FILE
root_dir, Repo.DVC_DIR, "tmp", CheckpointTask.SIGNAL_FILE
)

with builtins.open(signal_file, "w") as fobj:
Expand Down
24 changes: 0 additions & 24 deletions dvc/api/live.py

This file was deleted.

2 changes: 1 addition & 1 deletion dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc.env import DVCLIVE_RESUME
from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.stage.run import CheckpointKilledError
from dvc.stage.monitor import CheckpointKilledError
from dvc.utils import relpath

from .base import (
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from dvc.scm import SCM
from dvc.stage import PipelineStage
from dvc.stage.run import CheckpointKilledError
from dvc.stage.monitor import CheckpointKilledError
from dvc.stage.serialize import to_lockfile
from dvc.utils import dict_sha256
from dvc.utils.fs import remove
Expand Down
15 changes: 15 additions & 0 deletions dvc/repo/live.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
import contextlib
import logging
import os
from typing import TYPE_CHECKING, List, Optional

from dvc.exceptions import MetricDoesNotExistError, MetricsError

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from dvc.output import BaseOutput
from dvc.path_info import PathInfo
from dvc.repo import Repo


def create_summary(out):
from dvc.utils.html import write

assert out.live and out.live["html"]

metrics, plots = out.repo.live.show(str(out.path_info))

html_path = out.path_info.with_suffix(".html")
write(html_path, plots, metrics)
logger.info(f"\nfile://{os.path.abspath(html_path)}")


def summary_path_info(out: "BaseOutput") -> Optional["PathInfo"]:
from dvc.output import BaseOutput

Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _reproduce_stages(
_repro_callback, checkpoint_func, unchanged
)

from dvc.stage.run import CheckpointKilledError
from dvc.stage.monitor import CheckpointKilledError

try:
ret = _reproduce_stage(stage, **kwargs)
Expand Down
150 changes: 150 additions & 0 deletions dvc/stage/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import functools
import logging
import os
import subprocess
import threading
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, List

from dvc.repo.live import create_summary
from dvc.stage.decorators import relock_repo
from dvc.stage.exceptions import StageCmdFailedError

if TYPE_CHECKING:
from dvc.output import BaseOutput
from dvc.stage import Stage


logger = logging.getLogger(__name__)


class CheckpointKilledError(StageCmdFailedError):
pass


class LiveKilledError(StageCmdFailedError):
pass


@dataclass
class MonitorTask:
stage: "Stage"
execute: Callable
proc: subprocess.Popen
done: threading.Event = threading.Event()
killed: threading.Event = threading.Event()

@property
def name(self) -> str:
raise NotImplementedError

@property
def SIGNAL_FILE(self) -> str:
raise NotImplementedError

@property
def error_cls(self) -> type:
raise NotImplementedError

@property
def signal_path(self) -> str:
return os.path.join(self.stage.repo.tmp_dir, self.SIGNAL_FILE)

def after_run(self):
pass


class CheckpointTask(MonitorTask):
name = "checkpoint"
SIGNAL_FILE = "DVC_CHECKPOINT"
error_cls = CheckpointKilledError

def __init__(
self, stage: "Stage", callback_func: Callable, proc: subprocess.Popen
):
super().__init__(
stage,
functools.partial(
CheckpointTask._run_callback, stage, callback_func
),
proc,
)

@staticmethod
@relock_repo
def _run_callback(stage, callback_func):
stage.save(allow_missing=True)
stage.commit(allow_missing=True)
logger.debug("Running checkpoint callback for stage '%s'", stage)
callback_func()


class LiveTask(MonitorTask):
name = "live"
SIGNAL_FILE = "DVC_LIVE"
error_cls = LiveKilledError

def __init__(
self, stage: "Stage", out: "BaseOutput", proc: subprocess.Popen
):
super().__init__(stage, functools.partial(create_summary, out), proc)

def after_run(self):
# make sure summary is prepared for all the data
self.execute()


class Monitor:
AWAIT: float = 1.0

def __init__(self, tasks: List[MonitorTask]):
self.done = threading.Event()
self.tasks = tasks
self.monitor_thread = threading.Thread(
target=Monitor._loop, args=(self.tasks, self.done,),
)

def __enter__(self):
self.monitor_thread.start()

def __exit__(self, exc_type, exc_val, exc_tb):
self.done.set()
self.monitor_thread.join()
for t in self.tasks:
t.after_run()

@staticmethod
def kill(proc):
if os.name == "nt":
return Monitor._kill_nt(proc)
proc.terminate()
proc.wait()

@staticmethod
def _kill_nt(proc):
# windows stages are spawned with shell=True, proc is the shell process
# and not the actual stage process - we have to kill the entire tree
subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)])

@staticmethod
def _loop(tasks: List[MonitorTask], done: threading.Event):
while True:
for task in tasks:
if os.path.exists(task.signal_path):
try:
task.execute()
except Exception: # pylint: disable=broad-except
logger.exception(
"Error running '%s' task, '%s' will be aborted",
task.name,
task.stage,
)
Monitor.kill(task.proc)
task.killed.set()
finally:
logger.debug(
"Removing signal file for '%s' task", task.name
)
os.remove(task.signal_path)
if done.wait(Monitor.AWAIT):
return
Loading

0 comments on commit 4a8cb80

Please sign in to comment.