Skip to content

Commit

Permalink
add a weight merging tool file
Browse files Browse the repository at this point in the history
Signed-off-by: zhaohu xing <[email protected]>
  • Loading branch information
920232796 committed Mar 7, 2023
1 parent 32c4fce commit 254226a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
51 changes: 45 additions & 6 deletions flagai/mp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,24 @@
import copy

from_1_to_n_models = {
"gpt": {
"gpt2": {
"wte.weight": 0,
"attn.c_attn.weight": 30,
"attn.c_attn.bias": 30,
"attn.c_proj.weight": 1,
"mlp.c_fc.weight": 0,
"attn.c_proj.weight": 0,
"mlp.c_fc.weight": 1,
"mlp.c_fc.bias": 0,
"mlp.c_proj.weight": 1,
"mlp.c_proj.weight": 0,
},
# "gpt2": {
# "wte.weight": 0,
# "attn.c_attn.weight": 30,
# "attn.c_attn.bias": 30,
# "attn.c_proj.weight": 1,
# "mlp.c_fc.weight": 0,
# "mlp.c_fc.bias": 0,
# "mlp.c_proj.weight": 1,
# },
"opt": {
"decoder.embed_tokens.weight": 0,
"self_attn.k_proj.weight": 0,
Expand All @@ -30,6 +39,20 @@
"fc1.bias": 0,
"fc2.weight": 1,
},
"galactica": {
"decoder.embed_tokens.weight": 0,
"self_attn.k_proj.weight": 0,
"self_attn.k_proj.bias": 0,
"self_attn.q_proj.weight": 0,
"self_attn.q_proj.bias": 0,
"self_attn.v_proj.weight": 0,
"self_attn.v_proj.bias": 0,

"self_attn.out_proj.weight": 1,
"fc1.weight": 0,
"fc1.bias": 0,
"fc2.weight": 1,
},
"glm": {
"word_embeddings.weight": 0,
"attention.query_key_value.weight": 30,
Expand Down Expand Up @@ -238,7 +261,8 @@ def change_pytorch_model_mp_from_1_to_n_new(model_name_brief, checkpoint: str, t
d = d["module"]

for k, v in d.items():
assert len(v.shape) < 3
if len(v.shape) > 2:
continue
flag = 0
for keys in trans_keys:
if keys in k:
Expand All @@ -261,6 +285,21 @@ def change_pytorch_model_mp_from_1_to_n_new(model_name_brief, checkpoint: str, t
], 0)
break

elif dim == 31:
v = v.permute(1, 0)
part = v.shape[0] // ratio // 3
v = torch.cat([
v[shift * part:(shift + 1) *
part, :].clone(),
v[(shift + ratio) *
part:(shift + 1 + ratio) *
part, :].clone(),
v[(shift + 2 * ratio) *
part:(shift + 1 + 2 * ratio) *
part, :].clone()
], 0)
v = v.permute(1, 0)

elif dim == 0:
part = v.shape[dim] // ratio
d_new['module'][k] = v[shift *
Expand Down Expand Up @@ -412,4 +451,4 @@ def change_pytorch_model_mp_from_n_to_1(model_name_brief, checkpoint):

if __name__ == "__main__":
change_pytorch_model_mp_from_1_to_n(
'/mnt/test_10b_models/state_dict/GLM-10b-en', 2)
'/mnt/test_10b_models/state_dict/GLM-10b-en', 2)
25 changes: 25 additions & 0 deletions flagai/tools/merge_huggingface_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

import os
import torch

def merge_weight(model_dir):
model_files = os.listdir(model_dir)
checkpoint_merge = {}
print(f"merging the model weight....")
# multi weights files
for file_to_load in model_files:
if "pytorch_model-0" in file_to_load:
checkpoint_to_load = torch.load(os.path.join(model_dir, file_to_load),map_location="cpu")
for k, v in checkpoint_to_load.items():
checkpoint_merge[k] = v
print(f"{file_to_load} is merged successfully.")
# save all parameters
torch.save(
checkpoint_merge,
os.path.join(model_dir, "pytorch_model.bin"))
print(f"models are merged successfully.")


if __name__ == "__main__":
# merge_weight(model_dir="/share/projset/baaishare/baai-mrnd/xingzhaohu/galactica-6.7b-en/")
merge_weight(model_dir="./state_dict/opt-6.7b-en")

0 comments on commit 254226a

Please sign in to comment.