Skip to content

Commit

Permalink
优化代码
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Dec 9, 2021
1 parent 3ad0380 commit aad46c7
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@
2020.04.01: 重构了代码,开始训练一个新的任务花费时间更少。

python setup.py sdist
twine upload dist/bert_seq2seq-2.3.1.tar.gz
twine upload dist/bert_seq2seq-2.3.2.tar.gz
2 changes: 1 addition & 1 deletion bert_seq2seq/model/nezha_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=2048,
max_position_embeddings=1024,
max_relative_position=64,
type_vocab_size=2,
initializer_range=0.02,
Expand Down
141 changes: 127 additions & 14 deletions bert_seq2seq/simbert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import os
from bert_seq2seq.basic_bert import BasicBert
import numpy as np
from bert_seq2seq.helper import RepetitionPenaltyLogitsProcessor, TemperatureLogitsProcessor, TopKLogitsProcessor, \
TopPLogitsProcessor, ListProcessor



def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Expand Down Expand Up @@ -53,9 +57,9 @@ def __init__(self, word2ix, model_name="roberta", tokenizer=None):
self.hidden_dim = self.config.hidden_size
self.vocab_size = len(word2ix)

def compute_loss(self, predictions, labels, target_mask):
def compute_loss(self, cls_token_state, predictions, labels, target_mask):
loss1 = self.compute_loss_of_seq2seq(predictions, labels, target_mask)
loss2 = self.compute_loss_of_similarity(predictions[:, 0]) ## 拿出cls向量
loss2 = self.compute_loss_of_similarity(cls_token_state) ## 拿出cls向量
return loss1 + loss2

def compute_loss_of_seq2seq(self, predictions, labels, target_mask):
Expand Down Expand Up @@ -108,15 +112,16 @@ def forward(self, input_tensor, token_type_id, position_enc=None, labels=None):
output_all_encoded_layers=True)
squence_out = enc_layers[-1] ## 取出来最后一层输出

_, predictions = self.cls(squence_out)
sequence_hidden, predictions = self.cls(squence_out)


if labels is not None:
## 计算loss
## 需要构建特殊的输出mask 才能计算正确的loss
# 预测的值不用取最后sep符号的结果 因此是到-1
predictions = predictions[:, :-1].contiguous()
target_mask = token_type_id[:, 1:].contiguous()
loss = self.compute_loss(predictions, labels, target_mask)
loss = self.compute_loss(sequence_hidden[0], predictions, labels, target_mask)
return predictions, loss
else :
return predictions
Expand All @@ -143,30 +148,138 @@ def generate(self, text, out_max_length=40, beam_size=1, max_length=256):

return self.tokenizer.decode(out_puts_ids.cpu().numpy())

def random_sample(
self,
inputs,
n,
topk=None,
topp=None,
states=None,
temperature=1,
min_ends=1
):
"""随机采样n个结果
说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
返回:n个解码序列组成的list。
"""
inputs = [np.array([i]) for i in inputs]
output_ids = self.first_output_ids
results = []
for step in range(self.maxlen):
probas, states = self.predict(
inputs, output_ids, states, temperature, 'probas'
) # 计算当前概率
probas /= probas.sum(axis=1, keepdims=True) # 确保归一化
if step == 0: # 第1步预测后将结果重复n次
probas = np.repeat(probas, n, axis=0)
inputs = [np.repeat(i, n, axis=0) for i in inputs]
output_ids = np.repeat(output_ids, n, axis=0)
if topk is not None:
k_indices = probas.argpartition(-topk,
axis=1)[:, -topk:] # 仅保留topk
probas = np.take_along_axis(probas, k_indices, axis=1) # topk概率
probas /= probas.sum(axis=1, keepdims=True) # 重新归一化
if topp is not None:
p_indices = probas.argsort(axis=1)[:, ::-1] # 从高到低排序
probas = np.take_along_axis(probas, p_indices, axis=1) # 排序概率
cumsum_probas = np.cumsum(probas, axis=1) # 累积概率
flag = np.roll(cumsum_probas >= topp, 1, axis=1) # 标记超过topp的部分
flag[:, 0] = False # 结合上面的np.roll,实现平移一位的效果
probas[flag] = 0 # 后面的全部置零
probas /= probas.sum(axis=1, keepdims=True) # 重新归一化
sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数
sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样
sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状
if topp is not None:
sample_ids = np.take_along_axis(
p_indices, sample_ids, axis=1
) # 对齐原id
if topk is not None:
sample_ids = np.take_along_axis(
k_indices, sample_ids, axis=1
) # 对齐原id
output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出
is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束
end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记
if output_ids.shape[1] >= self.minlen: # 最短长度判断
flag = is_end & (end_counts >= min_ends) # 标记已完成序列
if flag.any(): # 如果有已完成的
for ids in output_ids[flag]: # 存好已完成序列
results.append(ids)
flag = (flag == False) # 标记未完成序列
inputs = [i[flag] for i in inputs] # 只保留未完成部分输入
output_ids = output_ids[flag] # 只保留未完成部分候选集
end_counts = end_counts[flag] # 只保留未完成部分end计数
if len(output_ids) == 0:
break
# 如果还有未完成序列,直接放入结果
for ids in output_ids:
results.append(ids)
# 返回结果
return results

