forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_convert_lit_checkpoint.py
503 lines (448 loc) · 18.3 KB
/
test_convert_lit_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from dataclasses import asdict
from unittest.mock import ANY
import pytest
import torch
import yaml
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.falcon import FalconConfig, FalconForCausalLM
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from litgpt import GPT, Config
from litgpt.scripts.convert_lit_checkpoint import (
check_conversion_supported,
convert_lit_checkpoint,
copy_weights_falcon,
copy_weights_gpt_neox,
copy_weights_llama,
copy_weights_phi,
qkv_split,
)
from tests.conftest import RunIf
def test_convert_lit_checkpoint(tmp_path):
ours_config = Config.from_name("Llama-2-7b-hf", block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128)
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
config_path = tmp_path / "model_config.yaml"
torch.save(ours_model.state_dict(), checkpoint_path)
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
output_dir = tmp_path / "out_dir"
convert_lit_checkpoint(checkpoint_path.parent, output_dir)
assert set(os.listdir(tmp_path)) == {"lit_model.pth", "model_config.yaml", "out_dir"}
assert os.path.isfile(output_dir / "model.pth")
# check checkpoint is unwrapped
torch.save({"model": ours_model.state_dict()}, checkpoint_path)
convert_lit_checkpoint(checkpoint_path.parent, output_dir)
converted_sd = torch.load(output_dir / "model.pth")
assert "model" not in converted_sd
@torch.inference_mode()
def test_against_falcon_40b():
ours_config = Config.from_name("falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32)
theirs_config = FalconConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_kv_heads=ours_config.n_query_groups,
new_decoder_architecture=True,
parallel_attn=ours_config.parallel_residual,
bias=ours_config.bias,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_falcon("40b", theirs_state_dict, ours_state_dict)
theirs_model = FalconForCausalLM(theirs_config)
# assign must be set to True for torch.testing.assert_close to pass
theirs_model.load_state_dict(theirs_state_dict, assign=True)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
def test_against_original_gpt_neox():
ours_config = Config(block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=16)
assert ours_config.padded_vocab_size == 512
theirs_config = GPTNeoXConfig(
hidden_act="gelu",
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
initializer_range=0.02,
intermediate_size=ours_config.intermediate_size,
layer_norm_eps=1e-05,
max_position_embeddings=ours_config.block_size,
rotary_emb_base=10000,
rotary_pct=ours_config.rotary_percentage,
vocab_size=ours_config.padded_vocab_size,
use_parallel_residual=ours_config.parallel_residual,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_gpt_neox(theirs_state_dict, ours_state_dict)
theirs_model = GPTNeoXForCausalLM(theirs_config)
# strict=False because we don't save the rotary embeddings inv frequency
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
assert all("inv_freq" in k for k in keys.missing_keys)
# test end to end
x = torch.randint(0, ours_config.padded_vocab_size, size=(2, ours_config.block_size), dtype=torch.int64)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
"ours_kwargs", [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf"}]
)
def test_against_hf_llama2(ours_kwargs):
ours_config = Config.from_name(
padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_query_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = LlamaForCausalLM(theirs_config)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
def test_against_mixtral():
ours_config = Config.from_name(
"Mixtral-8x7B-Instruct-v0.1",
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
n_expert=4,
)
T = 5
theirs_config = MixtralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
num_local_experts=ours_config.n_expert,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = MixtralForCausalLM(theirs_config)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
def test_against_original_open_llama_3b():
ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)
T = 5
theirs_config = LlamaConfig(
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = LlamaForCausalLM(theirs_config)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2"))
def test_against_hf_phi(model_name):
from transformers.models.phi.configuration_phi import PhiConfig
from transformers.models.phi.modeling_phi import PhiForCausalLM
ours_config = Config.from_name(
model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
)
T = 5
theirs_config = PhiConfig(
vocab_size=ours_config.padded_vocab_size,
max_position_embeddings=ours_config.block_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
partial_rotary_factor=ours_config.rotary_percentage,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = PhiForCausalLM(theirs_config)
# strict=False because we don't save the rotary embeddings inv frequency
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
assert all("inv_freq" in k for k in keys.missing_keys)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Phi-3-mini-4k-instruct",))
def test_against_hf_phi_3(model_name):
from transformers.models.phi3.configuration_phi3 import Phi3Config
from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM
ours_config = Config.from_name(
model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256
)
T = 5
theirs_config = Phi3Config(
attention_bias=ours_config.bias,
head_dim=ours_config.head_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
num_key_value_heads=ours_config.n_query_groups,
pad_token_id=ours_config.padded_vocab_size - 1,
partial_rotary_factor=ours_config.rotary_percentage,
rms_norm_eps=ours_config.norm_eps,
rope_theta=ours_config.rope_base,
vocab_size=ours_config.padded_vocab_size,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = Phi3ForCausalLM(theirs_config)
# strict=False because we don't save the rotary embeddings inv frequency
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
assert all("inv_freq" in k for k in keys.missing_keys)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
def test_against_original_stablelm_zephyr_3b():
T = 5
ours_config = Config.from_name("stablelm-zephyr-3b", n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = AutoConfig.from_pretrained(
"stabilityai/stablelm-zephyr-3b",
trust_remote_code=True,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_head,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = AutoModelForCausalLM.from_config(theirs_config, trust_remote_code=True)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config).to(device)
# tie weights
ours_model.lm_head.weight = ours_model.transformer.wte.weight
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)
theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_model.load_state_dict(theirs_state_dict, strict=False)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
def test_check_conversion_supported_adapter():
lit_weights = {"some.key.name": ANY, "error.key.gating_factor": ANY}
with pytest.raises(NotImplementedError, match="Converting adapter"):
check_conversion_supported(lit_weights=lit_weights)
lit_weights = {"some.key.name": ANY, "error.key.adapter_bias": ANY}
with pytest.raises(NotImplementedError, match="Converting adapter"):
check_conversion_supported(lit_weights=lit_weights)
def test_check_conversion_supported_lora():
lit_weights = {"some.key.name": ANY, "error.key.lora": ANY}
with pytest.raises(ValueError, match=r"LoRA.*cannot be converted"):
check_conversion_supported(lit_weights=lit_weights)
def test_qkv_split():
# MHA
config = Config(n_embd=4, n_head=4)
qkv_interleaved = torch.tensor(
[
[0, 1, 2, 3], # query
[16, 17, 18, 19], # key
[32, 33, 34, 35], # value
[4, 5, 6, 7], # query
[20, 21, 22, 23], # key
[36, 37, 38, 39], # value
[8, 9, 10, 11], # query
[24, 25, 26, 27], # key
[40, 41, 42, 43], # value
[12, 13, 14, 15], # query
[28, 29, 30, 31], # key
[44, 45, 46, 47], # value
]
)
qkv = torch.cat(qkv_split(qkv_interleaved, config))
torch.testing.assert_close(
qkv,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # key
[28, 29, 30, 31], # key
[32, 33, 34, 35], # value
[36, 37, 38, 39], # value
[40, 41, 42, 43], # value
[44, 45, 46, 47], # value
]
),
)
# GQA
config = Config(n_embd=4, n_head=4, n_query_groups=2)
qkv_interleaved = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[16, 17, 18, 19], # key
[24, 25, 26, 27], # value
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[20, 21, 22, 23], # key
[28, 29, 30, 31], # value
]
)
qkv = torch.cat(qkv_split(qkv_interleaved, config))
torch.testing.assert_close(
qkv,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # value
[28, 29, 30, 31], # value
]
),
)
# MQA
config = Config(n_embd=4, n_head=4, n_query_groups=1)
qkv_interleaved = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
)
qkv = torch.cat(qkv_split(qkv_interleaved, config))
torch.testing.assert_close(
qkv,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
),
)