From 4415b736c276e97ddbd3f858451f23aa5676ac40 Mon Sep 17 00:00:00 2001 From: XiangWang Date: Sun, 10 Jul 2022 12:35:39 +0800 Subject: [PATCH] =?UTF-8?q?feature(model=20and=20test):=20=E5=9C=A8model?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E4=B8=AD=E6=96=B0=E5=A2=9Eprompt=20bert?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E4=B8=94=E5=9C=A8test=E4=B8=AD=E6=96=B0?= =?UTF-8?q?=E5=A2=9Eprompt=20bert=E7=9A=84=E4=BE=8B=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ark_nlp/dataset/__init__.py | 2 + ark_nlp/factory/predictor/__init__.py | 1 + ...del.py => prompt_masked_language_model.py} | 0 ark_nlp/factory/task/__init__.py | 2 +- ...del.py => prompt_masked_language_model.py} | 0 ark_nlp/model/prompt/__init__.py | 0 ark_nlp/model/prompt/prompt_bert/__init__.py | 21 + ark_nlp/nn/__init__.py | 2 + ark_nlp/nn/prompt_bert.py | 8 +- test/prompt_bert.ipynb | 406 ++++++++++++++++++ 10 files changed, 437 insertions(+), 5 deletions(-) rename ark_nlp/factory/predictor/{prompt_mask_language_model.py => prompt_masked_language_model.py} (100%) rename ark_nlp/factory/task/{prompt_mask_language_model.py => prompt_masked_language_model.py} (100%) create mode 100644 ark_nlp/model/prompt/__init__.py create mode 100644 ark_nlp/model/prompt/prompt_bert/__init__.py create mode 100644 test/prompt_bert.ipynb diff --git a/ark_nlp/dataset/__init__.py b/ark_nlp/dataset/__init__.py index c13ec1a..2a6374a 100644 --- a/ark_nlp/dataset/__init__.py +++ b/ark_nlp/dataset/__init__.py @@ -15,3 +15,5 @@ from ark_nlp.dataset.biaffine_named_entity_recognition_dataset import BiaffineNERDataset from ark_nlp.dataset.span_named_entity_recognition_dataset import SpanNERDataset from ark_nlp.dataset.global_pointer_named_entity_recognition_dataset import GlobalPointerNERDataset + +from ark_nlp.dataset.prompt_dataset import PromptDataset diff --git a/ark_nlp/factory/predictor/__init__.py b/ark_nlp/factory/predictor/__init__.py index ef6e2b3..70851cb 100644 --- a/ark_nlp/factory/predictor/__init__.py +++ b/ark_nlp/factory/predictor/__init__.py @@ -9,3 +9,4 @@ from ark_nlp.factory.predictor.biaffine_named_entity_recognition import BiaffineNERPredictor from ark_nlp.factory.predictor.span_named_entity_recognition import SpanNERPredictor from ark_nlp.factory.predictor.global_pointer_named_entity_recognition import GlobalPointerNERPredictor +from ark_nlp.factory.predictor.prompt_masked_language_model import PromptMLMPredictor diff --git a/ark_nlp/factory/predictor/prompt_mask_language_model.py b/ark_nlp/factory/predictor/prompt_masked_language_model.py similarity index 100% rename from ark_nlp/factory/predictor/prompt_mask_language_model.py rename to ark_nlp/factory/predictor/prompt_masked_language_model.py diff --git a/ark_nlp/factory/task/__init__.py b/ark_nlp/factory/task/__init__.py index 6fed5e6..fb3d1e8 100644 --- a/ark_nlp/factory/task/__init__.py +++ b/ark_nlp/factory/task/__init__.py @@ -9,4 +9,4 @@ from ark_nlp.factory.task.named_entity_recognition import BiaffineNERTask from ark_nlp.factory.task.named_entity_recognition import GlobalPointerNERTask from ark_nlp.factory.task.named_entity_recognition import SpanNERTask -from ark_nlp.factory.task.prompt_mask_language_model import PromptMLMTask +from ark_nlp.factory.task.prompt_masked_language_model import PromptMLMTask diff --git a/ark_nlp/factory/task/prompt_mask_language_model.py b/ark_nlp/factory/task/prompt_masked_language_model.py similarity index 100% rename from ark_nlp/factory/task/prompt_mask_language_model.py rename to ark_nlp/factory/task/prompt_masked_language_model.py diff --git a/ark_nlp/model/prompt/__init__.py b/ark_nlp/model/prompt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ark_nlp/model/prompt/prompt_bert/__init__.py b/ark_nlp/model/prompt/prompt_bert/__init__.py new file mode 100644 index 0000000..a7cf7b6 --- /dev/null +++ b/ark_nlp/model/prompt/prompt_bert/__init__.py @@ -0,0 +1,21 @@ +from ark_nlp.dataset import PromptDataset as Dataset +from ark_nlp.dataset import PromptDataset as PromptBertDataset + +from ark_nlp.processor.tokenizer.transfomer import PromptMLMTransformerTokenizer as Tokenizer +from ark_nlp.processor.tokenizer.transfomer import PromptMLMTransformerTokenizer as PromptBertTokenizer +from ark_nlp.processor.tokenizer.transfomer import PromptMLMTransformerTokenizer + +from ark_nlp.nn import BertConfig as PromptBertConfig +from ark_nlp.nn import BertConfig as ModuleConfig + +from ark_nlp.nn import BertForPromptMaskedLM as PromptBert +from ark_nlp.nn import BertForPromptMaskedLM as Module + +from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_model_optimizer +from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_prompt_bert_optimizer + +from ark_nlp.factory.task import PromptMLMTask as Task +from ark_nlp.factory.task import PromptMLMTask as PromptBertMLMTask + +from ark_nlp.factory.predictor import PromptMLMPredictor as Predictor +from ark_nlp.factory.predictor import PromptMLMPredictor as PromptBertMLMPredictor diff --git a/ark_nlp/nn/__init__.py b/ark_nlp/nn/__init__.py index 234dd58..10b9b68 100644 --- a/ark_nlp/nn/__init__.py +++ b/ark_nlp/nn/__init__.py @@ -13,6 +13,8 @@ from ark_nlp.nn.global_pointer_bert import GlobalPointerBert from ark_nlp.nn.crf_bert import CrfBert +from ark_nlp.nn.prompt_bert import BertForPromptMaskedLM + from transformers import BertConfig from ark_nlp.nn.configuration import ErnieConfig from ark_nlp.nn.configuration.configuration_nezha import NeZhaConfig diff --git a/ark_nlp/nn/prompt_bert.py b/ark_nlp/nn/prompt_bert.py index bc637ad..7c5f8ca 100644 --- a/ark_nlp/nn/prompt_bert.py +++ b/ark_nlp/nn/prompt_bert.py @@ -42,9 +42,9 @@ def forward(self, hidden_states): return hidden_states -class BertForMaskedLM(Bert): +class BertForPromptMaskedLM(Bert): """ - 基于BERT的mlm任务 + 针对prompt的基于BERT的mlm任务 :param config: (obejct) 模型的配置对象 :param bert_trained: (bool) bert参数是否可训练, 默认可训练 @@ -60,7 +60,7 @@ def __init__( config, encoder_trained=True ): - super(BertForMaskedLM, self).__init__(config) + super(BertForPromptMaskedLM, self).__init__(config) self.bert = BertModel(config, add_pooling_layer=False) @@ -104,7 +104,7 @@ def forward( sequence_output = outputs[0] - sequence_output = BertForMaskedLM._batch_gather(sequence_output, mask_position) + sequence_output = BertForPromptMaskedLM._batch_gather(sequence_output, mask_position) batch_size, _, hidden_size = sequence_output.shape diff --git a/test/prompt_bert.ipynb b/test/prompt_bert.ipynb new file mode 100644 index 0000000..f431dba --- /dev/null +++ b/test/prompt_bert.ipynb @@ -0,0 +1,406 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "db5e7bf8-7b15-46af-8418-541ab861fbc4", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import os\n", + "import json\n", + "import jieba\n", + "import torch\n", + "import pickle\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import pandas as pd\n", + "\n", + "from ark_nlp.model.prompt.prompt_bert import Module\n", + "from ark_nlp.model.prompt.prompt_bert import ModuleConfig\n", + "from ark_nlp.model.prompt.prompt_bert import Dataset\n", + "from ark_nlp.model.prompt.prompt_bert import Task\n", + "from ark_nlp.model.prompt.prompt_bert import get_default_model_optimizer\n", + "from ark_nlp.model.prompt.prompt_bert import Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17d9f5bd-ac77-41fb-804a-1df52dec4ed5", + "metadata": {}, + "outputs": [], + "source": [ + "# 目录地址\n", + "# 数据集下载地址:https://bj.bcebos.com/paddlenlp/paddlenlp/datasets/nptag_dataset.tar.gz\n", + "\n", + "train_data_path = '../data/source_datasets/nptag_dataset/train.txt'\n", + "dev_data_path = '../data/source_datasets/nptag_dataset/dev.txt'\n", + "name_category_map_path = '../data/source_datasets/nptag_dataset/name_category_map.json'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f01b9695-284f-4710-bfd1-85a968435c6c", + "metadata": {}, + "outputs": [], + "source": [ + "# 预训练模型地址\n", + "module_path = 'nghuyong/ernie-1.0'" + ] + }, + { + "cell_type": "markdown", + "id": "9a133533-cacf-4f29-bbd8-01b117f4ac0d", + "metadata": {}, + "source": [ + "### 一、数据读入与处理" + ] + }, + { + "cell_type": "markdown", + "id": "0f44393d-67ac-4402-874d-5e8c6481cb0a", + "metadata": { + "tags": [] + }, + "source": [ + "#### 1. 数据读入" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2cf3cd4-754d-4360-adc7-58c70f678fcb", + "metadata": {}, + "outputs": [], + "source": [ + "train_data_df = pd.read_csv(train_data_path, sep='\\t', names=['text', 'label'])\n", + "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', names=['text', 'label'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d031a73d-c8ae-4b90-9f40-4c67d9d0729a", + "metadata": {}, + "outputs": [], + "source": [ + "name_category_map = json.load(open(name_category_map_path, 'r', encoding='utf-8'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "780193a2-8c05-4f6f-89fe-4826d232790e", + "metadata": {}, + "outputs": [], + "source": [ + "# 设置prompt\n", + "mask_tokens = [\"[MASK]\"] * 5\n", + "prompt = ['是'] + mask_tokens" + ] + }, + { + "cell_type": "markdown", + "id": "b3b2dffc-19c9-4c03-94fd-46851d8999c6", + "metadata": {}, + "source": [ + "#### 2. 词典创建和生成分词器" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d06d213c-f177-4c34-b901-38dfd24b8e5d", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = Tokenizer(module_path, 100)" + ] + }, + { + "cell_type": "markdown", + "id": "0fea731c-211d-46b6-8d7e-74cfb2435f5e", + "metadata": {}, + "source": [ + "#### 3. 对齐label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be715b41-3169-4bae-bc00-7a849d486681", + "metadata": {}, + "outputs": [], + "source": [ + "# 由于prompt中的[MASK]数量一定,所以需要对齐到\n", + "\n", + "label2newlabel = dict()\n", + "\n", + "for _k, _ in name_category_map.items():\n", + " _term = _k\n", + " label2newlabel[_k] = ''.join(tokenizer.tokenize(_term) + ['[PAD]'] * (5 - len(tokenizer.tokenize(_term))))\n", + " \n", + "label2newlabel['海绵蛋糕'] = '海绵蛋糕[PAD]'\n", + " \n", + "train_data_df['label'] = train_data_df['label'].apply(lambda x: label2newlabel[x])\n", + "dev_data_df['label'] = dev_data_df['label'].apply(lambda x: label2newlabel[x])\n", + "\n", + "categories = [_v for _, _v in label2newlabel.items()]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "275fb9bd-7d24-42f7-aef9-f9c3909d4b88", + "metadata": {}, + "outputs": [], + "source": [ + "prompt_train_dataset = Dataset(train_data_df, prompt=prompt, categories=categories)\n", + "prompt_dev_dataset = Dataset(dev_data_df, prompt=prompt, categories=categories)" + ] + }, + { + "cell_type": "markdown", + "id": "bba86db9-2d80-4d71-9a29-666d8bfe6c5b", + "metadata": {}, + "source": [ + "#### 4. ID化" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39cc52c4-6dd8-4fc2-9117-96151c1ec943", + "metadata": {}, + "outputs": [], + "source": [ + "prompt_train_dataset.convert_to_ids(tokenizer)\n", + "prompt_dev_dataset.convert_to_ids(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "4a71e024-2e54-4dcd-857b-a13fa50c7879", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### 二、模型构建" + ] + }, + { + "cell_type": "markdown", + "id": "55a19c77-7b1b-4ba8-a714-8558ca21ba57", + "metadata": {}, + "source": [ + "#### 1. 模型参数设置" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "899b7861-e977-4c5e-aad8-be9a6a7160e1", + "metadata": {}, + "outputs": [], + "source": [ + "config = ModuleConfig.from_pretrained(\n", + " module_path,\n", + " num_labels=tokenizer.vocab.vocab_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "54efa151-58bd-4ffc-ab09-e84fde34ee58", + "metadata": { + "tags": [] + }, + "source": [ + "#### 2. 模型创建" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d71c4208-c24d-4d80-bf06-30f5b21f1d0c", + "metadata": {}, + "outputs": [], + "source": [ + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c4487fe-a5b0-44b7-8e05-31bdb38fbb6f", + "metadata": {}, + "outputs": [], + "source": [ + "dl_module = Module.from_pretrained(\n", + " module_path,\n", + " config=config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6cd1bca5-5422-4ac7-a8e6-46611408be12", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### 三、任务构建" + ] + }, + { + "cell_type": "markdown", + "id": "d05cd18e-d76a-4f2c-a990-9ed5b9fdedff", + "metadata": {}, + "source": [ + "#### 1. 任务参数和必要部件设定" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ddd75bc-a146-4eb8-9ec1-b94dc37fed4f", + "metadata": {}, + "outputs": [], + "source": [ + "# 设置运行次数\n", + "num_epoches = 10\n", + "batch_size = 32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4be0754-f92c-4ebb-8ca3-462e3d0f6b7f", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = get_default_model_optimizer(dl_module)" + ] + }, + { + "cell_type": "markdown", + "id": "16a4ff54-601f-4a9b-a564-2821fb4edbe5", + "metadata": {}, + "source": [ + "#### 2. 任务创建" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13f0046f-007d-40c1-93c0-efb8e541a079", + "metadata": {}, + "outputs": [], + "source": [ + "model = Task(dl_module, optimizer, 'ce', cuda_device=0, tokenizer=tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "563f7718-8f83-4a0c-a44d-13cba7616b2b", + "metadata": { + "tags": [] + }, + "source": [ + "#### 3. 训练" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee3108d3-2648-43d8-8378-2e8937aa0775", + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(\n", + " prompt_train_dataset,\n", + " prompt_dev_dataset,\n", + " lr=2e-5,\n", + " epochs=10,\n", + " batch_size=batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7c10f553-0bec-45a3-98ce-22e6a9724161", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### 四、模型验证与保存" + ] + }, + { + "cell_type": "markdown", + "id": "a9287f8b-c25c-47fb-9eaf-b044db96d68c", + "metadata": { + "tags": [] + }, + "source": [ + "#### 1. 模型验证" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40de0353-cc4a-44d1-bc84-4ba20fad2a4e", + "metadata": {}, + "outputs": [], + "source": [ + "from ark_nlp.model.prompt.prompt_bert import Predictor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22def42a-ffec-4907-b358-51a65b11d998", + "metadata": {}, + "outputs": [], + "source": [ + "prompt_instance = Predictor(model.module, tokenizer, prompt_train_dataset.cat2id, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2fc3162e-3117-4b31-ab99-08f86871fda1", + "metadata": {}, + "outputs": [], + "source": [ + "prompt_instance.predict_one_sample('美国队长3', topk=15, return_proba=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}