Skip to content

Commit

Permalink
Remove Additional Dict Causing TypeError for Llama Models (Lightning-…
Browse files Browse the repository at this point in the history
…AI#350)

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
jxtngx and awaelchli authored Aug 4, 2023
1 parent cf5542a commit e40647f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 29 deletions.
3 changes: 1 addition & 2 deletions scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def convert_lit_checkpoint(
if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b")
elif config._mlp_class == "LLaMAMLP":
qkv_weights = {}
copy_fn = partial(copy_weights_llama, config, qkv_weights)
copy_fn = partial(copy_weights_llama, config)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
66 changes: 39 additions & 27 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,46 @@ def test_convert_lit_checkpoint(tmp_path):

ckpt_name = "lit_model.pth"

with pytest.raises(
RuntimeError, match="open file failed because of errno 2 on fopen"
):
convert_lit_checkpoint(
checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name="falcon-7b"
)
with pytest.raises(RuntimeError, match="open file failed because of errno 2 on fopen"):
convert_lit_checkpoint(checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name="falcon-7b")

ckpt_path = tmp_path / "lit_model.pth"
ckpt_path.touch()
with mock.patch("scripts.convert_lit_checkpoint.lazy_load") as load:
convert_lit_checkpoint(
checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name="falcon-7b"
)
convert_lit_checkpoint(checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name="falcon-7b")
load.assert_called_with(ckpt_path)

assert {p.name for p in tmp_path.glob("*")} == {"lit_model.pth", "lit_model.bin"}


def test_convert_lit_checkpoint_llama2(tmp_path):
from lit_gpt import Config, GPT
from scripts.convert_lit_checkpoint import convert_lit_checkpoint
from finetune.full import save_checkpoint

# fabric is needed for finetune.full::save_checkpoint
fabric = L.Fabric(devices=1)

ckpt_path: Path = tmp_path / "lit_model_finetune.pth"
ckpt_name = ckpt_path.name

model_name = "Llama-2-7b-hf"
ours_config = Config.from_name(
model_name,
block_size=8,
n_layer=2,
n_embd=32,
n_head=2,
padding_multiple=128,
)
ours_model = GPT(ours_config)

# save checkpoint to avoid RunTimeError for PytorchStreamReader
save_checkpoint(fabric, ours_model, ckpt_path)
# this should not cause a TypeError
convert_lit_checkpoint(checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name=model_name)


@torch.inference_mode()
def test_against_original_falcon_40b():
file_path = wd / "tests" / "original_falcon_40b.py"
Expand All @@ -43,9 +65,7 @@ def test_against_original_falcon_40b():
from lit_gpt import Config, GPT
from scripts.convert_lit_checkpoint import copy_weights_falcon as copy_to_theirs

ours_config = Config.from_name(
"falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32
)
ours_config = Config.from_name("falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32)
theirs_config = RWConfig(
hidden_size=32,
n_head=8,
Expand Down Expand Up @@ -127,9 +147,7 @@ def test_against_original_llama2(size):
ours_kwargs = {"name": "Llama-2-70b-chat-hf", "n_query_groups": 2}
theirs_kwargs = {"num_key_value_heads": 2}

ours_config = Config.from_name(
n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
)
ours_config = Config.from_name(n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs)
T = 5
theirs_config = LlamaConfig(
hidden_size=ours_config.n_embd,
Expand Down Expand Up @@ -172,10 +190,10 @@ def test_maybe_unwrap_state_dict(tmp_path):
model_name = "pythia-70m"
ours_config = Config.from_name(
model_name,
block_size=2048,
block_size=8,
n_layer=2,
n_embd=2048,
n_head=8,
n_embd=32,
n_head=2,
padding_multiple=128,
)
ours_model = GPT(ours_config)
Expand All @@ -188,19 +206,13 @@ def test_maybe_unwrap_state_dict(tmp_path):

# convert and check that model key does not exist
# and that a known key for pythia exists
convert_lit_checkpoint(
checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name=model_name
)
convert_lit_checkpoint(checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name=model_name)
bin_file = ckpt_path.with_suffix(".bin")
ckpt_from_unwrapped = torch.load(bin_file)
assert ckpt_from_unwrapped.get("model") is None
assert ckpt_from_unwrapped.get("embed_out.weight") is not None

# assert maybe_unwrap_state_dict is called
with mock.patch(
"scripts.convert_lit_checkpoint.maybe_unwrap_state_dict"
) as maybe_unwrap:
convert_lit_checkpoint(
checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name=model_name
)
with mock.patch("scripts.convert_lit_checkpoint.maybe_unwrap_state_dict") as maybe_unwrap:
convert_lit_checkpoint(checkpoint_name=ckpt_name, checkpoint_dir=tmp_path, model_name=model_name)
maybe_unwrap.assert_called()

0 comments on commit e40647f

Please sign in to comment.