forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_utils.py
136 lines (108 loc) · 4.82 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import sys
from contextlib import redirect_stderr
from io import StringIO
import pytest
import torch
import torch.nn.functional as F
def test_find_multiple():
from lit_gpt.utils import find_multiple
assert find_multiple(17, 5) == 20
assert find_multiple(30, 7) == 35
assert find_multiple(10, 2) == 10
assert find_multiple(5, 10) == 10
assert find_multiple(50254, 128) == 50304
assert find_multiple(50254, 256) == 50432
assert find_multiple(50254, 512) == 50688
@pytest.mark.skipif(sys.platform == "win32", reason="match fails on windows. why did they have to use backslashes?")
def test_check_valid_checkpoint_dir(tmp_path):
from lit_gpt.utils import check_valid_checkpoint_dir
os.chdir(tmp_path)
out = StringIO()
with pytest.raises(SystemExit), redirect_stderr(out):
check_valid_checkpoint_dir(tmp_path)
out = out.getvalue().strip()
expected = f"""
--checkpoint_dir '{str(tmp_path.absolute())}' is missing the files: ['lit_model.pth', 'lit_config.json', 'tokenizer.json OR tokenizer.model', 'tokenizer_config.json'].
Find download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials
See all download options by running:
python scripts/download.py
""".strip()
assert out == expected
out = StringIO()
checkpoint_dir = tmp_path / "checkpoints" / "stabilityai" / "stablelm-base-alpha-3b"
with pytest.raises(SystemExit), redirect_stderr(out):
check_valid_checkpoint_dir(checkpoint_dir)
out = out.getvalue().strip()
expected = f"""
--checkpoint_dir '{str(checkpoint_dir.absolute())}' is not a checkpoint directory.
Find download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials
See all download options by running:
python scripts/download.py
""".strip()
assert out == expected
out = StringIO()
checkpoint_dir.mkdir(parents=True)
foo_checkpoint_dir = tmp_path / "foo"
with pytest.raises(SystemExit), redirect_stderr(out):
check_valid_checkpoint_dir(foo_checkpoint_dir)
out = out.getvalue().strip()
expected = f"""
--checkpoint_dir '{str(foo_checkpoint_dir.absolute())}' is not a checkpoint directory.
Find download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials
You have downloaded locally:
--checkpoint_dir '{str(checkpoint_dir.absolute())}'
See all download options by running:
python scripts/download.py
""".strip()
assert out == expected
def test_incremental_write(tmp_path):
from lit_gpt.utils import incremental_save
sd = {str(k): torch.randn(5, 10) for k in range(3)}
sd["0"].someattr = 1
sd_expected = {k: v.clone() for k, v in sd.items()}
fn = str(tmp_path / "test.pt")
with incremental_save(fn) as f:
sd["0"] = f.store_early(sd["0"])
sd["2"] = f.store_early(sd["2"])
f.save(sd)
sd_actual = torch.load(fn)
assert sd_actual.keys() == sd_expected.keys()
assert sd_actual["0"].someattr == 1 # requires PyTorch 2.0+
for k, v_expected in sd_expected.items():
v_actual = sd_actual[k]
torch.testing.assert_close(v_expected, v_actual)
@pytest.mark.parametrize("B", (1, 2))
def test_chunked_cross_entropy(B):
from lit_gpt.utils import chunked_cross_entropy
V = 50
T = 25
regular_logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))
baseline_loss = F.cross_entropy(regular_logits.reshape(-1, regular_logits.size(-1)), targets.reshape(-1))
regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0)
assert torch.equal(baseline_loss, regular_loss)
assert regular_loss.numel() == 1
chunked_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=10)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)
logit_chunk_size = 6
assert T % logit_chunk_size != 0 # ensure leftover
chunked_logits = list(regular_logits.split(logit_chunk_size, dim=1))
chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=0)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)
chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=10)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)
def test_num_parameters():
from lit_gpt.utils import num_parameters
model = torch.nn.Linear(2, 2)
assert num_parameters(model) == 6
assert num_parameters(model, requires_grad=True) == 6
assert num_parameters(model, requires_grad=False) == 0
model = torch.nn.Linear(2, 2)
model.bias.requires_grad = False
assert num_parameters(model) == 6
assert num_parameters(model, requires_grad=True) == 4
assert num_parameters(model, requires_grad=False) == 2