Skip to content

Commit

Permalink
GPU UT - enable for test_device (pytorch#523)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#523

Adjust `test_device` - mock `torch.cuda.is_available` and methods in `test_get_gpu_stats`

Reviewed By: gunchu

Differential Revision: D48754274

fbshipit-source-id: 93345f3cbecf55451cb5942e7888e9abf8522045
  • Loading branch information
galrotem authored and facebook-github-bot committed Aug 30, 2023
1 parent 8be052d commit cc0fb2c
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions tests/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass
from typing import Any, Dict
from unittest import mock
from unittest.mock import patch

import torch
from torchtnt.utils.device import (
Expand All @@ -25,16 +26,10 @@

class DeviceTest(unittest.TestCase):

# pyre-fixme[4]: Attribute must be annotated.
cuda_available = torch.cuda.is_available()
cuda_available: bool = torch.cuda.is_available()

# pyre-fixme[56]: Pyre was not able to infer the type of argument `not
# torchtnt.tests.utils.test_device.DeviceTest.cuda_available` to decorator factory
# `unittest.skipUnless`.
@unittest.skipUnless(
condition=(not cuda_available), reason="This test shouldn't run on a GPU host."
)
def test_get_cpu_device(self) -> None:
@patch("torch.cuda.is_available", return_value=False)
def test_get_cpu_device(self, _) -> None:
device = get_device_from_env()
self.assertEqual(device.type, "cpu")
self.assertEqual(device.index, None)
Expand Down Expand Up @@ -279,13 +274,16 @@ def test_get_cpu_stats(self) -> None:
self.assertGreaterEqual(cpu_stats["cpu_swap_percent"], 0)
self.assertLessEqual(cpu_stats["cpu_swap_percent"], 100)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_get_gpu_stats(self) -> None:
"""Get Nvidia GPU stats, check that values are populated."""
device = torch.device("cuda:0")
gpu_stats = get_nvidia_smi_gpu_stats(device)

with mock.patch("shutil.which"), mock.patch(
"torchtnt.utils.device.subprocess.run"
) as subprocess_run_mock:
subprocess_run_mock.return_value.stdout = "0, 0, 0, 2, 16273, 38, 15"
gpu_stats = get_nvidia_smi_gpu_stats(device)

# Check that percentages are between 0 and 100
self.assertGreaterEqual(gpu_stats["utilization_gpu_percent"], 0)
self.assertLessEqual(gpu_stats["utilization_gpu_percent"], 100)
Expand Down

0 comments on commit cc0fb2c

Please sign in to comment.