def sample_generate(self, text, out_max_length=40, top_k=30,
top_p=0.0, max_length=256, repetition_penalty=1.0,
temperature=1.0, sample_num=1):

def sample_generate(self, text, out_max_length=40, top_k=30, top_p=0.0, max_length=256):
input_max_length = max_length - out_max_length
token_ids, token_type_ids = self.tokenizer.encode(text, max_length=input_max_length)

result_list = []
lp = [RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty),
TemperatureLogitsProcessor(temperature=temperature),
TopKLogitsProcessor(top_k=top_k),
TopPLogitsProcessor(top_p=top_p),
]
list_processor = ListProcessor(lp)

token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1)
token_type_ids = torch.tensor(token_type_ids, device=self.device, dtype=torch.long).view(1, -1)
device = self.device
output_ids = []
sep_id = self.word2ix["[SEP]"]
with torch.no_grad():
for step in range(out_max_length):
if step == 0:
token_ids = token_ids.repeat((sample_num, 1))
token_type_ids = token_type_ids.repeat((sample_num, 1))

scores = self.forward(token_ids, token_type_ids)
logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
logit_score[self.word2ix["[UNK]"]] = -float('Inf')
filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
logit_score = torch.log_softmax(scores[:, -1], dim=-1)
logit_score[:, self.word2ix["[UNK]"]] = -float('Inf')

filtered_logits = list_processor(token_ids, logit_score)

# filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if sep_id == next_token.item():
break
output_ids.append(next_token.item())
token_ids = torch.cat((token_ids, next_token.long().unsqueeze(0)), dim=1)
token_type_ids = torch.cat([token_type_ids, torch.ones((1, 1), device=device, dtype=torch.long)], dim=1)
if step == 0:
output_ids = next_token.view((sample_num, 1))

else :
output_ids = torch.cat([output_ids, next_token.view((sample_num, 1))], dim=1)


token_ids = torch.cat([token_ids, next_token.view((sample_num, 1)).long()], dim=1)
# token_ids = torch.cat((token_ids, next_token.long()), dim=1)
token_type_ids = torch.cat([token_type_ids, torch.ones((sample_num, 1), device=device, dtype=torch.long)], dim=1)

is_end = (output_ids[:, -1] == sep_id)

if is_end.any():
for ids in output_ids[is_end]:
# 保存输出结果
sample_num -= 1
result_list.append(self.tokenizer.decode(ids.cpu().numpy()[:-1]))

is_end = (is_end == False) # 标记未完成序列
token_ids = token_ids[is_end] # 保留未完成的输入
output_ids = output_ids[is_end] # 只保留未完成部分候选集
if len(output_ids) == 0:
break
token_type_ids = token_type_ids[is_end] # 保留未完成的输入

return self.tokenizer.decode(np.array(output_ids))
return result_list

def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, device="cpu"):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='bert_seq2seq',
version='2.3.1',
version='2.3.2',
description='use torch to do bert_seq2seq task',
long_description='bert_seq2seq: https://github.com/920232796/bert_seq2seq',
license='Apache License 2.0',
Expand Down

0 comments on commit aad46c7

Please sign in to comment.