-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathutils.py
60 lines (43 loc) · 1.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import re
import timm
import torch
import unittest
from git import Repo
from typing import List
from packaging.version import Version
has_timm_test_models = Version(timm.__version__) >= Version("1.0.12")
default_device = "cuda" if torch.cuda.is_available() else "cpu"
YES_LIST = ["true", "1", "y", "yes"]
RUN_ALL_ENCODERS = os.getenv("RUN_ALL_ENCODERS", "false").lower() in YES_LIST
RUN_SLOW = os.getenv("RUN_SLOW", "false").lower() in YES_LIST
RUN_ALL = os.getenv("RUN_ALL", "false").lower() in YES_LIST
def slow_test(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case)
def requires_torch_greater_or_equal(version: str):
torch_version = Version(torch.__version__)
provided_version = Version(version)
return unittest.skipUnless(
torch_version >= provided_version,
f"torch version {torch_version} is less than {provided_version}",
)
def check_run_test_on_diff_or_main(filepath_patterns: List[str]):
if RUN_ALL:
return True
try:
repo = Repo(".")
current_branch = repo.active_branch.name
diff_files = repo.git.diff("main", name_only=True).splitlines()
except Exception:
return True
if current_branch == "main":
return True
for pattern in filepath_patterns:
for file_path in diff_files:
if re.search(pattern, file_path):
return True
return False