Skip to content

Commit

Permalink
Merge pull request pengxiao-song#31 from pengxiao-song/dev
Browse files Browse the repository at this point in the history
Update source code
  • Loading branch information
herobrine19 authored May 22, 2023
2 parents f329433 + 5807e73 commit 87d0281
Show file tree
Hide file tree
Showing 24 changed files with 240 additions and 151 deletions.
20 changes: 19 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,19 @@
./data
__pycache__/
*.npy
*.npz
*.pyc
*.pyd
*.so
*.ipynb
.ipynb_checkpoints
models/base_models/*
!models/base_models/.gitkeep
models/lora_weights/*
!models/lora_weights/.gitkeep
outputs/*
!outputs/.gitkeep
data/*
!data/.gitkeep
wandb/
flagged/
.DS_Store
77 changes: 37 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
本项目持续开展,法律领域数据集及系列模型后续相继开源,敬请关注。

## 更新
- 🛠️ 2023/05/22:项目主分支结构调整,详见[项目结构](https://github.com/pengxiao-song/LaWGPT#项目结构)

- 🪴 2023/05/15:发布 [中文法律数据源汇总(Awesome Chinese Legal Resources)](https://github.com/pengxiao-song/awesome-chinese-legal-resources)[法律领域词表](https://github.com/pengxiao-song/LaWGPT/blob/main/resources/legal_vocab.txt)

Expand All @@ -44,13 +45,25 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
1. 准备代码,创建环境

```bash
# 下载代码
git clone [email protected]:pengxiao-song/LaWGPT.git
cd LaWGPT

# 创建环境
conda create -n lawgpt python=3.10 -y
conda activate lawgpt
pip install -r requirements.txt

# 启动可视化脚本(自动下载预训练模型约15GB)
bash ./scripts/webui.sh
```

2. 合并模型权重(可选)
2. 访问 http://127.0.0.1:7860
<p align="center">
<img src="./assets/demo/demo.png" width="80%" >
</p>

3. 合并模型权重(可选)

**如果您想使用 LaWGPT-7B-alpha 模型,可跳过改步,直接进入步骤3.**

Expand All @@ -61,44 +74,28 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。
本项目给出[合并方式](https://github.com/pengxiao-song/LaWGPT/wiki/%E6%A8%A1%E5%9E%8B%E5%90%88%E5%B9%B6),请各位获取原版权重后自行重构模型。


3. 启动示例

启动本地服务:

```bash
conda activate lawgpt
cd LaWGPT
sh src/scripts/generate.sh
```

接入服务:

<p align="center">
<img src="./assets/demo/demo.png" width="80%" >
</p>


## 项目结构

```bash
```bash
LaWGPT
├── assets # 项目静态资源
├── data # 语料及精调数据
├── tools # 数据清洗等工具
├── assets # 静态资源
├── resources # 项目资源
├── models # 基座模型及 lora 权重
│ ├── base_models
│ └── lora_weights
├── outputs # 指令微调的输出权重
├── data # 实验数据
├── scripts # 脚本目录
│ ├── finetune.sh # 指令微调脚本
│ └── webui.sh # 启动服务脚本
├── templates # prompt 模板
├── tools # 工具包
├── utils
├── train_clm.py # 二次训练
├── finetune.py # 指令微调
├── webui.py # 启动服务
├── README.md
├── requirements.txt
└── src # 源码
├── finetune.py
├── generate.py
├── models # 基座模型及 Lora 权重
│ ├── base_models
│ └── lora_weights
├── outputs
├── scripts # 脚本文件
│ ├── finetune.sh # 指令微调
│ └── generate.sh # 服务创建
├── templates
└── utils
└── requirements.txt
```


Expand All @@ -119,13 +116,13 @@ LawGPT 系列模型的训练过程分为两个阶段:

### 二次训练流程

1. 参考 `src/data/example_instruction_train.json` 构造二次训练数据集
2. 运行 `src/scripts/train_lora.sh`
1. 参考 `resources/example_instruction_train.json` 构造二次训练数据集
2. 运行 `scripts/train_clm.sh`

### 指令精调步骤

1. 参考 `src/data/example_instruction_tune.json` 构造指令微调数据集
2. 运行 `src/scripts/finetune.sh`
1. 参考 `resources/example_instruction_tune.json` 构造指令微调数据集
2. 运行 `scripts/finetune.sh`

### 计算资源

Expand Down Expand Up @@ -222,4 +219,4 @@ LawGPT 系列模型的训练过程分为两个阶段:

## 引用

如果您觉得我们的工作对您有所帮助,请考虑引用该项目
如果您觉得我们的工作对您有所帮助,请考虑引用该项目
File renamed without changes.
79 changes: 41 additions & 38 deletions src/finetune.py → finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
import transformers
from datasets import load_dataset

"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
LoraConfig,
get_peft_model,
Expand All @@ -15,45 +21,41 @@
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer

from utils.prompter import Prompter


def train(
# model/data params
base_model: str = "./models/base_models/your_base_model_dir",
data_path: str = "./data/your_data.json",
output_dir: str = "./outputs/your_version_dir",

base_model: str = "", # the only required argument
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./lora-alpaca",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
num_epochs: int = 10,
num_epochs: int = 3,
learning_rate: float = 3e-4,
cutoff_len: int = 512,
cutoff_len: int = 256,
val_set_size: int = 2000,

# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = ["q_proj", "v_proj",],

lora_target_modules: List[str] = [
"q_proj",
"v_proj",
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = True,
group_by_length: bool = False, # faster, but produces an odd training loss curve

# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true

# either training checkpoint or final adapter
resume_from_checkpoint: str = None,

# The prompt template to use, will default to alpaca.
prompt_template_name: str = "alpaca",
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
Expand Down Expand Up @@ -81,11 +83,13 @@ def train(
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size

prompter = Prompter(prompt_template_name)

# Configure device and distributed training
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
Expand All @@ -95,8 +99,8 @@ def train(

# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)

"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
Expand All @@ -113,21 +117,13 @@ def train(
)

tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id

print("pre-trained model's BOS EOS and PAD token id:",
bos, eos, pad, " => It should be 1,2,none")

tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference

def tokenize(prompt, add_eos_token=True):
def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
Expand Down Expand Up @@ -212,13 +208,18 @@ def generate_and_tokenize_prompt(data_point):
else:
print(f"Checkpoint {checkpoint_name} not found")

# Be more transparent about the % of trainable params.
model.print_trainable_parameters()
model.print_trainable_parameters() # Be more transparent about the % of trainable params.

if val_set_size > 0:
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt))
val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt))
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
Expand All @@ -235,18 +236,18 @@ def generate_and_tokenize_prompt(data_point):
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
warmup_ratio=0.1,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
eval_steps=50 if val_set_size > 0 else None,
save_steps=50,
output_dir=output_dir,
save_total_limit=3,
save_total_limit=5,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
Expand All @@ -273,7 +274,9 @@ def generate_and_tokenize_prompt(data_point):

model.save_pretrained(output_dir)

print("\n If there's a warning about missing keys above, please disregard :)")
print(
"\n If there's a warning about missing keys above, please disregard :)"
)


if __name__ == "__main__":
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
12 changes: 12 additions & 0 deletions resources/example_instruction_tune.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{
"instruction": "酒驾撞人要判多久?",
"input": "",
"output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。"
},
{
"instruction": "偷狗被派出所抓到,会怎么处理?",
"input": "",
"output": "取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。"
}
]
56 changes: 56 additions & 0 deletions scripts/finetune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
export WANDB_MODE=disabled # 禁用wandb

# 使用chinese-alpaca-plus-7b-merged模型在law_data.json数据集上finetune
experiment_name="chinese-alpaca-plus-7b-law-e1"

# 单卡或者模型并行
python finetune.py \
--base_model "minlik/chinese-alpaca-plus-7b-merged" \
--data_path "./data/finetune_law_data.json" \
--output_dir "./outputs/"${experiment_name} \
--batch_size 64 \
--micro_batch_size 8 \
--num_epochs 20 \
--learning_rate 3e-4 \
--cutoff_len 256 \
--val_set_size 0 \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--lora_target_modules "[q_proj,v_proj]" \
--train_on_inputs True \
--add_eos_token True \
--group_by_length False \
--wandb_project \
--wandb_run_name \
--wandb_watch \
--wandb_log_model \
--resume_from_checkpoint "./outputs/"${experiment_name} \
--prompt_template_name "alpaca" \


# 多卡数据并行
# WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \
# --base_model "minlik/chinese-alpaca-plus-7b-merged" \
# --data_path "./data/finetune_law_data.json" \
# --output_dir "./outputs/"${experiment_name} \
# --batch_size 64 \
# --micro_batch_size 8 \
# --num_epochs 20 \
# --learning_rate 3e-4 \
# --cutoff_len 256 \
# --val_set_size 0 \
# --lora_r 8 \
# --lora_alpha 16 \
# --lora_dropout 0.05 \
# --lora_target_modules "[q_proj,v_proj]" \
# --train_on_inputs True \
# --add_eos_token True \
# --group_by_length False \
# --wandb_project \
# --wandb_run_name \
# --wandb_watch \
# --wandb_log_model \
# --resume_from_checkpoint "./outputs/"${experiment_name} \
# --prompt_template_name "alpaca" \
Loading

0 comments on commit 87d0281

Please sign in to comment.