forked from modelscope/modelscope
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dd16c11
commit 4040320
Showing
13 changed files
with
406 additions
and
35 deletions.
There are no files selected for viewing
40 changes: 40 additions & 0 deletions
40
examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
|
||
from modelscope.metainfo import Trainers | ||
from modelscope.msdatasets.audio.asr_dataset import ASRDataset | ||
from modelscope.trainers import build_trainer | ||
|
||
|
||
def modelscope_finetune(params): | ||
if not os.path.exists(params.output_dir): | ||
os.makedirs(params.output_dir, exist_ok=True) | ||
# dataset split ["train", "validation"] | ||
ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr') | ||
kwargs = dict( | ||
model=params.model, | ||
data_dir=ds_dict, | ||
dataset_type=params.dataset_type, | ||
work_dir=params.output_dir, | ||
batch_bins=params.batch_bins, | ||
max_epoch=params.max_epoch, | ||
lr=params.lr) | ||
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) | ||
trainer.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
from funasr.utils.modelscope_param import modelscope_args | ||
|
||
params = modelscope_args( | ||
model= | ||
'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' | ||
) | ||
params.output_dir = './checkpoint' # 模型保存路径 | ||
params.data_path = 'speech_asr_aishell1_trainsets' # 数据路径,可以为modelscope中已上传数据,也可以是本地数据 | ||
params.dataset_type = 'small' # 小数据量设置small,若数据量大于1000小时,请使用large | ||
params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数, | ||
# 如果dataset_type="large",batch_bins单位为毫秒, | ||
params.max_epoch = 50 # 最大训练轮数 | ||
params.lr = 0.00005 # 设置学习率 | ||
|
||
modelscope_finetune(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
PYTHONPATH=. python examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py |
64 changes: 64 additions & 0 deletions
64
examples/pytorch/human_detection/finetune_human_detection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os.path as osp | ||
from argparse import ArgumentParser | ||
|
||
from modelscope.metainfo import Trainers | ||
from modelscope.msdatasets import MsDataset | ||
from modelscope.trainers import build_trainer | ||
from modelscope.utils.constant import DownloadMode | ||
|
||
parser = ArgumentParser() | ||
parser.add_argument('--dataset_name', type=str, help='The dataset name') | ||
parser.add_argument('--namespace', type=str, help='The dataset namespace') | ||
parser.add_argument('--model', type=str, help='The model id or model dir') | ||
parser.add_argument( | ||
'--num_classes', type=int, help='The num_classes in the dataset') | ||
parser.add_argument('--batch_size', type=int, help='The training batch size') | ||
parser.add_argument('--max_epochs', type=int, help='The training max epochs') | ||
parser.add_argument( | ||
'--base_lr_per_img', | ||
type=float, | ||
help='The base learning rate for per image') | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
|
||
# Step 1: 数据集准备,可以使用modelscope上已有的数据集,也可以自己在本地构建COCO数据集 | ||
train_dataset = MsDataset.load( | ||
args.dataset_name, | ||
namespace=args.namespace, | ||
split='train', | ||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | ||
val_dataset = MsDataset.load( | ||
args.dataset_name, | ||
namespace=args.namespace, | ||
split='validation', | ||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | ||
|
||
# Step 2: 相关参数设置 | ||
train_root_dir = train_dataset.config_kwargs['split_config']['train'] | ||
val_root_dir = val_dataset.config_kwargs['split_config']['validation'] | ||
train_img_dir = osp.join(train_root_dir, 'images') | ||
val_img_dir = osp.join(val_root_dir, 'images') | ||
train_anno_path = osp.join(train_root_dir, 'train.json') | ||
val_anno_path = osp.join(val_root_dir, 'val.json') | ||
kwargs = dict( | ||
model=args.model, # 使用DAMO-YOLO-S模型 | ||
gpu_ids=[ # 指定训练使用的gpu | ||
0, | ||
], | ||
batch_size=args. | ||
batch_size, # batch_size, 每个gpu上的图片数等于batch_size // len(gpu_ids) | ||
max_epochs=args.max_epochs, # 总的训练epochs | ||
num_classes=args.num_classes, # 自定义数据中的类别数 | ||
load_pretrain=True, # 是否载入预训练模型,若为False,则为从头重新训练 | ||
base_lr_per_img=args. | ||
base_lr_per_img, # 每张图片的学习率,lr=base_lr_per_img*batch_size | ||
train_image_dir=train_img_dir, # 训练图片路径 | ||
val_image_dir=val_img_dir, # 测试图片路径 | ||
train_ann=train_anno_path, # 训练标注文件路径 | ||
val_ann=val_anno_path, # 测试标注文件路径 | ||
) | ||
|
||
# Step 3: 开启训练任务 | ||
trainer = build_trainer(name=Trainers.tinynas_damoyolo, default_args=kwargs) | ||
trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
PYTHONPATH=. python examples/pytorch/human_detection/finetune_human_detection.py \ | ||
--dataset_name "person_detection_for_train" \ | ||
--namespace "modelscope" \ | ||
--model "damo/cv_tinynas_human-detection_damoyolo" \ | ||
--num_classes 1 \ | ||
--batch_size 2 \ | ||
--max_epochs 3 \ | ||
--base_lr_per_img 0.001 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 \ | ||
examples/pytorch/finetune_image_classification.py \ | ||
examples/pytorch/image_classification/finetune_image_classification.py \ | ||
--num_classes 2 \ | ||
--train_dataset_name 'tany0699/cats_and_dogs' \ | ||
--val_dataset_name 'tany0699/cats_and_dogs' |
68 changes: 68 additions & 0 deletions
68
examples/pytorch/keyword_spotting/finetune_keyword_spotting.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# coding = utf-8 | ||
|
||
import os | ||
from argparse import ArgumentParser | ||
|
||
from modelscope.metainfo import Trainers | ||
from modelscope.trainers import build_trainer | ||
from modelscope.utils.hub import read_config | ||
|
||
|
||
def main(): | ||
parser = ArgumentParser() | ||
parser.add_argument('--model', type=str, help='The model id or model dir') | ||
parser.add_argument('--train_scp', type=str, help='The train scp file') | ||
parser.add_argument('--cv_scp', type=str, help='The cv scp file') | ||
parser.add_argument('--merge_trans', type=str, help='The merge trans file') | ||
parser.add_argument('--keywords', type=str, help='The key words') | ||
parser.add_argument('--work_dir', type=str, help='The work dir') | ||
parser.add_argument('--test_scp', type=str, help='The test scp file') | ||
parser.add_argument('--test_trans', type=str, help='The test trains file') | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
# s1 | ||
work_dir = args.work_dir | ||
|
||
# s2 | ||
model_id = args.model | ||
configs = read_config(model_id) | ||
config_file = os.path.join(work_dir, 'config.json') | ||
configs.dump(config_file) | ||
|
||
# s3 | ||
kwargs = dict( | ||
model=model_id, | ||
work_dir=work_dir, | ||
cfg_file=config_file, | ||
) | ||
trainer = build_trainer( | ||
Trainers.speech_kws_fsmn_char_ctc_nearfield, default_args=kwargs) | ||
|
||
# s4 | ||
train_scp = args.train_scp | ||
cv_scp = args.cv_scp | ||
trans_file = args.merge_trans | ||
kwargs = dict(train_data=train_scp, cv_data=cv_scp, trans_data=trans_file) | ||
trainer.train(**kwargs) | ||
|
||
# s5 | ||
keywords = args.keywords | ||
test_dir = os.path.join(work_dir, 'test_dir') | ||
test_scp = args.test_scp | ||
trans_file = args.test_trans | ||
rank = int(os.environ['RANK']) | ||
if rank == 0: | ||
kwargs = dict( | ||
test_dir=test_dir, | ||
test_data=test_scp, | ||
trans_data=trans_file, | ||
gpu=0, | ||
keywords=keywords, | ||
batch_size=args.batch_size, | ||
) | ||
trainer.evaluate(None, None, **kwargs) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
PYTHONPATH=. torchrun --standalone --nnodes=1 --nproc_per_node=2 examples/pytorch/keyword_spotting/finetune_keyword_spotting.py \ | ||
--work_dir './test_kws_training' \ | ||
--model 'damo/speech_charctc_kws_phone-xiaoyun' \ | ||
--train_scp './example_kws/train_wav.scp' \ | ||
--cv_scp './example_kws/cv_wav.scp' \ | ||
--merge_trans './example_kws/merge_trans.txt' \ | ||
--keywords '小云小云' \ | ||
--test_scp './example_kws/test_wav.scp' \ | ||
--test_trans './example_kws/test_trans.txt' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
import concurrent.futures | ||
import os | ||
|
||
from modelscope.hub.api import HubApi | ||
from modelscope.hub.constants import Licenses, ModelVisibility | ||
from modelscope.hub.errors import NotExistError | ||
from modelscope.utils.logger import get_logger | ||
|
||
logger = get_logger() | ||
|
||
_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8) | ||
|
||
|
||
def _api_push_to_hub(repo_name, | ||
output_dir, | ||
token, | ||
private=True, | ||
commit_message='', | ||
source_repo=''): | ||
api = HubApi() | ||
api.login(token) | ||
api.push_model( | ||
repo_name, | ||
output_dir, | ||
visibility=ModelVisibility.PUBLIC | ||
if not private else ModelVisibility.PRIVATE, | ||
chinese_name=repo_name, | ||
commit_message=commit_message, | ||
original_model_id=source_repo) | ||
commit_message = commit_message or 'No commit message' | ||
logger.info( | ||
f'Successfully upload the model to {repo_name} with message: {commit_message}' | ||
) | ||
|
||
|
||
def push_to_hub(repo_name, | ||
output_dir, | ||
token=None, | ||
private=True, | ||
retry=3, | ||
commit_message='', | ||
source_repo='', | ||
async_upload=False): | ||
""" | ||
Args: | ||
repo_name: The repo name for the modelhub repo | ||
output_dir: The local output_dir for the checkpoint | ||
token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None | ||
private: If is a private repo, default True | ||
retry: Retry times if something error in uploading, default 3 | ||
commit_message: The commit message | ||
source_repo: The source repo (model id) which this model comes from | ||
async_upload: Upload async, if True, the `retry` parameter will have no affection | ||
Returns: | ||
A handler if async_upload is True, else None | ||
""" | ||
if token is None: | ||
token = os.environ.get('MODELSCOPE_API_TOKEN') | ||
assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.' | ||
assert os.path.isdir(output_dir) | ||
assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \ | ||
or 'configuration.yml' in os.listdir(output_dir) | ||
|
||
logger.info( | ||
f'Uploading {output_dir} to {repo_name} with message {commit_message}') | ||
if not async_upload: | ||
for i in range(retry): | ||
try: | ||
_api_push_to_hub(repo_name, output_dir, token, private, | ||
commit_message, source_repo) | ||
except Exception as e: | ||
logger.info(f'Error happens when uploading model: {e}') | ||
continue | ||
break | ||
else: | ||
return _executor.submit(_api_push_to_hub, repo_name, output_dir, token, | ||
private, commit_message, source_repo) |
Oops, something went wrong.