Skip to content

Commit

Permalink
Fix conversion for sharded models (Lightning-AI#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Apr 5, 2023
1 parent 75d20f5 commit ca7efd8
Showing 1 changed file with 52 additions and 3 deletions.
55 changes: 52 additions & 3 deletions scripts/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
import shutil
from pathlib import Path
Expand Down Expand Up @@ -54,6 +55,17 @@ def convert_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
return converted


shard_dims = {
"lm_head.weight": 0,
"wte.weight": 1,
"attn.c_attn.weight": 0,
"attn.c_proj.weight": 1,
"mlp.c_fc1.weight": 0,
"mlp.c_fc2.weight": 0,
"mlp.c_proj.weight": 1
}


def meta_weights_for_nano_model(
*,
output_dir: Path = Path("checkpoints/lit-llama"),
Expand All @@ -70,14 +82,51 @@ def meta_weights_for_nano_model(
shutil.copy(tokenizer_path, output_dir.parent)

checkpoint_files = sorted(ckpt_dir.glob("*.pth"))
checkpoint_files.sort()
n_checkpoints = len(checkpoint_files)

# for the bigger models, there are multiple model-parallel checkpoints
# and we combine them into one single file
combined = {}
for file in tqdm(checkpoint_files, total=len(checkpoint_files)):
combined = None
for file in tqdm(checkpoint_files, total=n_checkpoints):
checkpoint = torch.load(file, map_location="cpu")
converted = convert_state_dict(checkpoint)
combined.update(converted)
if combined is None:
combined = converted
continue
for name, param in converted.items():
dim = None
for k, d in shard_dims.items():
if k in name:
dim = d
break
if dim is None:
# Extra check: assert that tensors are the same if not sharded
# assert torch.allclose(combined[name], param)
continue
combined[name] = torch.cat((combined[name], param), dim=dim)

del checkpoint
del converted
gc.collect()

for name, param in combined.items():
if "c_attn" not in name:
continue

# Turn [Q1, K1, V1, Q2, K2, V2, ...] into [Q1, Q2, ..., K1, K2, .., V1, V2, ...]

src_chunk_len = param.shape[0] // n_checkpoints
mat_len = src_chunk_len // 3
dst_chunk_len = mat_len * n_checkpoints
attn = torch.clone(param)
for i in range(n_checkpoints):
for j in range(3):
param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \
attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len]

del attn
gc.collect()

torch.save(combined, Path(output_dir, "state_dict.pth"))

Expand Down

0 comments on commit ca7efd8

Please sign in to comment.