From 7d68863ba21855bb5c238af7e1597e9a3607c0fe Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 29 Nov 2021 19:30:38 +0800 Subject: [PATCH] proc: implement process signaling(#7062) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix: #7062 1. Add `send_signal`, `kill`, `terminate` to processManager. 2. Add tests for each of them. Co-authored-by: Peter Rowlands (변기호) --- dvc/proc/exceptions.py | 5 ++ dvc/proc/manager.py | 54 ++++++++++++++++++--- dvc/proc/process.py | 5 +- tests/unit/proc/test_manager.py | 84 +++++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 8 deletions(-) create mode 100644 tests/unit/proc/test_manager.py diff --git a/dvc/proc/exceptions.py b/dvc/proc/exceptions.py index 9149e797f2..5db4ee8b41 100644 --- a/dvc/proc/exceptions.py +++ b/dvc/proc/exceptions.py @@ -18,3 +18,8 @@ def __init__(self, cmd, timeout): ) self.cmd = cmd self.timeout = timeout + + +class UnsupportedSignalError(DvcException): + def __init__(self, sig): + super().__init__(f"Unsupported signal: {sig}") diff --git a/dvc/proc/manager.py b/dvc/proc/manager.py index 4f93b4d4ba..1d4391eb7d 100644 --- a/dvc/proc/manager.py +++ b/dvc/proc/manager.py @@ -3,10 +3,14 @@ import json import logging import os +import signal +import sys from typing import Generator, List, Optional, Union +from funcy.flow import reraise from shortuuid import uuid +from .exceptions import UnsupportedSignalError from .process import ManagedProcess, ProcessInfo logger = logging.getLogger(__name__) @@ -34,6 +38,12 @@ def __getitem__(self, key: str) -> "ProcessInfo": except FileNotFoundError: raise KeyError + @reraise(FileNotFoundError, KeyError) + def __setitem__(self, key: str, value: "ProcessInfo"): + info_path = os.path.join(self.wdir, key, f"{key}.json") + with open(info_path, "w", encoding="utf-8") as fobj: + return json.dump(value.asdict(), fobj) + def get(self, key: str, default=None): try: return self[key] @@ -63,17 +73,47 @@ def spawn(self, args: Union[str, List[str]], name: Optional[str] = None): pid, ) - def send_signal(self, name: str, signal: int): + def send_signal(self, name: str, sig: int): """Send `signal` to the specified named process.""" - raise NotImplementedError - - def kill(self, name: str): - """Kill the specified named process.""" - raise NotImplementedError + process_info = self[name] + if sys.platform == "win32": + if sig not in ( + signal.SIGTERM, + signal.CTRL_C_EVENT, + signal.CTRL_BREAK_EVENT, + ): + raise UnsupportedSignalError(sig) + + def handle_closed_process(): + logging.warning( + f"Process {name} had already aborted unexpectedly." + ) + process_info.returncode = -1 + self[name] = process_info + + if process_info.returncode is None: + try: + os.kill(process_info.pid, sig) + except ProcessLookupError: + handle_closed_process() + raise + except OSError as exc: + if sys.platform == "win32": + if exc.winerror == 87: + handle_closed_process() + raise ProcessLookupError from exc + raise def terminate(self, name: str): """Terminate the specified named process.""" - raise NotImplementedError + self.send_signal(name, signal.SIGTERM) + + def kill(self, name: str): + """Kill the specified named process.""" + if sys.platform == "win32": + self.send_signal(name, signal.SIGTERM) + else: + self.send_signal(name, signal.SIGKILL) def remove(self, name: str, force: bool = False): """Remove the specified named process from this manager. diff --git a/dvc/proc/process.py b/dvc/proc/process.py index a381c56ae9..edbce65262 100644 --- a/dvc/proc/process.py +++ b/dvc/proc/process.py @@ -29,6 +29,9 @@ class ProcessInfo: def from_dict(cls, d): return cls(**d) + def asdict(self): + return asdict(self) + class ManagedProcess(AbstractContextManager): """Run the specified command with redirected output. @@ -105,7 +108,7 @@ def _make_wdir(self): def _dump(self): self._make_wdir() with open(self.info_path, "w", encoding="utf-8") as fobj: - json.dump(asdict(self.info), fobj) + json.dump(self.info.asdict(), fobj) with open(self.pidfile_path, "w", encoding="utf-8") as fobj: fobj.write(str(self.pid)) diff --git a/tests/unit/proc/test_manager.py b/tests/unit/proc/test_manager.py new file mode 100644 index 0000000000..f97e3bfc58 --- /dev/null +++ b/tests/unit/proc/test_manager.py @@ -0,0 +1,84 @@ +import json +import os +import signal +import sys + +import pytest + +from dvc.proc.exceptions import UnsupportedSignalError +from dvc.proc.manager import ProcessManager +from dvc.proc.process import ProcessInfo + +PID_FINISHED = 1234 +PID_RUNNING = 5678 + + +def create_process(root: str, name: str, pid: int, returncode=None): + info_path = os.path.join(root, name, f"{name}.json") + os.makedirs(os.path.join(root, name)) + process_info = ProcessInfo( + pid=pid, stdin=None, stdout=None, stderr=None, returncode=returncode + ) + with open(info_path, "w", encoding="utf-8") as fobj: + json.dump(process_info.asdict(), fobj) + + +@pytest.fixture +def finished_process(tmp_dir): + key = "finished" + create_process(tmp_dir, key, PID_FINISHED, 0) + return key + + +@pytest.fixture +def running_process(tmp_dir): + key = "running" + create_process(tmp_dir, key, PID_RUNNING) + return key + + +def test_send_signal(tmp_dir, mocker, finished_process, running_process): + m = mocker.patch("os.kill") + process_manager = ProcessManager(tmp_dir) + process_manager.send_signal(running_process, signal.SIGTERM) + m.assert_called_once_with(PID_RUNNING, signal.SIGTERM) + + m = mocker.patch("os.kill") + process_manager.send_signal(finished_process, signal.SIGTERM) + m.assert_not_called() + + if sys.platform == "win32": + with pytest.raises(UnsupportedSignalError): + process_manager.send_signal(finished_process, signal.SIGABRT) + + +def test_dead_process(tmp_dir, mocker, running_process): + process_manager = ProcessManager(tmp_dir) + with pytest.raises(ProcessLookupError): + process_manager.send_signal(running_process, signal.SIGTERM) + assert process_manager[running_process].returncode == -1 + + +def test_kill(tmp_dir, mocker, finished_process, running_process): + m = mocker.patch("os.kill") + process_manager = ProcessManager(tmp_dir) + process_manager.kill(running_process) + if sys.platform == "win32": + m.assert_called_once_with(PID_RUNNING, signal.SIGTERM) + else: + m.assert_called_once_with(PID_RUNNING, signal.SIGKILL) + + m = mocker.patch("os.kill") + process_manager.kill(finished_process) + m.assert_not_called() + + +def test_terminate(tmp_dir, mocker, running_process, finished_process): + m = mocker.patch("os.kill") + process_manager = ProcessManager(tmp_dir) + process_manager.terminate(running_process) + m.assert_called_once_with(PID_RUNNING, signal.SIGTERM) + + m.reset_mock() + process_manager.terminate(finished_process) + m.assert_not_called()