Skip to content

Commit

Permalink
feat(launch): Support template variables when queueing launch runs (w…
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleGoyette authored Nov 16, 2023
1 parent 57d16d8 commit 5528355
Show file tree
Hide file tree
Showing 9 changed files with 459 additions and 149 deletions.
16 changes: 0 additions & 16 deletions tests/pytest_tests/system_tests/test_core/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,6 @@ def test_run_metadata(wandb_init):
assert len(metadata)


def test_run_queue(user):
api = Api()
queue = api.create_run_queue(
name="test-queue",
entity=user,
access="project",
type="local-container",
)
try:
assert queue.name == "test-queue"
assert queue.access == "PROJECT"
assert queue.type == "local-container"
finally:
queue.delete()


@pytest.fixture(scope="function")
def inject_run(user, inject_graphql_response):
def helper(
Expand Down
1 change: 0 additions & 1 deletion tests/pytest_tests/system_tests/test_launch/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_job_call(relay_server, user, wandb_init, test_settings):
assert queued_run.state == "pending"
assert queued_run.entity == user
assert queued_run.project == proj
assert queued_run.container_job is True

rqi = internal_api.pop_from_run_queue(queue, user, proj)

Expand Down
198 changes: 190 additions & 8 deletions tests/pytest_tests/system_tests/test_launch/test_launch_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import wandb
from wandb.apis.public import Api as PublicApi
from wandb.sdk.internal.internal_api import Api as InternalApi
from wandb.sdk.internal.internal_api import UnsupportedError
from wandb.sdk.launch._launch_add import launch_add
from wandb.sdk.launch.utils import LAUNCH_DEFAULT_PROJECT
from wandb.sdk.launch.utils import LAUNCH_DEFAULT_PROJECT, LaunchError


class MockBranch:
Expand Down Expand Up @@ -210,7 +211,6 @@ async def patched_validate_docker_installation():
assert queued_run.state == "pending"
assert queued_run.entity == user
assert queued_run.project == proj
assert queued_run.container_job is True
assert queued_run.project_queue == LAUNCH_DEFAULT_PROJECT

rqi = internal_api.pop_from_run_queue(queue, user, LAUNCH_DEFAULT_PROJECT)
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_push_to_runqueue_exists(
entity=user, project=LAUNCH_DEFAULT_PROJECT, queue_name=queue, access="USER"
)

result = api.push_to_run_queue(queue, args, LAUNCH_DEFAULT_PROJECT)
result = api.push_to_run_queue(queue, args, None, LAUNCH_DEFAULT_PROJECT)

assert result["runQueueItemId"]

Expand Down Expand Up @@ -364,7 +364,7 @@ def test_push_to_default_runqueue_notexist(
with relay_server():
run = wandb_init(settings=settings)
res = api.push_to_run_queue(
"nonexistent-queue", launch_spec, LAUNCH_DEFAULT_PROJECT
"nonexistent-queue", launch_spec, None, LAUNCH_DEFAULT_PROJECT
)
run.finish()

Expand Down Expand Up @@ -406,7 +406,7 @@ def test_push_to_runqueue_old_server(
entity=user, project=LAUNCH_DEFAULT_PROJECT, queue_name=queue, access="USER"
)

result = api.push_to_run_queue(queue, args, LAUNCH_DEFAULT_PROJECT)
result = api.push_to_run_queue(queue, args, None, LAUNCH_DEFAULT_PROJECT)
run.finish()

assert result["runQueueItemId"]
Expand All @@ -433,7 +433,7 @@ def test_push_with_repository(
with relay_server():
run = wandb_init(settings=settings)
res = api.push_to_run_queue(
"nonexistent-queue", launch_spec, LAUNCH_DEFAULT_PROJECT
"nonexistent-queue", launch_spec, None, LAUNCH_DEFAULT_PROJECT
)
run.finish()

Expand Down Expand Up @@ -475,6 +475,153 @@ def test_launch_add_repository(
run.finish()


def test_launch_add_template_variables(runner, relay_server, user):
queue_name = "tvqueue"
proj = "test1"
queue_config = {"e": ["{{var1}}"]}
queue_template_variables = {
"var1": {"schema": {"type": "string", "enum": ["a", "b"]}}
}
template_variables = {"var1": "a"}
base_config = {"template_variables": {"var1": "b"}}
with relay_server() as relay, runner.isolated_filesystem():
api = PublicApi(api_key=user)
api.create_run_queue(
entity=user,
name=queue_name,
type="local-container",
config=queue_config,
template_variables=queue_template_variables,
)
_ = launch_add(
template_variables=template_variables,
project=proj,
entity=user,
queue_name=queue_name,
docker_image="abc:latest",
config=base_config,
)
for comm in relay.context.raw_data:
q = comm["request"].get("query")
vars = comm["request"].get("variables")
if q and "mutation pushToRunQueueByName(" in str(q):
assert comm["response"].get("data") is not None
assert vars["templateVariableValues"] == '{"var1": "a"}'
elif q and "mutation pushToRunQueue(" in str(q):
raise Exception("should not be falling back to legacy here")


def test_launch_add_template_variables_legacy_push(
runner, relay_server, user, monkeypatch
):
queue_name = "tvqueue"
proj = "test1"
queue_config = {"e": ["{{var1}}"]}
queue_template_variables = {
"var1": {"schema": {"type": "string", "enum": ["a", "b"]}}
}
template_variables = {"var1": "a"}
monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_by_name",
lambda *args, **kwargs: None,
)
with relay_server() as relay, runner.isolated_filesystem():
api = PublicApi(api_key=user)
api.create_run_queue(
entity=user,
name=queue_name,
type="local-container",
config=queue_config,
template_variables=queue_template_variables,
)
_ = launch_add(
template_variables=template_variables,
project=proj,
entity=user,
queue_name=queue_name,
docker_image="abc:latest",
)
for comm in relay.context.raw_data:
q = comm["request"].get("query")
if q and "mutation pushToRunQueue(" in str(q):
assert comm["response"].get("data") is not None
elif q and "mutation pushToRunQueueByName(" in str(q):
raise Exception("should not be using non legacy here")


def test_launch_add_template_variables_not_supported(user, monkeypatch):
queue_name = "tvqueue"
proj = "test1"
queue_config = {"e": ["{{var1}}"]}
template_variables = {"var1": "a"}

def patched_push_to_run_queue_introspection(*args, **kwargs):
args[0].server_supports_template_varaibles = False
return False

monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_introspection",
patched_push_to_run_queue_introspection,
)
api = PublicApi(api_key=user)
api.create_run_queue(
entity=user,
name=queue_name,
type="local-container",
config=queue_config,
)
with pytest.raises(UnsupportedError):
_ = launch_add(
template_variables=template_variables,
project=proj,
entity=user,
queue_name=queue_name,
docker_image="abc:latest",
)


def test_launch_add_template_variables_not_supported_legacy_push(
runner, user, monkeypatch
):
queue_name = "tvqueue"
proj = "test1"
queue_config = {"e": ["{{var1}}"]}
template_variables = {"var1": "a"}

def patched_push_to_run_queue_introspection(*args, **kwargs):
args[0].server_supports_template_varaibles = False
return False

monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_introspection",
patched_push_to_run_queue_introspection,
)
monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_by_name",
lambda *args, **kwargs: None,
)
with runner.isolated_filesystem():
api = PublicApi(api_key=user)
api.create_run_queue(
entity=user,
name=queue_name,
type="local-container",
config=queue_config,
)
with pytest.raises(UnsupportedError):
_ = launch_add(
template_variables=template_variables,
project=proj,
entity=user,
queue_name=queue_name,
docker_image="abc:latest",
)


def test_display_updated_runspec(
relay_server, user, test_settings, wandb_init, monkeypatch
):
Expand All @@ -485,9 +632,11 @@ def test_display_updated_runspec(
settings = test_settings({"project": proj})
api = InternalApi()

def push_with_drc(api, queue_name, launch_spec, project_queue):
def push_with_drc(api, queue_name, launch_spec, template_variables, project_queue):
# mock having a DRC
res = api.push_to_run_queue(queue_name, launch_spec, project_queue)
res = api.push_to_run_queue(
queue_name, launch_spec, template_variables, project_queue
)
res["runSpec"] = launch_spec
res["runSpec"]["resource_args"] = {"kubernetes": {"volume": "x/awda/xxx"}}
return res
Expand Down Expand Up @@ -515,3 +664,36 @@ def push_with_drc(api, queue_name, launch_spec, project_queue):
)

run.finish()


def test_container_queued_run(monkeypatch, user):
def patched_push_to_run_queue_by_name(*args, **kwargs):
return {"runQueueItemId": "1"}

monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_by_name",
lambda *arg, **kwargs: patched_push_to_run_queue_by_name(*arg, **kwargs),
)
monkeypatch.setattr(
wandb.PublicApi,
"artifact",
lambda *arg, **kwargs: "artifact",
)

