Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Nov 14, 2021
1 parent 577851b commit c7220f6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,4 @@
2020.04.01: 重构了代码,开始训练一个新的任务花费时间更少。

python setup.py sdist
twine upload dist/bert_seq2seq-2.1.0.tar.gz
twine upload dist/bert_seq2seq-2.2.0.tar.gz
9 changes: 2 additions & 7 deletions bert_seq2seq/t5_ch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,9 @@ def sample_generate_encoder_decoder(self, text, input_max_length=256, out_max_le
filterd_logits_prob = F.softmax(filtered_logits, dim=-1)

next_token = torch.multinomial(filterd_logits_prob, num_samples=1)
_, max_prob_tokens = filtered_logits.max(dim=-1)
if self.eos_id == next_token.item() or self.eos_id == max_prob_tokens.item():
if self.eos_id == next_token.item():
break
if next_token.item() not in repeat_list:
repeat_list[next_token.item()] = 1
else :
repeat_list[next_token.item()] += 1


output_ids.append(next_token.item())
input_decoder_ids = torch.cat((input_decoder_ids, next_token.long().unsqueeze(0)), dim=1)

Expand Down

0 comments on commit c7220f6

Please sign in to comment.