forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_adapter.py
133 lines (103 loc) · 4.51 KB
/
test_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import sys
from contextlib import redirect_stdout
from dataclasses import asdict
from io import StringIO
from unittest.mock import Mock
import pytest
import torch
from lightning import Fabric
def test_config_identical():
import lit_gpt.adapter as gpt_adapter
import lit_gpt.model as gpt
name = "pythia-70m"
base_config = asdict(gpt.Config.from_name(name))
adapter_config = asdict(gpt_adapter.Config.from_name(name))
del adapter_config["adapter_prompt_length"]
del adapter_config["adapter_start_layer"]
assert adapter_config == base_config
with Fabric(accelerator="cpu").init_module(empty_init=True):
base_model = gpt.GPT.from_name(name)
adapter_model = gpt_adapter.GPT.from_name(name)
assert adapter_model.lm_head.weight.shape == base_model.lm_head.weight.shape
def test_adapter_filter(tmp_path):
from lit_gpt.adapter import GPT, adapter_filter
fabric = Fabric(devices=1)
model = GPT.from_name("pythia-70m", n_layer=4)
save_path = tmp_path / "model.pth"
fabric.save(save_path, {"model": model}, filter={"model": adapter_filter})
saved = torch.load(save_path)["model"]
expected = {
"transformer.h.2.attn.adapter_wte.weight",
"transformer.h.2.attn.gating_factor",
"transformer.h.3.attn.adapter_wte.weight",
"transformer.h.3.attn.gating_factor",
}
assert set(saved) == expected
def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch):
import finetune.adapter as module
module.gradient_accumulation_iters = 1
module.save_interval = 2
module.eval_interval = 2
module.eval_iters = 2
module.eval_max_new_tokens = 1
module.max_iters = 6
data = [
{"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([1, 2, 3])},
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([2, 3, 4])},
]
torch.save(data, tmp_path / "train.pt")
torch.save(data, tmp_path / "test.pt")
from lit_gpt.config import name_to_config
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)
monkeypatch.setitem(name_to_config, "tmp", model_config)
load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(module, "lazy_load", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **kwargs: torch.tensor([3, 2, 1], **kwargs)
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
stdout = StringIO()
with redirect_stdout(stdout):
module.setup(data_dir=tmp_path, checkpoint_dir=fake_checkpoint_dir, out_dir=tmp_path, precision="32-true")
assert {p.name for p in tmp_path.glob("*.pth")} == {
"iter-000001-ckpt.pth",
"iter-000003-ckpt.pth",
"iter-000005-ckpt.pth",
"lit_model_adapter_finetuned.pth",
}
assert (tmp_path / "version_0" / "metrics.csv").is_file()
logs = stdout.getvalue()
assert logs.count("optimizer.step") == module.max_iters
assert logs.count("val loss") == module.max_iters // module.eval_interval
assert "of trainable parameters: 168" in logs
def test_adapter_gpt_init_weights():
from lit_gpt.adapter import GPT, Config
config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0)
model = GPT(config)
param = model.transformer.h[0].attn.gating_factor
assert (param == 0).all()
torch.nn.init.constant_(param, 1.23)
assert (param != 0).any()
model.apply(model._init_weights)
assert (param == 0).all()
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="torch.compile not supported on this platform")
@torch.inference_mode()
def test_adapter_compile():
from lit_gpt.adapter import GPT
model = GPT.from_name("pythia-70m", n_layer=3)
x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)
from torch._dynamo.backends import debugging
explanation = torch._dynamo.explain(model, x)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
model = GPT(model.config)
model.set_kv_cache(2)
input_pos = torch.arange(model.config.block_size)
explanation = torch._dynamo.explain(model, x, input_pos)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0