queued_run = launch_add(job="test/test/test-job:v0")
assert queued_run


def test_job_dne(monkeypatch, user):
def patched_push_to_run_queue_by_name(*args, **kwargs):
return {"runQueueItemId": "1"}

monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_by_name",
lambda *arg, **kwargs: patched_push_to_run_queue_by_name(*arg, **kwargs),
)

with pytest.raises(LaunchError):
launch_add(job="test/test/test-job:v0")
49 changes: 49 additions & 0 deletions tests/pytest_tests/system_tests/tests_launch/test_public_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import wandb
import wandb.apis.public
import wandb.util
from wandb import Api
from wandb.sdk.internal.internal_api import UnsupportedError


def test_create_run_queue_template_variables_not_supported(runner, user, monkeypatch):
queue_name = "tvqueue"
queue_config = {"e": ["{{var1}}"]}
queue_template_variables = {
"var1": {"schema": {"type": "string", "enum": ["a", "b"]}}
}

def patched_push_to_run_queue_introspection(*args, **kwargs):
args[0].server_supports_template_varaibles = False
return False

monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_introspection",
patched_push_to_run_queue_introspection,
)
with runner.isolated_filesystem():
api = Api(api_key=user)
with pytest.raises(UnsupportedError):
api.create_run_queue(
entity=user,
name=queue_name,
type="local-container",
config=queue_config,
template_variables=queue_template_variables,
)


