Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Sep 30, 2021
1 parent 08c776b commit b248e41
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ state_dict/bert_seq2seq_save
# Distribution / packaging
.Python
build/
./test.py
nouse/*
develop-eggs/
dist/
Expand Down
23 changes: 23 additions & 0 deletions bert_seq2seq/gpt2_generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,29 @@ def sample_generate(self, text, input_max_length=256, out_max_length=200, top_k=

return self.tokenizer.decode(np.array(output_ids))

def sample_generate_once(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, sep="。"):

token_ids, _ = self.tokenizer.encode(text, max_length=input_max_length)
# 不加任何的开始符号和结束符号,就是输入一句话。
token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long)[1:-1].view(1, -1)


output_ids = []
sep_id = self.word2ix[sep] # 句号结尾
with torch.no_grad():
for step in range(out_max_length):
_, scores = self.model(token_ids)
logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
logit_score[self.word2ix["[UNK]"]] = -float('Inf')
filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if sep_id == next_token.item():
break
output_ids.append(next_token.item())
token_ids = torch.cat((token_ids, next_token.long().unsqueeze(0)), dim=1)

return self.tokenizer.decode(np.array(output_ids))

def sample_generate_english(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, add_eos=False):

token_ids = self.tokenizer.encode(text, max_length=input_max_length, truncation=True)
Expand Down
Loading

0 comments on commit b248e41

Please sign in to comment.