forked from modelscope/modelscope
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_megatron_ckpt.py
31 lines (23 loc) · 962 Bytes
/
convert_megatron_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
from modelscope.models import Model
from modelscope.utils.megatron_utils import convert_megatron_checkpoint
def unwrap_model(model):
for name in ('model', 'module', 'dist_model'):
while hasattr(model, name):
model = getattr(model, name)
return model
parser = argparse.ArgumentParser(
description='Split or merge your megatron_based checkpoint.')
parser.add_argument(
'--model_dir', type=str, required=True, help='Checkpoint to be converted.')
parser.add_argument(
'--target_dir', type=str, required=True, help='Target save path.')
args = parser.parse_args()
model = Model.from_pretrained(
args.model_dir,
rank=int(os.getenv('RANK')),
megatron_cfg={'tensor_model_parallel_size': int(os.getenv('WORLD_SIZE'))})
unwrapped_model = unwrap_model(model)
convert_megatron_checkpoint(unwrapped_model, model.model_dir, args.target_dir)