Skip to content

Commit

Permalink
update infer.ipynb
Browse files Browse the repository at this point in the history
Signed-off-by: mymusise <[email protected]>
  • Loading branch information
mymusise committed Mar 23, 2023
1 parent 3ab874f commit 19b626b
Showing 1 changed file with 93 additions and 56 deletions.
149 changes: 93 additions & 56 deletions infer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,67 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
"Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00, 1.23it/s]\n"
]
}
],
"source": [
"from transformers import HfArgumentParser\n",
"from modeling_chatglm import ChatGLMForConditionalGeneration\n",
"import torch\n",
"import transformers\n",
"from peft import get_peft_model, LoraConfig, TaskType\n",
"from dataclasses import dataclass, field"
"\n",
"\n",
"torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
"model = ChatGLMForConditionalGeneration.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True, device_map='auto')"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
]
}
],
"source": [
"from transformers import AutoTokenizer, AutoModel, TrainingArguments, AutoConfig\n",
"from modeling_chatglm import ChatGLMForConditionalGeneration\n",
"import torch\n",
"import torch.nn as nn\n",
"from peft import get_peft_model, LoraConfig, TaskType\n",
"from transformers import AutoTokenizer\n",
"\n",
"\n",
"torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
"model = ChatGLMForConditionalGeneration.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True, device_map='auto')"
"tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mymusise/pro/ChatGLM-Tuning/peft/tuners/lora.py:175: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from peft import get_peft_model, LoraConfig, TaskType\n",
"\n",
"peft_path = \"output/chatglm-lora.pt\"\n",
"\n",
"peft_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM, inference_mode=False,\n",
" task_type=TaskType.CAUSAL_LM, inference_mode=True,\n",
" r=8,\n",
" lora_alpha=32, lora_dropout=0.1\n",
")\n",
Expand All @@ -52,57 +74,72 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from tqdm import tqdm\n",
"\n",
"instructions = json.load(open(\"data/alpaca_data.json\"))\n",
"\n",
"instructions = [\n",
" {\n",
" 'instruction': \"下面哪个产品与其他不同\",\n",
" \"input\": \"知乎, 百度, 微博, 拼多多\",\n",
" \"output\": \"拼多多与其他产品不同,因为它是一个电子商务平台,而知乎、百度和微博都是社交平台/搜索引擎。\",\n",
" },\n",
" {\n",
" 'instruction': \"下面哪个产品与其他不同\",\n",
" \"input\": \"chatgpt, 文心, CPM, 拼多多\",\n",
" \"output\": \"拼多多与其他产品不同,因为它是一家电子商务公司,而ChatGPT、文心和CPM都不是电子商务公司。\",\n",
" }\n",
"]"
"instructions = json.load(open(\"data/alpaca_data.json\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mymusise/pro/ChatGLM-Tuning/transformers/generation/utils.py:1374: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Instruction: Give three tips for staying healthy.\n",
"Answer: 1. Eat a balanced diet.\n",
"2. Get regular exercise.\n",
"3. Stay hydrated.\n",
"### 1.Answer:\n",
" 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n",
"2. Exercise regularly to keep your body active and strong. \n",
"3. Get enough sleep and maintain a consistent sleep schedule. \n",
"\n",
"\n",
"Instruction: What are the three primary colors?\n",
"Answer: The three primary colors are red, blue, and yellow.\n",
"### 2.Answer:\n",
" The three primary colors are red, blue, and yellow. \n",
"\n",
"\n",
"Instruction: Describe the structure of an atom.\n",
"Answer: An atom is a small particle of an element, containing a small number of particles of the element. Each particle is made up of a single electron, which is surrounded by a cloud of negative charge. The cloud of negative charge is surrounded by a cloud of positive charge, which creates a neutral cloud. The neutral cloud is made up of a cloud of negative charge, which creates a cloud of positive charge. This process continues until all the atoms have been neutralized.\n",
"### 3.Answer:\n",
" An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom. \n",
"\n",
"\n"
]
}
],
"source": [
"answers = []\n",
"from cover_alpaca2jsonl import format_example\n",
"\n",
"\n",
"with torch.no_grad():\n",
" for idx, item in enumerate(instructions[:5]):\n",
" input_text = f\"### {idx+1}.Instruction:\\n{item['instruction']}\\n\\n\"\n",
" if item.get('input'):\n",
" input_text += f\"### {idx+1}.Input:\\n{item['input']}\\n\\n\"\n",
" input_text += f\"### {idx+1}.Response:\"\n",
" batch = tokenizer(input_text, return_tensors=\"pt\")\n",
" for idx, item in enumerate(instructions[:3]):\n",
" feature = format_example(item)\n",
" input_text = feature['context']\n",
" ids = tokenizer.encode(input_text)\n",
" input_ids = torch.LongTensor([ids])\n",
" out = model.generate(\n",
" input_ids=batch[\"input_ids\"],\n",
" attention_mask=torch.ones_like(batch[\"input_ids\"]).bool(),\n",
" max_length=512,\n",
" input_ids=input_ids,\n",
" max_length=150,\n",
" do_sample=False,\n",
" temperature=0\n",
" )\n",
" out_text = tokenizer.decode(out[0])\n",
Expand Down Expand Up @@ -130,7 +167,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0]"
"version": "3.8.12"
},
"orig_nbformat": 4,
"vscode": {
Expand Down

0 comments on commit 19b626b

Please sign in to comment.