diff --git a/.gitignore b/.gitignore index 41da0ad..0353852 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,19 @@ -./data +__pycache__/ +*.npy +*.npz +*.pyc +*.pyd +*.so +*.ipynb +.ipynb_checkpoints +models/base_models/* +!models/base_models/.gitkeep +models/lora_weights/* +!models/lora_weights/.gitkeep +outputs/* +!outputs/.gitkeep +data/* +!data/.gitkeep +wandb/ +flagged/ +.DS_Store diff --git a/README.md b/README.md index 1896590..e8f366a 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。 本项目持续开展,法律领域数据集及系列模型后续相继开源,敬请关注。 ## 更新 +- 🛠️ 2023/05/22:项目主分支结构调整,详见[项目结构](https://github.com/pengxiao-song/LaWGPT#项目结构) - 🪴 2023/05/15:发布 [中文法律数据源汇总(Awesome Chinese Legal Resources)](https://github.com/pengxiao-song/awesome-chinese-legal-resources) 和 [法律领域词表](https://github.com/pengxiao-song/LaWGPT/blob/main/resources/legal_vocab.txt) @@ -44,13 +45,25 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。 1. 准备代码,创建环境 ```bash + # 下载代码 git clone git@github.com:pengxiao-song/LaWGPT.git cd LaWGPT + + # 创建环境 + conda create -n lawgpt python=3.10 -y conda activate lawgpt pip install -r requirements.txt + + # 启动可视化脚本(自动下载预训练模型约15GB) + bash ./scripts/webui.sh ``` -2. 合并模型权重(可选) +2. 访问 http://127.0.0.1:7860 : +

+ +

+ +3. 合并模型权重(可选) **如果您想使用 LaWGPT-7B-alpha 模型,可跳过改步,直接进入步骤3.** @@ -61,44 +74,28 @@ LaWGPT 是一系列基于中文法律知识的开源大语言模型。 本项目给出[合并方式](https://github.com/pengxiao-song/LaWGPT/wiki/%E6%A8%A1%E5%9E%8B%E5%90%88%E5%B9%B6),请各位获取原版权重后自行重构模型。 -3. 启动示例 - - 启动本地服务: - - ```bash - conda activate lawgpt - cd LaWGPT - sh src/scripts/generate.sh - ``` - - 接入服务: - -

- -

- - ## 项目结构 -```bash +```bash LaWGPT -├── assets # 项目静态资源 -├── data # 语料及精调数据 -├── tools # 数据清洗等工具 +├── assets # 静态资源 +├── resources # 项目资源 +├── models # 基座模型及 lora 权重 +│ ├── base_models +│ └── lora_weights +├── outputs # 指令微调的输出权重 +├── data # 实验数据 +├── scripts # 脚本目录 +│ ├── finetune.sh # 指令微调脚本 +│ └── webui.sh # 启动服务脚本 +├── templates # prompt 模板 +├── tools # 工具包 +├── utils +├── train_clm.py # 二次训练 +├── finetune.py # 指令微调 +├── webui.py # 启动服务 ├── README.md -├── requirements.txt -└── src # 源码 - ├── finetune.py - ├── generate.py - ├── models # 基座模型及 Lora 权重 - │ ├── base_models - │ └── lora_weights - ├── outputs - ├── scripts # 脚本文件 - │ ├── finetune.sh # 指令微调 - │ └── generate.sh # 服务创建 - ├── templates - └── utils +└── requirements.txt ``` @@ -119,13 +116,13 @@ LawGPT 系列模型的训练过程分为两个阶段: ### 二次训练流程 -1. 参考 `src/data/example_instruction_train.json` 构造二次训练数据集 -2. 运行 `src/scripts/train_lora.sh` +1. 参考 `resources/example_instruction_train.json` 构造二次训练数据集 +2. 运行 `scripts/train_clm.sh` ### 指令精调步骤 -1. 参考 `src/data/example_instruction_tune.json` 构造指令微调数据集 -2. 运行 `src/scripts/finetune.sh` +1. 参考 `resources/example_instruction_tune.json` 构造指令微调数据集 +2. 运行 `scripts/finetune.sh` ### 计算资源 @@ -222,4 +219,4 @@ LawGPT 系列模型的训练过程分为两个阶段: ## 引用 -如果您觉得我们的工作对您有所帮助,请考虑引用该项目 \ No newline at end of file +如果您觉得我们的工作对您有所帮助,请考虑引用该项目 diff --git a/src/models/base_models/.gitkeep b/data/.gitkeep similarity index 100% rename from src/models/base_models/.gitkeep rename to data/.gitkeep diff --git a/src/finetune.py b/finetune.py similarity index 85% rename from src/finetune.py rename to finetune.py index ff7c0b3..4059fc7 100644 --- a/src/finetune.py +++ b/finetune.py @@ -7,6 +7,12 @@ import transformers from datasets import load_dataset +""" +Unused imports: +import torch.nn as nn +import bitsandbytes as bnb +""" + from peft import ( LoraConfig, get_peft_model, @@ -15,45 +21,41 @@ set_peft_model_state_dict, ) from transformers import LlamaForCausalLM, LlamaTokenizer + from utils.prompter import Prompter def train( # model/data params - base_model: str = "./models/base_models/your_base_model_dir", - data_path: str = "./data/your_data.json", - output_dir: str = "./outputs/your_version_dir", - + base_model: str = "", # the only required argument + data_path: str = "yahma/alpaca-cleaned", + output_dir: str = "./lora-alpaca", # training hyperparams batch_size: int = 128, micro_batch_size: int = 4, - num_epochs: int = 10, + num_epochs: int = 3, learning_rate: float = 3e-4, - cutoff_len: int = 512, + cutoff_len: int = 256, val_set_size: int = 2000, - # lora hyperparams lora_r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.05, - lora_target_modules: List[str] = ["q_proj", "v_proj",], - + lora_target_modules: List[str] = [ + "q_proj", + "v_proj", + ], # llm hyperparams train_on_inputs: bool = True, # if False, masks out inputs in loss add_eos_token: bool = True, group_by_length: bool = False, # faster, but produces an odd training loss curve - # wandb params wandb_project: str = "", wandb_run_name: str = "", wandb_watch: str = "", # options: false | gradients | all wandb_log_model: str = "", # options: false | true - - # either training checkpoint or final adapter - resume_from_checkpoint: str = None, - - # The prompt template to use, will default to alpaca. - prompt_template_name: str = "alpaca", + resume_from_checkpoint: str = None, # either training checkpoint or final adapter + prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. ): if int(os.environ.get("LOCAL_RANK", 0)) == 0: print( @@ -81,11 +83,13 @@ def train( f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" f"prompt template: {prompt_template_name}\n" ) + assert ( + base_model + ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" gradient_accumulation_steps = batch_size // micro_batch_size prompter = Prompter(prompt_template_name) - # Configure device and distributed training device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 @@ -95,8 +99,8 @@ def train( # Check if parameter passed or if set within environ use_wandb = len(wandb_project) > 0 or ( - "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0) - + "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 + ) # Only overwrite environ if wandb param passed if len(wandb_project) > 0: os.environ["WANDB_PROJECT"] = wandb_project @@ -113,21 +117,13 @@ def train( ) tokenizer = LlamaTokenizer.from_pretrained(base_model) - tokenizer.bos_token_id = 1 - tokenizer.eos_token_id = 2 - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - - print("pre-trained model's BOS EOS and PAD token id:", - bos, eos, pad, " => It should be 1,2,none") tokenizer.pad_token_id = ( 0 # unk. we want this to be different from the eos token ) tokenizer.padding_side = "left" # Allow batched inference - def tokenize(prompt, add_eos_token=True): + def tokenize(prompt): # there's probably a way to do this with the tokenizer settings # but again, gotta move fast result = tokenizer( @@ -212,13 +208,18 @@ def generate_and_tokenize_prompt(data_point): else: print(f"Checkpoint {checkpoint_name} not found") - # Be more transparent about the % of trainable params. - model.print_trainable_parameters() + model.print_trainable_parameters() # Be more transparent about the % of trainable params. if val_set_size > 0: - train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42) - train_data = (train_val["train"].shuffle().map(generate_and_tokenize_prompt)) - val_data = (train_val["test"].shuffle().map(generate_and_tokenize_prompt)) + train_val = data["train"].train_test_split( + test_size=val_set_size, shuffle=True, seed=42 + ) + train_data = ( + train_val["train"].shuffle().map(generate_and_tokenize_prompt) + ) + val_data = ( + train_val["test"].shuffle().map(generate_and_tokenize_prompt) + ) else: train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) val_data = None @@ -235,7 +236,7 @@ def generate_and_tokenize_prompt(data_point): args=transformers.TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, - warmup_steps=100, + warmup_ratio=0.1, num_train_epochs=num_epochs, learning_rate=learning_rate, fp16=True, @@ -243,10 +244,10 @@ def generate_and_tokenize_prompt(data_point): optim="adamw_torch", evaluation_strategy="steps" if val_set_size > 0 else "no", save_strategy="steps", - eval_steps=100 if val_set_size > 0 else None, - save_steps=100, + eval_steps=50 if val_set_size > 0 else None, + save_steps=50, output_dir=output_dir, - save_total_limit=3, + save_total_limit=5, load_best_model_at_end=True if val_set_size > 0 else False, ddp_find_unused_parameters=False if ddp else None, group_by_length=group_by_length, @@ -273,7 +274,9 @@ def generate_and_tokenize_prompt(data_point): model.save_pretrained(output_dir) - print("\n If there's a warning about missing keys above, please disregard :)") + print( + "\n If there's a warning about missing keys above, please disregard :)" + ) if __name__ == "__main__": diff --git a/src/models/lora_weights/.gitkeep b/models/base_models/.gitkeep similarity index 100% rename from src/models/lora_weights/.gitkeep rename to models/base_models/.gitkeep diff --git a/src/outputs/.gitkeep b/models/lora_weights/.gitkeep similarity index 100% rename from src/outputs/.gitkeep rename to models/lora_weights/.gitkeep diff --git a/src/utils/__init__.py b/outputs/.gitkeep similarity index 100% rename from src/utils/__init__.py rename to outputs/.gitkeep diff --git a/src/data/example_instruction_train.json b/resources/example_instruction_train.json similarity index 100% rename from src/data/example_instruction_train.json rename to resources/example_instruction_train.json diff --git a/resources/example_instruction_tune.json b/resources/example_instruction_tune.json new file mode 100644 index 0000000..496d4fd --- /dev/null +++ b/resources/example_instruction_tune.json @@ -0,0 +1,12 @@ +[ + { + "instruction": "酒驾撞人要判多久?", + "input": "", + "output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。" + }, + { + "instruction": "偷狗被派出所抓到,会怎么处理?", + "input": "", + "output": "取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。" + } +] \ No newline at end of file diff --git a/scripts/finetune.sh b/scripts/finetune.sh new file mode 100644 index 0000000..e8b5614 --- /dev/null +++ b/scripts/finetune.sh @@ -0,0 +1,56 @@ +#!/bin/bash +export WANDB_MODE=disabled # 禁用wandb + +# 使用chinese-alpaca-plus-7b-merged模型在law_data.json数据集上finetune +experiment_name="chinese-alpaca-plus-7b-law-e1" + +# 单卡或者模型并行 +python finetune.py \ + --base_model "minlik/chinese-alpaca-plus-7b-merged" \ + --data_path "./data/finetune_law_data.json" \ + --output_dir "./outputs/"${experiment_name} \ + --batch_size 64 \ + --micro_batch_size 8 \ + --num_epochs 20 \ + --learning_rate 3e-4 \ + --cutoff_len 256 \ + --val_set_size 0 \ + --lora_r 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --lora_target_modules "[q_proj,v_proj]" \ + --train_on_inputs True \ + --add_eos_token True \ + --group_by_length False \ + --wandb_project \ + --wandb_run_name \ + --wandb_watch \ + --wandb_log_model \ + --resume_from_checkpoint "./outputs/"${experiment_name} \ + --prompt_template_name "alpaca" \ + + +# 多卡数据并行 +# WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \ +# --base_model "minlik/chinese-alpaca-plus-7b-merged" \ +# --data_path "./data/finetune_law_data.json" \ +# --output_dir "./outputs/"${experiment_name} \ +# --batch_size 64 \ +# --micro_batch_size 8 \ +# --num_epochs 20 \ +# --learning_rate 3e-4 \ +# --cutoff_len 256 \ +# --val_set_size 0 \ +# --lora_r 8 \ +# --lora_alpha 16 \ +# --lora_dropout 0.05 \ +# --lora_target_modules "[q_proj,v_proj]" \ +# --train_on_inputs True \ +# --add_eos_token True \ +# --group_by_length False \ +# --wandb_project \ +# --wandb_run_name \ +# --wandb_watch \ +# --wandb_log_model \ +# --resume_from_checkpoint "./outputs/"${experiment_name} \ +# --prompt_template_name "alpaca" \ \ No newline at end of file diff --git a/src/scripts/train.sh b/scripts/train_clm.sh similarity index 64% rename from src/scripts/train.sh rename to scripts/train_clm.sh index 56532a2..cb45cb6 100644 --- a/src/scripts/train.sh +++ b/scripts/train_clm.sh @@ -1,9 +1,9 @@ #!/bin/bash -WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train_lora.py \ - --base_model '../models/base_models/chinese_llama_7b' \ - --data_path '' \ - --output_dir '../models/lora_weights' \ +WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1235 train_clm.py \ + --base_model './models/base_models/chinese_llama_7b' \ + --data_path './data/train_clm_data.json' \ + --output_dir './outputs/train-clm' \ --batch_size 128 \ --micro_batch_size 8 \ --num_epochs 1 \ @@ -17,4 +17,4 @@ WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 -- --train_on_inputs True \ --add_eos_token True \ --group_by_length True \ - --resume_from_checkpoint '../models/lora_weights' + --resume_from_checkpoint './outputs/train-clm' \ No newline at end of file diff --git a/scripts/webui.sh b/scripts/webui.sh new file mode 100644 index 0000000..097cd3e --- /dev/null +++ b/scripts/webui.sh @@ -0,0 +1,21 @@ +#!/bin/bash + + +# 使用huggingface上已经训练好的模型 +python webui.py \ + --load_8bit True \ + --base_model 'minlik/chinese-llama-7b-merged' \ + --lora_weights 'entity303/lawgpt-lora-7b' \ + --prompt_template "law_template" \ + --server_name "0.0.0.0" \ + --share_gradio Ture \ + + +# 使用自己finetune的lora, 把自己的模型放到对应目录即可 +# python webui.py \ +# --load_8bit True \ +# --base_model 'minlik/chinese-alpaca-plus-7b-merged' \ +# --lora_weights './outputs/chinese-alpaca-plus-7b-law-e1' \ +# --prompt_template "alpaca" \ +# --server_name "0.0.0.0" \ +# --share_gradio Ture \ \ No newline at end of file diff --git a/src/data/example_instruction_tune.json b/src/data/example_instruction_tune.json deleted file mode 100644 index f6db36f..0000000 --- a/src/data/example_instruction_tune.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "instruction": "酒驾撞人要判多久?", - "input": "", - "output": "《刑法》第一百三十三条规定:违反交通运输管理法规,因而发生重大事故,致人重伤、死广或者使公私财产遭受重大损失的,处三年以下有期徒刑或者拘役;交通运输肇事后逃逸,致人重伤的,处三年以上七年以下有期徒刑。交通肇事后逃逸致人死亡的,处七年以上有期徒刑。" - }, - {"instruction":"偷狗被派出所抓到,会怎么处理?", - "input":"", - "output":"取决于偷盗的狗的价值,可能按盗窃罪的罪名处理。如果价值达到犯罪标准,就构成盗窃罪,要承担刑事责任;如果不到犯罪标准,就是治安处罚、罚款或者拘留治安处罚这会涉嫌构成盗窃。如果不到一千元,则不会构成犯罪。如果超过一千元,则可能会是构成犯罪的。"} -] \ No newline at end of file diff --git a/src/scripts/finetune.sh b/src/scripts/finetune.sh deleted file mode 100644 index 14bf7a1..0000000 --- a/src/scripts/finetune.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -WORLD_SIZE=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=1234 finetune.py \ - --base_model 'minlik/chinese-llama-7b-merged' \ - --data_path '' \ - --output_dir './outputs/LawGPT' \ - --prompt_template_name 'law_template' \ - --micro_batch_size 16 \ - --batch_size 128 \ - --num_epochs 3 \ - --val_set_size 10000 \ - --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \ - --lora_r 16 \ - --lora_alpha 32 \ - --learning_rate 3e-4 \ - --cutoff_len 512 \ - --resume_from_checkpoint './outputs/LawGPT' \ \ No newline at end of file diff --git a/src/scripts/generate.sh b/src/scripts/generate.sh deleted file mode 100644 index 283007e..0000000 --- a/src/scripts/generate.sh +++ /dev/null @@ -1,7 +0,0 @@ - -CUDA_VISIBLE_DEVICES=1 python generate.py \ - --load_8bit \ - --base_model 'minlik/chinese-llama-7b-merged' \ - --lora_weights 'entity303/lawgpt-lora-7b' \ - --prompt_template 'law_template' \ - --share_gradio diff --git a/templates/alpaca.json b/templates/alpaca.json new file mode 100644 index 0000000..e486439 --- /dev/null +++ b/templates/alpaca.json @@ -0,0 +1,6 @@ +{ + "description": "Template used by Alpaca-LoRA.", + "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", + "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", + "response_split": "### Response:" +} diff --git a/src/templates/law_template.json b/templates/law_template.json similarity index 100% rename from src/templates/law_template.json rename to templates/law_template.json diff --git a/src/train_lora.py b/train_clm.py similarity index 100% rename from src/train_lora.py rename to train_clm.py diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/callbacks.py b/utils/callbacks.py similarity index 100% rename from src/utils/callbacks.py rename to utils/callbacks.py diff --git a/src/utils/evaluate.py b/utils/evaluate.py similarity index 100% rename from src/utils/evaluate.py rename to utils/evaluate.py diff --git a/src/utils/merge.py b/utils/merge.py similarity index 100% rename from src/utils/merge.py rename to utils/merge.py diff --git a/src/utils/prompter.py b/utils/prompter.py similarity index 100% rename from src/utils/prompter.py rename to utils/prompter.py diff --git a/src/generate.py b/webui.py similarity index 77% rename from src/generate.py rename to webui.py index bbe5513..453043a 100644 --- a/src/generate.py +++ b/webui.py @@ -6,7 +6,7 @@ import torch import transformers from peft import PeftModel -from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModel, AutoTokenizer, AutoModelForCausalLM from utils.callbacks import Iteratorize, Stream from utils.prompter import Prompter @@ -19,14 +19,14 @@ try: if torch.backends.mps.is_available(): device = "mps" -except: # noqa: E722 +except: pass def main( load_8bit: bool = False, base_model: str = "", - lora_weights: str = "tloen/alpaca-lora-7b", + lora_weights: str = "", prompt_template: str = "", # The prompt template to use, will default to alpaca. server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0. share_gradio: bool = False, @@ -45,33 +45,41 @@ def main( torch_dtype=torch.float16, device_map="auto", ) - model = PeftModel.from_pretrained( - model, - lora_weights, - torch_dtype=torch.float16, - ) + try: + model = PeftModel.from_pretrained( + model, + lora_weights, + torch_dtype=torch.float16, + ) + except: + print("*"*50, "\n Attention! No Lora Weights \n", "*"*50) elif device == "mps": model = LlamaForCausalLM.from_pretrained( base_model, device_map={"": device}, torch_dtype=torch.float16, ) - model = PeftModel.from_pretrained( - model, - lora_weights, - device_map={"": device}, - torch_dtype=torch.float16, - ) + try: + model = PeftModel.from_pretrained( + model, + lora_weights, + device_map={"": device}, + torch_dtype=torch.float16, + ) + except: + print("*"*50, "\n Attention! No Lora Weights \n", "*"*50) else: model = LlamaForCausalLM.from_pretrained( - base_model, - device_map={"": device}, low_cpu_mem_usage=True - ) - model = PeftModel.from_pretrained( - model, - lora_weights, - device_map={"": device}, + base_model, device_map={"": device}, low_cpu_mem_usage=True ) + try: + model = PeftModel.from_pretrained( + model, + lora_weights, + device_map={"": device}, + ) + except: + print("*"*50, "\n Attention! No Lora Weights \n", "*"*50) # unwind broken decapoda-research config model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk @@ -87,15 +95,16 @@ def main( def evaluate( instruction, - input=None, + # input=None, temperature=0.1, top_p=0.75, top_k=40, - num_beams=1, - max_new_tokens=256, - stream_output=True, + num_beams=4, + max_new_tokens=128, + stream_output=False, **kwargs, ): + input=None prompt = prompter.generate_prompt(instruction, input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) @@ -144,6 +153,7 @@ def generate_with_streaming(**kwargs): break yield prompter.get_response(decoded_output) + print(decoded_output) return # early return for stream_output # Without streaming @@ -157,6 +167,7 @@ def generate_with_streaming(**kwargs): ) s = generation_output.sequences[0] output = tokenizer.decode(s) + print(output) yield prompter.get_response(output) gr.Interface( @@ -165,11 +176,11 @@ def generate_with_streaming(**kwargs): gr.components.Textbox( lines=2, label="Instruction", - placeholder="Tell me about alpacas.", + placeholder="此处输入法律相关问题", ), - gr.components.Textbox(lines=2, label="Input", placeholder="none"), + # gr.components.Textbox(lines=2, label="Input", placeholder="none"), gr.components.Slider( - minimum=0, maximum=1, value=1.0, label="Temperature" + minimum=0, maximum=1, value=0.1, label="Temperature" ), gr.components.Slider( minimum=0, maximum=1, value=0.75, label="Top p" @@ -178,23 +189,22 @@ def generate_with_streaming(**kwargs): minimum=0, maximum=100, step=1, value=40, label="Top k" ), gr.components.Slider( - minimum=1, maximum=4, step=1, value=4, label="Beams" + minimum=1, maximum=4, step=1, value=1, label="Beams" ), gr.components.Slider( minimum=1, maximum=2000, step=1, value=256, label="Max tokens" ), - gr.components.Checkbox(label="Stream output", value=True), + gr.components.Checkbox(label="Stream output", value=True), ], outputs=[ gr.inputs.Textbox( - lines=5, + lines=8, label="Output", ) ], title="🦙🌲 LaWGPT", - description="", # noqa: E501 + description="", ).queue().launch(server_name="0.0.0.0", share=share_gradio) - # Old testing code follows. if __name__ == "__main__":