-
Notifications
You must be signed in to change notification settings - Fork 681
/
Copy pathconftest.py
44 lines (35 loc) · 1.04 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gc
import random
import numpy as np
import pytest
import torch
def _set_seed():
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.mps.manual_seed(0)
np.random.seed(0)
random.seed(0)
def pytest_runtest_call(item):
try:
_set_seed()
item.runtest()
except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled":
pytest.skip("Torch not compiled with CUDA enabled")
raise
except RuntimeError as re:
# CUDA-enabled Torch build, but no CUDA-capable device found
if "Found no NVIDIA driver on your system" in str(re):
pytest.skip("No NVIDIA driver found")
raise
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(item, nextitem):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@pytest.fixture(scope="session")
def requires_cuda() -> bool:
cuda_available = torch.cuda.is_available()
if not cuda_available:
pytest.skip("CUDA is required")
return cuda_available