-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_new.py
105 lines (95 loc) · 3.66 KB
/
train_new.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from dataclasses import dataclass, field
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from transformers import HfArgumentParser, TrainingArguments, AutoModelForCausalLM, Trainer, AutoTokenizer
from trl.trainer.utils import PeftSavingCallback
from llm_train_util import prepare_model_for_training, SFTDataCollector
@dataclass
class FinetuneArguments:
"""
微调参数
"""
pretrained_model_path: str = field()
train_dataset_path: str = field()
eval_dataset_path: str = field()
pad_token_id: int = field(default=0)
lora_rank: int = field(default=16)
lora_alpha: float = field(default=32.0)
lora_dropout: float = field(default=0.1)
lora_target: str = field(default="W_pack")
ft_type: str = field(default="lora")
def create_and_prepare_dataset(data_path):
"""
创建数据
:return:
"""
train_dataset = load_dataset("json", data_files=data_path)
def preprocess_function(example):
"""
预处理
:param example:
:return:
"""
prompt = f"任务: 纠错文本\n输入: {example['src']}\n输出: "
response = example['tgt']
return {
'prompt': prompt,
'response': response,
}
train_dataset = train_dataset.map(preprocess_function, batched=False)
return train_dataset['train']
def train():
"""
训练模型
:return:
"""
finetune_args, training_args = HfArgumentParser(
(FinetuneArguments, TrainingArguments)).parse_args_into_dataclasses()
# load model
model = AutoModelForCausalLM.from_pretrained(finetune_args.pretrained_model_path,
trust_remote_code=True)
if finetune_args.ft_type == 'lora':
model = prepare_model_for_training(model)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=[target.strip() for target in finetune_args.lora_target.split(",")]
)
model = get_peft_model(model, lora_config)
# load tokenizer
if 'Lang16' in finetune_args.pretrained_model_path:
import sys
sys.path.append(finetune_args.pretrained_model_path)
from tokenization_hackt5 import HackT5TokenizerFast
tokenizer = HackT5TokenizerFast.from_pretrained(finetune_args.pretrained_model_path)
else:
tokenizer = AutoTokenizer.from_pretrained(finetune_args.pretrained_model_path, trust_remote_code=True, use_fast=False)
if 'Baichuan' in finetune_args.pretrained_model_path:
tokenizer.pad_token_id = 0
elif 'Qwen' in finetune_args.pretrained_model_path:
tokenizer.pad_token_id = 135269
tokenizer.eos_token_id = 135269
elif 'internlm' in finetune_args.pretrained_model_path:
tokenizer.pad_token_id = tokenizer.eos_token_id
elif 'Skywork' in finetune_args.pretrained_model_path:
tokenizer.pad_token_id = 0
# load dataset
train_dataset = create_and_prepare_dataset(finetune_args.train_dataset_path)
eval_dataset = create_and_prepare_dataset(finetune_args.eval_dataset_path)
# start train
training_args.ddp_find_unused_parameters = False
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=SFTDataCollector(tokenizer),
callbacks=[PeftSavingCallback()] if finetune_args.ft_type == 'lora' else None,
)
trainer.train()
if __name__ == '__main__':
train()