Skip to content

Commit

Permalink
Add tests for get_jobs (#18)
Browse files Browse the repository at this point in the history
* Add tests for get_jobs

* Add test for job_state

* Fix timezones

* blackify

* Handle wrong job_id type

* Calm down pylint

* Add tzlocal fixture

* Handle job_id is None

* Fix node tests with new test config

* Blackify

* username -> user

* Using | instead of Union

* Validate json with txt instead of yml files

* job_id=[] returns no jobs

* Have 2 separate DB in tests

One DB for read-only tests that can have a session scope to avoid
wasting time on DB rebuild, and one DB for read-and-write tests that we
must keep on test scope to avoid some tests impacting other tests.

* blackify
  • Loading branch information
bouthilx authored Mar 1, 2023
1 parent a5829a9 commit 1736c24
Show file tree
Hide file tree
Showing 27 changed files with 3,335 additions and 74 deletions.
1 change: 1 addition & 0 deletions sarc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

MTL = zoneinfo.ZoneInfo("America/Montreal")
UTC = zoneinfo.ZoneInfo("UTC")
TZLOCAL = zoneinfo.ZoneInfo(str(datetime.now().astimezone().tzinfo))


class ConfigurationError(Exception):
Expand Down
41 changes: 26 additions & 15 deletions sarc/jobs/job.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# pylint: disable=dangerous-default-value
from __future__ import annotations

from datetime import datetime, time, timedelta
from enum import Enum
from typing import Optional, Union
from typing import Optional

from pydantic import validator
from pydantic_mongo import AbstractRepository, ObjectIdField

from ..config import MTL, UTC, BaseModel, ClusterConfig, config
from ..config import MTL, TZLOCAL, UTC, BaseModel, ClusterConfig, config


class SlurmState(str, Enum):
Expand Down Expand Up @@ -147,14 +147,15 @@ def jobs_collection():
return SlurmJobRepository(database=db)


# pylint: disable=too-many-branches,dangerous-default-value
def get_jobs(
*,
cluster: Union[str, ClusterConfig] = None,
job_id: Union[int, list[int]] = None,
job_state: Union[str, SlurmState] = None,
username: str = None,
start: Union[str, datetime] = None,
end: Union[str, datetime] = None,
cluster: str | ClusterConfig | None = None,
job_id: int | list[int] | None = None,
job_state: str | SlurmState | None = None,
user: str | None = None,
start: str | datetime | None = None,
end: str | datetime | None = None,
query_options: dict = {},
) -> list[SlurmJob]:
"""Get jobs that match the query.
Expand All @@ -170,11 +171,18 @@ def get_jobs(
cluster = config().clusters[cluster]

if isinstance(start, str):
start = datetime.combine(datetime.strptime(start, "%Y-%m-%d"), time.min)
start = datetime.combine(
datetime.strptime(start, "%Y-%m-%d"), time.min
).replace(tzinfo=TZLOCAL)
if isinstance(end, str):
end = datetime.combine(
datetime.strptime(end, "%Y-%m-%d"), time.min
) + timedelta(days=1)
end = (datetime.combine(datetime.strptime(end, "%Y-%m-%d"), time.min)).replace(
tzinfo=TZLOCAL
)

if start is not None:
start = start.astimezone(UTC)
if end is not None:
end = end.astimezone(UTC)

query = {}
if isinstance(cluster, ClusterConfig):
Expand All @@ -184,13 +192,15 @@ def get_jobs(
query["job_id"] = job_id
elif isinstance(job_id, list):
query["job_id"] = {"$in": job_id}
elif job_id is not None:
raise TypeError(f"job_id must be an int or a list of ints: {job_id}")

if end:
# Select any job that had a status before the given end time.
query["submit_time"] = {"$lt": end}

if username:
query["username"] = username
if user:
query["user"] = user

if job_state:
query["job_state"] = job_state
Expand All @@ -211,6 +221,7 @@ def get_jobs(
return coll.find_by(query, **query_options)


# pylint: disable=dangerous-default-value
def get_job(*, query_options={}, **kwargs):
"""Get a single job that matches the query, or None if nothing is found.
Expand Down
21 changes: 8 additions & 13 deletions sarc/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,14 @@ def curate_label_argument(label_name: str, label_values: None | str | list[str])

def generate_label_configs(
node_id: None | str | list[str],
cluster_name: None | str | list[str],
cluster: None | str | list[str],
):
node_configs = curate_label_argument("instance", node_id)
cluster_configs = curate_label_argument("cluster", cluster_name)

if any(
config.get("cluster", "mila-cluster") != "mila-cluster"
for config in cluster_configs
):
raise NotImplementedError("Only mila-cluster is supported for now.")
cluster_configs = curate_label_argument("cluster", cluster)

# Create list of label_configs based on node_id and cluster_name
for node_config, cluster_config in itertools.product(node_configs, cluster_configs):
query_config = copy.deepcopy(node_config)
if cluster_config:
logger.warning("Cluster name is not supported for now.")
query_config.update(cluster_config)
# yield node_config | cluster_config
yield query_config
Expand Down Expand Up @@ -109,6 +101,7 @@ def generate_custom_query(


def query_prom(
cluster: str,
metric_name: str,
label_config: dict,
start: datetime,
Expand All @@ -117,15 +110,15 @@ def query_prom(
):
query = generate_custom_query(metric_name, label_config, start, end, running_window)

return config().clusters["mila"].prometheus.custom_query(query)
return config().clusters[cluster].prometheus.custom_query(query)


def get_nodes_time_series(
metrics: str | list[str],
cluster: str | list[str],
start: datetime,
end: None | datetime = None,
node_id: None | str | list[str] = None,
cluster_name: None | str | list[str] = None,
running_window: timedelta = timedelta(days=1),
) -> pd.DataFrame:
"""Fetch node metrics
Expand Down Expand Up @@ -156,9 +149,11 @@ def get_nodes_time_series(

df = None
for metric_name, label_config in itertools.product(
metrics, generate_label_configs(node_id, cluster_name)
metrics, generate_label_configs(node_id, cluster)
):
label_config = copy.deepcopy(label_config)
rval = query_prom(
label_config.pop("cluster"),
metric_name,
label_config=label_config,
start=start,
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import zoneinfo

import pytest


@pytest.fixture
def tzlocal_is_mtl(monkeypatch):
monkeypatch.setattr("sarc.config.TZLOCAL", zoneinfo.ZoneInfo("America/Montreal"))
monkeypatch.setattr("sarc.jobs.job.TZLOCAL", zoneinfo.ZoneInfo("America/Montreal"))
4 changes: 2 additions & 2 deletions tests/functional/allocations/test_func_allocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
}


@pytest.mark.usefixtures("init_db_with_allocations")
@pytest.mark.usefixtures("read_only_db")
@pytest.mark.parametrize("params,", parameters.values(), ids=parameters.keys())
def test_get_allocations(params, data_regression):
data = get_allocations(**params)
Expand All @@ -32,7 +32,7 @@ def test_get_allocations(params, data_regression):
)


@pytest.mark.usefixtures("init_db_with_allocations")
@pytest.mark.usefixtures("read_only_db")
@pytest.mark.parametrize("params,", parameters.values(), ids=parameters.keys())
def test_get_allocations_summaries(params, dataframe_regression):
data = get_allocation_summaries(**params)
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/allocations/test_update_allocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@pytest.mark.freeze_time("2023-02-15")
@pytest.mark.usefixtures("init_empty_db")
@pytest.mark.usefixtures("empty_read_write_db")
def test_update_allocations(data_regression):
assert get_allocations(cluster_name=["fromage", "patate"]) == []
main(["acquire", "allocations", "--file", os.path.join(FOLDER, "allocations.csv")])
Expand All @@ -21,7 +21,7 @@ def test_update_allocations(data_regression):


@pytest.mark.freeze_time("2023-02-15")
@pytest.mark.usefixtures("init_empty_db")
@pytest.mark.usefixtures("empty_read_write_db")
def test_update_allocations_no_duplicates(data_regression):
assert get_allocations(cluster_name=["fromage", "patate"]) == []
main(["acquire", "allocations", "--file", os.path.join(FOLDER, "allocations.csv")])
Expand All @@ -36,7 +36,7 @@ def test_update_allocations_no_duplicates(data_regression):


@pytest.mark.freeze_time("2023-02-15")
@pytest.mark.usefixtures("init_empty_db")
@pytest.mark.usefixtures("empty_read_write_db")
def test_update_allocations_invalid_with_some_valid(data_regression):
assert get_allocations(cluster_name=["fromage", "patate"]) == []
main(
Expand All @@ -55,7 +55,7 @@ def test_update_allocations_invalid_with_some_valid(data_regression):


@pytest.mark.freeze_time("2023-02-14")
@pytest.mark.usefixtures("init_empty_db")
@pytest.mark.usefixtures("empty_read_write_db")
def test_update_allocations_invalid_error_msg(data_regression, capsys):
assert get_allocations(cluster_name=["fromage", "patate"]) == []
main(
Expand Down
Loading

0 comments on commit 1736c24

Please sign in to comment.