Skip to content

Commit

Permalink
delete noused code
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Mar 27, 2022
1 parent c2d24f9 commit 643ac53
Showing 1 changed file with 0 additions and 13 deletions.
13 changes: 0 additions & 13 deletions bert_seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,6 @@ def generate(self, text, out_max_length=40, beam_size=1, is_poem=False, max_leng
out_puts_ids = self.beam_search(token_ids, token_type_ids, self.word2ix, beam_size=beam_size, device=self.device)

return self.tokenizer.decode(out_puts_ids.cpu().numpy())

def generate_random(self, text, out_max_length=40, beam_size=3, max_length=256):
# 对 一个 句子生成相应的结果
## 通过输出最大长度得到输入的最大长度,这里问题不大,如果超过最大长度会进行截断
self.out_max_length = out_max_length
input_max_length = max_length - out_max_length
token_ids, token_type_ids = self.tokenizer.encode(text, max_length=input_max_length)
token_ids = torch.tensor(token_ids, device=self.device).view(1, -1)
token_type_ids = torch.tensor(token_type_ids, device=self.device).view(1, -1)

out_puts_ids_list = self.beam_search_list(token_ids, token_type_ids, self.word2ix, beam_size=beam_size, device=self.device)
random_int = random.randint(0, len(out_puts_ids_list) - 1)
return self.tokenizer.decode(out_puts_ids_list[random_int].cpu().numpy())

def sample_generate(self, text, out_max_length=40, top_k=30,
top_p=0.0, max_length=256, repetition_penalty=1.0,
Expand Down

0 comments on commit 643ac53

Please sign in to comment.