Skip to content

Commit

Permalink
feature(model and test): 在model模块中新增prompt bert,并且在test中新增prompt bert的例子
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangking committed Jul 10, 2022
1 parent 7923782 commit 4415b73
Show file tree
Hide file tree
Showing 10 changed files with 437 additions and 5 deletions.
2 changes: 2 additions & 0 deletions ark_nlp/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
from ark_nlp.dataset.biaffine_named_entity_recognition_dataset import BiaffineNERDataset
from ark_nlp.dataset.span_named_entity_recognition_dataset import SpanNERDataset
from ark_nlp.dataset.global_pointer_named_entity_recognition_dataset import GlobalPointerNERDataset

from ark_nlp.dataset.prompt_dataset import PromptDataset
1 change: 1 addition & 0 deletions ark_nlp/factory/predictor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from ark_nlp.factory.predictor.biaffine_named_entity_recognition import BiaffineNERPredictor
from ark_nlp.factory.predictor.span_named_entity_recognition import SpanNERPredictor
from ark_nlp.factory.predictor.global_pointer_named_entity_recognition import GlobalPointerNERPredictor
from ark_nlp.factory.predictor.prompt_masked_language_model import PromptMLMPredictor
2 changes: 1 addition & 1 deletion ark_nlp/factory/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from ark_nlp.factory.task.named_entity_recognition import BiaffineNERTask
from ark_nlp.factory.task.named_entity_recognition import GlobalPointerNERTask
from ark_nlp.factory.task.named_entity_recognition import SpanNERTask
from ark_nlp.factory.task.prompt_mask_language_model import PromptMLMTask
from ark_nlp.factory.task.prompt_masked_language_model import PromptMLMTask
Empty file.
21 changes: 21 additions & 0 deletions ark_nlp/model/prompt/prompt_bert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from ark_nlp.dataset import PromptDataset as Dataset
from ark_nlp.dataset import PromptDataset as PromptBertDataset

from ark_nlp.processor.tokenizer.transfomer import PromptMLMTransformerTokenizer as Tokenizer
from ark_nlp.processor.tokenizer.transfomer import PromptMLMTransformerTokenizer as PromptBertTokenizer
from ark_nlp.processor.tokenizer.transfomer import PromptMLMTransformerTokenizer

from ark_nlp.nn import BertConfig as PromptBertConfig
from ark_nlp.nn import BertConfig as ModuleConfig

from ark_nlp.nn import BertForPromptMaskedLM as PromptBert
from ark_nlp.nn import BertForPromptMaskedLM as Module

from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_model_optimizer
from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_prompt_bert_optimizer

from ark_nlp.factory.task import PromptMLMTask as Task
from ark_nlp.factory.task import PromptMLMTask as PromptBertMLMTask

from ark_nlp.factory.predictor import PromptMLMPredictor as Predictor
from ark_nlp.factory.predictor import PromptMLMPredictor as PromptBertMLMPredictor
2 changes: 2 additions & 0 deletions ark_nlp/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from ark_nlp.nn.global_pointer_bert import GlobalPointerBert
from ark_nlp.nn.crf_bert import CrfBert

from ark_nlp.nn.prompt_bert import BertForPromptMaskedLM

from transformers import BertConfig
from ark_nlp.nn.configuration import ErnieConfig
from ark_nlp.nn.configuration.configuration_nezha import NeZhaConfig
Expand Down
8 changes: 4 additions & 4 deletions ark_nlp/nn/prompt_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def forward(self, hidden_states):
return hidden_states


class BertForMaskedLM(Bert):
class BertForPromptMaskedLM(Bert):
"""
基于BERT的mlm任务
针对prompt的基于BERT的mlm任务
:param config: (obejct) 模型的配置对象
:param bert_trained: (bool) bert参数是否可训练, 默认可训练
Expand All @@ -60,7 +60,7 @@ def __init__(
config,
encoder_trained=True
):
super(BertForMaskedLM, self).__init__(config)
super(BertForPromptMaskedLM, self).__init__(config)

self.bert = BertModel(config, add_pooling_layer=False)

Expand Down Expand Up @@ -104,7 +104,7 @@ def forward(

sequence_output = outputs[0]

sequence_output = BertForMaskedLM._batch_gather(sequence_output, mask_position)
sequence_output = BertForPromptMaskedLM._batch_gather(sequence_output, mask_position)

batch_size, _, hidden_size = sequence_output.shape

Expand Down
Loading

0 comments on commit 4415b73

Please sign in to comment.