Skip to content

Commit

Permalink
Script to convert Meta checkpoints to ours (Lightning-AI#12)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
awaelchli and carmocca authored Mar 27, 2023
1 parent 28ece37 commit 68a2480
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__
.idea
.DS_Store
*.egg-info

# data
data
Expand Down
29 changes: 9 additions & 20 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,9 @@ def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k=
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
"""
# create an empty tensor of the expected final shape and fill in the current tokens
B, T = idx.shape
T_new = T + max_new_tokens
empty = torch.empty(B, T_new, dtype=idx.dtype, device=idx.device)
empty[:, :T] = idx
idx = empty

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:, :t]
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if T <= max_seq_length else idx_cond[:, -max_seq_length:]

# forward
idx_cond = idx if idx.size(1) <= max_seq_length else idx[:, -max_seq_length:]
logits = model(idx_cond)
logits = logits[:, -1, :] / temperature

Expand Down Expand Up @@ -80,8 +68,8 @@ def main(
compile: bool = False,
accelerator: str = "auto",
precision: str = "32-true",
checkpoint_path: str = "/srv/data/checkpoints/llama/converted_meta/7B/state_dict.pt",
tokenizer_path: str = "/srv/data/checkpoints/llama/converted_meta/tokenizer.model",
checkpoint_path: str = "/srv/data/checkpoints/llama/converted_nano/7B/state_dict.pth",
tokenizer_path: str = "/srv/data/checkpoints/llama/converted_nano/tokenizer.model",
original_model: bool = False,
):
"""
Expand All @@ -106,15 +94,14 @@ def main(
assert os.path.isfile(checkpoint_path)
assert os.path.isfile(tokenizer_path)

L.seed_everything(1234)
fabric = L.Fabric(accelerator=accelerator, precision=precision, devices=1)

# initialize the model directly on the device
with fabric.device:
model, max_seq_length = get_model(original_model)
# TODO: checkpoint loading is currently broken
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint, strict=(not original_model))

model.eval()
if compile:
model = torch.compile(model)
Expand All @@ -123,6 +110,8 @@ def main(
tokenizer = Tokenizer(tokenizer_path)
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False).to(fabric.device)
encoded_prompt = encoded_prompt[None, :]

L.seed_everything(1234)
for _ in range(num_samples):
y = generate(
model, encoded_prompt, max_new_tokens, max_seq_length, temperature=temperature, top_k=top_k
Expand Down
68 changes: 49 additions & 19 deletions scripts/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,76 @@
from contextlib import contextmanager
from pathlib import Path

import torch
from tqdm import tqdm
import os
import shutil

"""
Sample usage:
```bash
python -m models.llama.convert_checkpoint -h
python -m scripts.convert_checkpoint -h
python -m models.llama.convert_checkpoint meta_weights_for_meta_model converted
python -m scripts.convert_checkpoint converted
```
"""

def convert_state_dict(state_dict):
converted = {}
converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"]
converted["lm_head.weight"] = state_dict["output.weight"]
converted["transformer.ln_f.scale"] = state_dict["norm.weight"]

for key in [k for k in state_dict if k.startswith("layers")]:
layer_idx = key.split(".")[1]

@contextmanager
def on_dtype(dtype):
original = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(original)
# attention
# the wq, wk, wv from the FB model are stacked in our model as c_attn
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat((
state_dict[f"layers.{layer_idx}.attention.wq.weight"],
state_dict[f"layers.{layer_idx}.attention.wk.weight"],
state_dict[f"layers.{layer_idx}.attention.wv.weight"],
))
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[f"layers.{layer_idx}.attention.wo.weight"]
# mlp
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[f"layers.{layer_idx}.feed_forward.w1.weight"]
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[f"layers.{layer_idx}.feed_forward.w2.weight"]
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[f"layers.{layer_idx}.feed_forward.w3.weight"]
# rms norm
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"]
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"]
return converted


def meta_weights_for_meta_model(
def meta_weights_for_nano_model(
*,
output_dir: Path,
ckpt_dir: Path = Path("/srv/data/checkpoints/llama/raw"),
tokenizer_path: Path = Path("/srv/data/checkpoints/llama/raw/tokenizer.model"),
model_size: str = "7B",
):
...


def meta_weights_for_nano_model():
...
output_dir = output_dir / model_size
ckpt_dir = ckpt_dir / model_size
os.makedirs(output_dir, exist_ok=True)

# the tokenizer is the same for all model sizes, so we store it in the parent dir
if "tokenizer.model" not in os.listdir(output_dir.parent):
shutil.copy(tokenizer_path, output_dir.parent)

def lightning_weights_for_nano_model():
...
checkpoint_files = sorted(ckpt_dir.glob("*.pth"))


# for the bigger models, there are multiple model-parallel checkpoints
# and we combine them into one single file
combined = {}
for file in tqdm(checkpoint_files, total=len(checkpoint_files)):
checkpoint = torch.load(file, map_location="cpu")
converted = convert_state_dict(checkpoint)
combined.update(converted)

torch.save(combined, Path(output_dir, "state_dict.pth"))


if __name__ == "__main__":
from jsonargparse import CLI

CLI([meta_weights_for_meta_model, meta_weights_for_nano_model, lightning_weights_for_nano_model])
CLI(meta_weights_for_nano_model)
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from setuptools import find_packages, setup

setup(
name="lightning-llama",
version="0.0.1",
description="",
author="Lightning AI",
url="https://lightning.ai",
packages=find_packages(where="."),
python_requires=">=3.8",
)

0 comments on commit 68a2480

Please sign in to comment.