diff --git a/model.py b/model.py index 4cb9824..ae5d0e4 100644 --- a/model.py +++ b/model.py @@ -87,6 +87,8 @@ def forward(self, input_sequence, length): if self.word_dropout_rate > 0: # randomly replace decoder input with prob = torch.rand(input_sequence.size()) + if torch.cuda.is_available(): + prob=prob.cuda() prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1 decoder_input_sequence = input_sequence.clone() decoder_input_sequence[prob < self.word_dropout_rate] = self.unk_idx