forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
164 lines (124 loc) · 5.41 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import shutil
from pathlib import Path
from typing import List, Optional
import pytest
import torch
from lightning.fabric.utilities.testing import _runif_reasons
from lightning_utilities.core.imports import RequirementCache
@pytest.fixture()
def fake_checkpoint_dir(tmp_path):
os.chdir(tmp_path)
checkpoint_dir = tmp_path / "checkpoints" / "tmp"
checkpoint_dir.mkdir(parents=True)
(checkpoint_dir / "lit_model.pth").touch()
(checkpoint_dir / "model_config.yaml").touch()
(checkpoint_dir / "tokenizer.json").touch()
(checkpoint_dir / "tokenizer_config.json").touch()
return checkpoint_dir
class TensorLike:
def __eq__(self, other):
return isinstance(other, torch.Tensor)
@pytest.fixture()
def tensor_like():
return TensorLike()
class FloatLike:
def __eq__(self, other):
return not isinstance(other, int) and isinstance(other, float)
@pytest.fixture()
def float_like():
return FloatLike()
@pytest.fixture(autouse=True)
def restore_default_dtype():
# just in case
torch.set_default_dtype(torch.float32)
@pytest.fixture(autouse=True)
def destroy_process_group():
yield
import torch.distributed
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
class MockTokenizer:
"""A dummy tokenizer that encodes each character as its ASCII code."""
bos_id = 0
eos_id = 1
def encode(self, text: str, bos: Optional[bool] = None, eos: bool = False, max_length: int = -1) -> torch.Tensor:
output = []
if bos:
output.append(self.bos_id)
output.extend([ord(c) for c in text])
if eos:
output.append(self.eos_id)
output = output[:max_length] if max_length > 0 else output
return torch.tensor(output)
def decode(self, tokens: torch.Tensor) -> str:
return "".join(chr(int(t)) for t in tokens.tolist())
@pytest.fixture()
def mock_tokenizer():
return MockTokenizer()
@pytest.fixture()
def alpaca_path(tmp_path):
file = Path(__file__).parent / "data" / "fixtures" / "alpaca.json"
shutil.copyfile(file, tmp_path / "alpaca.json")
return tmp_path / "alpaca.json"
@pytest.fixture()
def dolly_path(tmp_path):
file = Path(__file__).parent / "data" / "fixtures" / "dolly.json"
shutil.copyfile(file, tmp_path / "dolly.json")
return tmp_path / "dolly.json"
@pytest.fixture()
def longform_path(tmp_path):
path = tmp_path / "longform"
path.mkdir()
for split in ("train", "val"):
file = Path(__file__).parent / "data" / "fixtures" / f"longform_{split}.json"
shutil.copyfile(file, path / f"{split}.json")
return path
def RunIf(thunder: Optional[bool] = None, **kwargs):
reasons, marker_kwargs = _runif_reasons(**kwargs)
if thunder is not None:
thunder_available = bool(RequirementCache("lightning-thunder", "thunder"))
if thunder and not thunder_available:
reasons.append("Thunder")
elif not thunder and thunder_available:
reasons.append("not Thunder")
return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs)
# https://github.com/Lightning-AI/lightning/blob/6e517bd55b50166138ce6ab915abd4547702994b/tests/tests_fabric/conftest.py#L140
def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:
initial_size = len(items)
conditions = []
filtered, skipped = 0, 0
options = {"standalone": "PL_RUN_STANDALONE_TESTS", "min_cuda_gpus": "PL_RUN_CUDA_TESTS"}
if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1":
# special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests.
# by deleting the key, we avoid filtering out the CPU tests
del options["min_cuda_gpus"]
for kwarg, env_var in options.items():
# this will compute the intersection of all tests selected per environment variable
if os.getenv(env_var, "0") == "1":
conditions.append(env_var)
for i, test in reversed(list(enumerate(items))): # loop in reverse, since we are going to pop items
already_skipped = any(marker.name == "skip" for marker in test.own_markers)
if already_skipped:
# the test was going to be skipped anyway, filter it out
items.pop(i)
skipped += 1
continue
has_runif_with_kwarg = any(
marker.name == "skipif" and marker.kwargs.get(kwarg) for marker in test.own_markers
)
if not has_runif_with_kwarg:
# the test has `@RunIf(kwarg=True)`, filter it out
items.pop(i)
filtered += 1
if config.option.verbose >= 0 and (filtered or skipped):
writer = config.get_terminal_writer()
writer.write(
f"\nThe number of tests has been filtered from {initial_size} to {initial_size - filtered} after the"
f" filters {conditions}.\n{skipped} tests are marked as unconditional skips.\nIn total,"
f" {len(items)} tests will run.\n",
flush=True,
bold=True,
purple=True, # oh yeah, branded pytest messages
)