Skip to content

Commit

Permalink
Enable specifying precision during conversion (Lightning-AI#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Apr 11, 2023
1 parent 98d7115 commit bae421f
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions scripts/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
"""


def convert_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
converted = {}
converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"]
converted["lm_head.weight"] = state_dict["output.weight"]
converted["transformer.ln_f.scale"] = state_dict["norm.weight"]
converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype)
converted["lm_head.weight"] = state_dict["output.weight"].to(dtype)
converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype)

for key in [k for k in state_dict if k.startswith("layers")]:
layer_idx = key.split(".")[1]
Expand All @@ -31,27 +31,27 @@ def convert_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
# the wq, wk, wv from the FB model are stacked in our model as c_attn
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(
(
state_dict[f"layers.{layer_idx}.attention.wq.weight"],
state_dict[f"layers.{layer_idx}.attention.wk.weight"],
state_dict[f"layers.{layer_idx}.attention.wv.weight"],
state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype),
state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype),
state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype),
)
)
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[
f"layers.{layer_idx}.attention.wo.weight"
]
].to(dtype)
# mlp
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[
f"layers.{layer_idx}.feed_forward.w1.weight"
]
].to(dtype)
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[
f"layers.{layer_idx}.feed_forward.w2.weight"
]
].to(dtype)
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[
f"layers.{layer_idx}.feed_forward.w3.weight"
]
].to(dtype)
# rms norm
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"]
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"]
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype)
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype)
return converted


Expand All @@ -72,11 +72,18 @@ def meta_weights_for_nano_model(
ckpt_dir: Path = Path("checkpoints/llama/"),
tokenizer_path: Path = Path("checkpoints/llama/tokenizer.model"),
model_size: str = "7B",
dtype: str = None,
) -> None:
output_dir = output_dir / model_size
ckpt_dir = ckpt_dir / model_size
os.makedirs(output_dir, exist_ok=True)

if dtype is not None:
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
dtype = dt

# the tokenizer is the same for all model sizes, so we store it in the parent dir
if "tokenizer.model" not in os.listdir(output_dir.parent):
shutil.copy(tokenizer_path, output_dir.parent)
Expand All @@ -90,7 +97,7 @@ def meta_weights_for_nano_model(
combined = None
for file in tqdm(checkpoint_files, total=n_checkpoints):
checkpoint = torch.load(file, map_location="cpu")
converted = convert_state_dict(checkpoint)
converted = convert_state_dict(checkpoint, dtype=dtype)
if combined is None:
combined = converted
continue
Expand Down

0 comments on commit bae421f

Please sign in to comment.