Skip to content

Commit

Permalink
对例子进行梳理,现在更加清晰了。
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Nov 15, 2021
1 parent 401be3c commit e9235e6
Show file tree
Hide file tree
Showing 14 changed files with 274 additions and 299 deletions.
85 changes: 23 additions & 62 deletions bert_seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from bert_seq2seq.config import yayun_list
import os
from bert_seq2seq.basic_bert import BasicBert
import numpy as np
import numpy as np
from bert_seq2seq.helper import RepetitionPenaltyLogitsProcessor, TemperatureLogitsProcessor, TopKLogitsProcessor, \
TopPLogitsProcessor, ListProcessor


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Expand Down Expand Up @@ -131,10 +134,21 @@ def generate_random(self, text, out_max_length=40, beam_size=3, max_length=256):
random_int = random.randint(0, len(out_puts_ids_list) - 1)
return self.tokenizer.decode(out_puts_ids_list[random_int].cpu().numpy())

def sample_generate(self, text, out_max_length=40, top_k=30, top_p=0.0, max_length=256):
def sample_generate(self, text, out_max_length=40, top_k=30,
top_p=0.0, max_length=256, repetition_penalty=1.0,
temperature=1.0):

input_max_length = max_length - out_max_length
token_ids, token_type_ids = self.tokenizer.encode(text, max_length=input_max_length)


lp = [RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty),
TemperatureLogitsProcessor(temperature=temperature),
TopKLogitsProcessor(top_k=top_k),
TopPLogitsProcessor(top_p=top_p),
]
list_processor = ListProcessor(lp)

token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1)
token_type_ids = torch.tensor(token_type_ids, device=self.device, dtype=torch.long).view(1, -1)
device = self.device
Expand All @@ -143,74 +157,21 @@ def sample_generate(self, text, out_max_length=40, top_k=30, top_p=0.0, max_leng
with torch.no_grad():
for step in range(out_max_length):
scores = self.forward(token_ids, token_type_ids)
logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
logit_score[self.word2ix["[UNK]"]] = -float('Inf')
filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
logit_score = torch.log_softmax(scores[:, -1], dim=-1)
logit_score[:, self.word2ix["[UNK]"]] = -float('Inf')

filtered_logits = list_processor(token_ids, logit_score)

# filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if sep_id == next_token.item():
break
output_ids.append(next_token.item())
token_ids = torch.cat((token_ids, next_token.long().unsqueeze(0)), dim=1)
token_ids = torch.cat((token_ids, next_token.long()), dim=1)
token_type_ids = torch.cat([token_type_ids, torch.ones((1, 1), device=device, dtype=torch.long)], dim=1)

return self.tokenizer.decode(np.array(output_ids))

def beam_search_list(self, token_ids, token_type_ids, word2ix, beam_size=3, device="cpu"):
"""
beam-search操作
"""
sep_id = word2ix["[SEP]"]
output_list = []
# 用来保存输出序列
output_ids = torch.empty(1, 0, device=device, dtype=torch.long)
# 用来保存累计得分
with torch.no_grad():
output_scores = torch.zeros(token_ids.shape[0], device=device)
for step in range(self.out_max_length):
if step == 0:
scores = self.forward(token_ids, token_type_ids)
# 重复beam-size次 输入ids
token_ids = token_ids.view(1, -1).repeat(beam_size, 1)
token_type_ids = token_type_ids.view(1, -1).repeat(beam_size, 1)
else:
scores = self.forward(new_input_ids, new_token_type_ids)

logit_score = torch.log_softmax(scores[:, -1], dim=-1)

logit_score = output_scores.view(-1, 1) + logit_score # 累计得分
## 取topk的时候我们是展平了然后再去调用topk函数
# 展平
logit_score = logit_score.view(-1)
hype_score, hype_pos = torch.topk(logit_score, beam_size)
indice1 = (hype_pos // scores.shape[-1]) # 行索引
indice2 = (hype_pos % scores.shape[-1]).long().reshape(-1, 1) # 列索引

# 更新得分
output_scores = hype_score
output_ids = torch.cat([output_ids[indice1], indice2], dim=1).long()
new_input_ids = torch.cat([token_ids, output_ids], dim=1)
new_token_type_ids = torch.cat([token_type_ids, torch.ones_like(output_ids)], dim=1)

end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记

flag = (end_counts < 1) # 标记未完成序列
if not flag.all(): # 如果有已完成的
for i, f in enumerate(flag):
output_list.append(output_ids[i][:-1])
token_ids = token_ids[flag]
token_type_ids = token_type_ids[flag]
new_input_ids = new_input_ids[flag]
new_token_type_ids = new_token_type_ids[flag]
output_ids = output_ids[flag] # 扔掉已完成序列
output_scores = output_scores[flag] # 扔掉已完成序列
end_counts = end_counts[flag] # 扔掉已完成end计数
beam_size = flag.sum() # topk相应变化

if beam_size < 1:
# print(output_list)
return output_list
# print(output_list)
return output_list

# def poem_beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, device="cpu"):
# """
Expand Down
189 changes: 0 additions & 189 deletions examples/auto_chat_train.py

This file was deleted.

Loading

0 comments on commit e9235e6

Please sign in to comment.