Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Dec 7, 2021
1 parent 38139d3 commit 3ad0380
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 16 deletions.
Binary file modified .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ poems
./test.py
nouse
test.py
TextRank
develop-eggs/
dist/
downloads/
Expand Down
144 changes: 144 additions & 0 deletions bert_seq2seq/bert_cls_multi_seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
## 基于seq2seq 的多标签分类模型。
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
import time
from bert_seq2seq.config import yayun_list
import os
from bert_seq2seq.basic_bert import BasicBert
import numpy as np
from bert_seq2seq.helper import RepetitionPenaltyLogitsProcessor, TemperatureLogitsProcessor, TopKLogitsProcessor, \
TopPLogitsProcessor, ListProcessor

class ClsMultiSeq2SeqModel(BasicBert):
"""
"""
def __init__(self, word2idx, target, model_name="roberta"):
super(ClsMultiSeq2SeqModel, self).__init__(word2ix=word2idx, model_name=model_name)
self.target = target
self.final_dense = nn.Linear(self.config.hidden_size, len(self.target))


def compute_loss(self, predictions, labels, target_mask):
"""
target_mask : 句子a部分和pad部分全为0, 而句子b部分为1
"""
predictions = predictions.view(-1, len(self.target))
labels = labels.view(-1)
target_mask = target_mask.view(-1).float()
loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none")
return (loss(predictions, labels) * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响

def forward(self, input_tensor, token_type_id, position_enc=None, labels=None):
input_tensor = input_tensor.to(self.device)
token_type_id = token_type_id.to(self.device)
if position_enc is not None:
position_enc = position_enc.to(self.device)
if labels is not None :
labels = labels.to(self.device)
input_shape = input_tensor.shape
batch_size = input_shape[0]
seq_len = input_shape[1]
## 构建特殊的mask
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
a_mask = ones.tril()
s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask

enc_layers, _ = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask,
output_all_encoded_layers=True)
squence_out = enc_layers[-1] ## 取出来最后一层输出 (batch, seq_len, 768)

tokens_hidden_state, _ = self.cls(squence_out)
predictions = self.final_dense(tokens_hidden_state)

if labels is not None:

predictions = predictions[:, :-1].contiguous()
target_mask = token_type_id[:, 1:].contiguous()
loss = self.compute_loss(predictions, labels, target_mask)
return predictions, loss
else :
return predictions

def generate(self, text, out_max_length=40, beam_size=1, is_poem=False, max_length=256):

self.out_max_length = out_max_length
input_max_length = max_length - out_max_length
# print(text)
try:
token_ids, token_type_ids = self.tokenizer.encode(text, max_length=input_max_length)
except:
# 可能是transformer的tokenizer
tokenizer_out = self.tokenizer.encode_plus(text, max_length=input_max_length, truncation=True)
token_ids = tokenizer_out["input_ids"]
token_type_ids = tokenizer_out["token_type_ids"]
token_ids = torch.tensor(token_ids, device=self.device).view(1, -1)
token_type_ids = torch.tensor(token_type_ids, device=self.device).view(1, -1)

out_puts_ids = self.beam_search(token_ids, token_type_ids, self.word2ix, beam_size=beam_size, device=self.device)

return self.tokenizer.decode(out_puts_ids.cpu().numpy())


def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, device="cpu"):
"""
beam-search操作
"""
sep_id = word2ix["[SEP]"]

# 用来保存输出序列
output_ids = torch.empty(1, 0, device=device, dtype=torch.long)
# 用来保存累计得分

with torch.no_grad():
output_scores = torch.zeros(token_ids.shape[0], device=device)
for step in range(self.out_max_length):
if step == 0:
scores = self.forward(token_ids, token_type_ids)
# 重复beam-size次 输入ids
token_ids = token_ids.view(1, -1).repeat(beam_size, 1)
token_type_ids = token_type_ids.view(1, -1).repeat(beam_size, 1)
else:
scores = self.forward(new_input_ids, new_token_type_ids)

logit_score = torch.log_softmax(scores[:, -1], dim=-1)