def test_run_queue(user):
api = Api()
queue = api.create_run_queue(
name="test-queue",
entity=user,
type="local-container",
)
try:
assert queue.name == "test-queue"
assert queue.access == "PROJECT"
assert queue.type == "local-container"
finally:
queue.delete()
37 changes: 0 additions & 37 deletions tests/pytest_tests/unit_tests_old/test_launch/test_launch_jobs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import json
import os
import tempfile
from unittest import mock

import pytest
import wandb
import wandb.sdk.launch._launch as _launch
import wandb.sdk.launch._project_spec as _project_spec
from wandb.sdk.data_types._dtypes import TypeRegistry
from wandb.sdk.launch._launch_add import launch_add
from wandb.sdk.launch.errors import LaunchError

from tests.pytest_tests.unit_tests_old import utils
Expand Down Expand Up @@ -209,38 +207,3 @@ def job_download_func(root):
}
mock_with_run_info = _launch.launch(**kwargs)
check_mock_run_info(mock_with_run_info, EMPTY_BACKEND_CONFIG, kwargs)


def test_launch_add_container_queued_run(
live_mock_server, mocked_public_artifact, monkeypatch
):
def job_download_func(root=None):
if root is None:
root = tempfile.mkdtemp()
with open(os.path.join(root, "wandb-job.json"), "w") as f:
source = {
"_version": "v0",
"source_type": "image",
"source": {"image": "my-test-image:latest"},
"input_types": INPUT_TYPES,
"output_types": OUTPUT_TYPES,
}
f.write(json.dumps(source))
with open(os.path.join(root, "requirements.frozen.txt"), "w") as f:
f.write(utils.fixture_open("requirements.txt").read())

return root

mocked_public_artifact(job_download_func)

def patched_push_to_run_queue_by_name(*args, **kwargs):
return {"runQueueItemId": "1"}

monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_by_name",
lambda *arg, **kwargs: patched_push_to_run_queue_by_name(*arg, **kwargs),
)

queued_run = launch_add(job="test/test/test-job:v0")
assert queued_run
Loading

0 comments on commit 5528355

Please sign in to comment.