Skip to content

Commit

Permalink
Merge pull request timbmg#12 from dhanajitb/master
Browse files Browse the repository at this point in the history
converting prob into a cuda variable
  • Loading branch information
timbmg authored Nov 20, 2018
2 parents 2b9b94f + f1a7991 commit 305cba9
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def forward(self, input_sequence, length):
if self.word_dropout_rate > 0:
# randomly replace decoder input with <unk>
prob = torch.rand(input_sequence.size())
if torch.cuda.is_available():
prob=prob.cuda()
prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1
decoder_input_sequence = input_sequence.clone()
decoder_input_sequence[prob < self.word_dropout_rate] = self.unk_idx
Expand Down

0 comments on commit 305cba9

Please sign in to comment.