forked from 920232796/bert_seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
158 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ poems | |
./test.py | ||
nouse | ||
test.py | ||
TextRank | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
## 基于seq2seq 的多标签分类模型。 | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import random | ||
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab | ||
import time | ||
from bert_seq2seq.config import yayun_list | ||
import os | ||
from bert_seq2seq.basic_bert import BasicBert | ||
import numpy as np | ||
from bert_seq2seq.helper import RepetitionPenaltyLogitsProcessor, TemperatureLogitsProcessor, TopKLogitsProcessor, \ | ||
TopPLogitsProcessor, ListProcessor | ||
|
||
class ClsMultiSeq2SeqModel(BasicBert): | ||
""" | ||
""" | ||
def __init__(self, word2idx, target, model_name="roberta"): | ||
super(ClsMultiSeq2SeqModel, self).__init__(word2ix=word2idx, model_name=model_name) | ||
self.target = target | ||
self.final_dense = nn.Linear(self.config.hidden_size, len(self.target)) | ||
|
||
|
||
def compute_loss(self, predictions, labels, target_mask): | ||
""" | ||
target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 | ||
""" | ||
predictions = predictions.view(-1, len(self.target)) | ||
labels = labels.view(-1) | ||
target_mask = target_mask.view(-1).float() | ||
loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none") | ||
return (loss(predictions, labels) * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响 | ||
|
||
def forward(self, input_tensor, token_type_id, position_enc=None, labels=None): | ||
input_tensor = input_tensor.to(self.device) | ||
token_type_id = token_type_id.to(self.device) | ||
if position_enc is not None: | ||
position_enc = position_enc.to(self.device) | ||
if labels is not None : | ||
labels = labels.to(self.device) | ||
input_shape = input_tensor.shape | ||
batch_size = input_shape[0] | ||
seq_len = input_shape[1] | ||
## 构建特殊的mask | ||
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device) | ||
a_mask = ones.tril() | ||
s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float() | ||
s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float() | ||
a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask | ||
|
||
enc_layers, _ = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask, | ||
output_all_encoded_layers=True) | ||
squence_out = enc_layers[-1] ## 取出来最后一层输出 (batch, seq_len, 768) | ||
|
||
tokens_hidden_state, _ = self.cls(squence_out) | ||
predictions = self.final_dense(tokens_hidden_state) | ||
|
||
if labels is not None: | ||
|
||
predictions = predictions[:, :-1].contiguous() | ||
target_mask = token_type_id[:, 1:].contiguous() | ||
loss = self.compute_loss(predictions, labels, target_mask) | ||
return predictions, loss | ||
else : | ||
return predictions | ||
|
||
def generate(self, text, out_max_length=40, beam_size=1, is_poem=False, max_length=256): | ||
|
||
self.out_max_length = out_max_length | ||
input_max_length = max_length - out_max_length | ||
# print(text) | ||
try: | ||
token_ids, token_type_ids = self.tokenizer.encode(text, max_length=input_max_length) | ||
except: | ||
# 可能是transformer的tokenizer | ||
tokenizer_out = self.tokenizer.encode_plus(text, max_length=input_max_length, truncation=True) | ||
token_ids = tokenizer_out["input_ids"] | ||
token_type_ids = tokenizer_out["token_type_ids"] | ||
token_ids = torch.tensor(token_ids, device=self.device).view(1, -1) | ||
token_type_ids = torch.tensor(token_type_ids, device=self.device).view(1, -1) | ||
|
||
out_puts_ids = self.beam_search(token_ids, token_type_ids, self.word2ix, beam_size=beam_size, device=self.device) | ||
|
||
return self.tokenizer.decode(out_puts_ids.cpu().numpy()) | ||
|
||
|
||
def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, device="cpu"): | ||
""" | ||
beam-search操作 | ||
""" | ||
sep_id = word2ix["[SEP]"] | ||
|
||
# 用来保存输出序列 | ||
output_ids = torch.empty(1, 0, device=device, dtype=torch.long) | ||
# 用来保存累计得分 | ||
|
||
with torch.no_grad(): | ||
output_scores = torch.zeros(token_ids.shape[0], device=device) | ||
for step in range(self.out_max_length): | ||
if step == 0: | ||
scores = self.forward(token_ids, token_type_ids) | ||
# 重复beam-size次 输入ids | ||
token_ids = token_ids.view(1, -1).repeat(beam_size, 1) | ||
token_type_ids = token_type_ids.view(1, -1).repeat(beam_size, 1) | ||
else: | ||
scores = self.forward(new_input_ids, new_token_type_ids) | ||
|
||
logit_score = torch.log_softmax(scores[:, -1], dim=-1) | ||
|
||
logit_score = output_scores.view(-1, 1) + logit_score # 累计得分 | ||
## 取topk的时候我们是展平了然后再去调用topk函数 | ||
# 展平 | ||
logit_score = logit_score.view(-1) | ||
hype_score, hype_pos = torch.topk(logit_score, beam_size) | ||
indice1 = (hype_pos // scores.shape[-1]) # 行索引 | ||
indice2 = (hype_pos % scores.shape[-1]).long().reshape(-1, 1) # 列索引 | ||
|
||
# 更新得分 | ||
output_scores = hype_score | ||
output_ids = torch.cat([output_ids[indice1], indice2], dim=1).long() | ||
new_input_ids = torch.cat([token_ids, output_ids], dim=1) | ||
new_token_type_ids = torch.cat([token_type_ids, torch.ones_like(output_ids)], dim=1) | ||
|
||
end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记 | ||
best_one = output_scores.argmax() | ||
if end_counts[best_one] == 1: | ||
# 说明出现终止了~ | ||
return output_ids[best_one][:-1] | ||
else : | ||
# 保留未完成部分 | ||
flag = (end_counts < 1) # 标记未完成序列 | ||
if not flag.all(): # 如果有已完成的 | ||
token_ids = token_ids[flag] | ||
token_type_ids = token_type_ids[flag] | ||
new_input_ids = new_input_ids[flag] | ||
new_token_type_ids = new_token_type_ids[flag] | ||
output_ids = output_ids[flag] # 扔掉已完成序列 | ||
output_scores = output_scores[flag] # 扔掉已完成序列 | ||
end_counts = end_counts[flag] # 扔掉已完成end计数 | ||
beam_size = flag.sum() # topk相应变化 | ||
|
||
return output_ids[output_scores.argmax()] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
for text in test_data: | ||
with torch.no_grad(): | ||
print(bert_model.generate(text, beam_size=3)) | ||
print("\n") | ||
|
||
|
||
|