Skip to content

Commit

Permalink
unify trainging pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
yangjianxin1 committed Jan 28, 2024
1 parent 6554f61 commit 49d9ae7
Show file tree
Hide file tree
Showing 13 changed files with 605 additions and 1,529 deletions.
127 changes: 0 additions & 127 deletions ChatGLM3.md

This file was deleted.

60 changes: 4 additions & 56 deletions component/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,63 +10,11 @@ class CustomizedArguments:
max_seq_length: int = field(metadata={"help": "输入最大长度"})
train_file: str = field(metadata={"help": "训练集。如果task_type=pretrain,请指定文件夹,将扫描其下面的所有jsonl文件"})
model_name_or_path: str = field(metadata={"help": "预训练权重路径"})
# min_seq_length: int = field(default=1024, metadata={"help": "输小最大长度"})
# window_step_size: int = field(default=1024, metadata={"help": "滑动窗口步长"})
template_name: str = field(default="", metadata={"help": "sft时的数据格式"})
eval_file: Optional[str] = field(default="", metadata={"help": "验证集"})
task_type: str = field(default="sft", metadata={"help": "预训练任务:[sft, pretrain]"})
tokenize_num_workers: int = field(default=1, metadata={"help": ""})


@dataclass
class QLoRAArguments:
"""
一些自定义参数
"""
max_seq_length: int = field(metadata={"help": "输入最大长度"})
train_file: str = field(metadata={"help": "训练集"})
model_name_or_path: str = field(metadata={"help": "预训练权重路径"})
task_type: str = field(default="", metadata={"help": "预训练任务:[sft, pretrain]"})
eval_file: Optional[str] = field(default="", metadata={"help": "验证集"})
lora_rank: Optional[int] = field(default=64, metadata={"help": "lora rank"})
lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"})


@dataclass
class DPOArguments:
"""
一些自定义参数
"""
max_seq_length: int = field(metadata={"help": "输入最大长度"})
max_prompt_length: Optional[int] = field(metadata={"help": "max length of prompt"})

train_file: str = field(metadata={"help": "训练集"})
model_name_or_path: str = field(metadata={"help": "预训练权重路径"})
eval_file: Optional[str] = field(default="", metadata={"help": "the file of training data"})

# 定义template,单轮对话prompt的拼接格式为:{system}{conv_begin}{human_begin}你好{human_end}{assistant_begin}
system: int = field(default='', metadata={"help": ""})
conv_begin: int = field(default='', metadata={"help": ""})
human_begin: int = field(default='', metadata={"help": ""})
human_end: int = field(default='', metadata={"help": ""})
assistant_begin: int = field(default='', metadata={"help": ""})
assistant_end: int = field(default='', metadata={"help": ""})

use_lora: bool = field(default=False, metadata={"help": "预训练任务:[sft, pretrain]"})
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
tokenize_num_workers: int = field(default=10, metadata={"help": "预训练时tokenize的线程数量"})
task_type: str = field(default="sft", metadata={"help": "预训练任务:[pretrain, sft]"})
train_mode: str = field(default="qlora", metadata={"help": "训练方式:[full, qlora]"})
lora_rank: Optional[int] = field(default=64, metadata={"help": "lora rank"})
lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"})


@dataclass
class LOMOArguments:
"""
LOMO训练的自定义参数
"""
max_seq_length: int = field(metadata={"help": "输入最大长度"})
train_file: str = field(metadata={"help": "训练集"})
model_name_or_path: str = field(metadata={"help": "预训练权重路径"})
clip_grad_norm: float = field(metadata={"help": "Maximum gradient normalized value (for gradient clipping)."})
clip_grad_value: float = field(default=None, metadata={"help": "Maximum gradient value (for gradient clipping)."})
eval_file: Optional[str] = field(default="", metadata={"help": "the file of training data"})
Loading

0 comments on commit 49d9ae7

Please sign in to comment.