logit_score = output_scores.view(-1, 1) + logit_score # 累计得分
## 取topk的时候我们是展平了然后再去调用topk函数
# 展平
logit_score = logit_score.view(-1)
hype_score, hype_pos = torch.topk(logit_score, beam_size)
indice1 = (hype_pos // scores.shape[-1]) # 行索引
indice2 = (hype_pos % scores.shape[-1]).long().reshape(-1, 1) # 列索引

# 更新得分
output_scores = hype_score
output_ids = torch.cat([output_ids[indice1], indice2], dim=1).long()
new_input_ids = torch.cat([token_ids, output_ids], dim=1)
new_token_type_ids = torch.cat([token_type_ids, torch.ones_like(output_ids)], dim=1)

end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记
best_one = output_scores.argmax()
if end_counts[best_one] == 1:
# 说明出现终止了~
return output_ids[best_one][:-1]
else :
# 保留未完成部分
flag = (end_counts < 1) # 标记未完成序列
if not flag.all(): # 如果有已完成的
token_ids = token_ids[flag]
token_type_ids = token_type_ids[flag]
new_input_ids = new_input_ids[flag]
new_token_type_ids = new_token_type_ids[flag]
output_ids = output_ids[flag] # 扔掉已完成序列
output_scores = output_scores[flag] # 扔掉已完成序列
end_counts = end_counts[flag] # 扔掉已完成end计数
beam_size = flag.sum() # topk相应变化

return output_ids[output_scores.argmax()]


2 changes: 1 addition & 1 deletion bert_seq2seq/model/nezha_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
max_position_embeddings=2048,
max_relative_position=64,
type_vocab_size=2,
initializer_range=0.02,
Expand Down
20 changes: 6 additions & 14 deletions bert_seq2seq/simbert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,12 @@ def compute_loss_of_similarity(self, y_pred):
y_true = self.get_labels_of_similarity(y_pred) # 构建标签
y_true = y_true.to(self.device)
norm_a = torch.nn.functional.normalize(y_pred, dim=-1, p=2)
# y_pred = K.l2_normalize(y_pred, axis=1) # 句向量归一化
similarities = norm_a.matmul(norm_a.t())

# similarities = K.dot(y_pred, K.transpose(y_pred)) # 相似度矩阵
similarities = similarities - (torch.eye(y_pred.shape[0]) * 1e12).to(self.device) # 排除对角线
similarities = similarities * 30 # scale
similarities = similarities
similarities = similarities * 20 # scale
loss_f = nn.CrossEntropyLoss()
loss = loss_f(similarities, y_true)
# loss = K.categorical_crossentropy(
# y_true, similarities, from_logits=True
# )
return loss

def get_labels_of_similarity(self, y_pred):
Expand Down Expand Up @@ -114,7 +108,7 @@ def forward(self, input_tensor, token_type_id, position_enc=None, labels=None):
output_all_encoded_layers=True)
squence_out = enc_layers[-1] ## 取出来最后一层输出

predictions = self.decoder(squence_out)
_, predictions = self.cls(squence_out)

if labels is not None:
## 计算loss
Expand All @@ -128,7 +122,7 @@ def forward(self, input_tensor, token_type_id, position_enc=None, labels=None):
return predictions


def generate(self, text, out_max_length=40, beam_size=1, is_poem=False, max_length=256):
def generate(self, text, out_max_length=40, beam_size=1, max_length=256):
# 对 一个 句子生成相应的结果
## 通过输出最大长度得到输入的最大长度,这里问题不大,如果超过最大长度会进行截断
self.out_max_length = out_max_length
Expand All @@ -141,13 +135,11 @@ def generate(self, text, out_max_length=40, beam_size=1, is_poem=False, max_leng
tokenizer_out = self.tokenizer.encode_plus(text, max_length=input_max_length, truncation=True)
token_ids = tokenizer_out["input_ids"]
token_type_ids = tokenizer_out["token_type_ids"]

token_ids = torch.tensor(token_ids, device=self.device).view(1, -1)
token_type_ids = torch.tensor(token_type_ids, device=self.device).view(1, -1)
if is_poem:## 古诗的beam-search稍有不同

out_puts_ids = self.beam_search_poem(text, token_ids, token_type_ids, self.word2ix, beam_size=beam_size, device=self.device)
else :
out_puts_ids = self.beam_search(token_ids, token_type_ids, self.word2ix, beam_size=beam_size, device=self.device)

out_puts_ids = self.beam_search(token_ids, token_type_ids, self.word2ix, beam_size=beam_size, device=self.device)

return self.tokenizer.decode(out_puts_ids.cpu().numpy())

Expand Down
6 changes: 5 additions & 1 deletion bert_seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from bert_seq2seq.bert_relation_extraction import BertRelationExtrac
from bert_seq2seq.bert_cls_multi_classifier import BertClsMultiClassifier
import torch.nn.functional as F
from bert_seq2seq.bert_cls_multi_seq2seq import ClsMultiSeq2SeqModel
from bert_seq2seq.simbert_model import SimBertModel
from bert_seq2seq.gpt2_generate_model import GPT2


def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2seq", target_size=0):
def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2seq", target_size=0, target=None):
"""
model_path: 模型位置
这是个统一的接口,用来加载模型的
Expand Down Expand Up @@ -48,6 +49,9 @@ def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2se
elif model_class == "multi_label_cls":
bert_model = BertClsMultiClassifier(word2ix, target_size, model_name=model_name)
return bert_model
elif model_class == "multi_label_cls_seq2seq":
bert_model = ClsMultiSeq2SeqModel(word2ix, target, model_name=model_name)
return bert_model
else :
raise Exception("model_name_err")

Expand Down
1 change: 1 addition & 0 deletions test/auto_title_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
for text in test_data:
with torch.no_grad():
print(bert_model.generate(text, beam_size=3))
print("\n")



0 comments on commit 3ad0380

Please sign in to comment.