Skip to content

Commit

Permalink
Update seq2seq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengyanzhao1997 authored Jan 22, 2021
1 parent c679166 commit 95c1150
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion model/train/NEZHA/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def predict(self, inputs, output_ids, states):
ides_temp[i][get_len:end_] = output_ids[i]
seg_id_temp[i][get_len:end_] = np.ones_like(output_ids[i])
mask_att_temp[i] = unilm_mask_single(seg_id_temp[i])
prediction = self.last_token(end_-1).predict([ides_temp,seg_id_temp,mask_att_temp])
prediction = self.model.predict([ides_temp,seg_id_temp,mask_att_temp])[:,end_-1]
'''
假设现在的topK = 2 所以每次只predict 二组的可能输出 len(ides_temp) = 2
那我们初始化[0,0] 代表每一组输出组目前的ngram情况
Expand Down

0 comments on commit 95c1150

Please sign in to comment.