Skip to content

Commit

Permalink
chore(launch): Improve run stopping robustness for launch runs (wandb…
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleGoyette authored Sep 29, 2023
1 parent e583691 commit 50d613f
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import base64
from unittest import mock
from unittest.mock import MagicMock

import wandb
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker


def test_check_stop_run_not_exist(wandb_init):
job_tracker = JobAndRunStatusTracker(
"run_queue_item_id", "test-queue", MagicMock(), MagicMock()
)
run = wandb_init(id="testrun")
api = wandb.InternalApi()
mock_launch_project = MagicMock()
mock_launch_project.target_entity = run._entity
mock_launch_project.target_project = run._project
mock_launch_project.run_id = run._run_id + "a"
job_tracker.update_run_info(mock_launch_project)

res = job_tracker.check_wandb_run_stopped(api)
assert not res
run.finish()


def test_check_stop_run_exist_stopped(user, wandb_init):
mock.patch("wandb.sdk.wandb_run.thread.interrupt_main", lambda x: None)
job_tracker = JobAndRunStatusTracker(
"run_queue_item_id", "test-queue", MagicMock(), MagicMock()
)
run = wandb_init(id="testrun", entity=user)
print(run._entity)
api = wandb.InternalApi()
encoded_run_id = base64.standard_b64encode(
f"Run:v1:testrun:{run._project}:{run._entity}".encode()
).decode("utf-8")
mock_launch_project = MagicMock()
mock_launch_project.target_entity = run._entity
mock_launch_project.target_project = run._project
mock_launch_project.run_id = run._run_id
job_tracker.update_run_info(mock_launch_project)
assert api.stop_run(encoded_run_id)
assert job_tracker.check_wandb_run_stopped(api)
run.finish()
50 changes: 50 additions & 0 deletions tests/pytest_tests/unit_tests/test_launch/test_agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def _setup(mocker):
mocker.patch("wandb.termerror", mocker.termerror)
mocker.patch("wandb.init", mocker.wandb_init)

mocker.status = MagicMock()
mocker.status.state = "running"
mocker.run = MagicMock()
mocker.run.get_status = MagicMock(return_value=mocker.status)
mocker.runner = MagicMock()
mocker.runner.run = MagicMock(return_value=mocker.run)
mocker.patch(
"wandb.sdk.launch.agent.agent.loader.runner_from_config",
return_value=mocker.runner,
)


def test_loop_capture_stack_trace(mocker):
_setup(mocker)
Expand Down Expand Up @@ -497,3 +508,42 @@ def mock_thread_run_job(*args, **kwargs):
mock_finish_thread_id.assert_called_once_with(
threading.current_thread().ident, exception
)


def test_inner_thread_run_job(mocker):
_setup(mocker)
mocker.patch("wandb.sdk.launch.agent.agent.MAX_WAIT_RUN_STOPPED", new=0)
mocker.patch("wandb.sdk.launch.agent.agent.AGENT_POLLING_INTERVAL", new=0)
mock_config = {
"entity": "test-entity",
"project": "test-project",
}
mock_saver = MagicMock()
job = JobAndRunStatusTracker(
"run_queue_item_id", "test-queue", mock_saver, run=MagicMock()
)
agent = LaunchAgent(api=mocker.api, config=mock_config)
mock_spec = {
"docker": {"docker_image": "blah-blah:latest"},
"entity": "user",
"project": "test",
}

mocker.api.check_stop_requested = True
cancel = MagicMock()
mocker.run.cancel = cancel

def side_effect_func():
job.completed_status = True

cancel.side_effect = side_effect_func

agent._thread_run_job(
mock_spec,
{"runQueueItemId": "blah"},
{},
mocker.api,
threading.current_thread().ident,
job,
)
cancel.assert_called_once()
11 changes: 11 additions & 0 deletions wandb/sdk/launch/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

RUN_INFO_GRACE_PERIOD = 60

MAX_WAIT_RUN_STOPPED = 60

_env_timeout = os.environ.get("WANDB_LAUNCH_START_TIMEOUT")
if _env_timeout:
try:
Expand Down Expand Up @@ -629,6 +631,7 @@ def _thread_run_job(
with self._jobs_lock:
job_tracker.run = run
start_time = time.time()
stopped_time: Optional[float] = None
while self._jobs_event.is_set():
# If run has failed to start before timeout, kill it
state = run.get_status().state
Expand All @@ -642,6 +645,13 @@ def _thread_run_job(
)
if self._check_run_finished(job_tracker, launch_spec):
return
if job_tracker.check_wandb_run_stopped(self._api):
if stopped_time is None:
stopped_time = time.time()
else:
if time.time() - stopped_time > MAX_WAIT_RUN_STOPPED:
run.cancel()

time.sleep(AGENT_POLLING_INTERVAL)
# temp: for local, kill all jobs. we don't yet have good handling for different
# types of runners in general
Expand Down Expand Up @@ -714,6 +724,7 @@ def _check_run_finished(
with self._jobs_lock:
job_tracker.completed_status = status
return True

return False
except LaunchError as e:
wandb.termerror(
Expand Down
18 changes: 18 additions & 0 deletions wandb/sdk/launch/agent/job_status_tracker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import logging
from dataclasses import dataclass
from typing import Optional

from wandb.apis.internal import Api
from wandb.errors import CommError
from wandb.sdk.launch._project_spec import LaunchProject

from ..runner.abstract import AbstractRun
from .run_queue_item_file_saver import RunQueueItemFileSaver

_logger = logging.getLogger(__name__)


@dataclass
class JobAndRunStatusTracker:
Expand All @@ -32,3 +37,16 @@ def update_run_info(self, launch_project: LaunchProject) -> None:

def set_err_stage(self, stage: str) -> None:
self.err_stage = stage

def check_wandb_run_stopped(self, api: Api) -> bool:
assert (
self.run_id is not None
and self.project is not None
and self.entity is not None
), "Job tracker does not contain run info. Update with run info before checking if run stopped"

try:
return api.api.check_stop_requested(self.project, self.entity, self.run_id)
except CommError as e:
_logger.error(f"CommError when checking if wandb run stopped: {e}")
return False

0 comments on commit 50d613f

Please sign in to comment.